File size: 9,715 Bytes
3974f08
 
 
8bd818c
3974f08
 
1eb8d7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3974f08
 
 
 
 
 
 
 
 
 
1eb8d7b
 
3974f08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1eb8d7b
 
 
 
 
 
 
 
 
 
 
 
3974f08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1eb8d7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3974f08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1eb8d7b
 
3974f08
 
 
1eb8d7b
 
 
 
 
 
 
 
 
 
 
3974f08
1eb8d7b
 
 
 
 
 
 
 
 
3974f08
1eb8d7b
3974f08
 
 
 
 
 
 
 
 
 
 
 
 
1eb8d7b
 
 
3974f08
 
 
 
1eb8d7b
3974f08
 
 
 
1eb8d7b
 
 
3974f08
 
 
 
1eb8d7b
 
 
 
 
3974f08
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import os
import json
import gradio as gr
import openai
from sentence_transformers import SentenceTransformer
import chromadb
from datasets import load_dataset

class DatasetCleaner:
    """Simple dataset cleaner for demonstration."""
    def __init__(self, dataset):
        self.dataset = dataset
    
    def export_to_json(self, split="train", output_file="cleaned_train.json"):
        """Export cleaned dataset to JSON."""
        data = []
        for item in self.dataset[split]:
            data.append({
                'Context': item.get('Context', ''),
                'Response': item.get('Response', '')
            })
        
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False, indent=2)

