kkngan commited on
Commit
8c0ca02
·
verified ·
1 Parent(s): 82c845d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -45
app.py CHANGED
@@ -2,8 +2,8 @@ import streamlit as st
2
  from streamlit_mic_recorder import mic_recorder
3
  from transformers import pipeline
4
  import torch
5
- from transformers import BertTokenizer, BertForSequenceClassification, AutoModelForSequenceClassification, AutoTokenizer
6
- from transformers import WhisperForConditionalGeneration, WhisperProcessor
7
  import numpy as np
8
  import pandas as pd
9
  import time
@@ -15,22 +15,27 @@ def callback():
15
  st.audio(audio_bytes)
16
 
17
 
 
 
 
 
 
 
18
  def translate(inputs, model="openai/whisper-medium"):
19
  pipe = pipeline("automatic-speech-recognition", model=model)
20
- # transcribe_result = pipe(upload, generate_kwargs={'task': 'transcribe'})
21
  translate_result = pipe(inputs, generate_kwargs={'task': 'translate'})
22
  return translate_result['text']
23
 
24
 
25
- def encode_depracated(docs, tokenizer):
26
- '''
27
- This function takes list of texts and returns input_ids and attention_mask of texts
28
- '''
29
- encoded_dict = tokenizer.batch_encode_plus(docs, add_special_tokens=True, max_length=128, padding='max_length',
30
- return_attention_mask=True, truncation=True, return_tensors='pt')
31
- input_ids = encoded_dict['input_ids']
32
- attention_masks = encoded_dict['attention_mask']
33
- return input_ids, attention_masks
34
 
35
 
36
  # def load_model_deprecated():
@@ -44,8 +49,8 @@ def encode_depracated(docs, tokenizer):
44
  # model.load_state_dict(torch.load(CUSTOMMODEL_PATH, map_location ='cpu'))
45
  # return model, tokenizer
46
 
47
-
48
- def load_model():
49
  PRETRAINED_LM = "kkngan/bert-base-uncased-it-service-classification"
50
  model = AutoModelForSequenceClassification.from_pretrained(PRETRAINED_LM, num_labels=8)
51
  tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_LM)
@@ -71,8 +76,28 @@ def predict(text, model, tokenizer):
71
  outputs = model(**inputs)
72
  predicted_class_id = outputs.logits.argmax().item()
73
  predicted_label = lookup_key.get(predicted_class_id)
74
- confidence = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().detach().numpy()
75
- return predicted_label, confidence
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
 
78
  def main():
@@ -83,8 +108,7 @@ def main():
83
 
84
  with st.sidebar:
85
  st.image('front_page_image.jpg' , use_column_width=True)
86
- text_to_speech_model = st.selectbox("Pick select a speech to text model",
87
- ["openai/whisper-base", "openai/whisper-medium", "openai/whisper-large", "openai/whisper-large-v3"])
88
  options = st.selectbox("Pick select an input method", ["Start a recording", "Upload an audio", "Enter a transcript"])
89
  if options == "Start a recording":
90
  audio = mic_recorder(key='my_recorder', callback=callback)
@@ -94,47 +118,27 @@ def main():
94
  text = st.text_area("Please input the transcript (Only support English)")
95
  button = st.button('Submit')
96
 
97
- if button:
98
-
99
  with st.spinner(text="Loading... It may take a while if you are running the app for the first time."):
100
  start_time = time.time()
101
- model, tokenizer = load_model()
102
  if options == "Start a recording":
103
  # transcibe_text, translate_text = transcribe_and_translate(upload=audio["bytes"])
104
  translate_text = translate(inputs=audio["bytes"], model=text_to_speech_model)
105
- prediction, confidence = predict(text=translate_text, model=model, tokenizer=tokenizer)
106
  elif options == "Upload an audio":
107
  # transcibe_text, translate_text = transcribe_and_translate(upload=audio.getvalue())
108
  translate_text = translate(inputs=audio.getvalue(), model=text_to_speech_model)
109
- prediction, confidence = predict(text=translate_text, model=model, tokenizer=tokenizer)
110
  else:
111
  translate_text = text
112
- prediction, confidence = predict(text=text, model=model, tokenizer=tokenizer)
 
113
  end_time = time.time()
114
- # st.markdown('<font color="blue"><b>Transcript:</b></font>', unsafe_allow_html=True)
115
- # st.write(f'{transcibe_text}')
116
- # st.write(f'\n')
117
- # if options != "Enter a transcript":
118
- st.markdown('<font color="purple"><b>(Translated) Text:</b></font>', unsafe_allow_html=True)
119
- st.write(f'{translate_text}')
120
- st.write(f'\n')
121
- st.write(f'\n')
122
- st.markdown('<font color="green"><b>Predicted Class:</b></font>', unsafe_allow_html=True)
123
- st.write(f'{prediction}')
124
 
125
- # Convert confidence to bar cart
126
- st.write(f'\n')
127
- st.write(f'\n')
128
- category = ('Hardware', 'Access', 'Miscellaneous', 'HR Support', 'Purchase', 'Administrative rights', 'Storage', 'Internal Project')
129
- confidence = np.array(confidence[0])
130
- df = pd.DataFrame({'Category': category, 'Confidence (%)': confidence * 100})
131
- df['Confidence (%)'] = df['Confidence (%)'].apply(lambda x: round(x, 2))
132
- st.bar_chart(data=df, x='Category', y='Confidence (%)')
133
- # df = df.sort_values(by='Confidence (%)', ascending=False).reset_index(drop=True)
134
- # st.write(df)
135
  st.write(f'\n')
