face/app/routes/embed.py

183 lines
5.3 KiB
Python

"""Face embedding endpoints."""
import logging
import numpy as np
from fastapi import APIRouter, HTTPException
from app.face import (
FaceServiceError,
face_area,
fallback_avatar_embedding,
get_faces_async,
load_face_app,
to_pixel_bbox,
validate_embedding,
)
from app.image import download_image
from app.models import (
EmbedAvatarResponse,
EmbedImageResponse,
EmbedRequest,
FaceEmbedding,
)
from app.resources import http_client, inference_executor
# Expected embedding dimension from buffalo_l model
EXPECTED_EMBEDDING_DIM = 512
logger = logging.getLogger("face_service")
router = APIRouter()
def validate_face_embedding(emb: np.ndarray, context: str) -> tuple[bool, str | None]:
"""
Validate embedding dimension and values.
Returns (is_valid, error_message).
"""
if len(emb) != EXPECTED_EMBEDDING_DIM:
return False, f"{context}: unexpected embedding dimension {len(emb)}, expected {EXPECTED_EMBEDDING_DIM}"
if not validate_embedding(emb):
return False, f"{context}: embedding contains NaN/Inf values"
return True, None
@router.post("/embed-avatar", response_model=EmbedAvatarResponse)
async def embed_avatar(req: EmbedRequest):
"""
Extract face embedding from an avatar image.
Returns the largest detected face. If no face is detected,
falls back to center crop embedding with score=0.0.
"""
logger.info("embed_avatar: image_url=%s", req.image_url)
img = await download_image(str(req.image_url), http_client, inference_executor)
h, w = img.shape[:2]
try:
faces = await get_faces_async(img, inference_executor)
except FaceServiceError as e:
logger.error("embed_avatar: face service error: %s", str(e))
raise HTTPException(status_code=503, detail="Face service unavailable")
if len(faces) == 0:
logger.warning(
"embed_avatar: no faces detected image_url=%s size=%dx%d, using fallback",
req.image_url,
w,
h,
)
fa = load_face_app() # Need face_app for recognition model
fallback = fallback_avatar_embedding(fa, img, w, h)
if fallback is None:
raise HTTPException(
status_code=422,
detail="No face detected in avatar image",
)
emb, bbox, score = fallback
logger.info(
"embed_avatar: using fallback bbox=%s score=%.4f embedding_len=%d",
bbox,
score,
len(emb),
)
return EmbedAvatarResponse(
embedding=emb,
bbox=bbox,
score=score,
processed_width=w,
processed_height=h,
)
# Sort by face area (largest first)
faces.sort(key=face_area, reverse=True)
face = faces[0]
emb = face.normed_embedding.astype(np.float32)
is_valid, error_msg = validate_face_embedding(emb, "embed_avatar")
if not is_valid:
logger.error(error_msg)
raise HTTPException(
status_code=422,
detail="Failed to generate valid face embedding",
)
emb_list = emb.tolist()
bbox = to_pixel_bbox(face.bbox, w, h)
score = float(getattr(face, "det_score", 1.0))
logger.info(
"embed_avatar: using face bbox=%s score=%.4f embedding_len=%d",
face.bbox,
score,
len(emb_list),
)
return EmbedAvatarResponse(
embedding=emb_list,
bbox=bbox,
score=score,
processed_width=w,
processed_height=h,
)
@router.post("/embed-image", response_model=EmbedImageResponse)
async def embed_image(req: EmbedRequest):
"""
Extract face embeddings from all faces in an image.
Returns all detected faces sorted by detection score (highest first).
Returns empty list if no faces detected.
"""
img = await download_image(str(req.image_url), http_client, inference_executor)
h, w = img.shape[:2]
try:
faces = await get_faces_async(img, inference_executor)
except FaceServiceError as e:
logger.error("embed_image: face service error: %s", str(e))
raise HTTPException(status_code=503, detail="Face service unavailable")
if len(faces) == 0:
logger.warning(
"embed_image: no faces detected image_url=%s size=%dx%d",
req.image_url,
w,
h,
)
return EmbedImageResponse(faces=[], processed_width=w, processed_height=h)
logger.info(
"embed_image: detected %d faces image_url=%s size=%dx%d",
len(faces),
req.image_url,
w,
h,
)
# Sort by detection score (highest first)
faces.sort(
key=lambda f: float(getattr(f, "det_score", 1.0)),
reverse=True,
)
result: list[FaceEmbedding] = []
for f in faces:
emb = f.normed_embedding.astype(np.float32)
is_valid, error_msg = validate_face_embedding(emb, "embed_image")
if not is_valid:
logger.warning(error_msg)
continue
emb_list = emb.tolist()
bbox = to_pixel_bbox(f.bbox, w, h)
score = float(getattr(f, "det_score", 1.0))
result.append(FaceEmbedding(bbox=bbox, score=score, embedding=emb_list))
return EmbedImageResponse(faces=result, processed_width=w, processed_height=h)