Files

239 lines
7.4 KiB
Python

import json
import os
import re
import shlex
import subprocess
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Optional
import httpx
from fastapi import FastAPI, File, HTTPException, UploadFile
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
app = FastAPI(title="plant-vision-node1", version="0.1.1")
class IdentifyRequest(BaseModel):
image_url: Optional[str] = None
top_k: int = Field(default=3, ge=1, le=10)
def _normalize_predictions(raw: Any, top_k: int) -> List[Dict[str, Any]]:
preds: List[Dict[str, Any]] = []
if isinstance(raw, dict):
for key in ("predictions", "results", "candidates"):
if isinstance(raw.get(key), list):
raw = raw[key]
break
if isinstance(raw, list):
for item in raw[:top_k]:
if not isinstance(item, dict):
continue
name = (
item.get("scientific_name")
or item.get("scientificName")
or item.get("label")
or item.get("name")
or "unknown"
)
common = item.get("common_name") or item.get("commonName") or item.get("common") or "-"
score = item.get("score", item.get("confidence", 0.0))
try:
score_f = float(score)
except Exception:
score_f = 0.0
preds.append({"scientific_name": str(name), "common_name": str(common), "score": score_f})
return preds[:top_k]
def _parse_text_output(text: str, top_k: int) -> List[Dict[str, Any]]:
"""
Parse only model score lines, e.g.:
97.6% Persicaria amphibia
86.1% Canada Goldenrod (Solidago canadensis)
Ignore service lines like "Read ..." or "Classification of ...".
"""
preds: List[Dict[str, Any]] = []
for raw_line in (text or "").splitlines():
line = raw_line.strip()
if not line or "%" not in line:
continue
m = re.match(r"^\s*(\d+(?:\.\d+)?)%\s+(.+)$", line)
if not m:
continue
score_str, name_part = m.groups()
try:
score = float(score_str)
except ValueError:
continue
name = name_part.strip()
if not name:
continue
common_name = "-"
scientific_name = name
# If output is "Common Name (Scientific name)", preserve both.
paren = re.match(r"^(.*?)\s*\(([^()]+)\)\s*$", name)
if paren:
common, scientific = paren.groups()
common = common.strip()
scientific = scientific.strip()
if common:
common_name = common
if scientific:
scientific_name = scientific
preds.append(
{
"scientific_name": scientific_name,
"common_name": common_name,
"score": score,
}
)
preds.sort(key=lambda x: float(x.get("score", 0.0)), reverse=True)
return preds[:top_k]
def _extract_inference_time(stdout: str) -> Optional[float]:
m = re.search(r"took\s+(\d+(?:\.\d+)?)\s+secs", stdout or "")
if not m:
return None
try:
return float(m.group(1))
except Exception:
return None
def _run_nature_id_cli(image_path: str, top_k: int) -> Dict[str, Any]:
cmd_tmpl = (os.getenv("NATURE_ID_CMD") or "").strip()
timeout_s = int(os.getenv("NATURE_ID_TIMEOUT", "40"))
if not cmd_tmpl:
raise RuntimeError("NATURE_ID_CMD is not configured")
cmd = cmd_tmpl.replace("{image_path}", image_path)
proc = subprocess.run(
shlex.split(cmd),
capture_output=True,
text=True,
timeout=timeout_s,
check=False,
)
if proc.returncode != 0:
raise RuntimeError(f"nature-id cli failed rc={proc.returncode}: {proc.stderr.strip()[:240]}")
out = (proc.stdout or "").strip()
inference_time_sec = _extract_inference_time(out)
if not out:
return {"predictions": [], "inference_time_sec": inference_time_sec}
try:
parsed = json.loads(out)
preds = _normalize_predictions(parsed, top_k)
except Exception:
preds = _parse_text_output(out, top_k)
return {"predictions": preds, "inference_time_sec": inference_time_sec}
async def _download_image(image_url: str) -> str:
timeout_s = float(os.getenv("DOWNLOAD_TIMEOUT", "20"))
async with httpx.AsyncClient(timeout=timeout_s) as client:
resp = await client.get(image_url)
resp.raise_for_status()
data = resp.content
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as f:
f.write(data)
return f.name
def _response_payload(result: Dict[str, Any]) -> Dict[str, Any]:
preds = result.get("predictions") or []
top_k = [
{
"confidence": float(p.get("score", 0.0)),
"name": str((p.get("common_name") if p.get("common_name") not in (None, "", "-") else p.get("scientific_name")) or "unknown"),
"scientific_name": str(p.get("scientific_name") or "unknown"),
}
for p in preds
]
return {
"status": "success",
"model": "aiy_plants_V1",
"source": "nature-id-cli",
"count": len(preds),
"inference_time_sec": result.get("inference_time_sec"),
"predictions": preds,
"top_k": top_k,
}
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc: RequestValidationError):
# Avoid leaking raw multipart bytes in validation responses.
errs: List[Dict[str, Any]] = []
for e in exc.errors() or []:
errs.append({"loc": e.get("loc"), "msg": e.get("msg"), "type": e.get("type")})
return JSONResponse(status_code=422, content={"detail": errs})
@app.get("/health")
def health() -> Dict[str, Any]:
cmd = (os.getenv("NATURE_ID_CMD") or "").strip()
return {
"status": "healthy",
"nature_id_cmd_configured": bool(cmd),
"nature_id_cmd": cmd,
}
@app.post("/identify")
async def identify(payload: IdentifyRequest) -> Dict[str, Any]:
if not payload.image_url:
raise HTTPException(status_code=400, detail="image_url is required")
tmp_path = ""
try:
tmp_path = await _download_image(payload.image_url)
result = _run_nature_id_cli(tmp_path, payload.top_k)
return _response_payload(result)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=503, detail=f"identify_failed: {e}")
finally:
if tmp_path:
try:
Path(tmp_path).unlink(missing_ok=True)
except Exception:
pass
@app.post("/identify-file")
async def identify_file(file: UploadFile = File(...), top_k: int = 3) -> Dict[str, Any]:
top_k = max(1, min(top_k, 10))
tmp_path = ""
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as f:
f.write(await file.read())
tmp_path = f.name
result = _run_nature_id_cli(tmp_path, top_k)
return _response_payload(result)
except Exception as e:
raise HTTPException(status_code=503, detail=f"identify_failed: {e}")
finally:
if tmp_path:
try:
Path(tmp_path).unlink(missing_ok=True)
except Exception:
pass