Spaces:
Runtime error
Runtime error
File size: 5,759 Bytes
0658357 65a1209 0658357 402f092 0658357 402f092 0658357 65a1209 4c4129f 0658357 65a1209 0658357 65a1209 572d7fe 0658357 4c4129f 572d7fe 0658357 572d7fe 0658357 4c4129f 65a1209 4c4129f 65a1209 4c4129f 65a1209 0658357 65a1209 0658357 65a1209 0658357 65a1209 4c4129f 402f092 65a1209 402f092 65a1209 572d7fe 65a1209 0658357 402f092 0658357 65a1209 572d7fe 65a1209 402f092 65a1209 572d7fe 65a1209 402f092 65a1209 572d7fe 65a1209 402f092 0658357 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import os
import gradio as gr
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.document_loaders import TextLoader
from langchain.memory import ConversationBufferMemory
from langchain.llms import HuggingFaceHub
from langchain.chains import ConversationalRetrievalChain
embeddings = None
qa_chain = None
def load_embeddings():
global embeddings
if not embeddings:
print("loading embeddings...")
model_name = os.environ['HUGGINGFACEHUB_EMBEDDINGS_MODEL_NAME']
embeddings = HuggingFaceInstructEmbeddings(model_name=model_name)
return embeddings
def split_file(file, chunk_size, chunk_overlap):
print('spliting file...', file.name, chunk_size, chunk_overlap)
loader = TextLoader(file.name)
documents = loader.load()
text_splitter = CharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap)
return text_splitter.split_documents(documents)
def get_persist_directory(file_name):
return os.path.join(os.environ['CHROMADB_PERSIST_DIRECTORY'], file_name)
def process_file(file, chunk_size, chunk_overlap):
docs = split_file(file, chunk_size, chunk_overlap)
embeddings = load_embeddings()
file_name, _ = os.path.splitext(os.path.basename(file.name))
persist_directory = get_persist_directory(file_name)
print("initializing vector store...", persist_directory)
vectordb = Chroma.from_documents(documents=docs, embedding=embeddings,
collection_name=file_name, persist_directory=persist_directory)
print("persisting...", vectordb._client.list_collections())
vectordb.persist()
return 'Done!', gr.Dropdown.update(choices=get_vector_dbs(), value=file_name)
def is_dir(root, name):
path = os.path.join(root, name)
return os.path.isdir(path)
def get_vector_dbs():
root = os.environ['CHROMADB_PERSIST_DIRECTORY']
if not os.path.exists(root):
return []
print('get vector dbs...', root)
files = os.listdir(root)
dirs = list(filter(lambda x: is_dir(root, x), files))
print(dirs)
return dirs
def load_vectordb(file_name):
embeddings = load_embeddings()
persist_directory = get_persist_directory(file_name)
print(persist_directory)
vectordb = Chroma(collection_name=file_name,
embedding_function=embeddings, persist_directory=persist_directory)
print(vectordb._client.list_collections())
return vectordb
def create_qa_chain(collection_name, temperature, max_length):
print('creating qa chain...', collection_name, temperature, max_length)
if not collection_name:
return
global qa_chain
memory = ConversationBufferMemory(
memory_key="chat_history", return_messages=True)
llm = HuggingFaceHub(
repo_id=os.environ["HUGGINGFACEHUB_LLM_REPO_ID"],
model_kwargs={"temperature": temperature, "max_length": max_length}
)
vectordb = load_vectordb(collection_name)
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm, retriever=vectordb.as_retriever(), memory=memory)
def refresh_collection():
choices = get_vector_dbs()
return gr.Dropdown.update(choices=choices, value=choices[0] if choices else None)
def submit_message(bot_history, text):
bot_history = bot_history + [(text, None)]
return bot_history, ""
def bot(bot_history):
global qa_chain
print(qa_chain, bot_history[-1][1])
result = qa_chain.run(bot_history[-1][0])
print(result)
bot_history[-1][1] = result
return bot_history
def clear_bot():
return None
title = "QnA Chatbot"
with gr.Blocks() as demo:
gr.Markdown(f"# {title}")
with gr.Tab("File"):
upload = gr.File(file_types=["text"], label="Upload File")
chunk_size = gr.Slider(
500, 5000, value=1000, step=100, label="Chunk Size")
chunk_overlap = gr.Slider(0, 30, value=20, label="Chunk Overlap")
process = gr.Button("Process")
result = gr.Label()
with gr.Tab("Bot"):
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column(scale=3):
choices = get_vector_dbs()
collection = gr.Dropdown(
choices, value=choices[0] if choices else None, label="Document", allow_custom_value=True)
with gr.Column():
refresh = gr.Button("Refresh")
temperature = gr.Slider(
0.0, 1.0, value=0.5, step=0.05, label="Temperature")
max_length = gr.Slider(
20, 1000, value=100, step=10, label="Max Length")
with gr.Column():
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=550)
message = gr.Textbox(
show_label=False, placeholder="Ask me anything!")
clear = gr.Button("Clear")
process.click(
process_file,
[upload, chunk_size, chunk_overlap],
[result, collection]
)
create_qa_chain(collection.value, temperature.value, max_length.value)
collection.change(create_qa_chain, [collection, temperature, max_length])
temperature.change(create_qa_chain, [collection, temperature, max_length])
max_length.change(create_qa_chain, [collection, temperature, max_length])
refresh.click(refresh_collection, None, collection)
message.submit(submit_message, [chatbot, message], [chatbot, message]).then(
bot, chatbot, chatbot
)
clear.click(clear_bot, None, chatbot)
demo.title = title
demo.launch()
|