|
import os |
|
import requests |
|
from PIL import Image |
|
from io import BytesIO |
|
from huggingface_hub import InferenceClient |
|
|
|
import gradio as gr |
|
|
|
|
|
read_token = os.getenv('HF_READ') |
|
write_token = os.getenv('HF_WRITE') |
|
|
|
|
|
HEADERS = {"Authorization": f"Bearer {read_token}"} |
|
BASE_URL='https://api-inference.huggingface.co/models/' |
|
|
|
CHAT_MODEL = "mistralai/Mistral-Nemo-Instruct-2407" |
|
WHISPER_API_URL = "distil-whisper/distil-large-v2" |
|
BARK_API_URL = "suno/bark" |
|
FLUX_API_URL = "enhanceaiteam/Flux-uncensored" |
|
|
|
|
|
client = InferenceClient(api_key=read_token) |
|
|
|
|
|
system_prompt = """ |
|
You are an empathetic and knowledgeable AI assistant designed to engage in meaningful conversations, |
|
assist with tasks, and provide accurate information. |
|
You can also generate vivid visuals! |
|
|
|
To request an image, include a description between the IMG tags, like this: |
|
##IMG: A serene forest at dawn with a golden glow :IMG## |
|
|
|
To request an websearch, include a search Q between the URL tags, like this: |
|
##URL: information about rainbow farting unicorns :URL## |
|
""" |
|
|
|
chat_history = [] |
|
|
|
def tagger(bot_response): |
|
""" |
|
Extract tags from the bot response and return the filtered response text and tags. |
|
|
|
Args: |
|
bot_response (str): The full response text from the chatbot. |
|
|
|
Returns: |
|
tuple: A tuple containing: |
|
- filtered_response (str): The response text with tags removed. |
|
- tags (dict): A dictionary of extracted tags and their values. |
|
""" |
|
import re |
|
|
|
tags = {} |
|
filtered_response = bot_response |
|
|
|
|
|
url_pattern = r"##URL:(.+?):URL##" |
|
url_matches = re.findall(url_pattern, bot_response) |
|
if url_matches: |
|
tags['url'] = url_matches |
|
filtered_response = re.sub(url_pattern, "", filtered_response).strip() |
|
|
|
|
|
img_pattern = r"##IMG:(.+?):IMG##" |
|
img_matches = re.findall(img_pattern, bot_response) |
|
if img_matches: |
|
tags['images'] = img_matches |
|
|
|
filtered_response = re.sub(img_pattern, "", filtered_response).strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return filtered_response, tags |
|
|
|
def speech_to_text(filename): |
|
"""Convert speech to text using Whisper API.""" |
|
try: |
|
with open(filename, "rb") as f: |
|
data = f.read() |
|
response = requests.post(BASE_URL+WHISPER_API_URL, headers=HEADERS, data=data) |
|
if response.status_code == 200: |
|
return response.json().get("text", "Could not recognize speech") |
|
print(f"Whisper Error: {response.status_code} - {response.text}") |
|
except Exception as e: |
|
print(f"Exception in speech_to_text: {e}") |
|
return None |
|
|
|
def chatbot_logic(input_text): |
|
"""Generate a response from the chatbot and handle tags.""" |
|
global chat_history |
|
chat_history.append({"role": "user", "content": input_text}) |
|
messages = [{"role": "system", "content": system_prompt}] + chat_history |
|
|
|
try: |
|
completion = client.chat.completions.create( |
|
model=CHAT_MODEL, |
|
messages=messages, |
|
max_tokens=500 |
|
) |
|
response_text = completion.choices[0].message["content"] |
|
|
|
|
|
response_text, tags = tagger(response_text) |
|
chat_history.append({"role": "assistant", "content": response_text}) |
|
|
|
|
|
image_prompt = tags.get("images")[0] if "images" in tags else None |
|
|
|
return response_text, image_prompt |
|
except Exception as e: |
|
print(f"Chatbot Error: {e}") |
|
return None, None |
|
|
|
def text_to_speech(text): |
|
"""Convert text to speech using Bark API.""" |
|
try: |
|
response = requests.post(BASE_URL+BARK_API_URL, headers=HEADERS, json={"inputs": text}) |
|
if response.status_code == 200: |
|
return response.content |
|
print(f"Bark Error: {response.status_code} - {response.text}") |
|
except Exception as e: |
|
print(f"Exception in text_to_speech: {e}") |
|
return None |
|
|
|
def generate_image(prompt): |
|
"""Generate an image using the Flux API.""" |
|
try: |
|
response = requests.post(BASE_URL+FLUX_API_URL, headers=HEADERS, json={"inputs": prompt}) |
|
if response.status_code == 200: |
|
return Image.open(BytesIO(response.content)) |
|
print(f"Flux Error: {response.status_code} - {response.text}") |
|
except Exception as e: |
|
print(f"Exception in generate_image: {e}") |
|
return None |
|
|
|
def process_chat(audio_file): |
|
"""Process user input, generate response, and optionally create media.""" |
|
|
|
recognized_text = speech_to_text(audio_file) |
|
if not recognized_text: |
|
return "Speech recognition failed.", None, None |
|
|
|
|
|
response_text, image_prompt = chatbot_logic(recognized_text) |
|
if not response_text: |
|
return "Failed to generate chatbot response.", None, None |
|
|
|
|
|
audio_response = text_to_speech(response_text) |
|
|
|
|
|
generated_image = generate_image(image_prompt) if image_prompt else None |
|
|
|
return response_text, Audio(audio_response, autoplay=True), generated_image |
|
|
|
def create_ui(): |
|
"""Build and launch the Gradio interface.""" |
|
with gr.Blocks(title="Enhanced Voice-to-Voice Chatbot with Images") as ui: |
|
gr.Markdown("## Voice-to-Voice AI Chatbot\nTalk to the AI and see its responses, including images it generates!") |
|
audio_input = gr.Audio(type="filepath", label="Input Audio File") |
|
|
|
submit_button = gr.Button("Submit") |
|
|
|
with gr.Row(): |
|
chatbot_response = gr.Textbox(label="Chatbot Response", lines=4) |
|
with gr.Row(): |
|
audio_output = gr.Audio(label="Audio Response") |
|
image_output = gr.Image(label="Generated Image") |
|
|
|
submit_button.click( |
|
fn=process_chat, |
|
inputs=audio_input, |
|
outputs=[chatbot_response, audio_output, image_output], |
|
show_progress=True |
|
) |
|
|
|
return ui |
|
|
|
if __name__ == "__main__": |
|
create_ui().launch(debug=True) |
|
|