Spaces:
Sleeping
Sleeping
import streamlit as st | |
from streamlit_mic_recorder import mic_recorder | |
from transformers import pipeline | |
import torch | |
from transformers import BertTokenizer, BertForSequenceClassification | |
# from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
import numpy as np | |
import pandas as pd | |
import time | |
import altair as alt | |
def callback(): | |
if st.session_state.my_recorder_output: | |
audio_bytes = st.session_state.my_recorder_output['bytes'] | |
st.audio(audio_bytes) | |
def load_text_to_speech_model(model="openai/whisper-base"): | |
pipe = pipeline("automatic-speech-recognition", model=model) | |
return pipe | |
def translate(inputs, model="openai/whisper-base"): | |
pipe = pipeline("automatic-speech-recognition", model=model) | |
translate_result = pipe(inputs, generate_kwargs={'task': 'translate'}) | |
return translate_result['text'] | |
# def encode_depracated(docs, tokenizer): | |
# ''' | |
# This function takes list of texts and returns input_ids and attention_mask of texts | |
# ''' | |
# encoded_dict = tokenizer.batch_encode_plus(docs, add_special_tokens=True, max_length=128, padding='max_length', | |
# return_attention_mask=True, truncation=True, return_tensors='pt') | |
# input_ids = encoded_dict['input_ids'] | |
# attention_masks = encoded_dict['attention_mask'] | |
# return input_ids, attention_masks | |
# def load_classification_model(): | |
# CUSTOMMODEL_PATH = "./bert-itserviceclassification" | |
# PRETRAINED_LM = "bert-base-uncased" | |
# tokenizer = BertTokenizer.from_pretrained(PRETRAINED_LM, do_lower_case=True) | |
# model = BertForSequenceClassification.from_pretrained(PRETRAINED_LM, | |
# num_labels=8, | |
# output_attentions=False, | |
# output_hidden_states=False) | |
# model.load_state_dict(torch.load(CUSTOMMODEL_PATH, map_location ='cpu')) | |
# return model, tokenizer | |
def load_classification_model(): | |
PRETRAINED_LM = "kkngan/bert-base-uncased-it-service-classification" | |
# model = AutoModelForSequenceClassification.from_pretrained(PRETRAINED_LM, num_labels=8) | |
# tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_LM) | |
tokenizer = BertTokenizer.from_pretrained(PRETRAINED_LM, do_lower_case=True) | |
model = BertForSequenceClassification.from_pretrained(PRETRAINED_LM, | |
num_labels=8) | |
return model, tokenizer | |
def predict(text, model, tokenizer): | |
lookup_key ={0: 'Hardware', | |
1: 'Access', | |
2: 'Miscellaneous', | |
3: 'HR Support', | |
4: 'Purchase', | |
5: 'Administrative rights', | |
6: 'Storage', | |
7: 'Internal Project'} | |
# with torch.no_grad(): | |
# input_ids, att_mask = encode([text], tokenizer) | |
# logits = model(input_ids = input_ids, attention_mask=att_mask).logits | |
inputs = tokenizer(text, | |
padding = True, | |
truncation = True, | |
return_tensors='pt') | |
outputs = model(**inputs) | |
predicted_class_id = outputs.logits.argmax().item() | |
predicted_label = lookup_key.get(predicted_class_id) | |
probability = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().detach().numpy() | |
return predicted_label, predicted_class_id, probability | |
def display_result(translate_text, prediction, predicted_class_id, probability): | |
st.markdown('<font color="purple"><b>Text:</b></font>', unsafe_allow_html=True) | |
st.write(f'{translate_text}') | |
st.write(f'\n') | |
st.write(f'\n') | |
# st.markdown(f'<font color="green"><b>Predicted Class: (Probability: {(probability[0][predicted_class_id] * 100):.2f}%) </b></font>', unsafe_allow_html=True) | |
st.markdown('<font color="green"><b>Predicted Class:</b></font>', unsafe_allow_html=True) | |
st.write(f'{prediction}') | |
# Convert probability to bar cart | |
st.write(f'\n') | |
st.write(f'\n') | |
# Show Probability of each Service Category | |
category = ('Hardware', 'Access', 'Miscellaneous', 'HR Support', 'Purchase', 'Administrative rights', 'Storage', 'Internal Project') | |
probability = np.array(probability[0]) | |
df = pd.DataFrame({'Category': category, 'Probability (%)': probability * 100}) | |
df['Probability (%)'] = df['Probability (%)'].apply(lambda x: round(x, 2)) | |
base = alt.Chart(df).encode( | |
x='Probability (%)', | |
y=alt.Y('Category').sort('-x'), | |
# color='b:O', | |
tooltip=['Category',alt.Tooltip('Probability (%)', format=",.2f")], | |
text='Probability (%)' | |
).properties(title="Probability of each Service Category") | |
chart = base.mark_bar() + base.mark_text(align='left', dx=2) | |
st.altair_chart(chart, use_container_width=True) | |
def main(): | |
# st.cache_resource.clear() | |
st.set_page_config(layout="wide", page_title="NLP IT Service Classification", page_icon="🤖",) | |
st.markdown('<b>🤖 Welcome to IT Service Classification Assistant!!! 🤖</b>', unsafe_allow_html=True) | |
st.write(f'\n') | |
st.write(f'\n') | |
with st.sidebar: | |
st.image('front_page_image.jpg' , use_column_width=True) | |
text_to_speech_model = st.selectbox("Pick select a speech to text model", ["openai/whisper-base", "openai/whisper-large-v3"]) | |
options = st.selectbox("Pick select an input method", ["Start a recording", "Upload an audio", "Enter a transcript"]) | |
if options == "Start a recording": | |
audio = mic_recorder(key='my_recorder', callback=callback) | |
elif options == "Upload an audio": | |
audio = st.file_uploader("Please upload an audio", type=["wav", "mp3"]) | |
else: | |
text = st.text_area("Please input the transcript (Only support English)") | |
button = st.button('Submit') | |
if button: | |
with st.spinner(text="Loading... It may take a while if you are running the app for the first time."): | |
start_time = time.time() | |
if options == "Start a recording": | |
# transcibe_text, translate_text = transcribe_and_translate(upload=audio["bytes"]) | |
translate_text = translate(inputs=audio["bytes"], model=text_to_speech_model) | |
elif options == "Upload an audio": | |
# transcibe_text, translate_text = transcribe_and_translate(upload=audio.getvalue()) | |
translate_text = translate(inputs=audio.getvalue(), model=text_to_speech_model) | |
else: | |
translate_text = text | |
model, tokenizer = load_classification_model() | |
prediction, predicted_class_id, probability = predict(text=translate_text, model=model, tokenizer=tokenizer) | |
end_time = time.time() | |
display_result(translate_text, prediction, predicted_class_id, probability) | |
st.write(f'\n') | |
st.write(f'\n') | |
st.markdown(f'*It took {(end_time-start_time):.2f} sec to process the input.', unsafe_allow_html=True) | |
if __name__ == '__main__': | |
main() | |