face/app/image.py

222 lines
6.9 KiB
Python

"""Image download, decoding, and validation utilities."""
import logging
from io import BytesIO
import cv2
import httpx
import numpy as np
from fastapi import HTTPException
from PIL import Image, ImageOps
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from app.config import (
DOWNLOAD_TIMEOUT,
MAX_DOWNLOAD_SIZE,
MAX_IMAGE_DIMENSION,
MAX_RETRIES,
MIN_IMAGE_DIMENSION,
)
logger = logging.getLogger("face_service")
# Set Pillow's decompression bomb limit
Image.MAX_IMAGE_PIXELS = MAX_IMAGE_DIMENSION * MAX_IMAGE_DIMENSION
class ImageDownloadError(Exception):
"""Raised when image download fails."""
pass
class ImageDecodeError(Exception):
"""Raised when image decoding fails."""
pass
class ImageValidationError(Exception):
"""Raised when image validation fails."""
pass
def _decode_image_bytes(data: bytes, source: str) -> np.ndarray:
"""
Decode image bytes to BGR numpy array.
Handles:
- EXIF orientation correction
- All color modes (RGB, RGBA, L, LA, PA, CMYK, I, F)
- Truncated/corrupted image detection
- Dimension validation
"""
try:
pil_image = Image.open(BytesIO(data))
except Exception as e:
logger.exception("Could not open image from %s", source)
raise ImageDecodeError(f"Could not decode image: {e}")
# Force load to detect truncated/corrupted images
try:
pil_image.load()
except Exception as e:
logger.exception("Image data is corrupted or truncated from %s", source)
raise ImageDecodeError(f"Image data is corrupted or truncated: {e}")
# Apply EXIF orientation
try:
pil_image = ImageOps.exif_transpose(pil_image)
except Exception:
logger.warning("Failed to apply EXIF orientation for %s", source)
# Validate dimensions
width, height = pil_image.size
if width < MIN_IMAGE_DIMENSION or height < MIN_IMAGE_DIMENSION:
raise ImageValidationError(
f"Image too small: {width}x{height}, minimum is {MIN_IMAGE_DIMENSION}x{MIN_IMAGE_DIMENSION}"
)
if width > MAX_IMAGE_DIMENSION or height > MAX_IMAGE_DIMENSION:
raise ImageValidationError(
f"Image too large: {width}x{height}, maximum is {MAX_IMAGE_DIMENSION}x{MAX_IMAGE_DIMENSION}"
)
# Convert to RGB, handling all color modes
mode = pil_image.mode
if mode in ("RGBA", "LA", "PA"):
# Has alpha channel - composite on white background
background = Image.new("RGB", pil_image.size, (255, 255, 255))
if mode == "LA":
pil_image = pil_image.convert("RGBA")
elif mode == "PA":
pil_image = pil_image.convert("RGBA")
background.paste(pil_image, mask=pil_image.split()[-1])
pil_image = background
elif mode == "CMYK":
pil_image = pil_image.convert("RGB")
elif mode in ("I", "F"):
# 16-bit or floating point - normalize to 8-bit
arr = np.array(pil_image)
if mode == "F":
arr = (arr * 255).clip(0, 255).astype(np.uint8)
else:
arr = (arr / 256).clip(0, 255).astype(np.uint8)
pil_image = Image.fromarray(arr, mode="L").convert("RGB")
elif mode == "L":
pil_image = pil_image.convert("RGB")
elif mode != "RGB":
pil_image = pil_image.convert("RGB")
# Convert to BGR for OpenCV/InsightFace
img = np.array(pil_image)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
logger.info(
"decode_image_bytes: source=%s shape=%s mode=%s",
source,
img.shape,
mode,
)
return img
@retry(
retry=retry_if_exception_type((httpx.TimeoutException, httpx.NetworkError)),
stop=stop_after_attempt(MAX_RETRIES),
wait=wait_exponential(multiplier=1, min=1, max=10),
reraise=True,
)
async def _download_with_retry(client: httpx.AsyncClient, url: str) -> bytes:
"""Download image with retry logic for transient failures."""
# TODO [PROD]: Add URL validation for SSRF protection
# - Block internal IPs (10.x, 172.16-31.x, 192.168.x, 127.x, 169.254.x)
# - Block cloud metadata endpoints
# - Validate against allowlist if configured
logger.info("download_image: url=%s", url) # TODO [PROD]: Redact query params
response = await client.get(url, follow_redirects=True)
response.raise_for_status()
# Check content length if provided
content_length = response.headers.get("content-length")
if content_length and int(content_length) > MAX_DOWNLOAD_SIZE:
raise ImageDownloadError(
f"Image too large: {int(content_length)} bytes, maximum is {MAX_DOWNLOAD_SIZE} bytes"
)
# Read content and check actual size
content = response.content
if len(content) > MAX_DOWNLOAD_SIZE:
raise ImageDownloadError(
f"Image too large: {len(content)} bytes, maximum is {MAX_DOWNLOAD_SIZE} bytes"
)
return content
async def download_image(image_url: str) -> np.ndarray:
"""
Download and decode an image from URL.
Features:
- Async HTTP with connection pooling
- Retry with exponential backoff for transient failures
- Size validation before and after download
- Comprehensive image decoding
"""
try:
async with httpx.AsyncClient(timeout=DOWNLOAD_TIMEOUT) as client:
data = await _download_with_retry(client, image_url)
except httpx.TimeoutException:
logger.exception("Timeout downloading image")
raise HTTPException(status_code=408, detail="Timeout downloading image")
except httpx.HTTPStatusError as e:
logger.exception("HTTP error downloading image")
raise HTTPException(
status_code=400,
detail=f"Failed to download image: HTTP {e.response.status_code}"
)
except ImageDownloadError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.exception("Failed to download image")
raise HTTPException(status_code=400, detail=f"Failed to download image: {e}")
try:
img = _decode_image_bytes(data, image_url)
except (ImageDecodeError, ImageValidationError) as e:
raise HTTPException(status_code=400, detail=str(e))
logger.info(
"download_image: success url=%s shape=%s",
image_url,
img.shape,
)
return img
def read_upload_image(data: bytes, filename: str) -> np.ndarray:
"""
Decode an uploaded image file.
Args:
data: Raw image bytes
filename: Original filename for logging
Returns:
BGR numpy array
"""
try:
img = _decode_image_bytes(data, filename or "<upload>")
except (ImageDecodeError, ImageValidationError) as e:
raise HTTPException(status_code=400, detail=str(e))
if img is None:
raise HTTPException(status_code=400, detail="Could not decode uploaded image")
return img