|
from fastapi import APIRouter, Request, Form |
|
from fastapi.responses import HTMLResponse, StreamingResponse |
|
from fastapi.templating import Jinja2Templates |
|
from langchain.vectorstores import FAISS |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
from app.modules.evaluation import Evaluation |
|
import os |
|
|
|
router = APIRouter() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
templates = Jinja2Templates(directory="app/templates") |
|
|
|
|
|
UPLOAD_DIRECTORY = "./uploaded_files" |
|
db_path = os.path.join(UPLOAD_DIRECTORY, "faiss_index") |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") |
|
|
|
@router.get("/search", response_class=HTMLResponse) |
|
async def search_page(request: Request): |
|
return templates.TemplateResponse("search.html", {"request": request}) |
|
|
|
|
|
@router.post("/search/stream") |
|
async def search_stream(query: str = Form(...), model: str = Form(...)): |
|
|
|
vector_db = FAISS.load_local(db_path, embeddings, allow_dangerous_deserialization=True) |
|
results = vector_db.similarity_search(query, k=5) |
|
|
|
|
|
full_text = "\n\n".join([result.page_content for result in results]) |
|
|
|
|
|
evaluator = Evaluation(model=model) |
|
|
|
|
|
instruction = ( |
|
f"๋ค์ ํ
์คํธ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก {query}\n\n{full_text}" |
|
) |
|
|
|
|
|
async def stream(): |
|
async for data in evaluator.evaluate_stream(instruction): |
|
yield data |
|
|
|
return StreamingResponse(stream(), media_type="text/plain") |