File size: 2,886 Bytes
0e57bfb
267e3a7
0e57bfb
 
 
 
267e3a7
0e57bfb
 
 
 
 
 
 
 
267e3a7
 
 
0e57bfb
267e3a7
0e57bfb
267e3a7
 
 
 
0e57bfb
267e3a7
0e57bfb
 
 
 
 
 
 
 
 
6e7c9db
0e57bfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e7c9db
0e57bfb
 
 
 
6e7c9db
 
 
 
 
 
 
 
267e3a7
6e7c9db
 
 
 
 
 
267e3a7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from fastapi import FastAPI, HTTPException, BackgroundTasks
import joblib
import uuid
import asyncio
from pydantic import BaseModel
from typing import Dict, List

# Load your model and vectorizer
model = joblib.load("./model.pkl")
vectorizer = joblib.load("./vectorizer.pkl")

order_centroids = model.cluster_centers_.argsort()[:, ::-1]
terms = vectorizer.get_feature_names_out()

# Simulate function to show recommendations
def show_recommendations(product):
    Y = vectorizer.transform([product])
    prediction = model.predict(Y)
    return int(prediction[0])  # Ensure the prediction is a native Python int

# Get terms associated with a cluster
def get_cluster_terms(cluster_index):
    cluster_terms = [terms[ind] for ind in order_centroids[cluster_index, :10]]
    return cluster_terms

app = FastAPI()

# In-memory store for inference batches
inferences: Dict[str, Dict] = {}

class BatchRequest(BaseModel):
    products: List[str]

class BatchResponse(BaseModel):
    inferenceId: str

class BatchInferenceResponse(BaseModel):
    inferenceId: str
    status: str
    results: List[Dict]

def process_batch(inferenceId, products):
    results = []
    for product in products:
        cluster_index = show_recommendations(product)
        cluster_terms = get_cluster_terms(cluster_index)
        results.append({"product": product, "cluster": cluster_index, "top_terms": cluster_terms})
    inferences[inferenceId]["status"] = "completed"
    inferences[inferenceId]["result"] = results

@app.post("/inference/batch", response_model=BatchResponse)
async def start_get_recommendations_batch(batch_request: BatchRequest, background_tasks: BackgroundTasks):
    inferenceId = str(uuid.uuid4())
    inferences[inferenceId] = {"status": "in_progress", "result": []}
    background_tasks.add_task(process_batch, inferenceId, batch_request.products)
    return BatchResponse(inferenceId=inferenceId)

@app.get("/inference/batch/{inferenceId}", response_model=BatchInferenceResponse)
async def get_recommendations_batch(inferenceId: str):
    if inferenceId not in inferences:
        raise HTTPException(status_code=404, detail="Inference ID not found")
    inference = inferences[inferenceId]
    return BatchInferenceResponse(inferenceId=inferenceId, status=inference["status"], results=inference["result"])

class InferenceRequest(BaseModel):
    product: str

class InferenceResponse(BaseModel):
    cluster: int
    top_terms: List[str]

# Add a new endpoint for single inferences
@app.post("/inference", response_model=InferenceResponse)
def get_recommendations(inference_request: InferenceRequest):
    cluster_index = show_recommendations(inference_request.product)
    cluster_terms = get_cluster_terms(cluster_index)
    return InferenceResponse(cluster=cluster_index, top_terms=cluster_terms)