Spaces:
Running
Running
Commit
Β·
317f434
1
Parent(s):
8515a17
Update app.py
Browse files
app.py
CHANGED
@@ -12,11 +12,7 @@ from streamlit_chat import message
|
|
12 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
|
13 |
import torch
|
14 |
|
15 |
-
|
16 |
-
CHECKPOINT = "MBZUAI/LaMini-T5-738M"
|
17 |
-
TOKENIZER = AutoTokenizer.from_pretrained(CHECKPOINT)
|
18 |
-
BASE_MODEL = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT, device_map=torch.device('cpu'), torch_dtype=torch.float32)
|
19 |
-
|
20 |
|
21 |
def process_answer(instruction, qa_chain):
|
22 |
response = ''
|
@@ -50,7 +46,11 @@ def data_ingestion():
|
|
50 |
|
51 |
|
52 |
@st.cache_resource
|
53 |
-
def initialize_qa_chain():
|
|
|
|
|
|
|
|
|
54 |
pipe = pipeline(
|
55 |
'text2text-generation',
|
56 |
model=BASE_MODEL,
|
@@ -101,7 +101,10 @@ def display_conversation(history):
|
|
101 |
|
102 |
|
103 |
def main():
|
104 |
-
|
|
|
|
|
|
|
105 |
st.markdown("<h1 style='text-align: center; color: blue;'>Custom PDF Chatbot π¦π </h1>", unsafe_allow_html=True)
|
106 |
st.markdown("<h2 style='text-align: center; color:red;'>Upload your PDF, and Ask Questions π</h2>", unsafe_allow_html=True)
|
107 |
|
@@ -125,6 +128,7 @@ def main():
|
|
125 |
pdf_view = display_pdf(filepath)
|
126 |
|
127 |
with col2:
|
|
|
128 |
with st.spinner('Embeddings are in process...'):
|
129 |
ingested_data = data_ingestion()
|
130 |
st.success('Embeddings are created successfully!')
|
@@ -140,7 +144,7 @@ def main():
|
|
140 |
|
141 |
# Search the database for a response based on user input and update session state
|
142 |
if user_input:
|
143 |
-
answer = process_answer({'query': user_input}, initialize_qa_chain())
|
144 |
st.session_state["past"].append(user_input)
|
145 |
response = answer
|
146 |
st.session_state["generated"].append(response)
|
|
|
12 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
|
13 |
import torch
|
14 |
|
15 |
+
st.set_page_config(layout="wide")
|
|
|
|
|
|
|
|
|
16 |
|
17 |
def process_answer(instruction, qa_chain):
|
18 |
response = ''
|
|
|
46 |
|
47 |
|
48 |
@st.cache_resource
|
49 |
+
def initialize_qa_chain(selected_model):
|
50 |
+
# Constants
|
51 |
+
CHECKPOINT = selected_model
|
52 |
+
TOKENIZER = AutoTokenizer.from_pretrained(CHECKPOINT)
|
53 |
+
BASE_MODEL = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT, device_map=torch.device('cpu'), torch_dtype=torch.float32)
|
54 |
pipe = pipeline(
|
55 |
'text2text-generation',
|
56 |
model=BASE_MODEL,
|
|
|
101 |
|
102 |
|
103 |
def main():
|
104 |
+
# Add a sidebar for model selection
|
105 |
+
model_options = ["MBZUAI/LaMini-T5-738M", "google/flan-t5-base", "google/flan-t5-small"]
|
106 |
+
selected_model = st.sidebar.selectbox("Select Model", model_options)
|
107 |
+
|
108 |
st.markdown("<h1 style='text-align: center; color: blue;'>Custom PDF Chatbot π¦π </h1>", unsafe_allow_html=True)
|
109 |
st.markdown("<h2 style='text-align: center; color:red;'>Upload your PDF, and Ask Questions π</h2>", unsafe_allow_html=True)
|
110 |
|
|
|
128 |
pdf_view = display_pdf(filepath)
|
129 |
|
130 |
with col2:
|
131 |
+
st.success(f'model selected successfully: {selected_model}')
|
132 |
with st.spinner('Embeddings are in process...'):
|
133 |
ingested_data = data_ingestion()
|
134 |
st.success('Embeddings are created successfully!')
|
|
|
144 |
|
145 |
# Search the database for a response based on user input and update session state
|
146 |
if user_input:
|
147 |
+
answer = process_answer({'query': user_input}, initialize_qa_chain(selected_model))
|
148 |
st.session_state["past"].append(user_input)
|
149 |
response = answer
|
150 |
st.session_state["generated"].append(response)
|