Spaces:
Sleeping
Sleeping
import gradio as gr | |
import edge_tts | |
import asyncio | |
import tempfile | |
import numpy as np | |
from pydub import AudioSegment | |
import torch | |
import sentencepiece as spm | |
import onnxruntime as ort | |
from huggingface_hub import hf_hub_download | |
# Dynamic Menu Items | |
MENU = { | |
"Pizza": 10.99, | |
"Burger": 6.99, | |
"Pasta": 8.49, | |
"Salad": 5.49, | |
"Soda": 1.99, | |
"Coffee": 2.99 | |
} | |
cart = [] # To store cart items | |
# Speech Recognition Model Configuration | |
model_name = "neongeckocom/stt_en_citrinet_512_gamma_0_25" | |
sample_rate = 16000 | |
# Download preprocessor, encoder, and tokenizer | |
preprocessor = torch.jit.load(hf_hub_download(model_name, "preprocessor.ts", subfolder="onnx")) | |
encoder = ort.InferenceSession(hf_hub_download(model_name, "model.onnx", subfolder="onnx")) | |
tokenizer = spm.SentencePieceProcessor(hf_hub_download(model_name, "tokenizer.spm", subfolder="onnx")) | |
async def text_to_speech(text): | |
communicate = edge_tts.Communicate(text) | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
tmp_path = tmp_file.name | |
await communicate.save(tmp_path) | |
return tmp_path | |
def resample(audio_fp32, sr): | |
return soxr.resample(audio_fp32, sr, sample_rate) | |
def to_float32(audio_buffer): | |
return np.divide(audio_buffer, np.iinfo(audio_buffer.dtype).max, dtype=np.float32) | |
def transcribe(audio_path): | |
audio_file = AudioSegment.from_file(audio_path) | |
sr = audio_file.frame_rate | |
audio_buffer = np.array(audio_file.get_array_of_samples()) | |
audio_fp32 = to_float32(audio_buffer) | |
audio_16k = resample(audio_fp32, sr) | |
input_signal = torch.tensor(audio_16k).unsqueeze(0) | |
length = torch.tensor(len(audio_16k)).unsqueeze(0) | |
processed_signal, _ = preprocessor.forward(input_signal=input_signal, length=length) | |
logits = encoder.run(None, {'audio_signal': processed_signal.numpy(), 'length': length.numpy()})[0][0] | |
blank_id = tokenizer.vocab_size() | |
decoded_prediction = [p for p in logits.argmax(axis=1).tolist() if p != blank_id] | |
text = tokenizer.decode_ids(decoded_prediction) | |
return text | |
def generate_menu(): | |
menu_text = "Here is our menu:\n" | |
for item, price in MENU.items(): | |
menu_text += f"{item}: ${price:.2f}\n" | |
menu_text += "What would you like to order?" | |
return menu_text | |
def handle_cart(command): | |
global cart | |
response = "" | |
# Check for menu-related commands | |
if "menu" in command.lower(): | |
response = generate_menu() | |
# Check for add-to-cart commands | |
else: | |
for item in MENU.keys(): | |
if item.lower() in command.lower(): | |
cart.append(item) | |
response = f"{item} has been added to your cart." | |
break | |
# If user asks for cart | |
if "cart" in command.lower(): | |
if cart: | |
response = "Your cart contains:\n" + ", ".join(cart) | |
else: | |
response = "Your cart is empty." | |
# If user confirms order | |
if "submit" in command.lower() or "done" in command.lower(): | |
if cart: | |
response = "Your final order is:\n" + ", ".join(cart) + ". Thank you for your order!" | |
cart = [] # Clear the cart | |
else: | |
response = "Your cart is empty. Add some items before submitting." | |
return response | |
async def respond(audio): | |
try: | |
user_command = transcribe(audio) | |
reply = handle_cart(user_command) | |
reply_audio_path = await text_to_speech(reply) | |
return user_command, reply, reply_audio_path | |
except Exception as e: | |
return "Error: Could not transcribe audio.", "Error: Could not process your request.", None | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
audio_input = gr.Audio(label="Speak Here", type="filepath") | |
submit = gr.Button("Submit") | |
with gr.Row(): | |
transcribed_text = gr.Textbox(label="Transcribed Text") | |
response_text = gr.Textbox(label="GPT Response") | |
response_audio = gr.Audio(label="Response Audio") | |
submit.click(fn=respond, inputs=[audio_input], outputs=[transcribed_text, response_text, response_audio]) | |
if __name__ == "__main__": | |
demo.queue().launch() | |