Chat_QnA_v2 / chains /openai_model.py
binh99's picture
update cosmos db
a4b89be
import json
import os
import re
import openai
from langchain.prompts import PromptTemplate
from config import TIMEOUT_STREAM, HISTORY_DIR
from vector_db import upload_file
from callback import StreamingGradioCallbackHandler
from queue import SimpleQueue, Empty, Queue
from threading import Thread
from utils import add_source_numbers, add_details, web_citation, get_history_names
from chains.custom_chain import CustomConversationalRetrievalChain
from langchain.chains import LLMChain
from chains.azure_openai import CustomAzureOpenAI
from config import OPENAI_API_TYPE, OPENAI_API_VERSION, OPENAI_API_KEY, OPENAI_API_BASE, API_KEY, \
DEPLOYMENT_ID, MODEL_ID
from cosmos_db import upsert_item, read_item, delete_items, query_items
class OpenAIModel:
def __init__(
self,
llm_model_name,
condense_model_name,
prompt_template="",
temperature=0.0,
top_p=1.0,
n_choices=1,
stop=None,
presence_penalty=0,
frequency_penalty=0,
user=None
):
self.llm_model_name = llm_model_name
self.condense_model_name = condense_model_name
self.prompt_template = prompt_template
self.temperature = temperature
self.top_p = top_p
self.n_choices = n_choices
self.stop = stop
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.history = []
self.user_identifier = user
def set_user_identifier(self, new_user_identifier):
self.user_identifier = new_user_identifier
def format_prompt(self, qa_prompt_template, condense_prompt_template):
# Prompt template langchain
qa_prompt = PromptTemplate(template=qa_prompt_template, input_variables=["question", "chat_history", "context"])
condense_prompt = PromptTemplate(template=condense_prompt_template,
input_variables=["question", "chat_history"])
return qa_prompt, condense_prompt
def memory(self, inputs, outputs, last_k=3):
# last_k: top k last conversation
if len(self.history) >= last_k:
self.history.pop(0)
self.history.extend([(inputs, outputs)])
def reset_conversation(self):
self.history = []
return []
def delete_first_conversation(self):
if self.history:
self.history.pop(0)
def delete_last_conversation(self):
if len(self.history) > 0:
self.history.pop()
def save_history(self, chatbot, file_name):
message = upsert_item(self.user_identifier, file_name, self.history, chatbot)
return message
def load_history(self, file_name):
items = read_item(self.user_identifier, file_name)
return items['id'], items['chatbot']
def delete_history(self, file_name):
message = delete_items(self.user_identifier, file_name)
return message, get_history_names(False, self.user_identifier), []
def audio_response(self, audio):
media_file = open(audio, 'rb')
response = openai.Audio.transcribe(
api_key=API_KEY,
model=MODEL_ID,
file=media_file
)
return response["text"], None
def inference(self, inputs, chatbot, streaming=False, upload_files_btn=False, custom_websearch=False,
local_db=False,
**kwargs):
if upload_files_btn or local_db:
status_text = "Indexing files to vector database"
yield chatbot, status_text
vectorstore = upload_file(upload_files_btn)
qa_prompt, condense_prompt = self.format_prompt(**kwargs)
job_done = object() # signals the processing is done
q = SimpleQueue()
if streaming:
timeout = TIMEOUT_STREAM
streaming_callback = [StreamingGradioCallbackHandler(q)]
# Define llm model
llm = CustomAzureOpenAI(deployment_name=DEPLOYMENT_ID,
openai_api_type=OPENAI_API_TYPE,
openai_api_base=OPENAI_API_BASE,
openai_api_version=OPENAI_API_VERSION,
openai_api_key=OPENAI_API_KEY,
temperature=self.temperature,
model_kwargs={"top_p": self.top_p},
streaming=streaming, \
callbacks=streaming_callback,
request_timeout=timeout)
condense_llm = CustomAzureOpenAI(deployment_name=self.condense_model_name,
openai_api_type=OPENAI_API_TYPE,
openai_api_base=OPENAI_API_BASE,
openai_api_version=OPENAI_API_VERSION,
openai_api_key=OPENAI_API_KEY,
temperature=self.temperature)
status_text = "Request URL: " + OPENAI_API_BASE
yield chatbot, status_text
# Create a function to call - this will run in a thread
# Create a Queue object
response_queue = SimpleQueue()
def task():
# Conversation + RetrivalChain
qa = CustomConversationalRetrievalChain.from_llm(llm, vectorstore.as_retriever(
search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.75}),
condense_question_llm=condense_llm, verbose=True,
condense_question_prompt=condense_prompt,
combine_docs_chain_kwargs={"prompt": qa_prompt},
return_source_documents=True)
# query with input and chat history
response = qa({"question": inputs, "chat_history": self.history})
response_queue.put(response)
q.put(job_done)
thread = Thread(target=task)
thread.start()
chatbot.append((inputs, ""))
content = ""
while True:
try:
next_token = q.get(block=True)
if next_token is job_done:
break
content += next_token
chatbot[-1] = (chatbot[-1][0], content)
yield chatbot, status_text
except Empty:
continue
# add citation info to response
response = response_queue.get()
relevant_docs = response["source_documents"]
if len(relevant_docs) == 0:
display_append = ""
else:
if upload_files_btn:
reference_results = [d.page_content for d in relevant_docs]
reference_sources = [d.metadata["source"] for d in relevant_docs]
display_append = add_details(reference_results, reference_sources)
display_append = '<div class = "source-a">' + "\n".join(display_append) + '</div>'
else:
display_append = []
for idx, d in enumerate(relevant_docs):
link = d.metadata["source"]
title = d.page_content.split("\n")[0]
# Remove non word characters and blank space before title
title = re.sub(r"[^\w\s]", "", title[:4]).strip()
display_append.append(
f'<a href=\"{link}\" target=\"_blank\">[{idx + 1}] {title}</a>'
)
display_append = '<div class = "source-a">' + "\n".join(display_append) + '</div>'
chatbot[-1] = (chatbot[-1][0], content + display_append)
yield chatbot, status_text
self.memory(inputs, content)
# self.auto_save_history(chatbot)
thread.join()
else:
import requests
from langchain.utilities.google_search import GoogleSearchAPIWrapper
from chains.web_search import GoogleWebSearch
from config import GOOGLE_API_KEY, GOOGLE_CSE_ID
top_k = 4
if custom_websearch:
status_text = "Retrieving information from website FPTSoftware.com"
yield chatbot, status_text
params = {
"q": inputs,
"v": "\{539C9DC1-663A-418D-82A4-662D34EE34BC\}",
"p": 10,
"l": "en",
"s": "{EACE8DB5-668F-4357-9782-405070D28D11}",
"itemid": "\{91F4101E-B1F3-4905-A832-96F703D3FBB1\}",
}
req = requests.get(
"https://fptsoftware.com//sxa/search/results/?",
params=params
)
res = json.loads(req.text)
results = []
for r in res["Results"][:top_k]:
link = "https://fptsoftware.com" + r["Url"]
results.append({"link": link})
reference_results, display_append = web_citation(inputs, results, True)
reference_results = add_source_numbers(reference_results)
display_append = '<div class = "source-a">' + "\n".join(display_append) + '</div>'
status_text = "Request URL: " + OPENAI_API_BASE
yield chatbot, status_text
chatbot.append((inputs, ""))
web_search = GoogleWebSearch()
ai_response = web_search.predict(context="\n\n".join(reference_results), question=inputs,
chat_history=self.history)
chatbot[-1] = (chatbot[-1][0], ai_response + display_append)
self.memory(inputs, ai_response)
# self.auto_save_history(chatbot)
yield chatbot, status_text
else:
from chains.decision_maker import DecisionMaker
from chains.simple_chain import SimpleChain
decision_maker = DecisionMaker()
simple_chain = SimpleChain()
decision = decision_maker.predict(question=inputs)
if "LLM Model" in decision:
status_text = "Request URL: " + OPENAI_API_BASE
yield chatbot, status_text
chatbot.append((inputs, ""))
ai_response = simple_chain.predict(question=inputs)
chatbot[-1] = (chatbot[-1][0], ai_response)
self.memory(inputs, ai_response)
# self.auto_save_history(chatbot)
yield chatbot, status_text
else:
status_text = "Retrieving information from Google"
yield chatbot, status_text
search = GoogleSearchAPIWrapper(google_api_key=GOOGLE_API_KEY, google_cse_id=GOOGLE_CSE_ID)
results = search.results(inputs, num_results=top_k)
reference_results, display_append = web_citation(inputs, results, False)
reference_results = add_source_numbers(reference_results)
display_append = '<div class = "source-a">' + "\n".join(display_append) + '</div>'
status_text = "Request URL: " + OPENAI_API_BASE
yield chatbot, status_text
chatbot.append((inputs, ""))
web_search = GoogleWebSearch()
ai_response = web_search.predict(context="\n\n".join(reference_results), question=inputs,
chat_history=self.history)
chatbot[-1] = (chatbot[-1][0], ai_response + display_append)
self.memory(inputs, ai_response)
# self.auto_save_history(chatbot)
yield chatbot, status_text
if __name__ == '__main__':
import os
from config import OPENAI_API_KEY
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate)
SYSTEM_PROMPT_TEMPLATE = "You're a helpful assistant."
HUMAN_PROMPT_TEMPLATE = "Human: {question}\n AI answer:"
prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT_TEMPLATE),
HumanMessagePromptTemplate.from_template(HUMAN_PROMPT_TEMPLATE)
]
)
llm = CustomAzureOpenAI(deployment_name="binh-gpt",
openai_api_key=OPENAI_API_KEY,
openai_api_base=OPENAI_API_BASE,
openai_api_version=OPENAI_API_VERSION,
temperature=0,
model_kwargs={"top_p": 1.0}, )
llm_chain = LLMChain(
llm=llm,
prompt=prompt
)
results = llm_chain.predict(question="Hello")
print(results)