SeaLLM / app.py
nurqoneah's picture
Update app.py
87375d5 verified
raw
history blame
6.87 kB
import streamlit as st
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.llms import HuggingFaceHub
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
import warnings
from transformers import pipeline
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import os
from dotenv import load_dotenv
warnings.filterwarnings("ignore")
load_dotenv()
# Constants and configurations
APP_TITLE = "πŸ’Š Asisten Kesehatan Feminacare"
INITIAL_MESSAGE = """Halo! πŸ‘‹ Saya adalah asisten kesehatan feminacare yang siap membantu Anda dengan informasi seputar kesehatan wanita.
Silakan ajukan pertanyaan apa saja dan saya akan membantu Anda dengan informasi yang akurat."""
# Model configurations
MODEL_NAME = "SeaLLMs/SeaLLMs-v3-7B-Chat"
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
TOP_K_DOCS = 5
def initialize_models():
"""Initialize the embedding model and vector store"""
data_directory = os.path.join(os.path.dirname(__file__), "vector_db_dir")
embedding_model = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
vector_store = Chroma(
embedding_function=embedding_model,
persist_directory=data_directory
)
return vector_store
def create_llm():
"""Initialize the language model with optimized parameters"""
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, quantization_config=bnb_config)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
text_generation_pipeline = pipeline(
model=model,
tokenizer=tokenizer,
task="text-generation",
temperature=0.2,
do_sample=True,
repetition_penalty=1.1,
return_full_text=False,
max_new_tokens=200,
eos_token_id=terminators,
)
llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
# return HuggingFaceHub(
# repo_id=MODEL_NAME,
# model_kwargs={
# "temperature": 0.7, # Balanced between creativity and accuracy
# "max_new_tokens": 1024,
# "top_p": 0.9,
# "frequency_penalty": 0.5
# }
# )
return llm
# Improved prompt template with better context handling and response structure
PROMPT_TEMPLATE = """
Anda adalah asisten kesehatan profesional dengan nama Feminacare.
Berikan informasi yang akurat, jelas, dan bermanfaat berdasarkan konteks yang tersedia.
Context yang tersedia:
{context}
Chat historyt:
{chat_history}
Question: {question}
Instruksi untuk menjawab:
1. Berikan jawaban yang LENGKAP dan TERSTRUKTUR
2. Selalu sertakan SUMBER informasi dari konteks yang diberikan
3. Jika informasi tidak tersedia dalam konteks, katakan: "Maaf, saya tidak memiliki informasi yang cukup untuk menjawab pertanyaan tersebut secara akurat. Silakan konsultasi dengan tenaga kesehatan untuk informasi lebih lanjut."
4. Gunakan bahasa yang mudah dipahami
5. Jika relevan, berikan poin-poin penting menggunakan format yang rapi
6. Akhiri dengan anjuran untuk konsultasi dengan tenaga kesehatan jika diperlukan
Answer:
"""
def setup_qa_chain(vector_store):
"""Set up the QA chain with improved configuration"""
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True,
output_key='answer'
)
custom_prompt = PromptTemplate(
template=PROMPT_TEMPLATE,
input_variables=["context", "question", "chat_history"]
)
return ConversationalRetrievalChain.from_llm(
llm=create_llm(),
retriever=vector_store.as_retriever(
# search_type="mmr", # Maximum Marginal Relevance for better diversity
# search_kwargs={"k": TOP_K_DOCS}
),
memory=memory,
# combine_docs_chain_kwargs={"prompt": custom_prompt},
return_source_documents=True,
# return_generated_question=True,
)
def initialize_session_state():
"""Initialize Streamlit session state"""
if "messages" not in st.session_state:
st.session_state.messages = [
{"role": "assistant", "content": INITIAL_MESSAGE}
]
if "qa_chain" not in st.session_state:
vector_store = initialize_models()
st.session_state.qa_chain = setup_qa_chain(vector_store)
def clear_chat():
"""Clear chat history and memory"""
st.session_state.messages = [
{"role": "assistant", "content": INITIAL_MESSAGE}
]
st.session_state.qa_chain.memory.clear()
def create_ui():
"""Create the Streamlit UI"""
st.set_page_config(page_title=APP_TITLE, page_icon="πŸ’Š")
# Custom CSS for better UI
st.markdown("""
<style>
.stApp {
max-width: 1200px;
margin: 0 auto;
}
.stChat {
border-radius: 10px;
padding: 20px;
margin: 10px 0;
}
</style>
""", unsafe_allow_html=True)
st.title(APP_TITLE)
# Sidebar
with st.sidebar:
st.title("ℹ️ Tentang Aplikasi")
st.markdown("""
Asisten digital ini dirancang untuk membantu Anda untuk berkonsultasi tentang kesehatan wanita.
_Catatan: Informasi yang diberikan bersifat umum. Selalu konsultasikan dengan tenaga kesehatan untuk saran yang lebih spesifik._
""")
st.button('πŸ—‘οΈ Hapus Riwayat Chat', on_click=clear_chat)
def handle_user_input(prompt):
"""Handle user input and generate response"""
with st.spinner("Sedang menyiapkan jawaban..."):
response = st.session_state.qa_chain({"question": prompt})
return response["answer"]
def main():
initialize_session_state()
create_ui()
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Handle user input
if prompt := st.chat_input("Ketik pertanyaan Anda di sini..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
response = handle_user_input(prompt)
if response:
with st.chat_message("assistant"):
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
if __name__ == "__main__":
main()