Add in .env.example for setting ports, fix upload limit, fix bounding box, can now dismiss previous image, change markdown expectation to HTML - not MD. updated README with nvidia driver/container instructions

This commit is contained in:
Ray Dumasia
2025-10-21 21:35:17 +01:00
parent e02338436b
commit 3efc4da7ff
9 changed files with 399 additions and 101 deletions

View File

@@ -12,6 +12,7 @@ import torch
from transformers import AutoModel, AutoTokenizer
from PIL import Image
import uvicorn
from decouple import config as env_config
# -----------------------------
# Lifespan context for model loading
@@ -26,8 +27,8 @@ async def lifespan(app: FastAPI):
# Environment setup
os.environ.pop("TRANSFORMERS_CACHE", None)
MODEL_NAME = os.environ.get("MODEL_NAME", "deepseek-ai/DeepSeek-OCR")
HF_HOME = os.environ.get("HF_HOME", "/models")
MODEL_NAME = env_config("MODEL_NAME", default="deepseek-ai/DeepSeek-OCR")
HF_HOME = env_config("HF_HOME", default="/models")
os.makedirs(HF_HOME, exist_ok=True)
# Load model
@@ -138,7 +139,7 @@ def build_prompt(
elif mode == "multilingual":
instruction = "Free OCR. Detect the language automatically and output in the same script."
elif mode == "describe":
instruction = "Describe this image concisely in 2-3 sentences. Focus on visible key elements."
instruction = "Describe this image. Focus on visible key elements."
elif mode == "freeform":
instruction = user_prompt.strip() if user_prompt else "OCR this image."
else:
@@ -153,36 +154,82 @@ def build_prompt(
# -----------------------------
# Grounding parser
# -----------------------------
# Match a full detection block and capture the coordinates as the entire list expression
# Examples of captured coords (including outer brackets):
# - [[312, 339, 480, 681]]
# - [[504, 700, 625, 910], [771, 570, 996, 996]]
# - [[110, 310, 255, 800], [312, 343, 479, 680], ...]
# Using a greedy bracket capture ensures we include all inner lists up to the last ']' before </|det|>
DET_BLOCK = re.compile(
r"<\|ref\|>(?P<label>.*?)<\|/ref\|>\s*<\|det\|>\s*\[\s*\[\s*(?P<coords>[^\]]+?)\s*\]\s*\]\s*<\|/det\|>",
r"<\|ref\|>(?P<label>.*?)<\|/ref\|>\s*<\|det\|>\s*(?P<coords>\[.*\])\s*<\|/det\|>",
re.DOTALL,
)
def clean_grounding_text(text: str) -> str:
"""Remove grounding tags from text for display, keeping labels"""
# Replace <|ref|>label<|/ref|><|det|>[[...]]<|/det|> with just "label"
# Replace <|ref|>label<|/ref|><|det|>[...any nested lists...]<|/det|> with just the label
cleaned = re.sub(
r"<\|ref\|>(.*?)<\|/ref\|>\s*<\|det\|>\s*\[\s*\[[^\]]+\]\s*\]\s*<\|/det\|>",
r"<\|ref\|>(.*?)<\|/ref\|>\s*<\|det\|>\s*\[.*\]\s*<\|/det\|>",
r"\1",
text,
flags=re.DOTALL
flags=re.DOTALL,
)
# Also remove any standalone grounding tags
cleaned = re.sub(r"<\|grounding\|>", "", cleaned)
return cleaned.strip()
def parse_detections(text: str) -> List[Dict[str, Any]]:
"""Parse grounding boxes from text"""
def parse_detections(text: str, image_width: int, image_height: int) -> List[Dict[str, Any]]:
"""Parse grounding boxes from text and scale from 0-999 normalized coords to actual image dimensions
Handles both single and multiple bounding boxes:
- Single: <|ref|>label<|/ref|><|det|>[[x1,y1,x2,y2]]<|/det|>
- Multiple: <|ref|>label<|/ref|><|det|>[[x1,y1,x2,y2], [x1,y1,x2,y2], ...]<|/det|>
"""
boxes: List[Dict[str, Any]] = []
for m in DET_BLOCK.finditer(text or ""):
label = m.group("label").strip()
coords = [c.strip() for c in m.group("coords").split(",")]
coords_str = m.group("coords").strip()
print(f"🔍 DEBUG: Found detection for '{label}'")
print(f"📦 Raw coords string (with brackets): {coords_str}")
try:
nums = list(map(float, coords[:4]))
except Exception:
import ast
# Parse the full bracket expression directly (handles single and multiple)
parsed = ast.literal_eval(coords_str)
# Normalize to a list of lists
if (
isinstance(parsed, list)
and len(parsed) == 4
and all(isinstance(n, (int, float)) for n in parsed)
):
# Single box provided as [x1,y1,x2,y2]
box_coords = [parsed]
print("📦 Single box (flat list) detected")
elif isinstance(parsed, list):
box_coords = parsed
print(f"📦 Boxes detected: {len(box_coords)}")
else:
raise ValueError("Unsupported coords structure")
# Process each box
for idx, box in enumerate(box_coords):
if isinstance(box, (list, tuple)) and len(box) >= 4:
x1 = int(float(box[0]) / 999 * image_width)
y1 = int(float(box[1]) / 999 * image_height)
x2 = int(float(box[2]) / 999 * image_width)
y2 = int(float(box[3]) / 999 * image_height)
print(f" Box {idx+1}: {box} → [{x1}, {y1}, {x2}, {y2}]")
boxes.append({"label": label, "box": [x1, y1, x2, y2]})
else:
print(f" ⚠️ Skipping invalid box: {box}")
except Exception as e:
print(f"❌ Parsing failed: {e}")
continue
if len(nums) == 4:
boxes.append({"label": label, "box": nums})
print(f"🎯 Total boxes parsed: {len(boxes)}")
return boxes
# -----------------------------
@@ -289,8 +336,8 @@ async def ocr_inference(
if not text:
text = "No text returned by model."
# Parse grounding boxes
boxes = parse_detections(text) if ("<|det|>" in text or "<|ref|>" in text) else []
# Parse grounding boxes with proper coordinate scaling
boxes = parse_detections(text, orig_w or 1, orig_h or 1) if ("<|det|>" in text or "<|ref|>" in text) else []
# Clean grounding tags from display text, but keep the labels
display_text = clean_grounding_text(text) if ("<|ref|>" in text or "<|grounding|>" in text) else text
@@ -302,6 +349,7 @@ async def ocr_inference(
return JSONResponse({
"success": True,
"text": display_text,
"raw_text": text, # Include raw model output for debugging
"boxes": boxes,
"image_dims": {"w": orig_w, "h": orig_h},
"metadata": {
@@ -326,4 +374,6 @@ async def ocr_inference(
shutil.rmtree(out_dir, ignore_errors=True)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
host = env_config("API_HOST", default="0.0.0.0")
port = env_config("API_PORT", default=8000, cast=int)
uvicorn.run(app, host=host, port=port)