Add in .env.example for setting ports, fix upload limit, fix bounding box, can now dismiss previous image, change markdown expectation to HTML - not MD. updated README with nvidia driver/container instructions
This commit is contained in:
@@ -12,6 +12,7 @@ import torch
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from PIL import Image
|
||||
import uvicorn
|
||||
from decouple import config as env_config
|
||||
|
||||
# -----------------------------
|
||||
# Lifespan context for model loading
|
||||
@@ -26,8 +27,8 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
# Environment setup
|
||||
os.environ.pop("TRANSFORMERS_CACHE", None)
|
||||
MODEL_NAME = os.environ.get("MODEL_NAME", "deepseek-ai/DeepSeek-OCR")
|
||||
HF_HOME = os.environ.get("HF_HOME", "/models")
|
||||
MODEL_NAME = env_config("MODEL_NAME", default="deepseek-ai/DeepSeek-OCR")
|
||||
HF_HOME = env_config("HF_HOME", default="/models")
|
||||
os.makedirs(HF_HOME, exist_ok=True)
|
||||
|
||||
# Load model
|
||||
@@ -138,7 +139,7 @@ def build_prompt(
|
||||
elif mode == "multilingual":
|
||||
instruction = "Free OCR. Detect the language automatically and output in the same script."
|
||||
elif mode == "describe":
|
||||
instruction = "Describe this image concisely in 2-3 sentences. Focus on visible key elements."
|
||||
instruction = "Describe this image. Focus on visible key elements."
|
||||
elif mode == "freeform":
|
||||
instruction = user_prompt.strip() if user_prompt else "OCR this image."
|
||||
else:
|
||||
@@ -153,36 +154,82 @@ def build_prompt(
|
||||
# -----------------------------
|
||||
# Grounding parser
|
||||
# -----------------------------
|
||||
# Match a full detection block and capture the coordinates as the entire list expression
|
||||
# Examples of captured coords (including outer brackets):
|
||||
# - [[312, 339, 480, 681]]
|
||||
# - [[504, 700, 625, 910], [771, 570, 996, 996]]
|
||||
# - [[110, 310, 255, 800], [312, 343, 479, 680], ...]
|
||||
# Using a greedy bracket capture ensures we include all inner lists up to the last ']' before </|det|>
|
||||
DET_BLOCK = re.compile(
|
||||
r"<\|ref\|>(?P<label>.*?)<\|/ref\|>\s*<\|det\|>\s*\[\s*\[\s*(?P<coords>[^\]]+?)\s*\]\s*\]\s*<\|/det\|>",
|
||||
r"<\|ref\|>(?P<label>.*?)<\|/ref\|>\s*<\|det\|>\s*(?P<coords>\[.*\])\s*<\|/det\|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def clean_grounding_text(text: str) -> str:
|
||||
"""Remove grounding tags from text for display, keeping labels"""
|
||||
# Replace <|ref|>label<|/ref|><|det|>[[...]]<|/det|> with just "label"
|
||||
# Replace <|ref|>label<|/ref|><|det|>[...any nested lists...]<|/det|> with just the label
|
||||
cleaned = re.sub(
|
||||
r"<\|ref\|>(.*?)<\|/ref\|>\s*<\|det\|>\s*\[\s*\[[^\]]+\]\s*\]\s*<\|/det\|>",
|
||||
r"<\|ref\|>(.*?)<\|/ref\|>\s*<\|det\|>\s*\[.*\]\s*<\|/det\|>",
|
||||
r"\1",
|
||||
text,
|
||||
flags=re.DOTALL
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
# Also remove any standalone grounding tags
|
||||
cleaned = re.sub(r"<\|grounding\|>", "", cleaned)
|
||||
return cleaned.strip()
|
||||
|
||||
def parse_detections(text: str) -> List[Dict[str, Any]]:
|
||||
"""Parse grounding boxes from text"""
|
||||
def parse_detections(text: str, image_width: int, image_height: int) -> List[Dict[str, Any]]:
|
||||
"""Parse grounding boxes from text and scale from 0-999 normalized coords to actual image dimensions
|
||||
|
||||
Handles both single and multiple bounding boxes:
|
||||
- Single: <|ref|>label<|/ref|><|det|>[[x1,y1,x2,y2]]<|/det|>
|
||||
- Multiple: <|ref|>label<|/ref|><|det|>[[x1,y1,x2,y2], [x1,y1,x2,y2], ...]<|/det|>
|
||||
"""
|
||||
boxes: List[Dict[str, Any]] = []
|
||||
for m in DET_BLOCK.finditer(text or ""):
|
||||
label = m.group("label").strip()
|
||||
coords = [c.strip() for c in m.group("coords").split(",")]
|
||||
coords_str = m.group("coords").strip()
|
||||
|
||||
print(f"🔍 DEBUG: Found detection for '{label}'")
|
||||
print(f"📦 Raw coords string (with brackets): {coords_str}")
|
||||
|
||||
try:
|
||||
nums = list(map(float, coords[:4]))
|
||||
except Exception:
|
||||
import ast
|
||||
|
||||
# Parse the full bracket expression directly (handles single and multiple)
|
||||
parsed = ast.literal_eval(coords_str)
|
||||
|
||||
# Normalize to a list of lists
|
||||
if (
|
||||
isinstance(parsed, list)
|
||||
and len(parsed) == 4
|
||||
and all(isinstance(n, (int, float)) for n in parsed)
|
||||
):
|
||||
# Single box provided as [x1,y1,x2,y2]
|
||||
box_coords = [parsed]
|
||||
print("📦 Single box (flat list) detected")
|
||||
elif isinstance(parsed, list):
|
||||
box_coords = parsed
|
||||
print(f"📦 Boxes detected: {len(box_coords)}")
|
||||
else:
|
||||
raise ValueError("Unsupported coords structure")
|
||||
|
||||
# Process each box
|
||||
for idx, box in enumerate(box_coords):
|
||||
if isinstance(box, (list, tuple)) and len(box) >= 4:
|
||||
x1 = int(float(box[0]) / 999 * image_width)
|
||||
y1 = int(float(box[1]) / 999 * image_height)
|
||||
x2 = int(float(box[2]) / 999 * image_width)
|
||||
y2 = int(float(box[3]) / 999 * image_height)
|
||||
print(f" Box {idx+1}: {box} → [{x1}, {y1}, {x2}, {y2}]")
|
||||
boxes.append({"label": label, "box": [x1, y1, x2, y2]})
|
||||
else:
|
||||
print(f" ⚠️ Skipping invalid box: {box}")
|
||||
except Exception as e:
|
||||
print(f"❌ Parsing failed: {e}")
|
||||
continue
|
||||
if len(nums) == 4:
|
||||
boxes.append({"label": label, "box": nums})
|
||||
|
||||
print(f"🎯 Total boxes parsed: {len(boxes)}")
|
||||
return boxes
|
||||
|
||||
# -----------------------------
|
||||
@@ -289,8 +336,8 @@ async def ocr_inference(
|
||||
if not text:
|
||||
text = "No text returned by model."
|
||||
|
||||
# Parse grounding boxes
|
||||
boxes = parse_detections(text) if ("<|det|>" in text or "<|ref|>" in text) else []
|
||||
# Parse grounding boxes with proper coordinate scaling
|
||||
boxes = parse_detections(text, orig_w or 1, orig_h or 1) if ("<|det|>" in text or "<|ref|>" in text) else []
|
||||
|
||||
# Clean grounding tags from display text, but keep the labels
|
||||
display_text = clean_grounding_text(text) if ("<|ref|>" in text or "<|grounding|>" in text) else text
|
||||
@@ -302,6 +349,7 @@ async def ocr_inference(
|
||||
return JSONResponse({
|
||||
"success": True,
|
||||
"text": display_text,
|
||||
"raw_text": text, # Include raw model output for debugging
|
||||
"boxes": boxes,
|
||||
"image_dims": {"w": orig_w, "h": orig_h},
|
||||
"metadata": {
|
||||
@@ -326,4 +374,6 @@ async def ocr_inference(
|
||||
shutil.rmtree(out_dir, ignore_errors=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
host = env_config("API_HOST", default="0.0.0.0")
|
||||
port = env_config("API_PORT", default=8000, cast=int)
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
Reference in New Issue
Block a user