Spaces:
Sleeping
Sleeping
import os | |
import json | |
import gradio as gr | |
import openai | |
from sentence_transformers import SentenceTransformer | |
import chromadb | |
from datasets import load_dataset | |
class DatasetCleaner: | |
"""Simple dataset cleaner for demonstration.""" | |
def __init__(self, dataset): | |
self.dataset = dataset | |
def export_to_json(self, split="train", output_file="cleaned_train.json"): | |
"""Export cleaned dataset to JSON.""" | |
data = [] | |
for item in self.dataset[split]: | |
data.append({ | |
'Context': item.get('Context', ''), | |
'Response': item.get('Response', '') | |
}) | |
with open(output_file, 'w', encoding='utf-8') as f: | |
json.dump(data, f, ensure_ascii=False, indent=2) | |
class MentalHealthRAGSystem: | |
"""A RAG system for mental health support conversations.""" | |
def __init__(self): | |
"""Initialize the RAG system.""" | |
self.db_path = 'health_care_db' | |
self.max_token_length = 2000 | |
self.collection_name = "health" | |
# Set OpenAI API key | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
# Initialize ChromaDB | |
self.chroma_client = chromadb.PersistentClient(path=self.db_path) | |
# Initialize sentence transformer | |
self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
# Try to get existing collection or create new one | |
self.collection = self.get_or_create_collection() | |
def get_or_create_collection(self): | |
"""Get existing collection or create a new one.""" | |
try: | |
# Try to get existing collection first | |
collection = self.chroma_client.get_collection(self.collection_name) | |
print(f"Using existing collection: {self.collection_name}") | |
return collection | |
except: | |
# If collection doesn't exist, create it | |
print(f"Creating new collection: {self.collection_name}") | |
return self.create_new_collection() | |
def create_new_collection(self): | |
"""Create a new collection with data.""" | |
# Create collection | |
collection = self.chroma_client.create_collection(self.collection_name) | |
# Load and process data | |
if os.path.exists('cleaned_train.json'): | |
with open('cleaned_train.json', "r", encoding="utf-8") as f: | |
content = json.load(f) | |
else: | |
# If cleaned data doesn't exist, process raw data | |
try: | |
dataset = load_dataset( | |
"Amod/mental_health_counseling_conversations", | |
data_files="combined_dataset.json" | |
) | |
cleaner = DatasetCleaner(dataset) | |
cleaner.export_to_json(split="train", output_file="cleaned_train.json") | |
with open('cleaned_train.json', "r", encoding="utf-8") as f: | |
content = json.load(f) | |
except: | |
# Fallback to sample data if dataset loading fails | |
content = self.create_sample_data() | |
# Prepare documents | |
documents = self.prepare_documents(content) | |
# Generate embeddings and add to collection | |
embeddings = self.model.encode(documents, batch_size=32, show_progress_bar=True) | |
ids = [f'health-{i}' for i in range(len(documents))] | |
metadatas = [{"source": "mental_health_dataset"} for _ in documents] | |
collection.add( | |
documents=documents, | |
embeddings=embeddings, | |
ids=ids, | |
metadatas=metadatas | |
) | |
print(f"Added {len(documents)} documents to collection") | |
return collection | |
def create_sample_data(self): | |
"""Create sample mental health data for demonstration.""" | |
return [ | |
{ | |
"Context": "I've been feeling really anxious lately and can't seem to calm down.", | |
"Response": "I understand how overwhelming anxiety can feel. It's important to remember that you're not alone in this experience. Some techniques that many people find helpful include deep breathing exercises, grounding techniques like the 5-4-3-2-1 method, and gentle physical activity. However, if your anxiety is persistent or interfering with your daily life, I'd encourage you to reach out to a mental health professional who can provide personalized support." | |
}, | |
{ | |
"Context": "I'm having trouble sleeping and my mind keeps racing at night.", | |
"Response": "Racing thoughts at bedtime can be really frustrating and exhausting. Creating a calming bedtime routine can be helpful - this might include putting away screens an hour before bed, trying relaxation techniques like progressive muscle relaxation, or writing down your thoughts in a journal to help clear your mind. If sleep difficulties persist, it's worth discussing with a healthcare provider as they can help identify underlying causes and suggest appropriate treatments." | |
}, | |
{ | |
"Context": "I feel like I'm not good enough and keep comparing myself to others.", | |
"Response": "Those feelings of inadequacy and comparison can be really painful. Remember that social media and what we see of others' lives often shows only the highlights, not the full picture. Practicing self-compassion and focusing on your own growth and achievements, no matter how small, can be helpful. Consider keeping a gratitude journal or practicing mindfulness to stay present with your own experience. If these feelings are significantly impacting your well-being, talking to a counselor can provide valuable support and tools." | |
} | |
] | |
def prepare_documents(self, content): | |
"""Prepare documents for embedding.""" | |
context = [doc['Context'] for doc in content] | |
response = [doc['Response'] for doc in content] | |
# Combine context and response | |
context_response = [f'context:{c}\nresponse:{r}' for c, r in zip(context, response)] | |
# Truncate text to max token length | |
text_truncated = [self.truncate_text(t) for t in context_response] | |
return text_truncated | |
def truncate_text(self, text, max_len=None): | |
"""Truncate text to maximum length.""" | |
if max_len is None: | |
max_len = self.max_token_length | |
if len(text) > max_len: | |
return text[:max_len] | |
return text | |
def query_database(self, question, n_results=3): | |
"""Query the vector database for relevant documents.""" | |
results = self.collection.query( | |
query_texts=[question], | |
n_results=n_results | |
) | |
return "\n\n".join([doc for doc in results["documents"][0]]) | |
def call_openai(self, question, history): | |
"""Generate answer using OpenAI with retrieved context and chat history.""" | |
try: | |
context = self.query_database(question) | |
# Build conversation history for context | |
conversation_messages = [ | |
{"role": "system", "content": f""" | |
You are a compassionate mental health support assistant. | |
Provide helpful, empathetic, and supportive answers to the following question using only the context provided below. | |
If the answer is not in the context, say "I'm not sure based on the available information." | |
Do not give medical advice, diagnoses, or instructions to self-harm. Instead, encourage seeking help from qualified professionals if needed. | |
Context: | |
{context} | |
"""} | |
] | |
# Add conversation history to messages | |
for human_msg, assistant_msg in history: | |
conversation_messages.append({"role": "user", "content": human_msg}) | |
conversation_messages.append({"role": "assistant", "content": assistant_msg}) | |
# Add current question | |
conversation_messages.append({"role": "user", "content": question}) | |
response = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo", | |
messages=conversation_messages, | |
temperature=0.2 | |
) | |
return response.choices[0].message.content.strip() | |
except Exception as e: | |
return f"Error: {str(e)}. Please make sure your OpenAI API key is set correctly." | |
# Initialize the RAG system | |
print("Initializing Mental Health Support RAG System...") | |
rag_system = MentalHealthRAGSystem() | |
print("System ready!") | |
# Chat interface function that handles history | |
def chat_interface_function(user_question, history): | |
"""Main interface function for Gradio Chat.""" | |
if not user_question.strip(): | |
return "Please enter your question or concern." | |
try: | |
answer = rag_system.call_openai(user_question, history) | |
return answer | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Create Gradio chat interface | |
demo = gr.ChatInterface( | |
fn=chat_interface_function, | |
title="π RAG Mental Health Support Assistant", | |
description=( | |
"Get supportive, empathetic responses to mental health-related questions " | |
"based on provided context. Not a substitute for professional help." | |
), | |
textbox=gr.Textbox( | |
placeholder="Ask a question about mental health support...", | |
container=False, | |
scale=7 | |
) | |
) | |
if __name__ == "__main__": | |
demo.launch() |