Adding missed files
This commit is contained in:
489
backend/providers.py
Normal file
489
backend/providers.py
Normal file
@@ -0,0 +1,489 @@
|
||||
"""
|
||||
OCR provider abstraction.
|
||||
|
||||
Each provider knows how to turn an image + a semantic OCR request (mode, prompt,
|
||||
options) into raw model text. DeepSeek-specific prompt tokens and grounding-box
|
||||
parsing live here too so the FastAPI routes stay model-agnostic.
|
||||
|
||||
Two providers ship today:
|
||||
- DeepSeekLocalProvider -> the local HF transformers DeepSeek-OCR model (GPU)
|
||||
- OllamaProvider -> any vision model served by an external Ollama host
|
||||
|
||||
The registry is built from environment variables at startup (see build_registry()).
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import base64
|
||||
import tempfile
|
||||
import shutil
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from decouple import config as env_config
|
||||
|
||||
# httpx is only needed when an Ollama model is actually used; import lazily so the
|
||||
# backend can run DeepSeek-only without the dependency installed.
|
||||
try:
|
||||
import httpx
|
||||
except Exception: # pragma: no cover - exercised only when httpx is missing
|
||||
httpx = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Prompt builders
|
||||
# =============================================================================
|
||||
def build_prompt(
|
||||
mode: str,
|
||||
user_prompt: str,
|
||||
grounding: bool,
|
||||
find_term: Optional[str],
|
||||
schema: Optional[str],
|
||||
include_caption: bool,
|
||||
) -> str:
|
||||
"""Build the DeepSeek-OCR prompt (with its special tokens) based on mode."""
|
||||
parts: List[str] = ["<image>"]
|
||||
mode_requires_grounding = mode in {"find_ref", "layout_map", "pii_redact"}
|
||||
if grounding or mode_requires_grounding:
|
||||
parts.append("<|grounding|>")
|
||||
|
||||
parts.append(_instruction_for_mode(mode, user_prompt, find_term, schema, include_caption))
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def build_ollama_prompt(
|
||||
mode: str,
|
||||
user_prompt: str,
|
||||
find_term: Optional[str],
|
||||
schema: Optional[str],
|
||||
include_caption: bool,
|
||||
) -> str:
|
||||
"""Build a plain natural-language prompt for a generic vision model.
|
||||
|
||||
No DeepSeek grounding tokens — Ollama vision models receive the image
|
||||
separately and respond in plain text.
|
||||
"""
|
||||
if mode == "plain_ocr":
|
||||
instruction = (
|
||||
"Transcribe all of the text in this image exactly as it appears, "
|
||||
"preserving line breaks and reading order. Output only the transcribed "
|
||||
"text with no commentary."
|
||||
)
|
||||
elif mode == "markdown":
|
||||
instruction = (
|
||||
"Convert this document image to clean GitHub-flavored Markdown, "
|
||||
"preserving headings, lists, and tables. Output only the Markdown."
|
||||
)
|
||||
elif mode == "tables_csv":
|
||||
instruction = (
|
||||
"Extract every table in this image and output CSV only. Use commas with "
|
||||
"minimal quoting. If there are multiple tables, separate them with a line "
|
||||
"containing '---'. Output only the CSV."
|
||||
)
|
||||
elif mode == "tables_md":
|
||||
instruction = (
|
||||
"Extract every table in this image as GitHub-flavored Markdown tables. "
|
||||
"Output only the tables."
|
||||
)
|
||||
elif mode == "kv_json":
|
||||
schema_text = schema.strip() if schema else "{}"
|
||||
instruction = (
|
||||
"Extract the key fields from this image and return strict JSON only "
|
||||
f"(no prose). Use this schema, filling in the values: {schema_text}"
|
||||
)
|
||||
elif mode == "figure_chart":
|
||||
instruction = (
|
||||
"Parse the figure in this image. First extract any numeric series as a "
|
||||
"two-column table (x,y). Then add a line containing '---' followed by a "
|
||||
"two-sentence summary of the chart."
|
||||
)
|
||||
elif mode == "find_ref":
|
||||
key = (find_term or "").strip() or "Total"
|
||||
instruction = (
|
||||
f"Find every occurrence of '{key}' in this image and quote the surrounding "
|
||||
"text for each match. If it does not appear, say so."
|
||||
)
|
||||
elif mode == "layout_map":
|
||||
instruction = (
|
||||
'Identify the layout blocks in this image and return a JSON array of '
|
||||
'objects {"type": one of ["title","paragraph","table","figure"]}. '
|
||||
"Do not include the text content."
|
||||
)
|
||||
elif mode == "pii_redact":
|
||||
instruction = (
|
||||
"Find all emails, phone numbers, postal addresses, and IBANs in this image. "
|
||||
'Return a JSON array of objects {"label", "text"}.'
|
||||
)
|
||||
elif mode == "multilingual":
|
||||
instruction = (
|
||||
"Transcribe all of the text in this image exactly, detecting the language "
|
||||
"automatically and preserving the original script. Output only the text."
|
||||
)
|
||||
elif mode == "describe":
|
||||
instruction = "Describe this image, focusing on the key visible elements."
|
||||
elif mode == "freeform":
|
||||
instruction = user_prompt.strip() if user_prompt else "Transcribe the text in this image."
|
||||
else:
|
||||
instruction = "Transcribe the text in this image."
|
||||
|
||||
if include_caption and mode != "describe":
|
||||
instruction += "\nThen add a one-paragraph description of the image."
|
||||
|
||||
return instruction
|
||||
|
||||
|
||||
def _instruction_for_mode(
|
||||
mode: str,
|
||||
user_prompt: str,
|
||||
find_term: Optional[str],
|
||||
schema: Optional[str],
|
||||
include_caption: bool,
|
||||
) -> str:
|
||||
"""The DeepSeek instruction text (without the <image>/<|grounding|> prefix tokens)."""
|
||||
if mode == "plain_ocr":
|
||||
instruction = "Free OCR."
|
||||
elif mode == "markdown":
|
||||
instruction = "Convert the document to markdown."
|
||||
elif mode == "tables_csv":
|
||||
instruction = (
|
||||
"Extract every table and output CSV only. "
|
||||
"Use commas, minimal quoting. If multiple tables, separate with a line containing '---'."
|
||||
)
|
||||
elif mode == "tables_md":
|
||||
instruction = "Extract every table as GitHub-flavored Markdown tables. Output only the tables."
|
||||
elif mode == "kv_json":
|
||||
schema_text = schema.strip() if schema else "{}"
|
||||
instruction = (
|
||||
"Extract key fields and return strict JSON only. "
|
||||
f"Use this schema (fill the values): {schema_text}"
|
||||
)
|
||||
elif mode == "figure_chart":
|
||||
instruction = (
|
||||
"Parse the figure. First extract any numeric series as a two-column table (x,y). "
|
||||
"Then summarize the chart in 2 sentences. Output the table, then a line '---', then the summary."
|
||||
)
|
||||
elif mode == "find_ref":
|
||||
key = (find_term or "").strip() or "Total"
|
||||
instruction = f"Locate <|ref|>{key}<|/ref|> in the image."
|
||||
elif mode == "layout_map":
|
||||
instruction = (
|
||||
'Return a JSON array of blocks with fields {"type":["title","paragraph","table","figure"],'
|
||||
'"box":[x1,y1,x2,y2]}. Do not include any text content.'
|
||||
)
|
||||
elif mode == "pii_redact":
|
||||
instruction = (
|
||||
'Find all occurrences of emails, phone numbers, postal addresses, and IBANs. '
|
||||
'Return a JSON array of objects {label, text, box:[x1,y1,x2,y2]}.'
|
||||
)
|
||||
elif mode == "multilingual":
|
||||
instruction = "Free OCR. Detect the language automatically and output in the same script."
|
||||
elif mode == "describe":
|
||||
instruction = "Describe this image. Focus on visible key elements."
|
||||
elif mode == "freeform":
|
||||
instruction = user_prompt.strip() if user_prompt else "OCR this image."
|
||||
else:
|
||||
instruction = "OCR this image."
|
||||
|
||||
if include_caption and mode != "describe":
|
||||
instruction = instruction + "\nThen add a one-paragraph description of the image."
|
||||
|
||||
return instruction
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Grounding parser (DeepSeek-specific; no-op on plain text)
|
||||
# =============================================================================
|
||||
DET_BLOCK = re.compile(
|
||||
r"<\|ref\|>(?P<label>.*?)<\|/ref\|>\s*<\|det\|>\s*(?P<coords>\[.*\])\s*<\|/det\|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
def clean_grounding_text(text: str) -> str:
|
||||
"""Remove grounding tags from text for display, keeping labels."""
|
||||
cleaned = re.sub(
|
||||
r"<\|ref\|>(.*?)<\|/ref\|>\s*<\|det\|>\s*\[.*\]\s*<\|/det\|>",
|
||||
r"\1",
|
||||
text,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
cleaned = re.sub(r"<\|grounding\|>", "", cleaned)
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
def parse_detections(text: str, image_width: int, image_height: int) -> List[Dict[str, Any]]:
|
||||
"""Parse grounding boxes from text and scale 0-999 normalized coords to pixels."""
|
||||
boxes: List[Dict[str, Any]] = []
|
||||
for m in DET_BLOCK.finditer(text or ""):
|
||||
label = m.group("label").strip()
|
||||
coords_str = m.group("coords").strip()
|
||||
|
||||
try:
|
||||
import ast
|
||||
|
||||
parsed = ast.literal_eval(coords_str)
|
||||
|
||||
if (
|
||||
isinstance(parsed, list)
|
||||
and len(parsed) == 4
|
||||
and all(isinstance(n, (int, float)) for n in parsed)
|
||||
):
|
||||
box_coords = [parsed]
|
||||
elif isinstance(parsed, list):
|
||||
box_coords = parsed
|
||||
else:
|
||||
raise ValueError("Unsupported coords structure")
|
||||
|
||||
for box in box_coords:
|
||||
if isinstance(box, (list, tuple)) and len(box) >= 4:
|
||||
x1 = int(float(box[0]) / 999 * image_width)
|
||||
y1 = int(float(box[1]) / 999 * image_height)
|
||||
x2 = int(float(box[2]) / 999 * image_width)
|
||||
y2 = int(float(box[3]) / 999 * image_height)
|
||||
boxes.append({"label": label, "box": [x1, y1, x2, y2]})
|
||||
except Exception as e:
|
||||
print(f"❌ Grounding parse failed: {e}")
|
||||
continue
|
||||
|
||||
return boxes
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Providers
|
||||
# =============================================================================
|
||||
GROUNDING_MODES = {"find_ref", "layout_map", "pii_redact"}
|
||||
|
||||
|
||||
class ProviderError(Exception):
|
||||
"""Raised when a provider cannot fulfil a request (e.g. backend unreachable)."""
|
||||
|
||||
|
||||
class OCRProvider(ABC):
|
||||
"""Turns an image + OCR request into raw model text."""
|
||||
|
||||
id: str
|
||||
label: str
|
||||
capabilities: Dict[str, Any]
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self,
|
||||
image_path: str,
|
||||
*,
|
||||
mode: str,
|
||||
prompt: str,
|
||||
grounding: bool,
|
||||
find_term: Optional[str],
|
||||
schema: Optional[str],
|
||||
include_caption: bool,
|
||||
options: Dict[str, Any],
|
||||
) -> str:
|
||||
"""Return the raw text output of the model for this image/request."""
|
||||
|
||||
def info(self) -> Dict[str, Any]:
|
||||
return {"id": self.id, "label": self.label, "capabilities": self.capabilities}
|
||||
|
||||
|
||||
class DeepSeekLocalProvider(OCRProvider):
|
||||
"""Local HF transformers DeepSeek-OCR model. Loaded lazily on first use."""
|
||||
|
||||
def __init__(self):
|
||||
self.id = "deepseek-local"
|
||||
self.label = "DeepSeek-OCR (local GPU)"
|
||||
self.capabilities = {"grounding": True, "advanced_settings": True}
|
||||
self._model = None
|
||||
self._tokenizer = None
|
||||
|
||||
@property
|
||||
def loaded(self) -> bool:
|
||||
return self._model is not None and self._tokenizer is not None
|
||||
|
||||
def _ensure_loaded(self):
|
||||
if self.loaded:
|
||||
return
|
||||
|
||||
# Heavy imports kept local so an Ollama-only deployment never needs torch.
|
||||
import torch
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
os.environ.pop("TRANSFORMERS_CACHE", None)
|
||||
model_name = env_config("MODEL_NAME", default="deepseek-ai/DeepSeek-OCR")
|
||||
hf_home = env_config("HF_HOME", default="/models")
|
||||
os.makedirs(hf_home, exist_ok=True)
|
||||
|
||||
print(f"🚀 Loading {model_name}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained(
|
||||
model_name,
|
||||
trust_remote_code=True,
|
||||
use_safetensors=True,
|
||||
attn_implementation="eager",
|
||||
torch_dtype=torch.bfloat16,
|
||||
).eval().to("cuda")
|
||||
|
||||
try:
|
||||
if getattr(tokenizer, "pad_token_id", None) is None and getattr(tokenizer, "eos_token_id", None) is not None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
if getattr(model.config, "pad_token_id", None) is None and getattr(tokenizer, "pad_token_id", None) is not None:
|
||||
model.config.pad_token_id = tokenizer.pad_token_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._model = model
|
||||
self._tokenizer = tokenizer
|
||||
print("✅ DeepSeek-OCR loaded and ready!")
|
||||
|
||||
def run(self, image_path, *, mode, prompt, grounding, find_term, schema, include_caption, options):
|
||||
self._ensure_loaded()
|
||||
|
||||
prompt_text = build_prompt(
|
||||
mode=mode,
|
||||
user_prompt=prompt,
|
||||
grounding=grounding,
|
||||
find_term=find_term,
|
||||
schema=schema,
|
||||
include_caption=include_caption,
|
||||
)
|
||||
|
||||
out_dir = tempfile.mkdtemp(prefix="dsocr_")
|
||||
try:
|
||||
res = self._model.infer(
|
||||
self._tokenizer,
|
||||
prompt=prompt_text,
|
||||
image_file=image_path,
|
||||
output_path=out_dir,
|
||||
base_size=int(options.get("base_size", 1024)),
|
||||
image_size=int(options.get("image_size", 640)),
|
||||
crop_mode=bool(options.get("crop_mode", True)),
|
||||
save_results=False,
|
||||
test_compress=bool(options.get("test_compress", False)),
|
||||
eval_mode=True,
|
||||
)
|
||||
|
||||
if isinstance(res, str):
|
||||
text = res.strip()
|
||||
elif isinstance(res, dict) and "text" in res:
|
||||
text = str(res["text"]).strip()
|
||||
elif isinstance(res, (list, tuple)):
|
||||
text = "\n".join(map(str, res)).strip()
|
||||
else:
|
||||
text = ""
|
||||
|
||||
if not text:
|
||||
mmd = os.path.join(out_dir, "result.mmd")
|
||||
if os.path.exists(mmd):
|
||||
with open(mmd, "r", encoding="utf-8") as fh:
|
||||
text = fh.read().strip()
|
||||
return text
|
||||
finally:
|
||||
shutil.rmtree(out_dir, ignore_errors=True)
|
||||
|
||||
|
||||
class OllamaProvider(OCRProvider):
|
||||
"""A single vision model served by an external Ollama host."""
|
||||
|
||||
def __init__(self, tag: str, base_url: str, label: Optional[str] = None):
|
||||
self.tag = tag
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.id = f"ollama:{tag}"
|
||||
self.label = label or f"{tag} (Ollama)"
|
||||
# Generic vision models don't emit DeepSeek grounding tokens.
|
||||
self.capabilities = {"grounding": False, "advanced_settings": False}
|
||||
|
||||
def run(self, image_path, *, mode, prompt, grounding, find_term, schema, include_caption, options):
|
||||
if httpx is None:
|
||||
raise ProviderError("httpx is not installed; cannot reach Ollama.")
|
||||
|
||||
prompt_text = build_ollama_prompt(
|
||||
mode=mode,
|
||||
user_prompt=prompt,
|
||||
find_term=find_term,
|
||||
schema=schema,
|
||||
include_caption=include_caption,
|
||||
)
|
||||
|
||||
with open(image_path, "rb") as f:
|
||||
img_b64 = base64.b64encode(f.read()).decode("utf-8")
|
||||
|
||||
payload = {
|
||||
"model": self.tag,
|
||||
"prompt": prompt_text,
|
||||
"images": [img_b64],
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
timeout = float(env_config("OLLAMA_TIMEOUT", default=300.0, cast=float))
|
||||
try:
|
||||
resp = httpx.post(f"{self.base_url}/api/generate", json=payload, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
detail = ""
|
||||
try:
|
||||
detail = e.response.json().get("error", "")
|
||||
except Exception:
|
||||
detail = e.response.text[:200]
|
||||
raise ProviderError(f"Ollama returned {e.response.status_code}: {detail}") from e
|
||||
except httpx.HTTPError as e:
|
||||
raise ProviderError(f"Could not reach Ollama at {self.base_url}: {e}") from e
|
||||
|
||||
return (data.get("response") or "").strip()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Registry
|
||||
# =============================================================================
|
||||
class ModelRegistry:
|
||||
def __init__(self, providers: List[OCRProvider], default_id: str):
|
||||
self._providers: Dict[str, OCRProvider] = {p.id: p for p in providers}
|
||||
# Fall back to the first registered provider if the configured default is gone.
|
||||
self.default_id = default_id if default_id in self._providers else (
|
||||
next(iter(self._providers), None)
|
||||
)
|
||||
|
||||
def get(self, model_id: Optional[str]) -> OCRProvider:
|
||||
chosen = model_id or self.default_id
|
||||
provider = self._providers.get(chosen)
|
||||
if provider is None:
|
||||
raise ProviderError(f"Unknown model '{chosen}'.")
|
||||
return provider
|
||||
|
||||
def list_models(self) -> List[Dict[str, Any]]:
|
||||
out = []
|
||||
for p in self._providers.values():
|
||||
entry = p.info()
|
||||
entry["default"] = (p.id == self.default_id)
|
||||
out.append(entry)
|
||||
return out
|
||||
|
||||
|
||||
def build_registry() -> ModelRegistry:
|
||||
"""Build the provider registry from environment variables.
|
||||
|
||||
Env:
|
||||
ENABLE_DEEPSEEK_LOCAL - register the local DeepSeek-OCR model (default: true)
|
||||
OLLAMA_BASE_URL - Ollama host (default: http://host.docker.internal:11434)
|
||||
OLLAMA_MODELS - comma-separated tags to surface (e.g. "glm-ocr,llama3.2-vision")
|
||||
DEFAULT_OCR_MODEL - id to select by default (default: deepseek-local)
|
||||
"""
|
||||
providers: List[OCRProvider] = []
|
||||
|
||||
enable_deepseek = env_config("ENABLE_DEEPSEEK_LOCAL", default="true").strip().lower() in {"1", "true", "yes"}
|
||||
if enable_deepseek:
|
||||
providers.append(DeepSeekLocalProvider())
|
||||
|
||||
base_url = env_config("OLLAMA_BASE_URL", default="http://host.docker.internal:11434")
|
||||
raw_tags = env_config("OLLAMA_MODELS", default="")
|
||||
tags = [t.strip() for t in raw_tags.split(",") if t.strip()]
|
||||
for tag in tags:
|
||||
providers.append(OllamaProvider(tag=tag, base_url=base_url))
|
||||
|
||||
default_id = env_config("DEFAULT_OCR_MODEL", default="deepseek-local")
|
||||
if not providers:
|
||||
# Defensive: nothing configured. Register DeepSeek so the app still starts.
|
||||
providers.append(DeepSeekLocalProvider())
|
||||
default_id = "deepseek-local"
|
||||
|
||||
registry = ModelRegistry(providers, default_id)
|
||||
print(f"🧠 OCR models registered: {[p.id for p in providers]} (default: {registry.default_id})")
|
||||
return registry
|
||||
Reference in New Issue
Block a user