Spaces:
Sleeping
Sleeping
from openai import OpenAI | |
import google.generativeai as genai | |
from crawler import extract_data | |
import time | |
import os | |
from dotenv import load_dotenv | |
import gradio as gr | |
# from together import Together | |
# from transformers import AutoModel, AutoTokenizer | |
# from sklearn.metrics.pairwise import cosine_similarity | |
# import torch | |
# | |
# load_dotenv("../.env") | |
# os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
# together_client = Together( | |
# api_key=os.getenv("TOGETHER_API_KEY"), | |
# ) | |
genai.configure(api_key=os.getenv("GEMINI_API_KEY")) | |
gemini_query = genai.GenerativeModel('gemini-2.0-flash-exp') | |
gemini_summarizer = genai.GenerativeModel('gemini-1.5-flash') | |
perplexity_client = OpenAI(api_key=os.getenv("PERPLEXITY_API_KEY"), base_url="https://api.perplexity.ai") | |
# gpt_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
# with torch.no_grad(): | |
# model = AutoModel.from_pretrained('BM-K/KoSimCSE-roberta') | |
# tokenizer = AutoTokenizer.from_pretrained('BM-K/KoSimCSE-roberta') | |
# def cal_score(input_data): | |
# similarity_scores = [] | |
# # Initialize model and tokenizer inside the function | |
# with torch.no_grad(): | |
# inputs = tokenizer(input_data, padding=True, truncation=True, return_tensors="pt") | |
# outputs = model.get_input_embeddings()(inputs["input_ids"]) | |
# for ind in range(1, outputs.size(0)): | |
# a, b = outputs[0], outputs[ind] | |
# a = a.reshape(1, -1) | |
# b = b.reshape(1, -1) | |
# a_norm = torch.nn.functional.normalize(a, p=2, dim=1) | |
# b_norm = torch.nn.functional.normalize(b, p=2, dim=1) | |
# similarity_scores.append(cosine_similarity(a_norm, b_norm)) # Scalar value | |
# return similarity_scores | |
def get_answers( query: str ): | |
context = extract_data(query, 1) | |
# if len(context) > 1: | |
# scores = cal_score( [query] + [answer['questionDetails'] for answer in context] ) | |
# context = [context for _, context in sorted(zip(scores, context), key=lambda x: x[0], reverse=True)] | |
# mean_score = sum(scores) / len(scores) | |
# context = [ctx for score, ctx in zip(scores, context) if score >= mean_score] | |
return context | |
def get_gemini_query( message: str ): | |
print(">>> Starting gemini query generation...") | |
response = gemini_query.generate_content(message) | |
print("Finished gemini query generation: ", response.text) | |
return response.text | |
def get_naver_answers( message: str ): | |
print(">>> Starting naver extraction...") | |
print("Question: ", message) | |
if len(message) > 300: | |
message = get_gemini_query(f"{message}\n ์์ ๋ด์ฉ์ ์งง์ ์ ๋ชฉ์ผ๋ก ์์ฝํฉ๋๋ค. ์ ๋ชฉ๋ง ๋ณด์ฌ์ฃผ์ธ์. ๋๋ตํ์ง ๋ง์ธ์. ํ๊ตญ์ด๋ก๋ง ๋ต๋ณํด์ฃผ์ธ์!!!") | |
print( "Query: ", message) | |
context = get_answers( message ) | |
sorted_answers = [ | |
f"{index}. ์ง๋ฌธ: {answer['questionDetails']}" + '\n' + f" ๋ต๋ณ: {'. '.join(answer['answers'])} " + '\n' | |
for (index, answer) in enumerate(context) | |
] | |
document = '\n'.join(sorted_answers) | |
return document | |
def get_perplexity_answer( message: str ): | |
print(">>> Starting perplexity extraction...") | |
messages = [ | |
{ | |
"role": "system", | |
"content": ( | |
"You are an artificial intelligence assistant and you need to " | |
"engage in a helpful, CONCISE, polite question-answer conversation with a user." | |
), | |
}, | |
{ | |
"role": "user", | |
"content": ( | |
message | |
), | |
}, | |
] | |
response = perplexity_client.chat.completions.create( | |
model="llama-3.1-sonar-small-128k-online", | |
messages=messages | |
) | |
return response.choices[0].message.content | |
def chatFunction( history ): | |
# MAX_TOKEN_LIMIT = 58000 | |
start_time = time.time() | |
message = history[-1][0] | |
# content = f' ์ง๋ฌธ๊ณผ ๋ต๋ณ์ผ๋ก ๊ตฌ์ฑ๋ ๋ฌธ์๋ฅผ ๋๋ฆฌ๊ฒ ์ต๋๋ค. \ | |
# ์๋์ ์ ๊ณต๋ ์ง๋ฌธ์ ๋ตํ๊ธฐ ์ํด ์ค์ํ ์ ๋ณด๋ฅผ ์ถ์ถํ์ธ์. \ | |
# ํ๊ตญ์ด๋ก๋ง ๋ต๋ณํ์ธ์. ๊ตฌ์ฒด์ ์ด์ง๋ง ๊ฐ๊ฒฐํ๊ฒ ์์ฑํ์ธ์. \ | |
# ์ค์ ๋ณดํ์๋ด์ฌ๊ฐ ๋ต๋ณ์ ํ๋ฏ์ด ์น์ ํ ๋ต๋ณ์ ํด ์ฃผ์ธ์. \n ์ง๋ฌธ: {message}\n ๋ฌธ์: ' | |
content = f' ๋ณดํ์ค๊ณ์ฌ๊ฐ ๋ต์ ์ค์, ๋ ๋ง์ ์ง๋ฌธ์ด๋ ํฉ๋นํ ๋ณดํ์ ๊ฐ์ ํ ์ ์๋๋ก ๋ต๋ณ์ ํ๋ ค๊ณ ํฉ๋๋ค. \ | |
๋ฌธ์์ ์๋ ์ 3์ ์ธ๊ธ์ 1์ธ์นญ์ผ๋ก โโ๋ฐ๊พธ์ธ์. ์๋ฅผ ๋ค์ด "KB์ํด๋ณดํ ์ค๊ณ์ฌ OOO์ ๋๋ค" ๋ฑ ์ 3์๊ฐ ์ธ๊ธ๋ ๊ฒฝ์ฐ "๋ณดํ๊ธฐ๊ด์ ๋๋ค"๋ก ๋์ฒดํฉ๋๋ค. \ | |
์ด๋ฌํ ๋ต๋ณ์ ํตํด์ ์ง๋ฌธ์๊ฐ ์ด ๋ต๋ณ์ ๋ณด๊ณ ๋ณดํ์ค๊ณ์ฌ์๊ฒ ๋ ์ ๋ขฐ๋ฅผ ๊ฐ๊ณ ์ถ๊ฐ ์ง๋ฌธ์ด ์์ผ๋ฉด ๋ฌผ์ด๋ณผ ์ ์๋๋ก ํ๋ ค๊ณ ํฉ๋๋ค. \ | |
์ค์ ๋ณดํ์๋ด์ฌ๊ฐ ๋ต๋ณ์ ํ๋ฏ์ด ์น์ ํ ๋ต๋ณ์ ํด ์ฃผ์ธ์. \n ์ง๋ฌธ: {message}\n ๋ฌธ์: ' | |
naver_docs = get_naver_answers( message ) | |
print(len(naver_docs)) | |
# if len(naver_docs) > MAX_TOKEN_LIMIT: | |
# print("HERE") | |
# start_tmp = time.time() | |
# overlap = 200 | |
# answers = [] | |
# split_len = len(naver_docs) // ( ( len(naver_docs) - MAX_TOKEN_LIMIT ) // MAX_TOKEN_LIMIT + 2 ) + 1 | |
# print(len(naver_docs) // split_len) | |
# for i in range( len(naver_docs) // split_len ): | |
# print("HERE: ", i) | |
# if i == 0: | |
# split = naver_docs[:split_len] | |
# else: | |
# split = naver_docs[i * split_len - overlap: (i + 1) * split_len] | |
# answer, _ = get_qwen_small_answer(f"Summarize important points in a paragraph, given the information below, using only Korean language. Give me only the summary!!! \n {split}") | |
# answers.append(answer) | |
# print("Answers: ", answers) | |
# naver_docs = '\n'.join(answers) | |
# naver_time_taken += time.time() - start_tmp | |
# print("Post chunking length: ", len(naver_docs) ) | |
content += "\n Naver ๋ฌธ์: " + naver_docs | |
### Extracting from Perplexity ### | |
perplexity_resp = get_perplexity_answer( message ) | |
content += "\n Perplexity ๋ฌธ์: " + perplexity_resp | |
print(">>> Starting Gemini summarization...") | |
response = gemini_summarizer.generate_content( content, stream=True ) | |
history[-1][1] = '' | |
ans = "" | |
for chunk in response: | |
ans += chunk.text.replace("*", "") | |
yield ans.strip() + "\n" | |
time.sleep(0.05) | |
print("Finished Gemini summarization") | |
print("Time taken: ", time.time() - start_time) | |
def set_user_response( message: str, history: list ): | |
history.append( [message, None] ) | |
return '', history | |
### Server-side code ### | |
from fastapi import FastAPI | |
from fastapi.responses import StreamingResponse | |
from pydantic import BaseModel | |
from fastapi.middleware.cors import CORSMiddleware | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=['*'], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def root(): | |
return {"message": "Hello World"} | |
class Message(BaseModel): | |
message: str | |
async def chat( message: Message ): | |
history = [[message.message, None]] | |
return StreamingResponse( | |
chatFunction(history), | |
media_type='text/event-stream' | |
) | |