from fastapi import FastAPI | |
from reranker import RankLLM, RankListwiseOSLLM, Result, RankingExecInfo | |
from pydantic import BaseModel | |
from typing import Optional, List, Tuple | |
# load RankListwiseOSLLM | |
reranker = RankListwiseOSLLM("Salesforce/SweRankLLM-small") | |
class RerankRequest(BaseModel): | |
query: str | |
hits: List[Tuple[int, str]] | |
class RerankResponse(BaseModel): | |
query: str | |
hits: List[Tuple[int, str]] | |
app = FastAPI() | |
def hello_world(): | |
return {"msg": "Success"} | |
def rerank(request: RerankRequest): | |
hits = request.hits | |
sorted_hits = sorted(hits, key=lambda x: x[0]) # sort hits again for safety | |
result = Result( | |
query=request.query, | |
hits = [{"content": hit[1]} for hit in sorted_hits] | |
) | |
reranked_result = reranker.permutation_pipeline( | |
result, | |
0, | |
len(hits), | |
logging=True | |
) | |
reranked_hits = [(i + 1, item["content"]) for i, item in enumerate(reranked_result.hits)] | |
return { | |
"query": request.query, | |
"reranked_hits": reranked_hits | |
} | |