Update app.py
Browse files
app.py
CHANGED
@@ -1,32 +1,53 @@
|
|
|
|
1 |
from fastapi import FastAPI
|
2 |
-
|
3 |
from sentence_transformers import SentenceTransformer
|
4 |
from threading import Thread
|
|
|
5 |
|
6 |
-
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
# Load model in background (avoids startup timeout)
|
9 |
-
model = None
|
10 |
def load_model():
|
|
|
11 |
global model
|
12 |
-
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
|
|
|
15 |
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
@app.post("/embed")
|
20 |
-
def
|
|
|
21 |
if model is None:
|
22 |
-
return {"error": "Model
|
23 |
-
|
24 |
-
return {
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
|
30 |
if __name__ == "__main__":
|
31 |
-
|
32 |
-
uvicorn.run(
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
from fastapi import FastAPI
|
3 |
+
import uvicorn
|
4 |
from sentence_transformers import SentenceTransformer
|
5 |
from threading import Thread
|
6 |
+
from typing import Optional
|
7 |
|
8 |
+
# Initialize FastAPI
|
9 |
+
app = FastAPI(title="Sentence Transformer API")
|
10 |
+
|
11 |
+
# Global model variable
|
12 |
+
model: Optional[SentenceTransformer] = None
|
13 |
|
|
|
|
|
14 |
def load_model():
|
15 |
+
"""Background thread for model loading"""
|
16 |
global model
|
17 |
+
try:
|
18 |
+
model = SentenceTransformer('all-MiniLM-L6-v2')
|
19 |
+
print("Model loaded successfully")
|
20 |
+
except Exception as e:
|
21 |
+
print(f"Model loading failed: {str(e)}")
|
22 |
|
23 |
+
# Start model loading immediately
|
24 |
+
Thread(target=load_model, daemon=True).start()
|
25 |
|
26 |
+
@app.get("/")
|
27 |
+
def health_check():
|
28 |
+
"""Required health check endpoint"""
|
29 |
+
return {
|
30 |
+
"status": "ready",
|
31 |
+
"model_loaded": model is not None
|
32 |
+
}
|
33 |
|
34 |
@app.post("/embed")
|
35 |
+
async def embed_text(text: str):
|
36 |
+
"""Endpoint for text embeddings"""
|
37 |
if model is None:
|
38 |
+
return {"error": "Model still loading"}, 503
|
39 |
+
embeddings = model.encode(text)
|
40 |
+
return {
|
41 |
+
"text": text,
|
42 |
+
"embedding": embeddings.tolist(),
|
43 |
+
"dimension": len(embeddings)
|
44 |
+
}
|
45 |
|
46 |
if __name__ == "__main__":
|
47 |
+
port = int(os.environ.get("PORT", 7860)) # Must be 7860 for Spaces
|
48 |
+
uvicorn.run(
|
49 |
+
"app:app",
|
50 |
+
host="0.0.0.0",
|
51 |
+
port=port,
|
52 |
+
reload=False # Disable auto-reload in production
|
53 |
+
)
|