Spaces:
Build error
Build error
init commit
Browse files- .gitignore +3 -0
- code/.chainlit/config.toml +84 -0
- code/chainlit.md +8 -0
- code/config.yml +26 -0
- code/main.py +109 -0
- code/modules/__init__.py +0 -0
- code/modules/chat_model_loader.py +25 -0
- code/modules/constants.py +33 -0
- code/modules/data_loader.py +248 -0
- code/modules/embedding_model_loader.py +23 -0
- code/modules/llm_tutor.py +84 -0
- code/modules/vector_db.py +107 -0
- code/vectorstores/db_FAISS_sentence-transformers/all-MiniLM-L6-v2/index.faiss +0 -0
- code/vectorstores/db_FAISS_sentence-transformers/all-MiniLM-L6-v2/index.pkl +0 -0
- code/vectorstores/db_FAISS_text-embedding-ada-002/index.faiss +0 -0
- code/vectorstores/db_FAISS_text-embedding-ada-002/index.pkl +0 -0
- data/webpage.pdf +0 -0
- requirements.txt +13 -0
.gitignore
CHANGED
|
@@ -158,3 +158,6 @@ cython_debug/
|
|
| 158 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 159 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 160 |
#.idea/
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 159 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 160 |
#.idea/
|
| 161 |
+
|
| 162 |
+
# log files
|
| 163 |
+
*.log
|
code/.chainlit/config.toml
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
# Whether to enable telemetry (default: true). No personal data is collected.
|
| 3 |
+
enable_telemetry = true
|
| 4 |
+
|
| 5 |
+
# List of environment variables to be provided by each user to use the app.
|
| 6 |
+
user_env = []
|
| 7 |
+
|
| 8 |
+
# Duration (in seconds) during which the session is saved when the connection is lost
|
| 9 |
+
session_timeout = 3600
|
| 10 |
+
|
| 11 |
+
# Enable third parties caching (e.g LangChain cache)
|
| 12 |
+
cache = false
|
| 13 |
+
|
| 14 |
+
# Follow symlink for asset mount (see https://github.com/Chainlit/chainlit/issues/317)
|
| 15 |
+
# follow_symlink = false
|
| 16 |
+
|
| 17 |
+
[features]
|
| 18 |
+
# Show the prompt playground
|
| 19 |
+
prompt_playground = true
|
| 20 |
+
|
| 21 |
+
# Process and display HTML in messages. This can be a security risk (see https://stackoverflow.com/questions/19603097/why-is-it-dangerous-to-render-user-generated-html-or-javascript)
|
| 22 |
+
unsafe_allow_html = false
|
| 23 |
+
|
| 24 |
+
# Process and display mathematical expressions. This can clash with "$" characters in messages.
|
| 25 |
+
latex = false
|
| 26 |
+
|
| 27 |
+
# Authorize users to upload files with messages
|
| 28 |
+
multi_modal = true
|
| 29 |
+
|
| 30 |
+
# Allows user to use speech to text
|
| 31 |
+
[features.speech_to_text]
|
| 32 |
+
enabled = false
|
| 33 |
+
# See all languages here https://github.com/JamesBrill/react-speech-recognition/blob/HEAD/docs/API.md#language-string
|
| 34 |
+
# language = "en-US"
|
| 35 |
+
|
| 36 |
+
[UI]
|
| 37 |
+
# Name of the app and chatbot.
|
| 38 |
+
name = "LLM Tutor"
|
| 39 |
+
|
| 40 |
+
# Show the readme while the conversation is empty.
|
| 41 |
+
show_readme_as_default = true
|
| 42 |
+
|
| 43 |
+
# Description of the app and chatbot. This is used for HTML tags.
|
| 44 |
+
# description = ""
|
| 45 |
+
|
| 46 |
+
# Large size content are by default collapsed for a cleaner ui
|
| 47 |
+
default_collapse_content = true
|
| 48 |
+
|
| 49 |
+
# The default value for the expand messages settings.
|
| 50 |
+
default_expand_messages = false
|
| 51 |
+
|
| 52 |
+
# Hide the chain of thought details from the user in the UI.
|
| 53 |
+
hide_cot = false
|
| 54 |
+
|
| 55 |
+
# Link to your github repo. This will add a github button in the UI's header.
|
| 56 |
+
# github = "https://github.com/DL4DS/dl4ds_tutor"
|
| 57 |
+
|
| 58 |
+
# Specify a CSS file that can be used to customize the user interface.
|
| 59 |
+
# The CSS file can be served from the public directory or via an external link.
|
| 60 |
+
# custom_css = "/public/test.css"
|
| 61 |
+
|
| 62 |
+
# Override default MUI light theme. (Check theme.ts)
|
| 63 |
+
[UI.theme.light]
|
| 64 |
+
#background = "#FAFAFA"
|
| 65 |
+
#paper = "#FFFFFF"
|
| 66 |
+
|
| 67 |
+
[UI.theme.light.primary]
|
| 68 |
+
#main = "#F80061"
|
| 69 |
+
#dark = "#980039"
|
| 70 |
+
#light = "#FFE7EB"
|
| 71 |
+
|
| 72 |
+
# Override default MUI dark theme. (Check theme.ts)
|
| 73 |
+
[UI.theme.dark]
|
| 74 |
+
#background = "#FAFAFA"
|
| 75 |
+
#paper = "#FFFFFF"
|
| 76 |
+
|
| 77 |
+
[UI.theme.dark.primary]
|
| 78 |
+
#main = "#F80061"
|
| 79 |
+
#dark = "#980039"
|
| 80 |
+
#light = "#FFE7EB"
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
[meta]
|
| 84 |
+
generated_by = "0.7.700"
|
code/chainlit.md
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Welcome to DL4DS Tutor! 🚀🤖
|
| 2 |
+
|
| 3 |
+
Hi there, this is an LLM chatbot designed to help answer questions on the course content, built using Langchain and Chainlit.
|
| 4 |
+
This is still very much a Work in Progress.
|
| 5 |
+
|
| 6 |
+
## Useful Links 🔗
|
| 7 |
+
|
| 8 |
+
- **Documentation:** [Chainlit Documentation](https://docs.chainlit.io) 📚
|
code/config.yml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
embedding_options:
|
| 2 |
+
embedd_files: True # bool
|
| 3 |
+
persist_directory: null # str or None
|
| 4 |
+
data_path: '../data' # str
|
| 5 |
+
db_option : 'FAISS' # str
|
| 6 |
+
db_path : 'vectorstores' # str
|
| 7 |
+
model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
|
| 8 |
+
llm_params:
|
| 9 |
+
use_history: False # bool
|
| 10 |
+
llm_loader: 'openai' # str [ctransformers, openai]
|
| 11 |
+
openai_params:
|
| 12 |
+
model: 'gpt-4' # str [gpt-3.5-turbo-1106, gpt-4]
|
| 13 |
+
ctransformers_params:
|
| 14 |
+
model: "TheBloke/Llama-2-7B-Chat-GGML"
|
| 15 |
+
model_type: "llama"
|
| 16 |
+
splitter_options:
|
| 17 |
+
use_splitter: True # bool
|
| 18 |
+
split_by_token : True # bool
|
| 19 |
+
remove_leftover_delimiters: True # bool
|
| 20 |
+
remove_chunks: False # bool
|
| 21 |
+
chunk_size : 800 # int
|
| 22 |
+
chunk_overlap : 80 # int
|
| 23 |
+
chunk_separators : ["\n\n", "\n", " ", ""] # list of strings
|
| 24 |
+
front_chunks_to_remove : null # int or None
|
| 25 |
+
last_chunks_to_remove : null # int or None
|
| 26 |
+
delimiters_to_remove : ['\t', '\n', ' ', ' '] # list of strings
|
code/main.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain.document_loaders import PyPDFLoader, DirectoryLoader
|
| 2 |
+
from langchain import PromptTemplate
|
| 3 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
| 4 |
+
from langchain.vectorstores import FAISS
|
| 5 |
+
from langchain.chains import RetrievalQA
|
| 6 |
+
from langchain.llms import CTransformers
|
| 7 |
+
import chainlit as cl
|
| 8 |
+
from langchain_community.chat_models import ChatOpenAI
|
| 9 |
+
from langchain_community.embeddings import OpenAIEmbeddings
|
| 10 |
+
import yaml
|
| 11 |
+
import logging
|
| 12 |
+
from dotenv import load_dotenv
|
| 13 |
+
|
| 14 |
+
from modules.llm_tutor import LLMTutor
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
logger.setLevel(logging.INFO)
|
| 19 |
+
|
| 20 |
+
# Console Handler
|
| 21 |
+
console_handler = logging.StreamHandler()
|
| 22 |
+
console_handler.setLevel(logging.INFO)
|
| 23 |
+
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
| 24 |
+
console_handler.setFormatter(formatter)
|
| 25 |
+
logger.addHandler(console_handler)
|
| 26 |
+
|
| 27 |
+
# File Handler
|
| 28 |
+
log_file_path = "log_file.log" # Change this to your desired log file path
|
| 29 |
+
file_handler = logging.FileHandler(log_file_path)
|
| 30 |
+
file_handler.setLevel(logging.INFO)
|
| 31 |
+
file_handler.setFormatter(formatter)
|
| 32 |
+
logger.addHandler(file_handler)
|
| 33 |
+
|
| 34 |
+
with open("config.yml", "r") as f:
|
| 35 |
+
config = yaml.safe_load(f)
|
| 36 |
+
print(config)
|
| 37 |
+
logger.info("Config file loaded")
|
| 38 |
+
logger.info(f"Config: {config}")
|
| 39 |
+
logger.info("Creating llm_tutor instance")
|
| 40 |
+
llm_tutor = LLMTutor(config, logger=logger)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# chainlit code
|
| 44 |
+
@cl.on_chat_start
|
| 45 |
+
async def start():
|
| 46 |
+
chain = llm_tutor.qa_bot()
|
| 47 |
+
msg = cl.Message(content="Starting the bot...")
|
| 48 |
+
await msg.send()
|
| 49 |
+
msg.content = "Hey, What Can I Help You With?"
|
| 50 |
+
await msg.update()
|
| 51 |
+
|
| 52 |
+
cl.user_session.set("chain", chain)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@cl.on_message
|
| 56 |
+
async def main(message):
|
| 57 |
+
chain = cl.user_session.get("chain")
|
| 58 |
+
cb = cl.AsyncLangchainCallbackHandler(
|
| 59 |
+
stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
|
| 60 |
+
)
|
| 61 |
+
cb.answer_reached = True
|
| 62 |
+
# res=await chain.acall(message, callbacks=[cb])
|
| 63 |
+
res = await chain.acall(message.content, callbacks=[cb])
|
| 64 |
+
# print(f"response: {res}")
|
| 65 |
+
try:
|
| 66 |
+
answer = res["answer"]
|
| 67 |
+
except:
|
| 68 |
+
answer = res["result"]
|
| 69 |
+
print(f"answer: {answer}")
|
| 70 |
+
source_elements_dict = {}
|
| 71 |
+
source_elements = []
|
| 72 |
+
found_sources = []
|
| 73 |
+
|
| 74 |
+
for idx, source in enumerate(res["source_documents"]):
|
| 75 |
+
title = source.metadata["source"]
|
| 76 |
+
|
| 77 |
+
if title not in source_elements_dict:
|
| 78 |
+
source_elements_dict[title] = {
|
| 79 |
+
"page_number": [source.metadata["page"]],
|
| 80 |
+
"url": source.metadata["source"],
|
| 81 |
+
"content": source.page_content,
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
else:
|
| 85 |
+
source_elements_dict[title]["page_number"].append(source.metadata["page"])
|
| 86 |
+
source_elements_dict[title][
|
| 87 |
+
"content_" + str(source.metadata["page"])
|
| 88 |
+
] = source.page_content
|
| 89 |
+
# sort the page numbers
|
| 90 |
+
# source_elements_dict[title]["page_number"].sort()
|
| 91 |
+
|
| 92 |
+
for title, source in source_elements_dict.items():
|
| 93 |
+
# create a string for the page numbers
|
| 94 |
+
page_numbers = ", ".join([str(x) for x in source["page_number"]])
|
| 95 |
+
text_for_source = f"Page Number(s): {page_numbers}\nURL: {source['url']}"
|
| 96 |
+
source_elements.append(cl.Pdf(name="File", path=title))
|
| 97 |
+
found_sources.append("File")
|
| 98 |
+
# for pn in source["page_number"]:
|
| 99 |
+
# source_elements.append(
|
| 100 |
+
# cl.Text(name=str(pn), content=source["content_"+str(pn)])
|
| 101 |
+
# )
|
| 102 |
+
# found_sources.append(str(pn))
|
| 103 |
+
|
| 104 |
+
if found_sources:
|
| 105 |
+
answer += f"\nSource:{', '.join(found_sources)}"
|
| 106 |
+
else:
|
| 107 |
+
answer += f"\nNo source found."
|
| 108 |
+
|
| 109 |
+
await cl.Message(content=answer, elements=source_elements).send()
|
code/modules/__init__.py
ADDED
|
File without changes
|
code/modules/chat_model_loader.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_community.chat_models import ChatOpenAI
|
| 2 |
+
from langchain.llms import CTransformers
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ChatModelLoader:
|
| 6 |
+
def __init__(self, config):
|
| 7 |
+
self.config = config
|
| 8 |
+
|
| 9 |
+
def load_chat_model(self):
|
| 10 |
+
if self.config["llm_params"]["llm_loader"] == "openai":
|
| 11 |
+
llm = ChatOpenAI(
|
| 12 |
+
model_name=self.config["llm_params"]["openai_params"]["model"]
|
| 13 |
+
)
|
| 14 |
+
elif self.config["llm_params"]["llm_loader"] == "Ctransformers":
|
| 15 |
+
llm = CTransformers(
|
| 16 |
+
model=self.config["llm_params"]["ctransformers_params"]["model"],
|
| 17 |
+
model_type=self.config["llm_params"]["ctransformers_params"][
|
| 18 |
+
"model_type"
|
| 19 |
+
],
|
| 20 |
+
max_new_tokens=512,
|
| 21 |
+
temperature=0.5,
|
| 22 |
+
)
|
| 23 |
+
else:
|
| 24 |
+
raise ValueError("Invalid LLM Loader")
|
| 25 |
+
return llm
|
code/modules/constants.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dotenv import load_dotenv
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
load_dotenv()
|
| 5 |
+
|
| 6 |
+
# API Keys - Loaded from the .env file
|
| 7 |
+
|
| 8 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# Prompt Templates
|
| 12 |
+
|
| 13 |
+
prompt_template = """Use the following pieces of information to answer the user's question.
|
| 14 |
+
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
| 15 |
+
|
| 16 |
+
Context: {context}
|
| 17 |
+
Question: {question}
|
| 18 |
+
|
| 19 |
+
Only return the helpful answer below and nothing else.
|
| 20 |
+
Helpful answer:
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
prompt_template_with_history = """Use the following pieces of information to answer the user's question.
|
| 24 |
+
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
| 25 |
+
Use the history to answer the question if you can.
|
| 26 |
+
Chat History:
|
| 27 |
+
{chat_history}
|
| 28 |
+
Context: {context}
|
| 29 |
+
Question: {question}
|
| 30 |
+
|
| 31 |
+
Only return the helpful answer below and nothing else.
|
| 32 |
+
Helpful answer:
|
| 33 |
+
"""
|
code/modules/data_loader.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import pysrt
|
| 3 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 4 |
+
from langchain.document_loaders import (
|
| 5 |
+
PyMuPDFLoader,
|
| 6 |
+
Docx2txtLoader,
|
| 7 |
+
YoutubeLoader,
|
| 8 |
+
WebBaseLoader,
|
| 9 |
+
TextLoader,
|
| 10 |
+
)
|
| 11 |
+
from langchain.schema import Document
|
| 12 |
+
from tempfile import NamedTemporaryFile
|
| 13 |
+
import logging
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DataLoader:
|
| 19 |
+
def __init__(self, config):
|
| 20 |
+
"""
|
| 21 |
+
Class for handling all data extraction and chunking
|
| 22 |
+
Inputs:
|
| 23 |
+
config - dictionary from yaml file, containing all important parameters
|
| 24 |
+
"""
|
| 25 |
+
self.config = config
|
| 26 |
+
self.remove_leftover_delimiters = config["splitter_options"][
|
| 27 |
+
"remove_leftover_delimiters"
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
# Main list of all documents
|
| 31 |
+
self.document_chunks_full = []
|
| 32 |
+
self.document_names = []
|
| 33 |
+
|
| 34 |
+
if config["splitter_options"]["use_splitter"]:
|
| 35 |
+
if config["splitter_options"]["split_by_token"]:
|
| 36 |
+
self.splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
| 37 |
+
chunk_size=config["splitter_options"]["chunk_size"],
|
| 38 |
+
chunk_overlap=config["splitter_options"]["chunk_overlap"],
|
| 39 |
+
separators=config["splitter_options"]["chunk_separators"],
|
| 40 |
+
)
|
| 41 |
+
else:
|
| 42 |
+
self.splitter = RecursiveCharacterTextSplitter(
|
| 43 |
+
chunk_size=config["splitter_options"]["chunk_size"],
|
| 44 |
+
chunk_overlap=config["splitter_options"]["chunk_overlap"],
|
| 45 |
+
separators=config["splitter_options"]["chunk_separators"],
|
| 46 |
+
)
|
| 47 |
+
else:
|
| 48 |
+
self.splitter = None
|
| 49 |
+
logger.info("InfoLoader instance created")
|
| 50 |
+
|
| 51 |
+
def get_chunks(self, uploaded_files, weblinks):
|
| 52 |
+
# Main list of all documents
|
| 53 |
+
self.document_chunks_full = []
|
| 54 |
+
self.document_names = []
|
| 55 |
+
|
| 56 |
+
def remove_delimiters(document_chunks: list):
|
| 57 |
+
"""
|
| 58 |
+
Helper function to remove remaining delimiters in document chunks
|
| 59 |
+
"""
|
| 60 |
+
for chunk in document_chunks:
|
| 61 |
+
for delimiter in self.config["splitter_options"][
|
| 62 |
+
"delimiters_to_remove"
|
| 63 |
+
]:
|
| 64 |
+
chunk.page_content = re.sub(delimiter, " ", chunk.page_content)
|
| 65 |
+
return document_chunks
|
| 66 |
+
|
| 67 |
+
def remove_chunks(document_chunks: list):
|
| 68 |
+
"""
|
| 69 |
+
Helper function to remove any unwanted document chunks after splitting
|
| 70 |
+
"""
|
| 71 |
+
front = self.config["splitter_options"]["front_chunk_to_remove"]
|
| 72 |
+
end = self.config["splitter_options"]["last_chunks_to_remove"]
|
| 73 |
+
# Remove pages
|
| 74 |
+
for _ in range(front):
|
| 75 |
+
del document_chunks[0]
|
| 76 |
+
for _ in range(end):
|
| 77 |
+
document_chunks.pop()
|
| 78 |
+
logger.info(f"\tNumber of pages after skipping: {len(document_chunks)}")
|
| 79 |
+
return document_chunks
|
| 80 |
+
|
| 81 |
+
def get_pdf(temp_file_path: str, title: str):
|
| 82 |
+
"""
|
| 83 |
+
Function to process PDF files
|
| 84 |
+
"""
|
| 85 |
+
loader = PyMuPDFLoader(
|
| 86 |
+
temp_file_path
|
| 87 |
+
) # This loader preserves more metadata
|
| 88 |
+
|
| 89 |
+
if self.splitter:
|
| 90 |
+
document_chunks = self.splitter.split_documents(loader.load())
|
| 91 |
+
else:
|
| 92 |
+
document_chunks = loader.load()
|
| 93 |
+
|
| 94 |
+
if "title" in document_chunks[0].metadata.keys():
|
| 95 |
+
title = document_chunks[0].metadata["title"]
|
| 96 |
+
|
| 97 |
+
logger.info(
|
| 98 |
+
f"\t\tOriginal no. of pages: {document_chunks[0].metadata['total_pages']}"
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
return title, document_chunks
|
| 102 |
+
|
| 103 |
+
def get_txt(temp_file_path: str, title: str):
|
| 104 |
+
"""
|
| 105 |
+
Function to process TXT files
|
| 106 |
+
"""
|
| 107 |
+
loader = TextLoader(temp_file_path, autodetect_encoding=True)
|
| 108 |
+
|
| 109 |
+
if self.splitter:
|
| 110 |
+
document_chunks = self.splitter.split_documents(loader.load())
|
| 111 |
+
else:
|
| 112 |
+
document_chunks = loader.load()
|
| 113 |
+
|
| 114 |
+
# Update the metadata
|
| 115 |
+
for chunk in document_chunks:
|
| 116 |
+
chunk.metadata["source"] = title
|
| 117 |
+
chunk.metadata["page"] = "N/A"
|
| 118 |
+
|
| 119 |
+
return title, document_chunks
|
| 120 |
+
|
| 121 |
+
def get_srt(temp_file_path: str, title: str):
|
| 122 |
+
"""
|
| 123 |
+
Function to process SRT files
|
| 124 |
+
"""
|
| 125 |
+
subs = pysrt.open(temp_file_path)
|
| 126 |
+
|
| 127 |
+
text = ""
|
| 128 |
+
for sub in subs:
|
| 129 |
+
text += sub.text
|
| 130 |
+
document_chunks = [Document(page_content=text)]
|
| 131 |
+
|
| 132 |
+
if self.splitter:
|
| 133 |
+
document_chunks = self.splitter.split_documents(document_chunks)
|
| 134 |
+
|
| 135 |
+
# Update the metadata
|
| 136 |
+
for chunk in document_chunks:
|
| 137 |
+
chunk.metadata["source"] = title
|
| 138 |
+
chunk.metadata["page"] = "N/A"
|
| 139 |
+
|
| 140 |
+
return title, document_chunks
|
| 141 |
+
|
| 142 |
+
def get_docx(temp_file_path: str, title: str):
|
| 143 |
+
"""
|
| 144 |
+
Function to process DOCX files
|
| 145 |
+
"""
|
| 146 |
+
loader = Docx2txtLoader(temp_file_path)
|
| 147 |
+
|
| 148 |
+
if self.splitter:
|
| 149 |
+
document_chunks = self.splitter.split_documents(loader.load())
|
| 150 |
+
else:
|
| 151 |
+
document_chunks = loader.load()
|
| 152 |
+
|
| 153 |
+
# Update the metadata
|
| 154 |
+
for chunk in document_chunks:
|
| 155 |
+
chunk.metadata["source"] = title
|
| 156 |
+
chunk.metadata["page"] = "N/A"
|
| 157 |
+
|
| 158 |
+
return title, document_chunks
|
| 159 |
+
|
| 160 |
+
def get_youtube_transcript(url: str):
|
| 161 |
+
"""
|
| 162 |
+
Function to retrieve youtube transcript and process text
|
| 163 |
+
"""
|
| 164 |
+
loader = YoutubeLoader.from_youtube_url(
|
| 165 |
+
url, add_video_info=True, language=["en"], translation="en"
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
if self.splitter:
|
| 169 |
+
document_chunks = self.splitter.split_documents(loader.load())
|
| 170 |
+
else:
|
| 171 |
+
document_chunks = loader.load_and_split()
|
| 172 |
+
|
| 173 |
+
# Replace the source with title (for display in st UI later)
|
| 174 |
+
for chunk in document_chunks:
|
| 175 |
+
chunk.metadata["source"] = chunk.metadata["title"]
|
| 176 |
+
logger.info(chunk.metadata["title"])
|
| 177 |
+
|
| 178 |
+
return title, document_chunks
|
| 179 |
+
|
| 180 |
+
def get_html(url: str):
|
| 181 |
+
"""
|
| 182 |
+
Function to process websites via HTML files
|
| 183 |
+
"""
|
| 184 |
+
loader = WebBaseLoader(url)
|
| 185 |
+
|
| 186 |
+
if self.splitter:
|
| 187 |
+
document_chunks = self.splitter.split_documents(loader.load())
|
| 188 |
+
else:
|
| 189 |
+
document_chunks = loader.load_and_split()
|
| 190 |
+
|
| 191 |
+
title = document_chunks[0].metadata["title"]
|
| 192 |
+
logger.info(document_chunks[0].metadata)
|
| 193 |
+
|
| 194 |
+
return title, document_chunks
|
| 195 |
+
|
| 196 |
+
# Handle file by file
|
| 197 |
+
for file_index, file_path in enumerate(uploaded_files):
|
| 198 |
+
|
| 199 |
+
file_name = file_path.split("/")[-1]
|
| 200 |
+
file_type = file_name.split(".")[-1]
|
| 201 |
+
|
| 202 |
+
# Handle different file types
|
| 203 |
+
if file_type == "pdf":
|
| 204 |
+
title, document_chunks = get_pdf(file_path, file_name)
|
| 205 |
+
elif file_type == "txt":
|
| 206 |
+
title, document_chunks = get_txt(file_path, file_name)
|
| 207 |
+
elif file_type == "docx":
|
| 208 |
+
title, document_chunks = get_docx(file_path, file_name)
|
| 209 |
+
elif file_type == "srt":
|
| 210 |
+
title, document_chunks = get_srt(file_path, file_name)
|
| 211 |
+
|
| 212 |
+
# Additional wrangling - Remove leftover delimiters and any specified chunks
|
| 213 |
+
if self.remove_leftover_delimiters:
|
| 214 |
+
document_chunks = remove_delimiters(document_chunks)
|
| 215 |
+
if self.config["splitter_options"]["remove_chunks"]:
|
| 216 |
+
document_chunks = remove_chunks(document_chunks)
|
| 217 |
+
|
| 218 |
+
logger.info(f"\t\tExtracted no. of chunks: {len(document_chunks)}")
|
| 219 |
+
self.document_names.append(title)
|
| 220 |
+
self.document_chunks_full.extend(document_chunks)
|
| 221 |
+
|
| 222 |
+
# Handle youtube links:
|
| 223 |
+
if weblinks[0] != "":
|
| 224 |
+
logger.info(f"Splitting weblinks: total of {len(weblinks)}")
|
| 225 |
+
|
| 226 |
+
# Handle link by link
|
| 227 |
+
for link_index, link in enumerate(weblinks):
|
| 228 |
+
logger.info(f"\tSplitting link {link_index+1} : {link}")
|
| 229 |
+
if "youtube" in link:
|
| 230 |
+
title, document_chunks = get_youtube_transcript(link)
|
| 231 |
+
else:
|
| 232 |
+
title, document_chunks = get_html(link)
|
| 233 |
+
|
| 234 |
+
# Additional wrangling - Remove leftover delimiters and any specified chunks
|
| 235 |
+
if self.remove_leftover_delimiters:
|
| 236 |
+
document_chunks = remove_delimiters(document_chunks)
|
| 237 |
+
if self.config["splitter_options"]["remove_chunks"]:
|
| 238 |
+
document_chunks = remove_chunks(document_chunks)
|
| 239 |
+
|
| 240 |
+
print(f"\t\tExtracted no. of chunks: {len(document_chunks)}")
|
| 241 |
+
self.document_names.append(title)
|
| 242 |
+
self.document_chunks_full.extend(document_chunks)
|
| 243 |
+
|
| 244 |
+
logger.info(
|
| 245 |
+
f"\tNumber of document chunks extracted in total: {len(self.document_chunks_full)}\n\n"
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
return self.document_chunks_full, self.document_names
|
code/modules/embedding_model_loader.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_community.embeddings import OpenAIEmbeddings
|
| 2 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
| 3 |
+
from modules.constants import *
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class EmbeddingModelLoader:
|
| 7 |
+
def __init__(self, config):
|
| 8 |
+
self.config = config
|
| 9 |
+
|
| 10 |
+
def load_embedding_model(self):
|
| 11 |
+
if self.config["embedding_options"]["model"] in ["text-embedding-ada-002"]:
|
| 12 |
+
embedding_model = OpenAIEmbeddings(
|
| 13 |
+
deployment="SL-document_embedder",
|
| 14 |
+
model=self.config["embedding_options"]["model"],
|
| 15 |
+
show_progress_bar=True,
|
| 16 |
+
openai_api_key=OPENAI_API_KEY,
|
| 17 |
+
)
|
| 18 |
+
else:
|
| 19 |
+
embedding_model = HuggingFaceEmbeddings(
|
| 20 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
| 21 |
+
model_kwargs={"device": "cpu"},
|
| 22 |
+
)
|
| 23 |
+
return embedding_model
|
code/modules/llm_tutor.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain import PromptTemplate
|
| 2 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
| 3 |
+
from langchain_community.chat_models import ChatOpenAI
|
| 4 |
+
from langchain_community.embeddings import OpenAIEmbeddings
|
| 5 |
+
from langchain.vectorstores import FAISS
|
| 6 |
+
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
|
| 7 |
+
from langchain.llms import CTransformers
|
| 8 |
+
from langchain.memory import ConversationBufferMemory
|
| 9 |
+
from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
from modules.constants import *
|
| 13 |
+
from modules.chat_model_loader import ChatModelLoader
|
| 14 |
+
from modules.vector_db import VectorDB
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class LLMTutor:
|
| 18 |
+
def __init__(self, config, logger=None):
|
| 19 |
+
self.config = config
|
| 20 |
+
self.vector_db = VectorDB(config, logger=logger)
|
| 21 |
+
if self.config['embedding_options']['embedd_files']:
|
| 22 |
+
self.vector_db.create_database()
|
| 23 |
+
self.vector_db.save_database()
|
| 24 |
+
|
| 25 |
+
def set_custom_prompt(self):
|
| 26 |
+
"""
|
| 27 |
+
Prompt template for QA retrieval for each vectorstore
|
| 28 |
+
"""
|
| 29 |
+
if self.config["llm_params"]["use_history"]:
|
| 30 |
+
custom_prompt_template = prompt_template_with_history
|
| 31 |
+
else:
|
| 32 |
+
custom_prompt_template = prompt_template
|
| 33 |
+
prompt = PromptTemplate(
|
| 34 |
+
template=custom_prompt_template,
|
| 35 |
+
input_variables=["context", "chat_history", "question"],
|
| 36 |
+
)
|
| 37 |
+
# prompt = QA_PROMPT
|
| 38 |
+
|
| 39 |
+
return prompt
|
| 40 |
+
|
| 41 |
+
# Retrieval QA Chain
|
| 42 |
+
def retrieval_qa_chain(self, llm, prompt, db):
|
| 43 |
+
if self.config["llm_params"]["use_history"]:
|
| 44 |
+
memory = ConversationBufferMemory(
|
| 45 |
+
memory_key="chat_history", return_messages=True, output_key="answer"
|
| 46 |
+
)
|
| 47 |
+
qa_chain = ConversationalRetrievalChain.from_llm(
|
| 48 |
+
llm=llm,
|
| 49 |
+
chain_type="stuff",
|
| 50 |
+
retriever=db.as_retriever(search_kwargs={"k": 3}),
|
| 51 |
+
return_source_documents=True,
|
| 52 |
+
memory=memory,
|
| 53 |
+
combine_docs_chain_kwargs={"prompt": prompt},
|
| 54 |
+
)
|
| 55 |
+
else:
|
| 56 |
+
qa_chain = RetrievalQA.from_chain_type(
|
| 57 |
+
llm=llm,
|
| 58 |
+
chain_type="stuff",
|
| 59 |
+
retriever=db.as_retriever(search_kwargs={"k": 3}),
|
| 60 |
+
return_source_documents=True,
|
| 61 |
+
chain_type_kwargs={"prompt": prompt},
|
| 62 |
+
)
|
| 63 |
+
return qa_chain
|
| 64 |
+
|
| 65 |
+
# Loading the model
|
| 66 |
+
def load_llm(self):
|
| 67 |
+
chat_model_loader = ChatModelLoader(self.config)
|
| 68 |
+
llm = chat_model_loader.load_chat_model()
|
| 69 |
+
return llm
|
| 70 |
+
|
| 71 |
+
# QA Model Function
|
| 72 |
+
def qa_bot(self):
|
| 73 |
+
db = self.vector_db.load_database()
|
| 74 |
+
self.llm = self.load_llm()
|
| 75 |
+
qa_prompt = self.set_custom_prompt()
|
| 76 |
+
qa = self.retrieval_qa_chain(self.llm, qa_prompt, db)
|
| 77 |
+
|
| 78 |
+
return qa
|
| 79 |
+
|
| 80 |
+
# output function
|
| 81 |
+
def final_result(query):
|
| 82 |
+
qa_result = qa_bot()
|
| 83 |
+
response = qa_result({"query": query})
|
| 84 |
+
return response
|
code/modules/vector_db.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import yaml
|
| 4 |
+
|
| 5 |
+
from modules.embedding_model_loader import EmbeddingModelLoader
|
| 6 |
+
from langchain.vectorstores import FAISS
|
| 7 |
+
from modules.data_loader import DataLoader
|
| 8 |
+
from modules.constants import *
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class VectorDB:
|
| 12 |
+
def __init__(self, config, logger=None):
|
| 13 |
+
self.config = config
|
| 14 |
+
self.db_option = config["embedding_options"]["db_option"]
|
| 15 |
+
self.document_names = None
|
| 16 |
+
|
| 17 |
+
# Set up logging to both console and a file
|
| 18 |
+
if logger is None:
|
| 19 |
+
self.logger = logging.getLogger(__name__)
|
| 20 |
+
self.logger.setLevel(logging.INFO)
|
| 21 |
+
|
| 22 |
+
# Console Handler
|
| 23 |
+
console_handler = logging.StreamHandler()
|
| 24 |
+
console_handler.setLevel(logging.INFO)
|
| 25 |
+
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
| 26 |
+
console_handler.setFormatter(formatter)
|
| 27 |
+
self.logger.addHandler(console_handler)
|
| 28 |
+
|
| 29 |
+
# File Handler
|
| 30 |
+
log_file_path = "vector_db.log" # Change this to your desired log file path
|
| 31 |
+
file_handler = logging.FileHandler(log_file_path, mode="w")
|
| 32 |
+
file_handler.setLevel(logging.INFO)
|
| 33 |
+
file_handler.setFormatter(formatter)
|
| 34 |
+
self.logger.addHandler(file_handler)
|
| 35 |
+
else:
|
| 36 |
+
self.logger = logger
|
| 37 |
+
|
| 38 |
+
self.logger.info("VectorDB instance instantiated")
|
| 39 |
+
|
| 40 |
+
def load_files(self):
|
| 41 |
+
files = os.listdir(self.config["embedding_options"]["data_path"])
|
| 42 |
+
files = [
|
| 43 |
+
os.path.join(self.config["embedding_options"]["data_path"], file)
|
| 44 |
+
for file in files
|
| 45 |
+
]
|
| 46 |
+
return files
|
| 47 |
+
|
| 48 |
+
def create_embedding_model(self):
|
| 49 |
+
self.logger.info("Creating embedding function")
|
| 50 |
+
self.embedding_model_loader = EmbeddingModelLoader(self.config)
|
| 51 |
+
self.embedding_model = self.embedding_model_loader.load_embedding_model()
|
| 52 |
+
|
| 53 |
+
def initialize_database(self, document_chunks: list, document_names: list):
|
| 54 |
+
# Track token usage
|
| 55 |
+
self.logger.info("Initializing vector_db")
|
| 56 |
+
self.logger.info("\tUsing {} as db_option".format(self.db_option))
|
| 57 |
+
if self.db_option == "FAISS":
|
| 58 |
+
self.vector_db = FAISS.from_documents(
|
| 59 |
+
documents=document_chunks, embedding=self.embedding_model
|
| 60 |
+
)
|
| 61 |
+
self.logger.info("Completed initializing vector_db")
|
| 62 |
+
|
| 63 |
+
def create_database(self):
|
| 64 |
+
data_loader = DataLoader(self.config)
|
| 65 |
+
self.logger.info("Loading data")
|
| 66 |
+
files = self.load_files()
|
| 67 |
+
document_chunks, document_names = data_loader.get_chunks(files, [""])
|
| 68 |
+
self.logger.info("Completed loading data")
|
| 69 |
+
|
| 70 |
+
self.create_embedding_model()
|
| 71 |
+
self.initialize_database(document_chunks, document_names)
|
| 72 |
+
|
| 73 |
+
def save_database(self):
|
| 74 |
+
self.vector_db.save_local(
|
| 75 |
+
os.path.join(
|
| 76 |
+
self.config["embedding_options"]["db_path"],
|
| 77 |
+
"db_"
|
| 78 |
+
+ self.config["embedding_options"]["db_option"]
|
| 79 |
+
+ "_"
|
| 80 |
+
+ self.config["embedding_options"]["model"],
|
| 81 |
+
)
|
| 82 |
+
)
|
| 83 |
+
self.logger.info("Saved database")
|
| 84 |
+
|
| 85 |
+
def load_database(self):
|
| 86 |
+
self.create_embedding_model()
|
| 87 |
+
self.vector_db = FAISS.load_local(
|
| 88 |
+
os.path.join(
|
| 89 |
+
self.config["embedding_options"]["db_path"],
|
| 90 |
+
"db_"
|
| 91 |
+
+ self.config["embedding_options"]["db_option"]
|
| 92 |
+
+ "_"
|
| 93 |
+
+ self.config["embedding_options"]["model"],
|
| 94 |
+
),
|
| 95 |
+
self.embedding_model,
|
| 96 |
+
)
|
| 97 |
+
self.logger.info("Loaded database")
|
| 98 |
+
return self.vector_db
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
with open("config.yml", "r") as f:
|
| 103 |
+
config = yaml.safe_load(f)
|
| 104 |
+
print(config)
|
| 105 |
+
vector_db = VectorDB(config)
|
| 106 |
+
vector_db.create_database()
|
| 107 |
+
vector_db.save_database()
|
code/vectorstores/db_FAISS_sentence-transformers/all-MiniLM-L6-v2/index.faiss
ADDED
|
Binary file (6.19 kB). View file
|
|
|
code/vectorstores/db_FAISS_sentence-transformers/all-MiniLM-L6-v2/index.pkl
ADDED
|
Binary file (9.21 kB). View file
|
|
|
code/vectorstores/db_FAISS_text-embedding-ada-002/index.faiss
ADDED
|
Binary file (24.6 kB). View file
|
|
|
code/vectorstores/db_FAISS_text-embedding-ada-002/index.pkl
ADDED
|
Binary file (9.21 kB). View file
|
|
|
data/webpage.pdf
ADDED
|
Binary file (51.3 kB). View file
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit==1.29.0
|
| 2 |
+
PyYAML==6.0.1
|
| 3 |
+
pysrt==1.1.2
|
| 4 |
+
langchain==0.0.353
|
| 5 |
+
tiktoken==0.5.2
|
| 6 |
+
streamlit-chat==0.1.1
|
| 7 |
+
pypdf==3.17.4
|
| 8 |
+
sentence-transformers==2.2.2
|
| 9 |
+
faiss-cpu==1.7.4
|
| 10 |
+
ctransformers==0.2.27
|
| 11 |
+
python-dotenv==1.0.0
|
| 12 |
+
openai==1.6.1
|
| 13 |
+
pymupdf==1.23.8
|