demo / app /routers /search.py
tekville's picture
Initial commit
ff72db3
from fastapi import APIRouter, Request, Form
from fastapi.responses import HTMLResponse, StreamingResponse
from fastapi.templating import Jinja2Templates
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from app.modules.evaluation import Evaluation # ํ‰๊ฐ€ ๋ชจ๋“ˆ ์ž„ํฌํŠธ
import os
router = APIRouter()
# ๋กœ์ปฌ ๋ชจ๋ธ ๋กœ๋“œ (HuggingFace T5 ์‚ฌ์šฉ)
# tokenizer = AutoTokenizer.from_pretrained("gogamza/kobart-base-v2")
# model = AutoModelForSeq2SeqLM.from_pretrained("gogamza/kobart-base-v2")
# def summarize_text(text: str, max_length=150):
# inputs = tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=512, truncation=True)
# outputs = model.generate(inputs, max_length=512, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True)
# return tokenizer.decode(outputs[0], skip_special_tokens=True)
# ํ…œํ”Œ๋ฆฟ ์„ค์ •
templates = Jinja2Templates(directory="app/templates")
# ์—…๋กœ๋“œ๋œ ํŒŒ์ผ๋กœ๋ถ€ํ„ฐ ์ƒ์„ฑ๋œ ๋ฒกํ„ฐ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ๊ฒฝ๋กœ ์„ค์ •
UPLOAD_DIRECTORY = "./uploaded_files"
db_path = os.path.join(UPLOAD_DIRECTORY, "faiss_index")
# ๋ฒกํ„ฐ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ๋กœ๋“œ
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
@router.get("/search", response_class=HTMLResponse)
async def search_page(request: Request):
return templates.TemplateResponse("search.html", {"request": request})
@router.post("/search/stream")
async def search_stream(query: str = Form(...), model: str = Form(...)):
# ๋ฒกํ„ฐ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ๋กœ๋“œ
vector_db = FAISS.load_local(db_path, embeddings, allow_dangerous_deserialization=True)
results = vector_db.similarity_search(query, k=5)
# ๊ฒ€์ƒ‰๋œ ๊ฒฐ๊ณผ ํ…์ŠคํŠธ๋ฅผ ์—ฐ๊ฒฐ
full_text = "\n\n".join([result.page_content for result in results])
# ํ‰๊ฐ€ ๋ชจ๋“ˆ ์ดˆ๊ธฐํ™”
evaluator = Evaluation(model=model)
# ํ‰๊ฐ€ ๊ธฐ์ค€
instruction = (
f"๋‹ค์Œ ํ…์ŠคํŠธ๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ {query}\n\n{full_text}"
)
async def stream():
async for data in evaluator.evaluate_stream(instruction):
yield data
return StreamingResponse(stream(), media_type="text/plain")