Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
import faiss | |
import pickle | |
from groq import Groq | |
from datasets import load_dataset | |
from transformers import pipeline | |
# Initialize Groq API | |
client = Groq(api_key=os.environ.get("GROQ_API_KEY")) | |
# Load datasets | |
healthcare_ds = load_dataset("harishnair04/mtsamples") | |
education_ds = load_dataset("ehovy/race", "all") | |
finance_ds = load_dataset("warwickai/financial_phrasebank_mirror") | |
# FAISS Index Setup | |
index = faiss.IndexFlatL2(768) | |
chat_history = [] | |
# Streamlit UI Setup | |
st.set_page_config(page_title="AI Chatbot", layout="wide") | |
st.title("π€ AI Chatbot (Healthcare, Education & Finance)") | |
# Sidebar for chat history | |
st.sidebar.title("π Chat History") | |
if st.sidebar.button("Download Chat History"): | |
with open("chat_history.txt", "w") as file: | |
file.write("\n".join(chat_history)) | |
st.sidebar.success("Chat history saved!") | |
# Chat Interface | |
user_input = st.text_input("π¬ Ask me anything:", placeholder="Type your query here...") | |
if st.button("Send"): | |
if user_input: | |
# Determine dataset based on user query (Basic CAG Implementation) | |
dataset = healthcare_ds if "health" in user_input.lower() else \ | |
education_ds if "education" in user_input.lower() else \ | |
finance_ds | |
# RAG: Retrieve relevant data | |
retrieved_data = dataset['train'][0] # Simplified retrieval | |
# Generate response using Llama via Groq API | |
chat_completion = client.chat.completions.create( | |
messages=[{"role": "user", "content": f"{user_input} {retrieved_data}"}], | |
model="llama-3.3-70b-versatile" | |
) | |
response = chat_completion.choices[0].message.content | |
# Save chat to FAISS and display | |
chat_history.append(f"User: {user_input}\nBot: {response}") | |
st.text_area("π€ AI Response:", value=response, height=200) | |
# Display past chats | |
st.sidebar.write("\n".join(chat_history)) | |
# Save chat history using pickle for persistence | |
def save_chat_history(): | |
with open("chat_history.pkl", "wb") as file: | |
pickle.dump(chat_history, file) | |
def load_chat_history(): | |
global chat_history | |
if os.path.exists("chat_history.pkl"): | |
with open("chat_history.pkl", "rb") as file: | |
chat_history = pickle.load(file) | |
load_chat_history() | |
if st.sidebar.button("Save Chat History"): | |
save_chat_history() | |
st.sidebar.success("Chat history saved permanently!") | |