Ishaan Shah commited on
Commit
6e7c9db
·
1 Parent(s): 0e57bfb

live inference api

Browse files
Files changed (1) hide show
  1. main.py +16 -3
main.py CHANGED
@@ -34,7 +34,7 @@ class BatchRequest(BaseModel):
34
  class BatchResponse(BaseModel):
35
  inferenceId: str
36
 
37
- class InferenceResponse(BaseModel):
38
  inferenceId: str
39
  status: str
40
  results: List[Dict]
@@ -55,11 +55,24 @@ async def start_get_recommendations_batch(batch_request: BatchRequest, backgroun
55
  background_tasks.add_task(process_batch, inferenceId, batch_request.products)
56
  return BatchResponse(inferenceId=inferenceId)
57
 
58
- @app.get("/inference/batch/{inferenceId}", response_model=InferenceResponse)
59
  async def get_recommendations_batch(inferenceId: str):
60
  if inferenceId not in inferences:
61
  raise HTTPException(status_code=404, detail="Inference ID not found")
62
  inference = inferences[inferenceId]
63
- return InferenceResponse(inferenceId=inferenceId, status=inference["status"], results=inference["result"])
 
 
 
 
 
 
 
64
 
 
 
 
 
 
 
65
 
 
34
  class BatchResponse(BaseModel):
35
  inferenceId: str
36
 
37
+ class BatchInferenceResponse(BaseModel):
38
  inferenceId: str
39
  status: str
40
  results: List[Dict]
 
55
  background_tasks.add_task(process_batch, inferenceId, batch_request.products)
56
  return BatchResponse(inferenceId=inferenceId)
57
 
58
+ @app.get("/inference/batch/{inferenceId}", response_model=BatchInferenceResponse)
59
  async def get_recommendations_batch(inferenceId: str):
60
  if inferenceId not in inferences:
61
  raise HTTPException(status_code=404, detail="Inference ID not found")
62
  inference = inferences[inferenceId]
63
+ return BatchInferenceResponse(inferenceId=inferenceId, status=inference["status"], results=inference["result"])
64
+
65
+ class InferenceRequest(BaseModel):
66
+ product: str
67
+
68
+ class InferenceResponse(BaseModel):
69
+ cluster: int
70
+ top_terms: List[str]
71
 
72
+ # Add a new endpoint for single inferences
73
+ @app.post("/inference", response_model=InferenceResponse)
74
+ def get_recommendations(inference_request: InferenceRequest):
75
+ cluster_index = show_recommendations(inference_request.product)
76
+ cluster_terms = get_cluster_terms(cluster_index)
77
+ return InferenceResponse(cluster=cluster_index, top_terms=cluster_terms)
78