kkngan commited on
Commit
99b23ed
Β·
verified Β·
1 Parent(s): e4d5b68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -35
app.py CHANGED
@@ -3,20 +3,24 @@ from streamlit_mic_recorder import mic_recorder
3
  from transformers import pipeline
4
  import torch
5
  from transformers import BertTokenizer, BertForSequenceClassification, AutoModelForSequenceClassification, AutoTokenizer
 
6
  import numpy as np
7
  import pandas as pd
 
 
8
 
9
  def callback():
10
  if st.session_state.my_recorder_output:
11
  audio_bytes = st.session_state.my_recorder_output['bytes']
12
  st.audio(audio_bytes)
13
 
14
- def transcribe_and_translate(upload):
15
- # pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large")
16
- pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large-v3")
17
- transcribe_result = pipe(upload, generate_kwargs={'task': 'transcribe'})
18
- translate_result = pipe(upload, generate_kwargs={'task': 'translate'})
19
- return transcribe_result['text'], translate_result['text']
 
20
 
21
  def encode_depracated(docs, tokenizer):
22
  '''
@@ -29,16 +33,16 @@ def encode_depracated(docs, tokenizer):
29
  return input_ids, attention_masks
30
 
31
 
32
- def load_model():
33
- CUSTOMMODEL_PATH = "./bert-itserviceclassification"
34
- PRETRAINED_LM = "bert-base-uncased"
35
- tokenizer = BertTokenizer.from_pretrained(PRETRAINED_LM, do_lower_case=True)
36
- model = BertForSequenceClassification.from_pretrained(PRETRAINED_LM,
37
- num_labels=8,
38
- output_attentions=False,
39
- output_hidden_states=False)
40
- model.load_state_dict(torch.load(CUSTOMMODEL_PATH, map_location ='cpu'))
41
- return model, tokenizer
42
 
43
 
44
  def load_model():
@@ -67,18 +71,19 @@ def predict(text, model, tokenizer):
67
  outputs = model(**inputs)
68
  predicted_class_id = outputs.logits.argmax().item()
69
  predicted_label = lookup_key.get(predicted_class_id)
70
- probability = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().detach().numpy()
71
- return predicted_label, probability
72
 
73
 
74
  def main():
75
-
76
  st.set_page_config(layout="wide", page_title="NLP IT Service Classification", page_icon="πŸ€–",)
77
  st.markdown('<b>πŸ€– Welcome to IT Service Classification Assistant!!! πŸ€–</b>', unsafe_allow_html=True)
78
  st.write(f'\n')
 
79
 
80
  with st.sidebar:
81
  st.image('front_page_image.jpg' , use_column_width=True)
 
82
  options = st.selectbox("Pick select an input method", ["Start a recording", "Upload an audio", "Enter a transcript"])
83
  if options == "Start a recording":
84
  audio = mic_recorder(key='my_recorder', callback=callback)
@@ -89,32 +94,46 @@ def main():
89
  button = st.button('Submit')
90
 
91
  if button:
 
92
  with st.spinner(text="Loading... It may take a while if you are running the app for the first time."):
 
93
  model, tokenizer = load_model()
94
  if options == "Start a recording":
95
- transcibe_text, translate_text = transcribe_and_translate(upload=audio["bytes"])
96
- prediction, probability = predict(text=translate_text, model=model, tokenizer=tokenizer)
 
97
  elif options == "Upload an audio":
98
- transcibe_text, translate_text = transcribe_and_translate(upload=audio.getvalue())
99
- prediction, probability = predict(text=translate_text, model=model, tokenizer=tokenizer)
 
100
  else:
101
- transcibe_text = text
102
- prediction, probability = predict(text=text, model=model, tokenizer=tokenizer)
103
- st.markdown('<font color="blue"><b>Transcript:</b></font>', unsafe_allow_html=True)
104
- st.write(f'{transcibe_text}')
 
 
 
 
 
 
105
  st.write(f'\n')
106
- if options != "Enter a transcript":
107
- st.markdown('<font color="red"><b>Translation:</b></font>', unsafe_allow_html=True)
108
- st.write(f'{translate_text}')
109
- st.write(f'\n')
110
  st.markdown('<font color="green"><b>Predicted Class:</b></font>', unsafe_allow_html=True)
111
  st.write(f'{prediction}')
112
 
113
- # Convert probability to bar
 
 
 
 
 
 
 
 
 
 
114
  st.write(f'\n')
115
- objects = ('Hardware', 'Access', 'Miscellaneous', 'HR Support', 'Purchase', 'Administrative rights', 'Storage', 'Internal Project')
116
- df = pd.DataFrame({'Categories': objects, 'Probability': np.around(probability[0])})
117
- st.bar_chart(data=df, x='Categories', y='Probability')
118
 
119
  if __name__ == '__main__':
120
  main()
 
3
  from transformers import pipeline
4
  import torch
5
  from transformers import BertTokenizer, BertForSequenceClassification, AutoModelForSequenceClassification, AutoTokenizer
6
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor
7
  import numpy as np
8
  import pandas as pd
9
+ import time
10
+
11
 
12
  def callback():
13
  if st.session_state.my_recorder_output:
14
  audio_bytes = st.session_state.my_recorder_output['bytes']
15
  st.audio(audio_bytes)
16
 
17
+
18
+ def translate(inputs, model="openai/whisper-medium"):
19
+ pipe = pipeline("automatic-speech-recognition", model=model)
20
+ # transcribe_result = pipe(upload, generate_kwargs={'task': 'transcribe'})
21
+ translate_result = pipe(inputs, generate_kwargs={'task': 'translate'})
22
+ return translate_result['text']
23
+
24
 
25
  def encode_depracated(docs, tokenizer):
26
  '''
 
