Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from typing import List | |
| import torch | |
| import uvicorn | |
| from models.schemas import EmbeddingRequest, EmbeddingResponse, ModelInfo | |
| from utils.helpers import load_models, get_embeddings, cleanup_memory | |
| app = FastAPI( | |
| title="Multilingual & Legal Embedding API", | |
| description="Multi-model embedding API for Spanish, Catalan, English and Legal texts", | |
| version="3.0.0" | |
| ) | |
| # Add CORS middleware to allow cross-origin requests | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, specify actual domains | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global model cache - loaded on demand | |
| models_cache = {} | |
| def ensure_models_loaded(): | |
| """Load models on first request if not already loaded""" | |
| global models_cache | |
| if not models_cache: | |
| try: | |
| print("Loading models on demand...") | |
| models_cache = load_models() | |
| print("All models loaded successfully!") | |
| except Exception as e: | |
| print(f"Failed to load models: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Model loading failed: {str(e)}") | |
| async def root(): | |
| return { | |
| "message": "Multilingual & Legal Embedding API", | |
| "models": ["jina", "robertalex", "jina-v3", "legal-bert", "roberta-ca"], | |
| "status": "running", | |
| "docs": "/docs", | |
| "total_models": 5 | |
| } | |
| async def create_embeddings(request: EmbeddingRequest): | |
| """Generate embeddings for input texts""" | |
| try: | |
| # Load models on first request | |
| ensure_models_loaded() | |
| if not request.texts: | |
| raise HTTPException(status_code=400, detail="No texts provided") | |
| if len(request.texts) > 50: # Rate limiting | |
| raise HTTPException(status_code=400, detail="Maximum 50 texts per request") | |
| embeddings = get_embeddings( | |
| request.texts, | |
| request.model, | |
| models_cache, | |
| request.normalize, | |
| request.max_length | |
| ) | |
| # Cleanup memory after large batches | |
| if len(request.texts) > 20: | |
| cleanup_memory() | |
| return EmbeddingResponse( | |
| embeddings=embeddings, | |
| model_used=request.model, | |
| dimensions=len(embeddings[0]) if embeddings else 0, | |
| num_texts=len(request.texts) | |
| ) | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}") | |
| async def list_models(): | |
| """List available models and their specifications""" | |
| return [ | |
| ModelInfo( | |
| model_id="jina", | |
| name="jinaai/jina-embeddings-v2-base-es", | |
| dimensions=768, | |
| max_sequence_length=8192, | |
| languages=["Spanish", "English"], | |
| model_type="bilingual", | |
| description="Bilingual Spanish-English embeddings with long context support" | |
| ), | |
| ModelInfo( | |
| model_id="robertalex", | |
| name="PlanTL-GOB-ES/RoBERTalex", | |
| dimensions=768, | |
| max_sequence_length=512, | |
| languages=["Spanish"], | |
| model_type="legal domain", | |
| description="Spanish legal domain specialized embeddings" | |
| ), | |
| ModelInfo( | |
| model_id="jina-v3", | |
| name="jinaai/jina-embeddings-v3", | |
| dimensions=1024, | |
| max_sequence_length=8192, | |
| languages=["Multilingual"], | |
| model_type="multilingual", | |
| description="Latest Jina v3 with superior multilingual performance" | |
| ), | |
| ModelInfo( | |
| model_id="legal-bert", | |
| name="nlpaueb/legal-bert-base-uncased", | |
| dimensions=768, | |
| max_sequence_length=512, | |
| languages=["English"], | |
| model_type="legal domain", | |
| description="English legal domain BERT model" | |
| ), | |
| ModelInfo( | |
| model_id="roberta-ca", | |
| name="projecte-aina/roberta-large-ca-v2", | |
| dimensions=1024, | |
| max_sequence_length=512, | |
| languages=["Catalan"], | |
| model_type="general", | |
| description="Catalan RoBERTa-large model trained on large corpus" | |
| ) | |
| ] | |
| async def health_check(): | |
| """Health check endpoint""" | |
| models_loaded = len(models_cache) == 5 | |
| return { | |
| "status": "healthy" if models_loaded else "ready", | |
| "models_loaded": models_loaded, | |
| "available_models": list(models_cache.keys()), | |
| "expected_models": ["jina", "robertalex", "jina-v3", "legal-bert", "roberta-ca"], | |
| "models_count": len(models_cache), | |
| "note": "Models load on first embedding request" if not models_loaded else "All models ready" | |
| } | |
| if __name__ == "__main__": | |
| # Set multi-threading for CPU | |
| torch.set_num_threads(8) | |
| torch.set_num_interop_threads(1) | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |