SeaLLM / app.py
nurqoneah's picture
Update app.py
e708fb6 verified
raw
history blame
6.62 kB
import streamlit as st
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.llms import HuggingFaceHub
from langchain.prompts import PromptTemplate
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
import warnings
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import os
from dotenv import load_dotenv
from langchain_huggingface import HuggingFacePipeline
warnings.filterwarnings("ignore")
load_dotenv()
# Constants and configurations
APP_TITLE = "πŸ’Š Asisten Kesehatan Feminacare"
INITIAL_MESSAGE = """Halo! πŸ‘‹ Saya adalah asisten kesehatan feminacare yang siap membantu Anda dengan informasi seputar kesehatan wanita.
Silakan ajukan pertanyaan apa saja dan saya akan membantu Anda dengan informasi yang akurat."""
MODEL_NAME = "SeaLLMs/SeaLLM-13B-Chat"
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
TOP_K_DOCS = 5
def initialize_models():
"""Initialize the embedding model and vector store"""
data_directory = os.path.join(os.path.dirname(__file__), "vector_db_dir")
embedding_model = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
vector_store = Chroma(
embedding_function=embedding_model,
persist_directory=data_directory
)
return vector_store
def create_llm():
"""Initialize the language model with auto device mapping"""
# model = AutoModelForCausalLM.from_pretrained(
# MODEL_NAME,
# device_map="auto",
# trust_remote_code=True
# )
# tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# # Get terminators for the model
# terminators = [tokenizer.eos_token_id]
# if hasattr(tokenizer, 'convert_tokens_to_ids'):
# try:
# terminators.append(tokenizer.convert_tokens_to_ids("<|eot_id|>"))
# except:
# pass
# text_generation_pipeline = pipeline(
# model=model,
# tokenizer=tokenizer,
# task="text-generation",
# temperature=0.2,
# do_sample=True,
# repetition_penalty=1.1,
# return_full_text=False,
# max_new_tokens=200,
# eos_token_id=terminators,
# )
# return HuggingFacePipeline(pipeline=text_generation_pipeline)
return HuggingFaceHub(
repo_id=MODEL_NAME,
task="text-generation",
model_kwargs={
"temperature": 0.7, # Balanced between creativity and accuracy
"max_new_tokens": 1024,
"top_p": 0.9,
"frequency_penalty": 0.5
}
)
PROMPT_TEMPLATE = """
Anda adalah asisten kesehatan profesional dengan nama Feminacare.
Berikan informasi yang akurat, jelas, dan bermanfaat berdasarkan konteks yang tersedia.
Context yang tersedia:
{context}
Chat historyt:
{chat_history}
Question: {question}
Instruksi untuk menjawab:
1. Berikan jawaban yang LENGKAP dan TERSTRUKTUR
2. Selalu sertakan SUMBER informasi dari konteks yang diberikan
3. Jika informasi tidak tersedia dalam konteks, katakan: "Maaf, saya tidak memiliki informasi yang cukup untuk menjawab pertanyaan tersebut secara akurat. Silakan konsultasi dengan tenaga kesehatan untuk informasi lebih lanjut."
4. Gunakan bahasa yang mudah dipahami
5. Jika relevan, berikan poin-poin penting menggunakan format yang rapi
6. Akhiri dengan anjuran untuk konsultasi dengan tenaga kesehatan jika diperlukan
Answer:
"""
def setup_qa_chain(vector_store):
"""Set up the QA chain with improved configuration"""
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True,
output_key='answer'
)
custom_prompt = PromptTemplate(
template=PROMPT_TEMPLATE,
input_variables=["context", "question", "chat_history"]
)
return ConversationalRetrievalChain.from_llm(
llm=create_llm(),
retriever=vector_store.as_retriever(),
memory=memory,
combine_docs_chain_kwargs={"prompt": custom_prompt},
return_source_documents=True,
)
def initialize_session_state():
"""Initialize Streamlit session state"""
if "messages" not in st.session_state:
st.session_state.messages = [
{"role": "assistant", "content": INITIAL_MESSAGE}
]
if "qa_chain" not in st.session_state:
vector_store = initialize_models()
st.session_state.qa_chain = setup_qa_chain(vector_store)
def clear_chat():
"""Clear chat history and memory"""
st.session_state.messages = [
{"role": "assistant", "content": INITIAL_MESSAGE}
]
st.session_state.qa_chain.memory.clear()
def create_ui():
"""Create the Streamlit UI"""
st.set_page_config(page_title=APP_TITLE, page_icon="πŸ’Š")
# Custom CSS for better UI
st.markdown("""
<style>
.stApp {
max-width: 1200px;
margin: 0 auto;
}
.stChat {
border-radius: 10px;
padding: 20px;
margin: 10px 0;
}
</style>
""", unsafe_allow_html=True)
st.title(APP_TITLE)
# Sidebar
with st.sidebar:
st.title("ℹ️ Tentang Aplikasi")
st.markdown("""
Asisten digital ini dirancang untuk membantu Anda untuk berkonsultasi tentang kesehatan wanita.
_Catatan: Informasi yang diberikan bersifat umum. Selalu konsultasikan dengan tenaga kesehatan untuk saran yang lebih spesifik._
""")
st.button('πŸ—‘οΈ Hapus Riwayat Chat', on_click=clear_chat)
def handle_user_input(prompt):
"""Handle user input and generate response"""
with st.spinner("Sedang menyiapkan jawaban..."):
response = st.session_state.qa_chain({"question": prompt})
return response["answer"]
def main():
initialize_session_state()
create_ui()
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Handle user input
if prompt := st.chat_input("Ketik pertanyaan Anda di sini..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
response = handle_user_input(prompt)
if response:
with st.chat_message("assistant"):
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
if __name__ == "__main__":
main()