import json import os import re import shlex import subprocess import tempfile from pathlib import Path from typing import Any, Dict, List, Optional import httpx from fastapi import FastAPI, File, HTTPException, UploadFile from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse from pydantic import BaseModel, Field app = FastAPI(title="plant-vision-node1", version="0.1.1") class IdentifyRequest(BaseModel): image_url: Optional[str] = None top_k: int = Field(default=3, ge=1, le=10) def _normalize_predictions(raw: Any, top_k: int) -> List[Dict[str, Any]]: preds: List[Dict[str, Any]] = [] if isinstance(raw, dict): for key in ("predictions", "results", "candidates"): if isinstance(raw.get(key), list): raw = raw[key] break if isinstance(raw, list): for item in raw[:top_k]: if not isinstance(item, dict): continue name = ( item.get("scientific_name") or item.get("scientificName") or item.get("label") or item.get("name") or "unknown" ) common = item.get("common_name") or item.get("commonName") or item.get("common") or "-" score = item.get("score", item.get("confidence", 0.0)) try: score_f = float(score) except Exception: score_f = 0.0 preds.append({"scientific_name": str(name), "common_name": str(common), "score": score_f}) return preds[:top_k] def _parse_text_output(text: str, top_k: int) -> List[Dict[str, Any]]: """ Parse only model score lines, e.g.: 97.6% Persicaria amphibia 86.1% Canada Goldenrod (Solidago canadensis) Ignore service lines like "Read ..." or "Classification of ...". """ preds: List[Dict[str, Any]] = [] for raw_line in (text or "").splitlines(): line = raw_line.strip() if not line or "%" not in line: continue m = re.match(r"^\s*(\d+(?:\.\d+)?)%\s+(.+)$", line) if not m: continue score_str, name_part = m.groups() try: score = float(score_str) except ValueError: continue name = name_part.strip() if not name: continue common_name = "-" scientific_name = name # If output is "Common Name (Scientific name)", preserve both. paren = re.match(r"^(.*?)\s*\(([^()]+)\)\s*$", name) if paren: common, scientific = paren.groups() common = common.strip() scientific = scientific.strip() if common: common_name = common if scientific: scientific_name = scientific preds.append( { "scientific_name": scientific_name, "common_name": common_name, "score": score, } ) preds.sort(key=lambda x: float(x.get("score", 0.0)), reverse=True) return preds[:top_k] def _extract_inference_time(stdout: str) -> Optional[float]: m = re.search(r"took\s+(\d+(?:\.\d+)?)\s+secs", stdout or "") if not m: return None try: return float(m.group(1)) except Exception: return None def _run_nature_id_cli(image_path: str, top_k: int) -> Dict[str, Any]: cmd_tmpl = (os.getenv("NATURE_ID_CMD") or "").strip() timeout_s = int(os.getenv("NATURE_ID_TIMEOUT", "40")) if not cmd_tmpl: raise RuntimeError("NATURE_ID_CMD is not configured") cmd = cmd_tmpl.replace("{image_path}", image_path) proc = subprocess.run( shlex.split(cmd), capture_output=True, text=True, timeout=timeout_s, check=False, ) if proc.returncode != 0: raise RuntimeError(f"nature-id cli failed rc={proc.returncode}: {proc.stderr.strip()[:240]}") out = (proc.stdout or "").strip() inference_time_sec = _extract_inference_time(out) if not out: return {"predictions": [], "inference_time_sec": inference_time_sec} try: parsed = json.loads(out) preds = _normalize_predictions(parsed, top_k) except Exception: preds = _parse_text_output(out, top_k) return {"predictions": preds, "inference_time_sec": inference_time_sec} async def _download_image(image_url: str) -> str: timeout_s = float(os.getenv("DOWNLOAD_TIMEOUT", "20")) async with httpx.AsyncClient(timeout=timeout_s) as client: resp = await client.get(image_url) resp.raise_for_status() data = resp.content with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as f: f.write(data) return f.name def _response_payload(result: Dict[str, Any]) -> Dict[str, Any]: preds = result.get("predictions") or [] top_k = [ { "confidence": float(p.get("score", 0.0)), "name": str((p.get("common_name") if p.get("common_name") not in (None, "", "-") else p.get("scientific_name")) or "unknown"), "scientific_name": str(p.get("scientific_name") or "unknown"), } for p in preds ] return { "status": "success", "model": "aiy_plants_V1", "source": "nature-id-cli", "count": len(preds), "inference_time_sec": result.get("inference_time_sec"), "predictions": preds, "top_k": top_k, } @app.exception_handler(RequestValidationError) async def validation_exception_handler(_, exc: RequestValidationError): # Avoid leaking raw multipart bytes in validation responses. errs: List[Dict[str, Any]] = [] for e in exc.errors() or []: errs.append({"loc": e.get("loc"), "msg": e.get("msg"), "type": e.get("type")}) return JSONResponse(status_code=422, content={"detail": errs}) @app.get("/health") def health() -> Dict[str, Any]: cmd = (os.getenv("NATURE_ID_CMD") or "").strip() return { "status": "healthy", "nature_id_cmd_configured": bool(cmd), "nature_id_cmd": cmd, } @app.post("/identify") async def identify(payload: IdentifyRequest) -> Dict[str, Any]: if not payload.image_url: raise HTTPException(status_code=400, detail="image_url is required") tmp_path = "" try: tmp_path = await _download_image(payload.image_url) result = _run_nature_id_cli(tmp_path, payload.top_k) return _response_payload(result) except HTTPException: raise except Exception as e: raise HTTPException(status_code=503, detail=f"identify_failed: {e}") finally: if tmp_path: try: Path(tmp_path).unlink(missing_ok=True) except Exception: pass @app.post("/identify-file") async def identify_file(file: UploadFile = File(...), top_k: int = 3) -> Dict[str, Any]: top_k = max(1, min(top_k, 10)) tmp_path = "" try: with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as f: f.write(await file.read()) tmp_path = f.name result = _run_nature_id_cli(tmp_path, top_k) return _response_payload(result) except Exception as e: raise HTTPException(status_code=503, detail=f"identify_failed: {e}") finally: if tmp_path: try: Path(tmp_path).unlink(missing_ok=True) except Exception: pass