Spaces:
Runtime error
Runtime error
import json | |
import os | |
import openai | |
from langchain.prompts import PromptTemplate | |
from config import TIMEOUT_STREAM | |
from vector_db import upload_file | |
from callback import StreamingGradioCallbackHandler | |
from queue import SimpleQueue, Empty | |
from threading import Thread | |
from utils import history_file_path, load_lasted_file_username, add_source_numbers, add_details | |
from chains.custom_chain import CustomConversationalRetrievalChain | |
from langchain.chains import LLMChain | |
from chains.azure_openai import CustomAzureOpenAI | |
from config import OPENAI_API_TYPE, OPENAI_API_VERSION, OPENAI_API_KEY, OPENAI_API_BASE, API_KEY, \ | |
DEPLOYMENT_ID, MODEL_ID, EMBEDDING_API_KEY, EMBEDDING_API_BASE | |
class OpenAIModel: | |
def __init__( | |
self, | |
llm_model_name, | |
condense_model_name, | |
prompt_template="", | |
temperature=0.0, | |
top_p=1.0, | |
n_choices=1, | |
stop = None, | |
presence_penalty=0, | |
frequency_penalty=0, | |
user = None | |
): | |
self.llm_model_name = llm_model_name | |
self.condense_model_name = condense_model_name | |
self.prompt_template = prompt_template | |
self.temperature = temperature | |
self.top_p = top_p | |
self.n_choices = n_choices | |
self.stop = stop | |
self.presence_penalty = presence_penalty | |
self.frequency_penalty = frequency_penalty | |
self.history = [] | |
self.user_identifier = user | |
def set_user_indentifier(self, new_user_indentifier): | |
self.user_identifier = new_user_indentifier | |
def format_prompt(self, qa_prompt_template, condense_prompt_template): | |
# Prompt template langchain | |
qa_prompt = PromptTemplate(template=qa_prompt_template, input_variables=["question", "chat_history", "context"]) | |
condense_prompt = PromptTemplate(template=condense_prompt_template, input_variables=["question", "chat_history"]) | |
return qa_prompt, condense_prompt | |
def memory(self, inputs, outputs, last_k=3): | |
# last_k: top k last conversation | |
if len(self.history) >= last_k: | |
self.history.pop(0) | |
self.history.extend([(inputs, outputs)]) | |
def reset_conversation(self): | |
self.history = [] | |
return [] | |
def delete_first_conversation(self): | |
if self.history: | |
self.history.pop(0) | |
def delete_last_conversation(self): | |
if len(self.history) > 0: | |
self.history.pop() | |
def auto_save_history(self, chatbot): | |
if self.user_identifier is not None: | |
file_path = history_file_path(self.user_identifier) | |
json_s = {"history": self.history, "chatbot": chatbot} | |
with open(file_path, "w", encoding='utf-8') as f: | |
json.dump(json_s, f, ensure_ascii=False) | |
def load_history(self): | |
lasted_file = load_lasted_file_username(self.user_identifier) | |
if lasted_file is not None: | |
with open(f"{lasted_file}.json", "r", encoding="utf-8") as f: | |
json_s = json.load(f) | |
self.history = json_s["history"] | |
chatbot = json_s["chatbot"] | |
return chatbot | |
def audio_response(self, audio): | |
media_file = open(audio, 'rb') | |
response = openai.Audio.transcribe( | |
api_key=API_KEY, | |
model=MODEL_ID, | |
file=media_file | |
) | |
return response["text"] | |
def inference(self, inputs, chatbot, streaming=False, use_websearch=False, custom_websearch=False, **kwargs): | |
if use_websearch or custom_websearch: | |
import requests | |
from bs4 import BeautifulSoup | |
from langchain.utilities.google_search import GoogleSearchAPIWrapper | |
from chains.web_search import GoogleWebSearch | |
from config import GOOGLE_API_KEY, GOOGLE_CSE_ID, CUSTOM_API_KEY, CUSTOM_CSE_ID | |
from chains.summary import WebSummary | |
from chains.multi_queries import MultiQueries | |
status_text = "Retrieving information from the web" | |
yield chatbot, status_text | |
if use_websearch: | |
google_api_key = GOOGLE_API_KEY | |
google_cse_id = GOOGLE_CSE_ID | |
else: | |
google_api_key = CUSTOM_API_KEY | |
google_cse_id = CUSTOM_CSE_ID | |
search = GoogleSearchAPIWrapper(google_api_key=google_api_key, google_cse_id=google_cse_id) | |
queries_chain = MultiQueries() | |
out = queries_chain.predict(question=inputs) | |
queries = list(map(lambda x: x.split(': ')[-1], out.split('\n\n'))) | |
print(queries) | |
results = [] | |
for query in queries: | |
search_rs = search.results(query, 2) | |
results.extend(search_rs) | |
reference_results = [] | |
display_append = [] | |
for idx, result in enumerate(results): | |
try: | |
head = requests.head(result['link']) | |
if "text/html" in head.headers['Content-Type']: | |
html_response = requests.get(result['link']) | |
soup = BeautifulSoup(html_response.content, "html.parser") | |
try: | |
web_summary = WebSummary() | |
summary = web_summary.predict(question=inputs, doc=soup.get_text()) | |
print("Can access", result['link']) | |
except: | |
print("Cannot access ", result['link']) | |
summary = result['snippet'] | |
reference_results.append([summary, result['link']]) | |
display_append.append( | |
f"<a href=\"{result['link']}\" target=\"_blank\">{idx+1}. {result['title']}</a>" | |
) | |
except: | |
continue | |
reference_results = add_source_numbers(reference_results) | |
display_append = '<div class = "source-a">' + "\n".join(display_append) + '</div>' | |
status_text = "Request URL: " + OPENAI_API_BASE | |
yield chatbot, status_text | |
chatbot.append((inputs, "")) | |
web_search = GoogleWebSearch() | |
ai_response = web_search.predict(context="\n\n".join(reference_results), question=inputs, chat_history=self.history) | |
chatbot[-1] = (chatbot[-1][0], ai_response+display_append) | |
self.memory(inputs, ai_response) | |
self.auto_save_history(chatbot) | |
yield chatbot, status_text | |
else: | |
status_text = "Indexing files to vector database" | |
yield chatbot, status_text | |
vectorstore = upload_file() | |
status_text = "OpenAI version: " + OPENAI_API_VERSION | |
yield chatbot, status_text | |
qa_prompt, condense_prompt = self.format_prompt(**kwargs) | |
job_done = object() # signals the processing is done | |
q = SimpleQueue() | |
if streaming: | |
timeout = TIMEOUT_STREAM | |
streaming_callback =[StreamingGradioCallbackHandler(q)] | |
# Define llm model | |
llm = CustomAzureOpenAI(deployment_name=DEPLOYMENT_ID, | |
openai_api_type=OPENAI_API_TYPE, | |
openai_api_base=OPENAI_API_BASE, | |
openai_api_version=OPENAI_API_VERSION, | |
openai_api_key=OPENAI_API_KEY, | |
temperature=self.temperature, | |
model_kwargs={"top_p": self.top_p}, | |
streaming=streaming,\ | |
callbacks=streaming_callback, | |
request_timeout=timeout) | |
condense_llm = CustomAzureOpenAI(deployment_name=self.condense_model_name, | |
openai_api_type=OPENAI_API_TYPE, | |
openai_api_base=OPENAI_API_BASE, | |
openai_api_version=OPENAI_API_VERSION, | |
openai_api_key=OPENAI_API_KEY, | |
temperature=self.temperature) | |
status_text = "Request URL: " + OPENAI_API_BASE | |
yield chatbot, status_text | |
# Create a funciton to call - this will run in a thread | |
# Create a Queue object | |
response_queue = SimpleQueue() | |
def task(): | |
# Converation + RetrivalChain | |
qa = CustomConversationalRetrievalChain.from_llm(llm, vectorstore.as_retriever(k=5), | |
condense_question_llm = condense_llm, verbose=True, | |
condense_question_prompt=condense_prompt, | |
combine_docs_chain_kwargs={"prompt": qa_prompt}, | |
return_source_documents=True) | |
# query with input and chat history | |
response = qa({"question": inputs, "chat_history": self.history}) | |
# Put response in the queue | |
response_queue.put(response) | |
q.put(job_done) | |
thread = Thread(target=task) | |
thread.start() | |
chatbot.append((inputs, "")) | |
content = "" | |
while True: | |
try: | |
next_token = q.get(block=True) | |
if next_token is job_done: | |
break | |
content += next_token | |
chatbot[-1] = (chatbot[-1][0], content) | |
yield chatbot, status_text | |
except Empty: | |
continue | |
# add citation info to response | |
# Get the response from the queue | |
response = response_queue.get() | |
relevant_docs = response["source_documents"] | |
reference_results = [d.page_content for d in relevant_docs] | |
display_append = add_details(reference_results) | |
display_append = "\n\n" + "<details><summary><b>Citation</b></summary>"+ "".join(display_append) + "</details>" | |
chatbot[-1] = (chatbot[-1][0], content+display_append) | |
yield chatbot, status_text | |
self.memory(inputs, content) | |
self.auto_save_history(chatbot) | |
thread.join() | |
if __name__ == '__main__': | |
import os | |
from config import OPENAI_API_KEY | |
from langchain.chains.llm import LLMChain | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
SystemMessagePromptTemplate, | |
HumanMessagePromptTemplate) | |
SYSTEM_PROMPT_TEMPLATE = "You're a helpfull assistant." | |
HUMAN_PROMPT_TEMPLATE = "Human: {question}\n AI answer:" | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
SystemMessagePromptTemplate.from_template(SYSTEM_PROMPT_TEMPLATE), | |
HumanMessagePromptTemplate.from_template(HUMAN_PROMPT_TEMPLATE) | |
] | |
) | |
print("-===============") | |
llm = CustomAzureOpenAI(deployment_name="binh-gpt", | |
openai_api_key=OPENAI_API_KEY, | |
openai_api_base=OPENAI_API_BASE, | |
openai_api_version=OPENAI_API_VERSION, | |
temperature=0, | |
model_kwargs={"top_p": 1.0},) | |
llm_chain = LLMChain( | |
llm=llm, | |
prompt=prompt | |
) | |
results = llm_chain.predict(question="Hello") | |
print(results) |