from fastapi import FastAPI, HTTPException, BackgroundTasks import joblib import uuid import asyncio from pydantic import BaseModel from typing import Dict, List # Load your model and vectorizer model = joblib.load("./model.pkl") vectorizer = joblib.load("./vectorizer.pkl") order_centroids = model.cluster_centers_.argsort()[:, ::-1] terms = vectorizer.get_feature_names_out() # Simulate function to show recommendations def show_recommendations(product): Y = vectorizer.transform([product]) prediction = model.predict(Y) return int(prediction[0]) # Ensure the prediction is a native Python int # Get terms associated with a cluster def get_cluster_terms(cluster_index): cluster_terms = [terms[ind] for ind in order_centroids[cluster_index, :10]] return cluster_terms app = FastAPI() # In-memory store for inference batches inferences: Dict[str, Dict] = {} class BatchRequest(BaseModel): products: List[str] class BatchResponse(BaseModel): inferenceId: str class BatchInferenceResponse(BaseModel): inferenceId: str status: str results: List[Dict] def process_batch(inferenceId, products): results = [] for product in products: cluster_index = show_recommendations(product) cluster_terms = get_cluster_terms(cluster_index) results.append({"product": product, "cluster": cluster_index, "top_terms": cluster_terms}) inferences[inferenceId]["status"] = "completed" inferences[inferenceId]["result"] = results @app.post("/inference/batch", response_model=BatchResponse) async def start_get_recommendations_batch(batch_request: BatchRequest, background_tasks: BackgroundTasks): inferenceId = str(uuid.uuid4()) inferences[inferenceId] = {"status": "in_progress", "result": []} background_tasks.add_task(process_batch, inferenceId, batch_request.products) return BatchResponse(inferenceId=inferenceId) @app.get("/inference/batch/{inferenceId}", response_model=BatchInferenceResponse) async def get_recommendations_batch(inferenceId: str): if inferenceId not in inferences: raise HTTPException(status_code=404, detail="Inference ID not found") inference = inferences[inferenceId] return BatchInferenceResponse(inferenceId=inferenceId, status=inference["status"], results=inference["result"]) class InferenceRequest(BaseModel): product: str class InferenceResponse(BaseModel): cluster: int top_terms: List[str] # Add a new endpoint for single inferences @app.post("/inference", response_model=InferenceResponse) def get_recommendations(inference_request: InferenceRequest): cluster_index = show_recommendations(inference_request.product) cluster_terms = get_cluster_terms(cluster_index) return InferenceResponse(cluster=cluster_index, top_terms=cluster_terms)