BLY0608 commited on
Commit
106f076
·
verified ·
1 Parent(s): 0038202

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -18
app.py CHANGED
@@ -1,32 +1,53 @@
 
1
  from fastapi import FastAPI
2
- from pydantic import BaseModel
3
  from sentence_transformers import SentenceTransformer
4
  from threading import Thread
 
5
 
6
- app = FastAPI()
 
 
 
 
7
 
8
- # Load model in background (avoids startup timeout)
9
- model = None
10
  def load_model():
 
11
  global model
12
- model = SentenceTransformer("BAAI/bge-small-zh")
 
 
 
 
13
 
14
- Thread(target=load_model).start()
 
15
 
16
- class Input(BaseModel):
17
- texts: list[str]
 
 
 
 
 
18
 
19
  @app.post("/embed")
20
- def embed(input: Input):
 
21
  if model is None:
22
- return {"error": "Model not loaded yet"}, 503
23
- vectors = model.encode(input.texts, normalize_embeddings=True).tolist()
24
- return {"embeddings": vectors}
25
-
26
- @app.get("/")
27
- def health():
28
- return {"status": "OK"} # Critical for health checks
29
 
30
  if __name__ == "__main__":
31
- import uvicorn
32
- uvicorn.run(app, host="0.0.0.0", port=8080)
 
 
 
 
 
 
1
+ import os
2
  from fastapi import FastAPI
3
+ import uvicorn
4
  from sentence_transformers import SentenceTransformer
5
  from threading import Thread
6
+ from typing import Optional
7
 
8
+ # Initialize FastAPI
9
+ app = FastAPI(title="Sentence Transformer API")
10
+
11
+ # Global model variable
12
+ model: Optional[SentenceTransformer] = None
13
 
 
 
14
  def load_model():
15
+ """Background thread for model loading"""
16
  global model
17
+ try:
18
+ model = SentenceTransformer('all-MiniLM-L6-v2')
19
+ print("Model loaded successfully")
20
+ except Exception as e:
21
+ print(f"Model loading failed: {str(e)}")
22
 
23
+ # Start model loading immediately
24
+ Thread(target=load_model, daemon=True).start()
25
 
26
+ @app.get("/")
27
+ def health_check():
28
+ """Required health check endpoint"""
29
+ return {
30
+ "status": "ready",
31
+ "model_loaded": model is not None
32
+ }
33
 
34
  @app.post("/embed")
35
+ async def embed_text(text: str):
36
+ """Endpoint for text embeddings"""
37
  if model is None:
38
+ return {"error": "Model still loading"}, 503
39
+ embeddings = model.encode(text)
40
+ return {
41
+ "text": text,
42
+ "embedding": embeddings.tolist(),
43
+ "dimension": len(embeddings)
44
+ }
45
 
46
  if __name__ == "__main__":
47
+ port = int(os.environ.get("PORT", 7860)) # Must be 7860 for Spaces
48
+ uvicorn.run(
49
+ "app:app",
50
+ host="0.0.0.0",
51
+ port=port,
52
+ reload=False # Disable auto-reload in production
53
+ )