281 lines
8.8 KiB
Python
281 lines
8.8 KiB
Python
"""Image download, decoding, and validation utilities."""
|
|
|
|
import asyncio
|
|
import logging
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from io import BytesIO
|
|
|
|
import cv2
|
|
import httpx
|
|
import numpy as np
|
|
from fastapi import HTTPException
|
|
from PIL import Image, ImageOps
|
|
from tenacity import (
|
|
retry,
|
|
retry_if_exception_type,
|
|
stop_after_attempt,
|
|
wait_exponential,
|
|
)
|
|
|
|
from app.config import (
|
|
DOWNLOAD_TIMEOUT,
|
|
MAX_DOWNLOAD_SIZE,
|
|
MAX_IMAGE_DIMENSION,
|
|
MAX_RETRIES,
|
|
MIN_IMAGE_DIMENSION,
|
|
TARGET_MAX_DIMENSION,
|
|
)
|
|
|
|
logger = logging.getLogger("face_service")
|
|
|
|
# Set Pillow's decompression bomb limit
|
|
Image.MAX_IMAGE_PIXELS = MAX_IMAGE_DIMENSION * MAX_IMAGE_DIMENSION
|
|
|
|
|
|
class ImageDownloadError(Exception):
|
|
"""Raised when image download fails."""
|
|
pass
|
|
|
|
|
|
class ImageDecodeError(Exception):
|
|
"""Raised when image decoding fails."""
|
|
pass
|
|
|
|
|
|
class ImageValidationError(Exception):
|
|
"""Raised when image validation fails."""
|
|
pass
|
|
|
|
|
|
def _decode_image_bytes(data: bytes, source: str) -> np.ndarray:
|
|
"""
|
|
Decode image bytes to BGR numpy array.
|
|
|
|
Handles:
|
|
- EXIF orientation correction
|
|
- All color modes (RGB, RGBA, L, LA, PA, CMYK, I, F)
|
|
- Truncated/corrupted image detection
|
|
- Dimension validation
|
|
"""
|
|
try:
|
|
pil_image = Image.open(BytesIO(data))
|
|
except Exception as e:
|
|
logger.exception("Could not open image from %s", source)
|
|
raise ImageDecodeError(f"Could not decode image: {e}")
|
|
|
|
# Force load to detect truncated/corrupted images
|
|
try:
|
|
pil_image.load()
|
|
except Exception as e:
|
|
logger.exception("Image data is corrupted or truncated from %s", source)
|
|
raise ImageDecodeError(f"Image data is corrupted or truncated: {e}")
|
|
|
|
# Apply EXIF orientation
|
|
try:
|
|
pil_image = ImageOps.exif_transpose(pil_image)
|
|
except Exception:
|
|
logger.warning("Failed to apply EXIF orientation for %s", source)
|
|
|
|
# Validate dimensions
|
|
width, height = pil_image.size
|
|
if width < MIN_IMAGE_DIMENSION or height < MIN_IMAGE_DIMENSION:
|
|
raise ImageValidationError(
|
|
f"Image too small: {width}x{height}, minimum is {MIN_IMAGE_DIMENSION}x{MIN_IMAGE_DIMENSION}"
|
|
)
|
|
if width > MAX_IMAGE_DIMENSION or height > MAX_IMAGE_DIMENSION:
|
|
raise ImageValidationError(
|
|
f"Image too large: {width}x{height}, maximum is {MAX_IMAGE_DIMENSION}x{MAX_IMAGE_DIMENSION}"
|
|
)
|
|
|
|
# Convert to RGB, handling all color modes
|
|
mode = pil_image.mode
|
|
if mode in ("RGBA", "LA", "PA"):
|
|
# Has alpha channel - composite on white background
|
|
background = Image.new("RGB", pil_image.size, (255, 255, 255))
|
|
if mode == "LA":
|
|
pil_image = pil_image.convert("RGBA")
|
|
elif mode == "PA":
|
|
pil_image = pil_image.convert("RGBA")
|
|
background.paste(pil_image, mask=pil_image.split()[-1])
|
|
pil_image = background
|
|
elif mode == "CMYK":
|
|
pil_image = pil_image.convert("RGB")
|
|
elif mode in ("I", "F"):
|
|
# 16-bit or floating point - normalize to 8-bit
|
|
arr = np.array(pil_image)
|
|
if mode == "F":
|
|
arr = (arr * 255).clip(0, 255).astype(np.uint8)
|
|
else:
|
|
arr = (arr / 256).clip(0, 255).astype(np.uint8)
|
|
pil_image = Image.fromarray(arr, mode="L").convert("RGB")
|
|
elif mode == "L":
|
|
pil_image = pil_image.convert("RGB")
|
|
elif mode != "RGB":
|
|
pil_image = pil_image.convert("RGB")
|
|
|
|
# Convert to BGR for OpenCV/InsightFace
|
|
img = np.array(pil_image)
|
|
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
|
|
|
# Downscale large images for faster processing
|
|
img = _maybe_downscale(img)
|
|
|
|
logger.info(
|
|
"decode_image_bytes: source=%s shape=%s mode=%s",
|
|
source,
|
|
img.shape,
|
|
mode,
|
|
)
|
|
return img
|
|
|
|
|
|
def _maybe_downscale(img: np.ndarray, max_dim: int = TARGET_MAX_DIMENSION) -> np.ndarray:
|
|
"""Downscale image if larger than max_dim while preserving aspect ratio."""
|
|
h, w = img.shape[:2]
|
|
if max(h, w) <= max_dim:
|
|
return img
|
|
|
|
scale = max_dim / max(h, w)
|
|
new_w = int(w * scale)
|
|
new_h = int(h * scale)
|
|
|
|
logger.info(
|
|
"downscaling image from %dx%d to %dx%d (scale=%.2f)",
|
|
w, h, new_w, new_h, scale,
|
|
)
|
|
|
|
# Use INTER_AREA for downscaling (best quality)
|
|
return cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
|
|
|
|
|
@retry(
|
|
retry=retry_if_exception_type((httpx.TimeoutException, httpx.NetworkError)),
|
|
stop=stop_after_attempt(MAX_RETRIES),
|
|
wait=wait_exponential(multiplier=1, min=1, max=10),
|
|
reraise=True,
|
|
)
|
|
async def _download_with_retry(client: httpx.AsyncClient, url: str) -> bytes:
|
|
"""Download image with retry logic for transient failures."""
|
|
# TODO [PROD]: Add URL validation for SSRF protection
|
|
# - Block internal IPs (10.x, 172.16-31.x, 192.168.x, 127.x, 169.254.x)
|
|
# - Block cloud metadata endpoints
|
|
# - Validate against allowlist if configured
|
|
|
|
logger.info("download_image: url=%s", url) # TODO [PROD]: Redact query params
|
|
|
|
response = await client.get(url, follow_redirects=True)
|
|
response.raise_for_status()
|
|
|
|
# Check content length if provided
|
|
content_length = response.headers.get("content-length")
|
|
if content_length and int(content_length) > MAX_DOWNLOAD_SIZE:
|
|
raise ImageDownloadError(
|
|
f"Image too large: {int(content_length)} bytes, maximum is {MAX_DOWNLOAD_SIZE} bytes"
|
|
)
|
|
|
|
# Read content and check actual size
|
|
content = response.content
|
|
if len(content) > MAX_DOWNLOAD_SIZE:
|
|
raise ImageDownloadError(
|
|
f"Image too large: {len(content)} bytes, maximum is {MAX_DOWNLOAD_SIZE} bytes"
|
|
)
|
|
|
|
return content
|
|
|
|
|
|
async def download_image(
|
|
image_url: str,
|
|
client: httpx.AsyncClient | None = None,
|
|
executor: ThreadPoolExecutor | None = None,
|
|
) -> np.ndarray:
|
|
"""
|
|
Download and decode an image from URL.
|
|
|
|
Features:
|
|
- Async HTTP with connection pooling (uses shared client if provided)
|
|
- Retry with exponential backoff for transient failures
|
|
- Size validation before and after download
|
|
- Async image decoding in thread pool
|
|
|
|
Args:
|
|
image_url: URL to download image from
|
|
client: Shared httpx client (falls back to creating new one if None)
|
|
executor: Thread pool for blocking decode (runs sync if None)
|
|
"""
|
|
# Use shared client or create temporary one
|
|
if client is None:
|
|
from app.resources import http_client
|
|
client = http_client
|
|
|
|
# Fallback to temporary client if still None (e.g., during tests)
|
|
if client is None:
|
|
async with httpx.AsyncClient(timeout=DOWNLOAD_TIMEOUT) as temp_client:
|
|
return await _download_and_decode(temp_client, image_url, executor)
|
|
|
|
return await _download_and_decode(client, image_url, executor)
|
|
|
|
|
|
async def _download_and_decode(
|
|
client: httpx.AsyncClient,
|
|
image_url: str,
|
|
executor: ThreadPoolExecutor | None,
|
|
) -> np.ndarray:
|
|
"""Internal helper to download and decode image."""
|
|
try:
|
|
data = await _download_with_retry(client, image_url)
|
|
except httpx.TimeoutException:
|
|
logger.exception("Timeout downloading image")
|
|
raise HTTPException(status_code=408, detail="Timeout downloading image")
|
|
except httpx.HTTPStatusError as e:
|
|
logger.exception("HTTP error downloading image")
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Failed to download image: HTTP {e.response.status_code}"
|
|
)
|
|
except ImageDownloadError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
except Exception as e:
|
|
logger.exception("Failed to download image")
|
|
raise HTTPException(status_code=400, detail=f"Failed to download image: {e}")
|
|
|
|
# Decode in thread pool to avoid blocking event loop
|
|
try:
|
|
if executor is not None:
|
|
loop = asyncio.get_running_loop()
|
|
img = await loop.run_in_executor(
|
|
executor, _decode_image_bytes, data, image_url
|
|
)
|
|
else:
|
|
img = _decode_image_bytes(data, image_url)
|
|
except (ImageDecodeError, ImageValidationError) as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
logger.info(
|
|
"download_image: success url=%s shape=%s",
|
|
image_url,
|
|
img.shape,
|
|
)
|
|
return img
|
|
|
|
|
|
def read_upload_image(data: bytes, filename: str) -> np.ndarray:
|
|
"""
|
|
Decode an uploaded image file.
|
|
|
|
Args:
|
|
data: Raw image bytes
|
|
filename: Original filename for logging
|
|
|
|
Returns:
|
|
BGR numpy array
|
|
"""
|
|
try:
|
|
img = _decode_image_bytes(data, filename or "<upload>")
|
|
except (ImageDecodeError, ImageValidationError) as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
if img is None:
|
|
raise HTTPException(status_code=400, detail="Could not decode uploaded image")
|
|
return img
|
|
|