Spaces:
Sleeping
Sleeping
File size: 8,473 Bytes
79d5986 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 |
import streamlit as st
from langchain.document_loaders import PyPDFLoader, DirectoryLoader
from langchain import PromptTemplate
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms import CTransformers
from langchain.chains import RetrievalQA
import geocoder
from geopy.distance import geodesic
import pandas as pd
import folium
from streamlit_folium import folium_static
from transformers import pipeline
import logging
#-----------------
# demonstrating use of a Vectordb store
#-----------------
DB_FAISS_PATH = 'vectorstores/db_faiss'
#-----------------
# Detecting the context if its to be a normal textual chat, load nearest clinic map or shopping link
#-----------------
classifier = pipeline("zero-shot-classification")
#-----------------
# Set up logging. mostly for debugging purposes only
#-----------------
logging.basicConfig(filename='app.log', level=logging.DEBUG, format='%(asctime)s %(message)s')
custom_prompt_template = """Use the following pieces of information to answer the user's question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Context: {context}
Question: {question}
Only return the helpful answer below and nothing else.
Helpful answer:
"""
def set_custom_prompt():
prompt = PromptTemplate(template=custom_prompt_template,
input_variables=['context', 'question'])
return prompt
def retrieval_qa_chain(llm, prompt, db):
qa_chain = RetrievalQA.from_chain_type(llm=llm,
chain_type='stuff',
retriever=db.as_retriever(search_kwargs={'k': 2}),
return_source_documents=True,
chain_type_kwargs={'prompt': prompt}
)
return qa_chain
#-----------------
#function to load LLM from huggingface
#-----------------
def load_llm():
llm = CTransformers(
model="TheBloke/Llama-2-7B-Chat-GGML",
model_type="llama",
max_new_tokens=512,
temperature=0.5
)
return llm
#-----------------
#function that does 3 things
#1. loads maps using Folium if Context is nearest clinic (maps loads dataset from csv)
#2. loads a shopee link if Context is to buy things
#3. loads normal chat bubble which is to infer the chat bubble
#-----------------
def qa_bot(query, context=""):
logging.info(f"Received query: {query}, Context: {context}")
if context in ["nearest clinic","nearest TCM clinic","nearest TCM doctor","near me","nearest to me"]:
#-----------
# Loads map
#-----------
logging.info("Context matched for nearest TCM clinic.")
# Get user's current location
g = geocoder.ip('me')
user_lat, user_lon = g.latlng
# Load locations from the CSV file
locations_df = pd.read_csv("dataset/locations.csv")
# Filter locations within 5km from user's current location
filtered_locations_df = locations_df[locations_df.apply(lambda row: geodesic((user_lat, user_lon), (row['latitude'], row['longitude'])).kilometers <= 5, axis=1)]
# Create map centered at user's location
my_map = folium.Map(location=[user_lat, user_lon], zoom_start=12)
# Add markers with custom tooltips for filtered locations
for index, location in filtered_locations_df.iterrows():
folium.Marker(location=[location['latitude'], location['longitude']], tooltip=f"{location['name']}<br>Reviews: {location['Stars_review']}<br>Avg Price $: {location['Price']}<br>Contact No: {location['Contact']}").add_to(my_map)
# Display map
folium_static(my_map)
return "[Map of Clinic Locations 5km from your current location]"
elif context in ["buy", "Ointment", "Hong You", "Feng You", "Fengyou", "Po chai pills"]:
#-----------
# Loads shopee link
#-----------
logging.info("Context matched for buying.")
# Create a hyperlink to shopee.sg based on the search query
shopee_link = f"<a href='https://shopee.sg/search?keyword={context}'>at this Shopee link!</a>"
return f"You may visit this page to purchase {context} {shopee_link}!"
else:
#-----------
# Loads normal chat bubble
#-----------
logging.info("Context not matched for nearest TCM clinic or buying.")
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={'device': 'cpu'})
db = FAISS.load_local(DB_FAISS_PATH, embeddings)
llm = load_llm()
qa_prompt = set_custom_prompt()
qa = retrieval_qa_chain(llm, qa_prompt, db)
# Implement the question-answering logic here
response = qa({'query': query})
return response['result']
def add_vertical_space(spaces=1):
for _ in range(spaces):
st.markdown("---")
def main():
st.set_page_config(page_title="Ask me anything about TCM")
with st.sidebar:
st.title('Welcome to Nexus AI TCM!')
st.markdown('''
<style>
[data-testid=stSidebar] {
background-color: #ffffff;
}
</style>
<img src="https://huggingface.co/spaces/mathslearn/chatbot_test_streamlit/resolve/main/logo.jpeg" width=200>
''', unsafe_allow_html=True)
add_vertical_space(1) # Adjust the number of spaces as needed
st.title("Nexus AI TCM")
st.markdown(
"""
<style>
.chat-container {
display: flex;
flex-direction: column;
height: 400px;
overflow-y: auto;
padding: 10px;
color: white; /* Font color */
}
.user-bubble {
background-color: #007bff; /* Blue color for user */
align-self: flex-end;
border-radius: 10px;
padding: 8px;
margin: 5px;
max-width: 70%;
word-wrap: break-word;
}
.bot-bubble {
background-color: #363636; /* Slightly lighter background color */
align-self: flex-start;
border-radius: 10px;
padding: 8px;
margin: 5px;
max-width: 70%;
word-wrap: break-word;
}
</style>
"""
, unsafe_allow_html=True)
conversation = st.session_state.get("conversation", [])
if "my_text" not in st.session_state:
st.session_state.my_text = ""
st.text_input("Enter text here", key="widget", on_change=submit)
query = st.session_state.my_text
if st.button("Ask"):
if query:
with st.spinner("Processing your question..."): # Display the processing message
conversation.append({"role": "user", "message": query})
# Call your QA function
answer = qa_bot(query, infer_context(query))
conversation.append({"role": "bot", "message": answer})
st.session_state.conversation = conversation
else:
st.warning("Please input a question.")
#
# Display the conversation history
chat_container = st.empty()
chat_bubbles = ''.join([f'<div class="{c["role"]}-bubble">{c["message"]}</div>' for c in conversation])
chat_container.markdown(f'<div class="chat-container">{chat_bubbles}</div>', unsafe_allow_html=True)
def submit():
st.session_state.my_text = st.session_state.widget
st.session_state.widget = ""
#-----------
# Setting the Context
#-----------
def infer_context(query):
"""
Function to infer context based on the user's query.
Modify this function to suit your context detection needs.
"""
labels = ["TCM","sick","herbs","traditional","nearest clinic","nearest TCM clinic","nearest TCM doctor","near me","nearest to me", "Ointment", "Hong You", "Feng You", "Fengyou", "Po chai pills"]
result = classifier(query, labels)
predicted_label = result["labels"][0]
return predicted_label
if __name__ == "__main__":
main()
|