Chat_QnA_v2 / chains /openai_model.py
minatosnow's picture
Fix response declaration
6ed9bdb
raw
history blame
10.9 kB
import json
import os
import openai
from langchain.prompts import PromptTemplate
from config import TIMEOUT_STREAM
from vector_db import upload_file
from callback import StreamingGradioCallbackHandler
from queue import SimpleQueue, Empty, Queue
from threading import Thread
from utils import history_file_path, load_lasted_file_username, add_source_numbers, add_details
from chains.custom_chain import CustomConversationalRetrievalChain
from langchain.chains import LLMChain
from chains.azure_openai import CustomAzureOpenAI
from config import OPENAI_API_TYPE, OPENAI_API_VERSION, OPENAI_API_KEY, OPENAI_API_BASE, API_KEY, \
DEPLOYMENT_ID, MODEL_ID, EMBEDDING_API_KEY, EMBEDDING_API_BASE
class OpenAIModel:
def __init__(
self,
llm_model_name,
condense_model_name,
prompt_template="",
temperature=0.0,
top_p=1.0,
n_choices=1,
stop = None,
presence_penalty=0,
frequency_penalty=0,
user = None
):
self.llm_model_name = llm_model_name
self.condense_model_name = condense_model_name
self.prompt_template = prompt_template
self.temperature = temperature
self.top_p = top_p
self.n_choices = n_choices
self.stop = stop
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.history = []
self.user_identifier = user
def set_user_indentifier(self, new_user_indentifier):
self.user_identifier = new_user_indentifier
def format_prompt(self, qa_prompt_template, condense_prompt_template):
# Prompt template langchain
qa_prompt = PromptTemplate(template=qa_prompt_template, input_variables=["question", "chat_history", "context"])
condense_prompt = PromptTemplate(template=condense_prompt_template, input_variables=["question", "chat_history"])
return qa_prompt, condense_prompt
def memory(self, inputs, outputs, last_k=3):
# last_k: top k last conversation
if len(self.history) >= last_k:
self.history.pop(0)
self.history.extend([(inputs, outputs)])
def reset_conversation(self):
self.history = []
return []
def delete_first_conversation(self):
if self.history:
self.history.pop(0)
def delete_last_conversation(self):
if len(self.history) > 0:
self.history.pop()
def auto_save_history(self, chatbot):
if self.user_identifier is not None:
file_path = history_file_path(self.user_identifier)
json_s = {"history": self.history, "chatbot": chatbot}
with open(file_path, "w", encoding='utf-8') as f:
json.dump(json_s, f, ensure_ascii=False)
def load_history(self):
lasted_file = load_lasted_file_username(self.user_identifier)
if lasted_file is not None:
with open(f"{lasted_file}.json", "r", encoding="utf-8") as f:
json_s = json.load(f)
self.history = json_s["history"]
chatbot = json_s["chatbot"]
return chatbot
def audio_response(self, audio):
media_file = open(audio, 'rb')
response = openai.Audio.transcribe(
api_key=API_KEY,
model=MODEL_ID,
file=media_file
)
return response["text"]
def inference(self, inputs, chatbot, streaming=False, use_websearch=False, custom_websearch=False, **kwargs):
if use_websearch or custom_websearch:
import requests
from bs4 import BeautifulSoup
from langchain.utilities.google_search import GoogleSearchAPIWrapper
from chains.web_search import GoogleWebSearch
from config import GOOGLE_API_KEY, GOOGLE_CSE_ID, CUSTOM_API_KEY, CUSTOM_CSE_ID
from chains.summary import WebSummary
status_text = "Retrieving information from the web"
yield chatbot, status_text
if use_websearch:
google_api_key = GOOGLE_API_KEY
google_cse_id = GOOGLE_CSE_ID
else:
google_api_key = CUSTOM_API_KEY
google_cse_id = CUSTOM_CSE_ID
search = GoogleSearchAPIWrapper(google_api_key=google_api_key, google_cse_id=google_cse_id)
results = search.results(inputs,4)
reference_results = []
display_append = []
for idx, result in enumerate(results):
print(result['link'])
response = requests.get(result['link'])
soup = BeautifulSoup(response.content, "html.parser")
try:
summary = WebSummary.predict(question=inputs, doc=soup.get_text())
print("Can access", result['link'])
except:
print("Cannot access ", result['link'])
summary = result['snippet']
reference_results.append([summary, result['link']])
display_append.append(
f"<a href=\"{result['link']}\" target=\"_blank\">{idx+1}.&nbsp;{result['title']}</a>"
)
reference_results = add_source_numbers(reference_results)
display_append = '<div class = "source-a">' + "\n".join(display_append) + '</div>'
status_text = "Request URL: " + OPENAI_API_BASE
yield chatbot, status_text
chatbot.append((inputs, ""))
web_search = GoogleWebSearch()
ai_response = web_search.predict(context="\n\n".join(reference_results), question=inputs, chat_history=self.history)
chatbot[-1] = (chatbot[-1][0], ai_response+display_append)
self.memory(inputs, ai_response)
self.auto_save_history(chatbot)
yield chatbot, status_text
else:
status_text = "Indexing files to vector database"
yield chatbot, status_text
vectorstore = upload_file()
status_text = "OpenAI version: " + OPENAI_API_VERSION
yield chatbot, status_text
qa_prompt, condense_prompt = self.format_prompt(**kwargs)
job_done = object() # signals the processing is done
q = SimpleQueue()
if streaming:
timeout = TIMEOUT_STREAM
streaming_callback =[StreamingGradioCallbackHandler(q)]
# Define llm model
llm = CustomAzureOpenAI(deployment_name=DEPLOYMENT_ID,
openai_api_type=OPENAI_API_TYPE,
openai_api_base=OPENAI_API_BASE,
openai_api_version=OPENAI_API_VERSION,
openai_api_key=OPENAI_API_KEY,
temperature=self.temperature,
model_kwargs={"top_p": self.top_p},
streaming=streaming,\
callbacks=streaming_callback,
request_timeout=timeout)
condense_llm = CustomAzureOpenAI(deployment_name=self.condense_model_name,
openai_api_type=OPENAI_API_TYPE,
openai_api_base=OPENAI_API_BASE,
openai_api_version=OPENAI_API_VERSION,
openai_api_key=OPENAI_API_KEY,
temperature=self.temperature)
status_text = "Request URL: " + OPENAI_API_BASE
yield chatbot, status_text
# Create a funciton to call - this will run in a thread
response_queue = Queue()
def task():
# Converation + RetrivalChain
qa = CustomConversationalRetrievalChain.from_llm(llm, vectorstore.as_retriever(k=5),
condense_question_llm = condense_llm, verbose=True,
condense_question_prompt=condense_prompt,
combine_docs_chain_kwargs={"prompt": qa_prompt},
return_source_documents=True)
# query with input and chat history
response = qa({"question": inputs, "chat_history": self.history})
response_queue.put(response)
q.put(job_done)
thread = Thread(target=task)
thread.start()
chatbot.append((inputs, ""))
content = ""
while True:
try:
next_token = q.get(block=True)
if next_token is job_done:
break
content += next_token
chatbot[-1] = (chatbot[-1][0], content)
yield chatbot, status_text
except Empty:
continue
# add citation info to response
response = response_queue.get()
relevant_docs = response["source_documents"]
reference_results = [d.page_content for d in relevant_docs]
display_append = add_details(reference_results)
display_append = "\n\n" + "<details><summary><b>Citation</b></summary>"+ "".join(display_append) + "</details>"
chatbot[-1] = (chatbot[-1][0], content+display_append)
yield chatbot, status_text
self.memory(inputs, content)
self.auto_save_history(chatbot)
thread.join()
if __name__ == '__main__':
import os
from config import OPENAI_API_KEY
from langchain.chains.llm import LLMChain
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate)
SYSTEM_PROMPT_TEMPLATE = "You're a helpfull assistant."
HUMAN_PROMPT_TEMPLATE = "Human: {question}\n AI answer:"
prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT_TEMPLATE),
HumanMessagePromptTemplate.from_template(HUMAN_PROMPT_TEMPLATE)
]
)
print("-===============")
llm = CustomAzureOpenAI(deployment_name="binh-gpt",
openai_api_key=OPENAI_API_KEY,
openai_api_base=OPENAI_API_BASE,
openai_api_version=OPENAI_API_VERSION,
temperature=0,
model_kwargs={"top_p": 1.0},)
llm_chain = LLMChain(
llm=llm,
prompt=prompt
)
results = llm_chain.predict(question="Hello")
print(results)