136
  st.write(f'\n')
137
- st.markdown(f'*It took {(end_time-start_time):.2f} sec to process the input', unsafe_allow_html=True)
 
138
 
139
  if __name__ == '__main__':
140
  main()
 
2
  from streamlit_mic_recorder import mic_recorder
3
  from transformers import pipeline
4
  import torch
5
+ # from transformers import BertTokenizer, BertForSequenceClassification
6
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
7
  import numpy as np
8
  import pandas as pd
9
  import time
 
15
  st.audio(audio_bytes)
16
 
17
 
18
+ @st.cache_resource
19
+ def load_text_to_speech_model(model="openai/whisper-medium"):
20
+ pipe = pipeline("automatic-speech-recognition", model=model)
21
+ return pipe
22
+
23
+
24
  def translate(inputs, model="openai/whisper-medium"):
25
  pipe = pipeline("automatic-speech-recognition", model=model)
 
26
  translate_result = pipe(inputs, generate_kwargs={'task': 'translate'})
27
  return translate_result['text']
28
 
29
 
30
+ # def encode_depracated(docs, tokenizer):
31
+ # '''
32
+ # This function takes list of texts and returns input_ids and attention_mask of texts
33
+ # '''
34
+ # encoded_dict = tokenizer.batch_encode_plus(docs, add_special_tokens=True, max_length=128, padding='max_length',
35
+ # return_attention_mask=True, truncation=True, return_tensors='pt')
36
+ # input_ids = encoded_dict['input_ids']
37
+ # attention_masks = encoded_dict['attention_mask']
38
+ # return input_ids, attention_masks
39
 
40
 
41
  # def load_model_deprecated():
 
49
  # model.load_state_dict(torch.load(CUSTOMMODEL_PATH, map_location ='cpu'))
50
  # return model, tokenizer
51
 
52
+ @st.cache_resource
53
+ def load_classification_model():
54
  PRETRAINED_LM = "kkngan/bert-base-uncased-it-service-classification"
55
  model = AutoModelForSequenceClassification.from_pretrained(PRETRAINED_LM, num_labels=8)
56
  tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_LM)
 
76
  outputs = model(**inputs)
77
  predicted_class_id = outputs.logits.argmax().item()
78
  predicted_label = lookup_key.get(predicted_class_id)
79
+ probability = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().detach().numpy()
80
+ return predicted_label, predicted_class_id, probability
81
+
82
+
83
+ def display_result(translate_text, prediction, predicted_class_id, probability):
84
+ st.markdown('<font color="purple"><b>Text:</b></font>', unsafe_allow_html=True)
85
+ st.write(f'{translate_text}')
86
+ st.write(f'\n')
87
+ st.write(f'\n')
88
+
89
+ st.markdown(f'<font color="green"><b>Predicted Class: (Probability: {(probability[0][predicted_class_id] * 100):.2f}%) </b></font>', unsafe_allow_html=True)
90
+ st.write(f'{prediction}')
91
+
92
+ # Convert probability to bar cart
93
+ st.write(f'\n')
94
+ st.write(f'\n')
95
+
96
+ category = ('Hardware', 'Access', 'Miscellaneous', 'HR Support', 'Purchase', 'Administrative rights', 'Storage', 'Internal Project')
97
+ probability = np.array(probability[0])
98
+ df = pd.DataFrame({'Category': category, 'Probability (%)': probability * 100})
99
+ df['Probability (%)'] = df['Probability (%)'].apply(lambda x: round(x, 2))
100
+ st.bar_chart(data=df, x='Category', y='Probability (%)')
101
 
102
 
103
  def main():
 
108
 
109
  with st.sidebar:
110
  st.image('front_page_image.jpg' , use_column_width=True)
111
+ text_to_speech_model = st.selectbox("Pick select a speech to text model", ["openai/whisper-base", "openai/whisper-medium", "openai/whisper-large", "openai/whisper-large-v3"])
 
112
  options = st.selectbox("Pick select an input method", ["Start a recording", "Upload an audio", "Enter a transcript"])
113
  if options == "Start a recording":
114
  audio = mic_recorder(key='my_recorder', callback=callback)
 
118
  text = st.text_area("Please input the transcript (Only support English)")
119
  button = st.button('Submit')
120
 
121
+ if button:
 
122
  with st.spinner(text="Loading... It may take a while if you are running the app for the first time."):
123
  start_time = time.time()
 
124
  if options == "Start a recording":
125
  # transcibe_text, translate_text = transcribe_and_translate(upload=audio["bytes"])
126
  translate_text = translate(inputs=audio["bytes"], model=text_to_speech_model)
 
127
  elif options == "Upload an audio":
128
  # transcibe_text, translate_text = transcribe_and_translate(upload=audio.getvalue())
129
  translate_text = translate(inputs=audio.getvalue(), model=text_to_speech_model)
 
130
  else:
131
  translate_text = text
132
+ model, tokenizer = load_classification_model()
133
+ prediction, predicted_class_id, probability = predict(text=translate_text, model=model, tokenizer=tokenizer)
134
  end_time = time.time()
 
 
 
 
 
 
 
 
 
 
135
 
136
+ display_result(translate_text, prediction, predicted_class_id, probability)
137
+
 
 
 
 
 
 
 
 
138
  st.write(f'\n')
139
  st.write(f'\n')
140
+ st.markdown(f'*It took {(end_time-start_time):.2f} sec to process the input.', unsafe_allow_html=True)
141
+
142
 
143
  if __name__ == '__main__':
144
  main()