File size: 4,221 Bytes
bfaa008
19cb684
413cde5
2a52487
e9a8554
e3851ba
413cde5
7b284cd
19cb684
7b284cd
8edf7ae
e9a8554
 
 
e3851ba
 
e9a8554
 
 
e3851ba
 
 
 
 
 
 
 
19cb684
2a52487
19cb684
 
d22d1d9
2a52487
 
 
 
 
 
 
 
 
 
bfaa008
19cb684
 
 
920e657
19cb684
 
 
 
2a52487
bfaa008
19cb684
 
 
f45cb95
19cb684
 
2a52487
19cb684
54fb593
bfdd8ad
fd009ff
2a52487
 
bfdd8ad
2a52487
19cb684
bfdd8ad
 
 
 
 
 
 
 
 
e3851ba
bfdd8ad
 
 
 
 
 
 
 
 
 
36bb34e
 
 
 
bfdd8ad
2a52487
bfdd8ad
 
4bc65f5
bfdd8ad
bfaa008
6305dd7
19cb684
 
2a52487
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gradio as gr
from groq import Groq
import os
import time
import whisper
from gtts import gTTS

api_key = os.getenv('GROQ_API_KEY')
# Initialize Groq client
client = Groq(api_key=api_key)

# Load Whisper model
whisper_model = whisper.load_model("base")  # You can use "tiny", "base", "small", "medium", or "large"

# Function to convert audio to text
def audio_to_text(audio_file):
    # Use Whisper to transcribe audio
    result = whisper_model.transcribe(audio_file)
    return result['text']

# Function to convert text to audio
def text_to_audio(text):
    tts = gTTS(text)
    audio_file = "output_audio.mp3"
    tts.save(audio_file)
    return audio_file

# Function to generate responses with error handling
def generate_response(user_input, chat_history: list):
    try:
        # Prepare messages with chat history
        messages = [{"role": "system", "content": "You are a mental health assistant. Your responses should be empathetic, non-judgmental, and provide helpful advice based on mental health principles. Always encourage seeking professional help when needed. Your responses should look human as well."}]
        
        # Iterate through chat history and add user and assistant messages
        for message in chat_history:
            # Ensure that each message contains only 'role' and 'content' keys
            if 'role' in message and 'content' in message:
                messages.append({"role": message["role"], "content": message["content"]})
            else:
                print(f"Skipping invalid message: {message}")
        
        messages.append({"role": "user", "content": user_input})  # Add the current user message

        # Call Groq API to get a response from LLaMA
        chat_completion = client.chat.completions.create(
            messages=messages,
            model='llama-3.1-70b-versatile'
        )
        
        # Extract response
        response = chat_completion.choices[0].message.content
        return response, chat_history  # Ensure you return both response and chat_history

    except Exception as e:
        print(f"Error occurred: {e}")  # Print error to console for debugging
        return "An error occurred while generating the response. Please try again.", chat_history

def gradio_interface():
    with gr.Blocks() as demo:
        # Initialize chat history
        chat_history = []

        # Create input components
        gr.Markdown("## A Mental Health Chatbot")
        chatbot = gr.Chatbot(type="messages")
        msg = gr.Textbox(placeholder="Type your message here...")
        audio_input = gr.Audio(type="filepath", label="Speak your message")
        clear = gr.Button("Clear")

        # User message submission function
        def user(user_message, history: list):
            history.append({"role": "user", "content": user_message})
            return "", history

        def bot(history: list):
            if len(history) > 0:
                user_input = history[-1]["content"]
                response, updated_history = generate_response(user_input, history)
                history = updated_history
                history.append({"role": "assistant", "content": ""})
                for character in response:
                    history[-1]['content'] += character
                    time.sleep(0.02)
                    yield history

        # Speech-to-text processing
        def process_audio(audio_file, history):
            if audio_file:
                transcription = audio_to_text(audio_file)  # Convert audio to text
                if transcription:
                    history.append({"role": "user", "content": transcription})
                else:
                    history.append({"role": "assistant", "content": "I couldn't understand your audio. Please try again."})
                return history

        # Set up interaction flow:
        msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(bot, chatbot, chatbot)
        audio_input.change(process_audio, [audio_input, chatbot], chatbot).then(bot, chatbot, chatbot)  # Ensure bot is triggered
        clear.click(lambda: [], None, chatbot, queue=False)

    demo.launch()

# Run the interface
gradio_interface()