Spaces:
Running
Running
disable conversational memory with zephyr
Browse files- streamlit_app.py +12 -7
streamlit_app.py
CHANGED
|
@@ -54,7 +54,7 @@ if 'uploaded' not in st.session_state:
|
|
| 54 |
st.session_state['uploaded'] = False
|
| 55 |
|
| 56 |
if 'memory' not in st.session_state:
|
| 57 |
-
st.session_state['memory'] =
|
| 58 |
|
| 59 |
if 'binary' not in st.session_state:
|
| 60 |
st.session_state['binary'] = None
|
|
@@ -117,12 +117,14 @@ def clear_memory():
|
|
| 117 |
def init_qa(model, api_key=None):
|
| 118 |
## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])])
|
| 119 |
if model == 'chatgpt-3.5-turbo':
|
|
|
|
| 120 |
if api_key:
|
| 121 |
chat = ChatOpenAI(model_name="gpt-3.5-turbo",
|
| 122 |
temperature=0,
|
| 123 |
openai_api_key=api_key,
|
| 124 |
frequency_penalty=0.1)
|
| 125 |
embeddings = OpenAIEmbeddings(openai_api_key=api_key)
|
|
|
|
| 126 |
else:
|
| 127 |
chat = ChatOpenAI(model_name="gpt-3.5-turbo",
|
| 128 |
temperature=0,
|
|
@@ -134,11 +136,13 @@ def init_qa(model, api_key=None):
|
|
| 134 |
model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048})
|
| 135 |
embeddings = HuggingFaceEmbeddings(
|
| 136 |
model_name="all-MiniLM-L6-v2")
|
|
|
|
| 137 |
|
| 138 |
elif model == 'zephyr-7b-beta':
|
| 139 |
chat = HuggingFaceHub(repo_id="HuggingFaceH4/zephyr-7b-beta",
|
| 140 |
model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048})
|
| 141 |
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
|
|
|
| 142 |
else:
|
| 143 |
st.error("The model was not loaded properly. Try reloading. ")
|
| 144 |
st.stop()
|
|
@@ -255,7 +259,8 @@ with st.sidebar:
|
|
| 255 |
'Reset chat memory.',
|
| 256 |
key="reset-memory-button",
|
| 257 |
on_click=clear_memory,
|
| 258 |
-
help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages."
|
|
|
|
| 259 |
|
| 260 |
left_column, right_column = st.columns([1, 1])
|
| 261 |
|
|
@@ -267,8 +272,8 @@ with right_column:
|
|
| 267 |
":warning: Do not upload sensitive data. We **temporarily** store text from the uploaded PDF documents solely for the purpose of processing your request, and we **do not assume responsibility** for any subsequent use or handling of the data submitted to third parties LLMs.")
|
| 268 |
|
| 269 |
uploaded_file = st.file_uploader("Upload an article",
|
| 270 |
-
|
| 271 |
-
|
| 272 |
disabled=st.session_state['model'] is not None and st.session_state['model'] not in
|
| 273 |
st.session_state['api_keys'],
|
| 274 |
help="The full-text is extracted using Grobid. ")
|
|
@@ -335,8 +340,8 @@ if uploaded_file and not st.session_state.loaded_embeddings:
|
|
| 335 |
|
| 336 |
st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name,
|
| 337 |
chunk_size=chunk_size,
|
| 338 |
-
|
| 339 |
-
|
| 340 |
st.session_state['loaded_embeddings'] = True
|
| 341 |
st.session_state.messages = []
|
| 342 |
|
|
@@ -389,7 +394,7 @@ with right_column:
|
|
| 389 |
elif mode == "LLM":
|
| 390 |
with st.spinner("Generating response..."):
|
| 391 |
_, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
|
| 392 |
-
|
| 393 |
|
| 394 |
if not text_response:
|
| 395 |
st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
|
|
|
|
| 54 |
st.session_state['uploaded'] = False
|
| 55 |
|
| 56 |
if 'memory' not in st.session_state:
|
| 57 |
+
st.session_state['memory'] = None
|
| 58 |
|
| 59 |
if 'binary' not in st.session_state:
|
| 60 |
st.session_state['binary'] = None
|
|
|
|
| 117 |
def init_qa(model, api_key=None):
|
| 118 |
## For debug add: callbacks=[PromptLayerCallbackHandler(pl_tags=["langchain", "chatgpt", "document-qa"])])
|
| 119 |
if model == 'chatgpt-3.5-turbo':
|
| 120 |
+
st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
|
| 121 |
if api_key:
|
| 122 |
chat = ChatOpenAI(model_name="gpt-3.5-turbo",
|
| 123 |
temperature=0,
|
| 124 |
openai_api_key=api_key,
|
| 125 |
frequency_penalty=0.1)
|
| 126 |
embeddings = OpenAIEmbeddings(openai_api_key=api_key)
|
| 127 |
+
|
| 128 |
else:
|
| 129 |
chat = ChatOpenAI(model_name="gpt-3.5-turbo",
|
| 130 |
temperature=0,
|
|
|
|
| 136 |
model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048})
|
| 137 |
embeddings = HuggingFaceEmbeddings(
|
| 138 |
model_name="all-MiniLM-L6-v2")
|
| 139 |
+
st.session_state['memory'] = ConversationBufferWindowMemory(k=4)
|
| 140 |
|
| 141 |
elif model == 'zephyr-7b-beta':
|
| 142 |
chat = HuggingFaceHub(repo_id="HuggingFaceH4/zephyr-7b-beta",
|
| 143 |
model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048})
|
| 144 |
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
|
| 145 |
+
st.session_state['memory'] = None
|
| 146 |
else:
|
| 147 |
st.error("The model was not loaded properly. Try reloading. ")
|
| 148 |
st.stop()
|
|
|
|
| 259 |
'Reset chat memory.',
|
| 260 |
key="reset-memory-button",
|
| 261 |
on_click=clear_memory,
|
| 262 |
+
help="Clear the conversational memory. Currently implemented to retrain the 4 most recent messages.",
|
| 263 |
+
disabled=model in st.session_state['rqa'] and st.session_state['rqa'][model].memory is None)
|
| 264 |
|
| 265 |
left_column, right_column = st.columns([1, 1])
|
| 266 |
|
|
|
|
| 272 |
":warning: Do not upload sensitive data. We **temporarily** store text from the uploaded PDF documents solely for the purpose of processing your request, and we **do not assume responsibility** for any subsequent use or handling of the data submitted to third parties LLMs.")
|
| 273 |
|
| 274 |
uploaded_file = st.file_uploader("Upload an article",
|
| 275 |
+
type=("pdf", "txt"),
|
| 276 |
+
on_change=new_file,
|
| 277 |
disabled=st.session_state['model'] is not None and st.session_state['model'] not in
|
| 278 |
st.session_state['api_keys'],
|
| 279 |
help="The full-text is extracted using Grobid. ")
|
|
|
|
| 340 |
|
| 341 |
st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(tmp_file.name,
|
| 342 |
chunk_size=chunk_size,
|
| 343 |
+
perc_overlap=0.1,
|
| 344 |
+
include_biblio=True)
|
| 345 |
st.session_state['loaded_embeddings'] = True
|
| 346 |
st.session_state.messages = []
|
| 347 |
|
|
|
|
| 394 |
elif mode == "LLM":
|
| 395 |
with st.spinner("Generating response..."):
|
| 396 |
_, text_response = st.session_state['rqa'][model].query_document(question, st.session_state.doc_id,
|
| 397 |
+
context_size=context_size)
|
| 398 |
|
| 399 |
if not text_response:
|
| 400 |
st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")
|