InsuHelp / app.py
anisrashidov's picture
Upload 4 files
ee46c3b verified
from openai import OpenAI
import google.generativeai as genai
from crawler import extract_data
import time
import os
from dotenv import load_dotenv
import gradio as gr
# from together import Together
# from transformers import AutoModel, AutoTokenizer
# from sklearn.metrics.pairwise import cosine_similarity
# import torch
#
# load_dotenv("../.env")
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
# together_client = Together(
# api_key=os.getenv("TOGETHER_API_KEY"),
# )
genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
gemini_query = genai.GenerativeModel('gemini-2.0-flash-exp')
gemini_summarizer = genai.GenerativeModel('gemini-1.5-flash')
perplexity_client = OpenAI(api_key=os.getenv("PERPLEXITY_API_KEY"), base_url="https://api.perplexity.ai")
# gpt_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
# with torch.no_grad():
# model = AutoModel.from_pretrained('BM-K/KoSimCSE-roberta')
# tokenizer = AutoTokenizer.from_pretrained('BM-K/KoSimCSE-roberta')
# def cal_score(input_data):
# similarity_scores = []
# # Initialize model and tokenizer inside the function
# with torch.no_grad():
# inputs = tokenizer(input_data, padding=True, truncation=True, return_tensors="pt")
# outputs = model.get_input_embeddings()(inputs["input_ids"])
# for ind in range(1, outputs.size(0)):
# a, b = outputs[0], outputs[ind]
# a = a.reshape(1, -1)
# b = b.reshape(1, -1)
# a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
# b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
# similarity_scores.append(cosine_similarity(a_norm, b_norm)) # Scalar value
# return similarity_scores
def get_answers( query: str ):
context = extract_data(query, 1)
# if len(context) > 1:
# scores = cal_score( [query] + [answer['questionDetails'] for answer in context] )
# context = [context for _, context in sorted(zip(scores, context), key=lambda x: x[0], reverse=True)]
# mean_score = sum(scores) / len(scores)
# context = [ctx for score, ctx in zip(scores, context) if score >= mean_score]
return context
def get_gemini_query( message: str ):
print(">>> Starting gemini query generation...")
response = gemini_query.generate_content(message)
print("Finished gemini query generation: ", response.text)
return response.text
def get_naver_answers( message: str ):
print(">>> Starting naver extraction...")
print("Question: ", message)
if len(message) > 300:
message = get_gemini_query(f"{message}\n ์œ„์˜ ๋‚ด์šฉ์„ ์งง์€ ์ œ๋ชฉ์œผ๋กœ ์š”์•ฝํ•ฉ๋‹ˆ๋‹ค. ์ œ๋ชฉ๋งŒ ๋ณด์—ฌ์ฃผ์„ธ์š”. ๋Œ€๋‹ตํ•˜์ง€ ๋งˆ์„ธ์š”. ํ•œ๊ตญ์–ด๋กœ๋งŒ ๋‹ต๋ณ€ํ•ด์ฃผ์„ธ์š”!!!")
print( "Query: ", message)
context = get_answers( message )
sorted_answers = [
f"{index}. ์งˆ๋ฌธ: {answer['questionDetails']}" + '\n' + f" ๋‹ต๋ณ€: {'. '.join(answer['answers'])} " + '\n'
for (index, answer) in enumerate(context)
]
document = '\n'.join(sorted_answers)
return document
def get_perplexity_answer( message: str ):
print(">>> Starting perplexity extraction...")
messages = [
{
"role": "system",
"content": (
"You are an artificial intelligence assistant and you need to "
"engage in a helpful, CONCISE, polite question-answer conversation with a user."
),
},
{
"role": "user",
"content": (
message
),
},
]
response = perplexity_client.chat.completions.create(
model="llama-3.1-sonar-small-128k-online",
messages=messages
)
return response.choices[0].message.content
def chatFunction( history ):
# MAX_TOKEN_LIMIT = 58000
start_time = time.time()
message = history[-1][0]
# content = f' ์งˆ๋ฌธ๊ณผ ๋‹ต๋ณ€์œผ๋กœ ๊ตฌ์„ฑ๋œ ๋ฌธ์„œ๋ฅผ ๋“œ๋ฆฌ๊ฒ ์Šต๋‹ˆ๋‹ค. \
# ์•„๋ž˜์— ์ œ๊ณต๋œ ์งˆ๋ฌธ์— ๋‹ตํ•˜๊ธฐ ์œ„ํ•ด ์ค‘์š”ํ•œ ์ •๋ณด๋ฅผ ์ถ”์ถœํ•˜์„ธ์š”. \
# ํ•œ๊ตญ์–ด๋กœ๋งŒ ๋‹ต๋ณ€ํ•˜์„ธ์š”. ๊ตฌ์ฒด์ ์ด์ง€๋งŒ ๊ฐ„๊ฒฐํ•˜๊ฒŒ ์ž‘์„ฑํ•˜์„ธ์š”. \
# ์‹ค์ œ ๋ณดํ—˜์ƒ๋‹ด์‚ฌ๊ฐ€ ๋‹ต๋ณ€์„ ํ•˜๋“ฏ์ด ์นœ์ ˆํ•œ ๋‹ต๋ณ€์„ ํ•ด ์ฃผ์„ธ์š”. \n ์งˆ๋ฌธ: {message}\n ๋ฌธ์„œ: '
content = f' ๋ณดํ—˜์„ค๊ณ„์‚ฌ๊ฐ€ ๋‹ต์„ ์ค˜์„œ, ๋” ๋งŽ์€ ์งˆ๋ฌธ์ด๋‚˜ ํ•ฉ๋‹นํ•œ ๋ณดํ—˜์— ๊ฐ€์ž…ํ•  ์ˆ˜ ์žˆ๋„๋ก ๋‹ต๋ณ€์„ ํ•˜๋ ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค. \
๋ฌธ์„œ์— ์žˆ๋Š” ์ œ3์ž ์–ธ๊ธ‰์„ 1์ธ์นญ์œผ๋กœ โ€‹โ€‹๋ฐ”๊พธ์„ธ์š”. ์˜ˆ๋ฅผ ๋“ค์–ด "KB์†ํ•ด๋ณดํ—˜ ์„ค๊ณ„์‚ฌ OOO์ž…๋‹ˆ๋‹ค" ๋“ฑ ์ œ3์ž๊ฐ€ ์–ธ๊ธ‰๋œ ๊ฒฝ์šฐ "๋ณดํ—˜๊ธฐ๊ด€์ž…๋‹ˆ๋‹ค"๋กœ ๋Œ€์ฒดํ•ฉ๋‹ˆ๋‹ค. \
์ด๋Ÿฌํ•œ ๋‹ต๋ณ€์„ ํ†ตํ•ด์„œ ์งˆ๋ฌธ์ž๊ฐ€ ์ด ๋‹ต๋ณ€์„ ๋ณด๊ณ  ๋ณดํ—˜์„ค๊ณ„์‚ฌ์—๊ฒŒ ๋” ์‹ ๋ขฐ๋ฅผ ๊ฐ–๊ณ  ์ถ”๊ฐ€ ์งˆ๋ฌธ์ด ์žˆ์œผ๋ฉด ๋ฌผ์–ด๋ณผ ์ˆ˜ ์žˆ๋„๋ก ํ•˜๋ ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค. \
์‹ค์ œ ๋ณดํ—˜์ƒ๋‹ด์‚ฌ๊ฐ€ ๋‹ต๋ณ€์„ ํ•˜๋“ฏ์ด ์นœ์ ˆํ•œ ๋‹ต๋ณ€์„ ํ•ด ์ฃผ์„ธ์š”. \n ์งˆ๋ฌธ: {message}\n ๋ฌธ์„œ: '
naver_docs = get_naver_answers( message )
print(len(naver_docs))
# if len(naver_docs) > MAX_TOKEN_LIMIT:
# print("HERE")
# start_tmp = time.time()
# overlap = 200
# answers = []
# split_len = len(naver_docs) // ( ( len(naver_docs) - MAX_TOKEN_LIMIT ) // MAX_TOKEN_LIMIT + 2 ) + 1
# print(len(naver_docs) // split_len)
# for i in range( len(naver_docs) // split_len ):
# print("HERE: ", i)
# if i == 0:
# split = naver_docs[:split_len]
# else:
# split = naver_docs[i * split_len - overlap: (i + 1) * split_len]
# answer, _ = get_qwen_small_answer(f"Summarize important points in a paragraph, given the information below, using only Korean language. Give me only the summary!!! \n {split}")
# answers.append(answer)
# print("Answers: ", answers)
# naver_docs = '\n'.join(answers)
# naver_time_taken += time.time() - start_tmp
# print("Post chunking length: ", len(naver_docs) )
content += "\n Naver ๋ฌธ์„œ: " + naver_docs
### Extracting from Perplexity ###
perplexity_resp = get_perplexity_answer( message )
content += "\n Perplexity ๋ฌธ์„œ: " + perplexity_resp
print(">>> Starting Gemini summarization...")
response = gemini_summarizer.generate_content( content, stream=True )
history[-1][1] = ''
ans = ""
for chunk in response:
ans += chunk.text.replace("*", "")
yield ans.strip() + "\n"
time.sleep(0.05)
print("Finished Gemini summarization")
print("Time taken: ", time.time() - start_time)
def set_user_response( message: str, history: list ):
history.append( [message, None] )
return '', history
### Server-side code ###
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
return {"message": "Hello World"}
class Message(BaseModel):
message: str
@app.post("/chat")
async def chat( message: Message ):
history = [[message.message, None]]
return StreamingResponse(
chatFunction(history),
media_type='text/event-stream'
)