33
  return input_ids, attention_masks
34
 
35
 
36
+ # def load_model_deprecated():
37
+ # CUSTOMMODEL_PATH = "./bert-itserviceclassification"
38
+ # PRETRAINED_LM = "bert-base-uncased"
39
+ # tokenizer = BertTokenizer.from_pretrained(PRETRAINED_LM, do_lower_case=True)
40
+ # model = BertForSequenceClassification.from_pretrained(PRETRAINED_LM,
41
+ # num_labels=8,
42
+ # output_attentions=False,
43
+ # output_hidden_states=False)
44
+ # model.load_state_dict(torch.load(CUSTOMMODEL_PATH, map_location ='cpu'))
45
+ # return model, tokenizer
46
 
47
 
48
  def load_model():
 
71
  outputs = model(**inputs)
72
  predicted_class_id = outputs.logits.argmax().item()
73
  predicted_label = lookup_key.get(predicted_class_id)
74
+ confidence = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().detach().numpy()
75
+ return predicted_label, confidence
76
 
77
 
78
  def main():
 
79
  st.set_page_config(layout="wide", page_title="NLP IT Service Classification", page_icon="πŸ€–",)
80
  st.markdown('<b>πŸ€– Welcome to IT Service Classification Assistant!!! πŸ€–</b>', unsafe_allow_html=True)
81
  st.write(f'\n')
82
+ st.write(f'\n')
83
 
84
  with st.sidebar:
85
  st.image('front_page_image.jpg' , use_column_width=True)
86
+ text_to_speech_model = st.selectbox("Pick select a text_to_speech_model", ["openai/whisper-base", "openai/whisper-medium", "openai/whisper-large-v3"])
87
  options = st.selectbox("Pick select an input method", ["Start a recording", "Upload an audio", "Enter a transcript"])
88
  if options == "Start a recording":
89
  audio = mic_recorder(key='my_recorder', callback=callback)
 
94
  button = st.button('Submit')
95
 
96
  if button:
97
+
98
  with st.spinner(text="Loading... It may take a while if you are running the app for the first time."):
99
+ start_time = time.time()
100
  model, tokenizer = load_model()
101
  if options == "Start a recording":
102
+ # transcibe_text, translate_text = transcribe_and_translate(upload=audio["bytes"])
103
+ translate_text = translate(inputs=audio["bytes"], model=text_to_speech_model)
104
+ prediction, confidence = predict(text=translate_text, model=model, tokenizer=tokenizer)
105
  elif options == "Upload an audio":
106
+ # transcibe_text, translate_text = transcribe_and_translate(upload=audio.getvalue())
107
+ translate_text = translate(inputs=audio.getvalue(), model=text_to_speech_model)
108
+ prediction, confidence = predict(text=translate_text, model=model, tokenizer=tokenizer)
109
  else:
110
+ translate_text = text
111
+ prediction, confidence = predict(text=text, model=model, tokenizer=tokenizer)
112
+ end_time = time.time()
113
+ # st.markdown('<font color="blue"><b>Transcript:</b></font>', unsafe_allow_html=True)
114
+ # st.write(f'{transcibe_text}')
115
+ # st.write(f'\n')
116
+ # if options != "Enter a transcript":
117
+ st.markdown('<font color="purple"><b>(Translated) Text:</b></font>', unsafe_allow_html=True)
118
+ st.write(f'{translate_text}')
119
+ st.write(f'\n')
120
  st.write(f'\n')
 
 
 
 
121
  st.markdown('<font color="green"><b>Predicted Class:</b></font>', unsafe_allow_html=True)
122
  st.write(f'{prediction}')
123
 
124
+ # Convert confidence to bar cart
125
+ st.write(f'\n')
126
+ st.write(f'\n')
127
+ category = ('Hardware', 'Access', 'Miscellaneous', 'HR Support', 'Purchase', 'Administrative rights', 'Storage', 'Internal Project')
128
+ confidence = np.array(confidence[0])
129
+ df = pd.DataFrame({'Category': category, 'Confidence (%)': confidence * 100})
130
+ df['Confidence (%)'] = df['Confidence (%)'].apply(lambda x: round(x, 2))
131
+ st.bar_chart(data=df, x='Category', y='Confidence (%)')
132
+ # df = df.sort_values(by='Confidence (%)', ascending=False).reset_index(drop=True)
133
+ # st.write(df)
134
+ st.write(f'\n')
135
  st.write(f'\n')
136
+ st.markdown(f'*It took {(end_time-start_time):.2f} sec to process the input', unsafe_allow_html=True)
 
 
137
 
138
  if __name__ == '__main__':
139
  main()