import os import re import uuid import tempfile import shutil import base64 from typing import List, Dict, Any, Optional from contextlib import asynccontextmanager from datetime import datetime, timezone from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse, FileResponse from pydantic import BaseModel import torch from transformers import AutoModel, AutoTokenizer from PIL import Image import uvicorn from decouple import config as env_config # Import PDF and document conversion utilities from pdf_utils import ( pdf_to_images_high_quality, images_to_pdf, extract_ref_patterns, crop_images_from_refs, clean_markdown_content ) from format_converter import DocumentConverter from database import init_db, get_db OCR_IMAGES_DIR = env_config("OCR_IMAGES_DIR", default="/data/ocr_images") # ----------------------------- # Lifespan context for model loading # ----------------------------- model = None tokenizer = None @asynccontextmanager async def lifespan(app: FastAPI): """Load model on startup, cleanup on shutdown""" global model, tokenizer # Image storage directory os.makedirs(OCR_IMAGES_DIR, exist_ok=True) # Database try: init_db() except Exception as exc: print(f"Warning: database initialization failed: {exc}") # Environment setup os.environ.pop("TRANSFORMERS_CACHE", None) MODEL_NAME = env_config("MODEL_NAME", default="deepseek-ai/DeepSeek-OCR") HF_HOME = env_config("HF_HOME", default="/models") os.makedirs(HF_HOME, exist_ok=True) # Load model print(f"🚀 Loading {MODEL_NAME}...") torch_dtype = torch.bfloat16 tokenizer = AutoTokenizer.from_pretrained( MODEL_NAME, trust_remote_code=True, ) model = AutoModel.from_pretrained( MODEL_NAME, trust_remote_code=True, use_safetensors=True, attn_implementation="eager", torch_dtype=torch_dtype, ).eval().to("cuda") # Pad token setup try: if getattr(tokenizer, "pad_token_id", None) is None and getattr(tokenizer, "eos_token_id", None) is not None: tokenizer.pad_token = tokenizer.eos_token if getattr(model.config, "pad_token_id", None) is None and getattr(tokenizer, "pad_token_id", None) is not None: model.config.pad_token_id = tokenizer.pad_token_id except Exception: pass print("✅ Model loaded and ready!") yield # Cleanup print("🛑 Shutting down...") # ----------------------------- # FastAPI app # ----------------------------- app = FastAPI( title="DeepSeek-OCR API", description="Blazing fast OCR with DeepSeek-OCR model 🔥", version="2.0.0", lifespan=lifespan ) # CORS middleware for React frontend CORS_ORIGINS = env_config("CORS_ORIGINS", default="").split(",") CORS_ORIGINS = [o.strip() for o in CORS_ORIGINS if o.strip()] app.add_middleware( CORSMiddleware, allow_origins=CORS_ORIGINS if CORS_ORIGINS else ["http://localhost:3000"], allow_credentials=True, allow_methods=["GET", "POST"], allow_headers=["*"], ) # ----------------------------- # Prompt builder # ----------------------------- def build_prompt( mode: str, user_prompt: str, grounding: bool, find_term: Optional[str], schema: Optional[str], include_caption: bool, ) -> str: """Build the prompt based on mode""" parts: List[str] = [""] mode_requires_grounding = mode in {"find_ref", "layout_map", "pii_redact"} if grounding or mode_requires_grounding: parts.append("<|grounding|>") instruction = "" if mode == "plain_ocr": instruction = "Free OCR." elif mode == "markdown": instruction = "Convert the document to markdown." elif mode == "tables_csv": instruction = ( "Extract every table and output CSV only. " "Use commas, minimal quoting. If multiple tables, separate with a line containing '---'." ) elif mode == "tables_md": instruction = "Extract every table as GitHub-flavored Markdown tables. Output only the tables." elif mode == "kv_json": schema_text = schema.strip() if schema else "{}" instruction = ( "Extract key fields and return strict JSON only. " f"Use this schema (fill the values): {schema_text}" ) elif mode == "figure_chart": instruction = ( "Parse the figure. First extract any numeric series as a two-column table (x,y). " "Then summarize the chart in 2 sentences. Output the table, then a line '---', then the summary." ) elif mode == "find_ref": key = (find_term or "").strip() or "Total" instruction = f"Locate <|ref|>{key}<|/ref|> in the image." elif mode == "layout_map": instruction = ( 'Return a JSON array of blocks with fields {"type":["title","paragraph","table","figure"],' '"box":[x1,y1,x2,y2]}. Do not include any text content.' ) elif mode == "pii_redact": instruction = ( 'Find all occurrences of emails, phone numbers, postal addresses, and IBANs. ' 'Return a JSON array of objects {label, text, box:[x1,y1,x2,y2]}.' ) elif mode == "multilingual": instruction = "Free OCR. Detect the language automatically and output in the same script." elif mode == "describe": instruction = "Describe this image. Focus on visible key elements." elif mode == "freeform": instruction = user_prompt.strip() if user_prompt else "OCR this image." else: instruction = "OCR this image." if include_caption and mode not in {"describe"}: instruction = instruction + "\nThen add a one-paragraph description of the image." parts.append(instruction) return "\n".join(parts) # ----------------------------- # Grounding parser # ----------------------------- # Match a full detection block and capture the coordinates as the entire list expression # Examples of captured coords (including outer brackets): # - [[312, 339, 480, 681]] # - [[504, 700, 625, 910], [771, 570, 996, 996]] # - [[110, 310, 255, 800], [312, 343, 479, 680], ...] # Using a greedy bracket capture ensures we include all inner lists up to the last ']' before DET_BLOCK = re.compile( r"<\|ref\|>(?P