import streamlit as st
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from transformers import pipeline
import librosa
import torch
# from spleeter.separator import Separator
from pydub import AudioSegment
from IPython.display import Audio
import os
import accelerate






# preprocess and crop audio file
def audio_preprocess(file_name = '/test1/vocals.wav'):
   # separate music and vocal
   separator = Separator('spleeter:2stems')
   separator.separate_to_file(input_file, output_file)


   # Crop the audio
   start_time = 60000  # e.g. 30 seconds, 30000
   end_time = 110000  # e.g. 40 seconds, 40000




   audio = AudioSegment.from_file(file_name)
   cropped_audio = audio[start_time:end_time]
   processed_audio = cropped_audio
   # .export('cropped_vocals.wav', format='wav') # save vocal audio file
   return processed_audio




# ASR transcription
def asr_model(processed_audio):
   # load audio file
   y, sr = librosa.load(processed_audio, sr=16000)


   # ASR model
   MODEL_NAME = "RexChan/ISOM5240-whisper-small-zhhk_1"
   processor = WhisperProcessor.from_pretrained(MODEL_NAME)
   model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True)


   model.config.forced_decoder_ids = None
   model.config.suppress_tokens = []
   model.config.use_cache = False


   processed_in = processor(y, sampling_rate=sr, return_tensors="pt")
   gout = model.generate(
       input_features=processed_in.input_features,
       output_scores=True, return_dict_in_generate=True
   )
   transcription = processor.batch_decode(gout.sequences, skip_special_tokens=True)[0]


   # print result
   print(f"Song lyrics = {transcription}")


   return transcription




# sentiment analysis
def senti_model(transcription):


   pipe = pipeline("text-classification", model="lxyuan/distilbert-base-multilingual-cased-sentiments-student")
   final_result = pipe(transcription)
   display = f"Sentiment Analysis shows that this song is {final_result[0]['label']}. Confident level of this analysis is {final_result[0]['score']*100:.1f}%."
   print(display)
   return display


   # return final_result




# main
def main(input_file):


   # processed_audio = audio_preprocess(input_file)
   processed_audio = input_file


   transcription = asr_model(processed_audio)
   final_result = senti_model(transcription)
   st.write(final_result)


   if st.button("Play Audio"):
       st.audio(audio_data['audio'],
                   format="audio/wav",
                   start_time=0,
                   sample_rate = audio_data['sampling_rate'])




if __name__ == '__main__':


   # steamlit setup
   st.set_page_config(page_title="Sentiment Analysis on Your Cantonese Song",)
   st.header("Cantonese Song Sentiment Analyzer")
   input_file = st.file_uploader("upload a song in mp3 format", type="mp3") # upload song
   if input_file is not None:
       st.write("File uploaded successfully!")
       st.write(input_file)
   else:
       st.write("No file uploaded.")
   button_click = st.button("Run Analysis", type="primary")


   # load song
   #input_file = os.path.isfile("test1.mp3")
   # output_file = os.path.isdir("")


   if button_click:
       main(input_file=input_file)