import streamlit as st
from transformers import pipeline
import numpy as np
from scipy.io.wavfile import write

# Title of the Streamlit app
st.title("Text-to-Speech+ Generation App")

# Text area for user input
text_input = st.text_area('Enter text prompt')

# Create the audio generation pipeline
try:
    pipe = pipeline(model="suno/bark-small")
except ImportError as e:
    st.error(f"Error importing pipeline from transformers: {e}")
    st.stop()

# Generate audio based on user input
if text_input:
    with st.spinner('Generating audio...'):
        output = pipe(text_input)

        # Extract audio array and sampling rate from the output
        audio_array = output["audio"]
        sampling_rate = output["sampling_rate"]
        
        # Ensure the audio array is a numpy array
        audio_array = np.array(audio_array, dtype=np.float32)
        
        # Squeeze to remove single-dimensional entries from the shape of the array
        audio_array = np.squeeze(audio_array)
        
        # Save the audio array as a WAV file
        write("output.wav", sampling_rate, audio_array)

        # Read the saved WAV file
        audio_file = open("output.wav", "rb")
        audio_bytes = audio_file.read()
        
        # Display the output audio
        st.audio(audio_bytes, format="audio/wav")

# Optional: Display JSON output for debugging
if st.checkbox('Show raw output'):
    st.json(output)