Kathirsci commited on
Commit
9b35aa3
·
verified ·
1 Parent(s): e6f242e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ from langchain_community.document_loaders import TextLoader
4
+ from langchain.vectorstores import Chroma
5
+ from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
6
+ from langchain_community.llms import HuggingFaceHub
7
+ from langchain.prompts import PromptTemplate
8
+ from langchain.memory import ConversationBufferMemory
9
+ from langchain.chains import ConversationalRetrievalChain
10
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
11
+ from langchain_core.output_parsers import StrOutputParser
12
+ from langchain_core.runnables import RunnablePassthrough
13
+ import gradio as gr
14
+ import wandb
15
+
16
+ # Initialize the chatbot
17
+ loaders = []
18
+ folder_path = "Data"
19
+ for i in range(12):
20
+ file_path = os.path.join(folder_path,"{}.txt".format(i))
21
+ loaders.append(TextLoader(file_path))
22
+ docs = []
23
+ for loader in loaders:
24
+ docs.extend(loader.load())
25
+ HF_TOKEN = os.getenv("HF_TOKEN")
26
+ embeddings = HuggingFaceInferenceAPIEmbeddings(
27
+ api_key=HF_TOKEN,
28
+ model_name="sentence-transformers/all-mpnet-base-v2"
29
+ )
30
+ vectordb = Chroma.from_documents(
31
+ documents=docs,
32
+ embedding=embeddings
33
+ )
34
+ llm = HuggingFaceHub(
35
+ repo_id="google/gemma-1.1-7b-it",
36
+ task="text-generation",
37
+ model_kwargs={
38
+ "max_new_tokens": 512,
39
+ "top_k": 5,
40
+ "temperature": 0.1,
41
+ "repetition_penalty": 1.03,
42
+ },
43
+ huggingfacehub_api_token=HF_TOKEN
44
+ )
45
+ template = """
46
+ You are a Mental Health Chatbot. Help the user with their mental health concerns.
47
+ Use the context below to answer the questions {context}
48
+ Question: {question}
49
+ Helpful Answer:"""
50
+
51
+ QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"],template=template)
52
+ memory = ConversationBufferMemory(
53
+ memory_key="chat_history",
54
+ return_messages=True
55
+ )
56
+ retriever = vectordb.as_retriever()
57
+ qa = ConversationalRetrievalChain.from_llm(
58
+ llm,
59
+ retriever=retriever,
60
+ memory=memory,
61
+ )
62
+ contextualize_q_system_prompt = """
63
+ Given a chat history and the latest user question
64
+ which might reference context in the chat history,
65
+ formulate a standalone question
66
+ which can be understood without the chat history.
67
+ Do NOT answer the question, just reformulate it if needed and otherwise return it as is."""
68
+ contextualize_q_prompt = ChatPromptTemplate.from_messages(
69
+ [
70
+ ("system", contextualize_q_system_prompt),
71
+ MessagesPlaceholder(variable_name="chat_history"),
72
+ ("human", "{question}"),
73
+ ]
74
+ )
75
+ contextualize_q_chain = contextualize_q_prompt | llm | StrOutputParser()
76
+ def contextualized_question(input: dict):
77
+ if input.get("chat_history"):
78
+ return contextualize_q_chain
79
+ else:
80
+ return input["question"]
81
+ rag_chain = (
82
+ RunnablePassthrough.assign(
83
+ context=contextualized_question | retriever
84
+ )
85
+ | QA_CHAIN_PROMPT
86
+ | llm
87
+ )
88
+ wandb.login(key=os.getenv("key"))
89
+ os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
90
+ os.environ["WANDB_PROJECT"] = "Mental_Health_ChatBot"
91
+ print("Welcome to the Mental Health Chatbot. How can I help you today?")
92
+ chat_history = []
93
+ def predict(message, history):
94
+ ai_msg = rag_chain.invoke({"question": message, "chat_history": chat_history})
95
+ idx = ai_msg.find("Answer")
96
+ chat_history.extend([HumanMessage(content=message), ai_msg])
97
+ return ai_msg[idx:]
98
+ gr.ChatInterface(predict).launch()