239 lines
7.4 KiB
Python
239 lines
7.4 KiB
Python
import json
|
|
import os
|
|
import re
|
|
import shlex
|
|
import subprocess
|
|
import tempfile
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import httpx
|
|
from fastapi import FastAPI, File, HTTPException, UploadFile
|
|
from fastapi.exceptions import RequestValidationError
|
|
from fastapi.responses import JSONResponse
|
|
from pydantic import BaseModel, Field
|
|
|
|
app = FastAPI(title="plant-vision-node1", version="0.1.1")
|
|
|
|
|
|
class IdentifyRequest(BaseModel):
|
|
image_url: Optional[str] = None
|
|
top_k: int = Field(default=3, ge=1, le=10)
|
|
|
|
|
|
def _normalize_predictions(raw: Any, top_k: int) -> List[Dict[str, Any]]:
|
|
preds: List[Dict[str, Any]] = []
|
|
if isinstance(raw, dict):
|
|
for key in ("predictions", "results", "candidates"):
|
|
if isinstance(raw.get(key), list):
|
|
raw = raw[key]
|
|
break
|
|
if isinstance(raw, list):
|
|
for item in raw[:top_k]:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
name = (
|
|
item.get("scientific_name")
|
|
or item.get("scientificName")
|
|
or item.get("label")
|
|
or item.get("name")
|
|
or "unknown"
|
|
)
|
|
common = item.get("common_name") or item.get("commonName") or item.get("common") or "-"
|
|
score = item.get("score", item.get("confidence", 0.0))
|
|
try:
|
|
score_f = float(score)
|
|
except Exception:
|
|
score_f = 0.0
|
|
preds.append({"scientific_name": str(name), "common_name": str(common), "score": score_f})
|
|
return preds[:top_k]
|
|
|
|
|
|
def _parse_text_output(text: str, top_k: int) -> List[Dict[str, Any]]:
|
|
"""
|
|
Parse only model score lines, e.g.:
|
|
97.6% Persicaria amphibia
|
|
86.1% Canada Goldenrod (Solidago canadensis)
|
|
Ignore service lines like "Read ..." or "Classification of ...".
|
|
"""
|
|
preds: List[Dict[str, Any]] = []
|
|
for raw_line in (text or "").splitlines():
|
|
line = raw_line.strip()
|
|
if not line or "%" not in line:
|
|
continue
|
|
|
|
m = re.match(r"^\s*(\d+(?:\.\d+)?)%\s+(.+)$", line)
|
|
if not m:
|
|
continue
|
|
|
|
score_str, name_part = m.groups()
|
|
try:
|
|
score = float(score_str)
|
|
except ValueError:
|
|
continue
|
|
|
|
name = name_part.strip()
|
|
if not name:
|
|
continue
|
|
|
|
common_name = "-"
|
|
scientific_name = name
|
|
|
|
# If output is "Common Name (Scientific name)", preserve both.
|
|
paren = re.match(r"^(.*?)\s*\(([^()]+)\)\s*$", name)
|
|
if paren:
|
|
common, scientific = paren.groups()
|
|
common = common.strip()
|
|
scientific = scientific.strip()
|
|
if common:
|
|
common_name = common
|
|
if scientific:
|
|
scientific_name = scientific
|
|
|
|
preds.append(
|
|
{
|
|
"scientific_name": scientific_name,
|
|
"common_name": common_name,
|
|
"score": score,
|
|
}
|
|
)
|
|
|
|
preds.sort(key=lambda x: float(x.get("score", 0.0)), reverse=True)
|
|
return preds[:top_k]
|
|
|
|
|
|
def _extract_inference_time(stdout: str) -> Optional[float]:
|
|
m = re.search(r"took\s+(\d+(?:\.\d+)?)\s+secs", stdout or "")
|
|
if not m:
|
|
return None
|
|
try:
|
|
return float(m.group(1))
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def _run_nature_id_cli(image_path: str, top_k: int) -> Dict[str, Any]:
|
|
cmd_tmpl = (os.getenv("NATURE_ID_CMD") or "").strip()
|
|
timeout_s = int(os.getenv("NATURE_ID_TIMEOUT", "40"))
|
|
|
|
if not cmd_tmpl:
|
|
raise RuntimeError("NATURE_ID_CMD is not configured")
|
|
|
|
cmd = cmd_tmpl.replace("{image_path}", image_path)
|
|
proc = subprocess.run(
|
|
shlex.split(cmd),
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=timeout_s,
|
|
check=False,
|
|
)
|
|
if proc.returncode != 0:
|
|
raise RuntimeError(f"nature-id cli failed rc={proc.returncode}: {proc.stderr.strip()[:240]}")
|
|
|
|
out = (proc.stdout or "").strip()
|
|
inference_time_sec = _extract_inference_time(out)
|
|
if not out:
|
|
return {"predictions": [], "inference_time_sec": inference_time_sec}
|
|
|
|
try:
|
|
parsed = json.loads(out)
|
|
preds = _normalize_predictions(parsed, top_k)
|
|
except Exception:
|
|
preds = _parse_text_output(out, top_k)
|
|
|
|
return {"predictions": preds, "inference_time_sec": inference_time_sec}
|
|
|
|
|
|
async def _download_image(image_url: str) -> str:
|
|
timeout_s = float(os.getenv("DOWNLOAD_TIMEOUT", "20"))
|
|
async with httpx.AsyncClient(timeout=timeout_s) as client:
|
|
resp = await client.get(image_url)
|
|
resp.raise_for_status()
|
|
data = resp.content
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as f:
|
|
f.write(data)
|
|
return f.name
|
|
|
|
|
|
def _response_payload(result: Dict[str, Any]) -> Dict[str, Any]:
|
|
preds = result.get("predictions") or []
|
|
top_k = [
|
|
{
|
|
"confidence": float(p.get("score", 0.0)),
|
|
"name": str((p.get("common_name") if p.get("common_name") not in (None, "", "-") else p.get("scientific_name")) or "unknown"),
|
|
"scientific_name": str(p.get("scientific_name") or "unknown"),
|
|
}
|
|
for p in preds
|
|
]
|
|
return {
|
|
"status": "success",
|
|
"model": "aiy_plants_V1",
|
|
"source": "nature-id-cli",
|
|
"count": len(preds),
|
|
"inference_time_sec": result.get("inference_time_sec"),
|
|
"predictions": preds,
|
|
"top_k": top_k,
|
|
}
|
|
|
|
|
|
@app.exception_handler(RequestValidationError)
|
|
async def validation_exception_handler(_, exc: RequestValidationError):
|
|
# Avoid leaking raw multipart bytes in validation responses.
|
|
errs: List[Dict[str, Any]] = []
|
|
for e in exc.errors() or []:
|
|
errs.append({"loc": e.get("loc"), "msg": e.get("msg"), "type": e.get("type")})
|
|
return JSONResponse(status_code=422, content={"detail": errs})
|
|
|
|
|
|
@app.get("/health")
|
|
def health() -> Dict[str, Any]:
|
|
cmd = (os.getenv("NATURE_ID_CMD") or "").strip()
|
|
return {
|
|
"status": "healthy",
|
|
"nature_id_cmd_configured": bool(cmd),
|
|
"nature_id_cmd": cmd,
|
|
}
|
|
|
|
|
|
@app.post("/identify")
|
|
async def identify(payload: IdentifyRequest) -> Dict[str, Any]:
|
|
if not payload.image_url:
|
|
raise HTTPException(status_code=400, detail="image_url is required")
|
|
|
|
tmp_path = ""
|
|
try:
|
|
tmp_path = await _download_image(payload.image_url)
|
|
result = _run_nature_id_cli(tmp_path, payload.top_k)
|
|
return _response_payload(result)
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(status_code=503, detail=f"identify_failed: {e}")
|
|
finally:
|
|
if tmp_path:
|
|
try:
|
|
Path(tmp_path).unlink(missing_ok=True)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
@app.post("/identify-file")
|
|
async def identify_file(file: UploadFile = File(...), top_k: int = 3) -> Dict[str, Any]:
|
|
top_k = max(1, min(top_k, 10))
|
|
tmp_path = ""
|
|
try:
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as f:
|
|
f.write(await file.read())
|
|
tmp_path = f.name
|
|
result = _run_nature_id_cli(tmp_path, top_k)
|
|
return _response_payload(result)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=503, detail=f"identify_failed: {e}")
|
|
finally:
|
|
if tmp_path:
|
|
try:
|
|
Path(tmp_path).unlink(missing_ok=True)
|
|
except Exception:
|
|
pass
|