Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import openai | |
import re | |
import numpy as np | |
from sklearn.neighbors import NearestNeighbors | |
import tensorflow_hub as hub | |
import fitz | |
def add_source_numbers(lst, source_name="Source", use_source=True): | |
if use_source: | |
return [f'[{idx + 1}]\t "{item[0]}"\n{source_name}: {item[1]}' for idx, item in enumerate(lst)] | |
else: | |
return [f'[{idx + 1}]\t "{item}"' for idx, item in enumerate(lst)] | |
def add_details(lst): | |
nodes = [] | |
for index, txt in enumerate(lst): | |
brief = txt[:25].replace("\n", "") | |
nodes.append( | |
f"<details><summary>{brief}...</summary><p>{txt}</p></details>" | |
) | |
return nodes | |
prompt_template = "Instructions: Compose a comprehensive reply to the query using the search results given. " \ | |
"If the search results mention multiple subjects " \ | |
"with the same name, create separate answers for each. Only include information found in the results and " \ | |
"don't add any additional information. Make sure the answer is correct and don't output false content. " \ | |
"Ignore outlier search results which has nothing to do with the question. Only answer what is asked. " \ | |
"The answer should be short and concise. \n\nQuery: {question}\nAnswer: " | |
MODELS = ["text-davinci-001", "text-davinci-002", "text-davinci-003"] | |
LANGUAGES = [ | |
"English", | |
"简体中文", | |
"日本語", | |
"Deutsch", | |
"Vietnamese" | |
] | |
def set_openai_api_key(my_api_key): | |
openai.api_key = my_api_key | |
return gr.update(visible = True) | |
def add_source_numbers(lst): | |
return [item[:3] + '\t' + item[3:] for item in (lst)] | |
def add_details(lst): | |
nodes = [] | |
for index, txt in enumerate(lst): | |
brief = txt[:25].replace("\n", "") | |
nodes.append( | |
f"<details><summary>{brief}...</summary><p>{txt}</p></details>" | |
) | |
return nodes | |
def preprocess(text): | |
text = text.replace('\n', ' ') | |
text = re.sub('\s+', ' ', text) | |
return text | |
def pdf_to_text(files_src, start_page=1, end_page=None): | |
text_list = [] | |
for file in files_src: | |
if (os.path.splitext(file.name)[1]).lower() == ".pdf": | |
doc = fitz.open(file.name) | |
total_pages = doc.page_count | |
# if end_page is None: | |
end_page = total_pages | |
for i in range(start_page - 1, end_page): | |
text = doc.load_page(i).get_text("text") | |
text = preprocess(text) | |
text_list.append(text) | |
doc.close() | |
return text_list | |
def text_to_chunks(texts, word_length=150, start_page=1): | |
text_toks = [t.split(' ') for t in texts] | |
chunks = [] | |
for idx, words in enumerate(text_toks): | |
for i in range(0, len(words), word_length): | |
chunk = words[i:i + word_length] | |
if (i + word_length) > len(words) and (len(chunk) < word_length) and ( | |
len(text_toks) != (idx + 1)): | |
text_toks[idx + 1] = chunk + text_toks[idx + 1] | |
continue | |
chunk = ' '.join(chunk).strip() | |
chunk = f'[{idx + start_page}]' + ' ' + '"' + chunk + '"' | |
chunks.append(chunk) | |
return chunks | |
def embedding(model, files_src, batch=1000): | |
name_file = '_'.join([os.path.basename(file.name).split('.')[0] for file in files_src]) | |
embeddings_file = f"{name_file}.npy" | |
texts = pdf_to_text(files_src) | |
chunks = text_to_chunks(texts) | |
if os.path.isfile(embeddings_file): | |
embeddings = np.load(embeddings_file) | |
return embeddings, chunks | |
data = chunks | |
embeddings = [] | |
for i in range(0, len(data), batch): | |
text_batch = data[i:(i + batch)] | |
emb_batch = model(text_batch) | |
embeddings.append(emb_batch) | |
embeddings = np.vstack(embeddings) | |
np.save(embeddings_file, embeddings) | |
return embeddings, chunks | |
def get_top_chunks(inp_emb, data, n_neighbors=5): | |
n_neighbors = min(n_neighbors, len(data)) | |
nn = NearestNeighbors(n_neighbors=n_neighbors) | |
nn.fit(data) | |
neighbors = nn.kneighbors(inp_emb, return_distance=False)[0] | |
return neighbors | |
def predict( | |
my_api_key, | |
history, | |
chatbot, | |
inputs, | |
temperature, | |
lang = LANGUAGES[0], | |
selected_model=MODELS[0], | |
files=None | |
): | |
old_inputs = None | |
if files: | |
old_inputs = inputs | |
emb_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4') | |
vector_emb, chunks = embedding(emb_model, files) | |
input_emb = emb_model([inputs]) | |
index_top_chunks = get_top_chunks(input_emb, vector_emb) | |
topn_chunks = [chunks[i] for i in index_top_chunks] | |
prompt = "" | |
prompt += 'search results:\n\n' | |
for c in topn_chunks: | |
prompt += c + '\n\n' | |
prompt += prompt_template | |
prompt += f"Query: {inputs}. Reply in {lang}\nAnswer:" | |
inputs = prompt | |
reference_results = add_source_numbers(topn_chunks) | |
display_reference = add_details(reference_results) | |
display_reference = "\n\n" + "".join(display_reference) | |
else: | |
display_reference = "" | |
history.append(inputs) | |
if old_inputs: | |
chatbot.append((old_inputs, "")) | |
else: | |
chatbot.append((inputs, "")) | |
openai.api_key = my_api_key | |
completions = openai.Completion.create( | |
engine=selected_model, | |
prompt=inputs, | |
max_tokens=256, | |
stop=None, | |
temperature=temperature, | |
) | |
message = completions.choices[0].text | |
if old_inputs is not None: | |
history[-1] = old_inputs | |
chatbot[-1] = (chatbot[-1][0], message + display_reference) | |
return chatbot, history | |
# Create theme | |
with open("custom.css", "r", encoding="utf-8") as f: | |
customCSS = f.read() | |
beautiful_theme = gr.themes.Soft( | |
primary_hue=gr.themes.Color( | |
c50="#02C160", | |
c100="rgba(2, 193, 96, 0.2)", | |
c200="#02C160", | |
c300="rgba(2, 193, 96, 0.32)", | |
c400="rgba(2, 193, 96, 0.32)", | |
c500="rgba(2, 193, 96, 1.0)", | |
c600="rgba(2, 193, 96, 1.0)", | |
c700="rgba(2, 193, 96, 0.32)", | |
c800="rgba(2, 193, 96, 0.32)", | |
c900="#02C160", | |
c950="#02C160", | |
), | |
radius_size=gr.themes.sizes.radius_sm, | |
).set( | |
button_primary_background_fill="#06AE56", | |
button_primary_background_fill_dark="#06AE56", | |
button_primary_background_fill_hover="#07C863", | |
button_primary_border_color="#06AE56", | |
button_primary_border_color_dark="#06AE56", | |
button_primary_text_color="#FFFFFF", | |
button_primary_text_color_dark="#FFFFFF", | |
block_title_text_color="*primary_500", | |
block_title_background_fill="*primary_100", | |
input_background_fill="#F6F6F6", | |
) | |
# Gradio app | |
title = """<h1 align="left" style="min-width:200px; margin-top:6px; white-space: nowrap;">ChatGPT 🚀</h1>""" | |
with gr.Blocks(css=customCSS, theme=beautiful_theme) as demo: | |
history = gr.State([]) | |
user_question = gr.State("") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.HTML(title) | |
with gr.Row().style(equal_height=True): | |
with gr.Column(scale=5): | |
with gr.Row(): | |
chatbot = gr.Chatbot(elem_id="chatbot").style(height="100%") | |
with gr.Row(visible=False) as input_raws: | |
with gr.Column(scale=12): | |
user_input = gr.Textbox( | |
show_label=False, placeholder="Enter here" | |
).style(container=False) | |
with gr.Column(min_width=70, scale=1): | |
submitBtn = gr.Button("Send", variant="primary") | |
with gr.Column(): | |
with gr.Column(min_width=50, scale=1): | |
with gr.Tab(label="ChatGPT"): | |
gr.Markdown(f'<p style="text-align:center">Get your Open AI API key <a ' | |
f'href="https://platform.openai.com/account/api-keys">here</a></p>') | |
openAI_key=gr.Textbox(label='Enter your OpenAI API key here') | |
f'href="https://platform.openai.com/account/api-keys">here</a></p>') | |
openAI_key=gr.Textbox(label='Enter your OpenAI API key here and press Enter') | |
model_select_dropdown = gr.Dropdown( | |
label="Select model", choices=MODELS, multiselect=False, value=MODELS[0] | |
) | |
language_select_dropdown = gr.Dropdown( | |
label="Select reply language", choices=LANGUAGES, multiselect=False, value=LANGUAGES[0] | |
) | |
index_files = gr.Files(label="Files", type="file", multiple=True) | |
with gr.Tab(label="Advanced"): | |
gr.Markdown( | |
"⚠️Be careful to change ⚠️\n\nIf you can't use it, please restore the default settings") | |
with gr.Tab(label="Advanced"): | |
with gr.Accordion("Parameter", open=False): | |
temperature = gr.Slider( | |
minimum=-0, | |
maximum=1.0, | |
value=0.0, | |
step=0.1, | |
interactive=True, | |
label="Temperature", | |
) | |
openAI_key.submit(set_openai_api_key, [openAI_key], [input_raws]) | |
user_input.submit(predict, inputs=[history, chatbot, user_input, temperature, language_select_dropdown, model_select_dropdown, index_files], | |
outputs=[chatbot, history]) | |
user_input.submit(lambda: "", None, user_input) | |
submitBtn.click(predict, inputs=[history, chatbot, user_input, temperature, language_select_dropdown, model_select_dropdown, index_files], | |
outputs=[chatbot, history]) | |
submitBtn.click(lambda: "", None, user_input) | |
demo.queue(concurrency_count=10).launch(server_name="0.0.0.0", server_port=7862) | |