class MentalHealthRAGSystem:
    """A RAG system for mental health support conversations."""
    
    def __init__(self):
        """Initialize the RAG system."""
        self.db_path = 'health_care_db'
        self.max_token_length = 2000
        self.collection_name = "health"
        
        # Set OpenAI API key
        openai.api_key = os.getenv("OPENAI_API_KEY")
        
        # Initialize ChromaDB
        self.chroma_client = chromadb.PersistentClient(path=self.db_path)
        
        # Initialize sentence transformer
        self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
        
        # Try to get existing collection or create new one
        self.collection = self.get_or_create_collection()
        
    def get_or_create_collection(self):
        """Get existing collection or create a new one."""
        try:
            # Try to get existing collection first
            collection = self.chroma_client.get_collection(self.collection_name)
            print(f"Using existing collection: {self.collection_name}")
            return collection
        except:
            # If collection doesn't exist, create it
            print(f"Creating new collection: {self.collection_name}")
            return self.create_new_collection()
    
    def create_new_collection(self):
        """Create a new collection with data."""
        # Create collection
        collection = self.chroma_client.create_collection(self.collection_name)
        
        # Load and process data
        if os.path.exists('cleaned_train.json'):
            with open('cleaned_train.json', "r", encoding="utf-8") as f:
                content = json.load(f)
        else:
            # If cleaned data doesn't exist, process raw data
            try:
                dataset = load_dataset(
                    "Amod/mental_health_counseling_conversations",
                    data_files="combined_dataset.json"
                )
                cleaner = DatasetCleaner(dataset)
                cleaner.export_to_json(split="train", output_file="cleaned_train.json")
                with open('cleaned_train.json', "r", encoding="utf-8") as f:
                    content = json.load(f)
            except:
                # Fallback to sample data if dataset loading fails
                content = self.create_sample_data()
        
        # Prepare documents
        documents = self.prepare_documents(content)
        
        # Generate embeddings and add to collection
        embeddings = self.model.encode(documents, batch_size=32, show_progress_bar=True)
        
        ids = [f'health-{i}' for i in range(len(documents))]
        metadatas = [{"source": "mental_health_dataset"} for _ in documents]
        
        collection.add(
            documents=documents,
            embeddings=embeddings,
            ids=ids,
            metadatas=metadatas
        )
        
        print(f"Added {len(documents)} documents to collection")
        return collection
    
    def create_sample_data(self):
        """Create sample mental health data for demonstration."""
        return [
            {
                "Context": "I've been feeling really anxious lately and can't seem to calm down.",
                "Response": "I understand how overwhelming anxiety can feel. It's important to remember that you're not alone in this experience. Some techniques that many people find helpful include deep breathing exercises, grounding techniques like the 5-4-3-2-1 method, and gentle physical activity. However, if your anxiety is persistent or interfering with your daily life, I'd encourage you to reach out to a mental health professional who can provide personalized support."
            },
            {
                "Context": "I'm having trouble sleeping and my mind keeps racing at night.",
                "Response": "Racing thoughts at bedtime can be really frustrating and exhausting. Creating a calming bedtime routine can be helpful - this might include putting away screens an hour before bed, trying relaxation techniques like progressive muscle relaxation, or writing down your thoughts in a journal to help clear your mind. If sleep difficulties persist, it's worth discussing with a healthcare provider as they can help identify underlying causes and suggest appropriate treatments."
            },
            {
                "Context": "I feel like I'm not good enough and keep comparing myself to others.",
                "Response": "Those feelings of inadequacy and comparison can be really painful. Remember that social media and what we see of others' lives often shows only the highlights, not the full picture. Practicing self-compassion and focusing on your own growth and achievements, no matter how small, can be helpful. Consider keeping a gratitude journal or practicing mindfulness to stay present with your own experience. If these feelings are significantly impacting your well-being, talking to a counselor can provide valuable support and tools."
            }
        ]
    
    def prepare_documents(self, content):
        """Prepare documents for embedding."""
        context = [doc['Context'] for doc in content]
        response = [doc['Response'] for doc in content]
        
        # Combine context and response
        context_response = [f'context:{c}\nresponse:{r}' for c, r in zip(context, response)]
        
        # Truncate text to max token length
        text_truncated = [self.truncate_text(t) for t in context_response]
        
        return text_truncated
    
    def truncate_text(self, text, max_len=None):
        """Truncate text to maximum length."""
        if max_len is None:
            max_len = self.max_token_length
        if len(text) > max_len:
            return text[:max_len]
        return text
    
    def query_database(self, question, n_results=3):
        """Query the vector database for relevant documents."""
        results = self.collection.query(
            query_texts=[question],
            n_results=n_results
        )
        return "\n\n".join([doc for doc in results["documents"][0]])
    
    def call_openai(self, question, history):
        """Generate answer using OpenAI with retrieved context and chat history."""
        try:
            context = self.query_database(question)
            
            # Build conversation history for context
            conversation_messages = [
                {"role": "system", "content": f"""
                You are a compassionate mental health support assistant.
                Provide helpful, empathetic, and supportive answers to the following question using only the context provided below.
                If the answer is not in the context, say "I'm not sure based on the available information."
                Do not give medical advice, diagnoses, or instructions to self-harm. Instead, encourage seeking help from qualified professionals if needed.
                Context:
                {context}
                """}
            ]
            
            # Add conversation history to messages
            for human_msg, assistant_msg in history:
                conversation_messages.append({"role": "user", "content": human_msg})
                conversation_messages.append({"role": "assistant", "content": assistant_msg})
            
            # Add current question
            conversation_messages.append({"role": "user", "content": question})
            
            response = openai.ChatCompletion.create(
                model="gpt-3.5-turbo",
                messages=conversation_messages,
                temperature=0.2
            )
            
            return response.choices[0].message.content.strip()
        
        except Exception as e:
            return f"Error: {str(e)}. Please make sure your OpenAI API key is set correctly."

# Initialize the RAG system
print("Initializing Mental Health Support RAG System...")
rag_system = MentalHealthRAGSystem()
print("System ready!")

# Chat interface function that handles history
def chat_interface_function(user_question, history):
    """Main interface function for Gradio Chat."""
    if not user_question.strip():
        return "Please enter your question or concern."
    
    try:
        answer = rag_system.call_openai(user_question, history)
        return answer
    except Exception as e:
        return f"Error: {str(e)}"

# Create Gradio chat interface
demo = gr.ChatInterface(
    fn=chat_interface_function,
    title="πŸ’š RAG Mental Health Support Assistant",
    description=(
        "Get supportive, empathetic responses to mental health-related questions "
        "based on provided context. Not a substitute for professional help."
    ),
    textbox=gr.Textbox(
        placeholder="Ask a question about mental health support...",
        container=False,
        scale=7
    )
)

if __name__ == "__main__":
    demo.launch()