File size: 9,568 Bytes
fa37060 25dc5c7 73863f2 fa37060 c3e457a fa37060 c3e457a fa37060 cb806bb 5baad4c cb806bb fa37060 6feed8a cb806bb fa37060 6feed8a fa37060 25dc5c7 fa37060 a57b5bc fa37060 6feed8a fa37060 6feed8a adeff61 6feed8a fa37060 25dc5c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
import gradio as gr
import regex as re
import torch
import nltk
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from nltk.tokenize import sent_tokenize
import plotly.express as px
import time
import tqdm
nltk.download('punkt_tab')
# Define the model and tokenizer
checkpoint = "sadickam/sdg-classification-bert"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
# Define the function for preprocessing text
def prep_text(text):
clean_sents = []
sent_tokens = sent_tokenize(str(text))
for sent_token in sent_tokens:
word_tokens = [str(word_token).strip().lower() for word_token in sent_token.split()]
clean_sents.append(' '.join((word_tokens)))
joined = ' '.join(clean_sents).strip(' ')
joined = re.sub(r'`', "", joined)
joined = re.sub(r'"', "", joined)
return joined
# APP INFO
def app_info():
check = """
Please go to either the "Single-Text-Prediction" or "Multi-Text-Prediction" tab to analyse your text.
"""
return check
# Create Gradio interface for single text
iface1 = gr.Interface(
fn=app_info, inputs=None, outputs=['text'], title="General-Infomation",
description= '''
This app, powered by the sgdBERT model (sadickam/sdg-classification-bert), is for automatic classification of text concerning
the UN Sustainable Development Goals (SDG). Note that 16 out of the 17 SDGs labels are covered. This app is for sustainability
assessment and benchmarking and is not limited to a specific industry. The model powering this app was developed using the
OSDG Community Dataset (OSDG-CD) [Link - https://zenodo.org/record/5550238#.Y8Sd5f5ByF5].
This app has two analysis modules summarised below:
- Single-Text-Prediction - Analyses text pasted in a text box and return SDG prediction.
- Multi-Text-Prediction - Analyses multiple rows of texts in an uploaded CSV file and returns a downloadable CSV file with SDG prediction for each row of text.
This app runs on a free server and may therefore not be suitable for analysing large CSV and PDF files.
If you need assistance with analysing large CSV or PDF files, do get in touch using the contact information in the Contact section.
<h3>Contact</h3>
<p>We would be happy to receive your feedback regarding this app. If you would also like to collaborate with us to explore some use cases for the model
powering this app, we are happy to hear from you.</p>
''')
# SINGLE TEXT
# Define the prediction function
def predict_sdg(text):
# Preprocess the input text
cleaned_text = prep_text(text)
if cleaned_text == "":
raise gr.Error('This model needs some text input to return a prediction')
elif cleaned_text != "":
# Tokenize the preprocessed text
tokenized_text = tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=512, padding=True)
# Predict
text_logits = model(**tokenized_text).logits
predictions = torch.softmax(text_logits, dim=1).tolist()[0]
# SDG labels
label_list = [
'GOAL 1: No Poverty',
'GOAL 2: Zero Hunger',
'GOAL 3: Good Health and Well-being',
'GOAL 4: Quality Education',
'GOAL 5: Gender Equality',
'GOAL 6: Clean Water and Sanitation',
'GOAL 7: Affordable and Clean Energy',
'GOAL 8: Decent Work and Economic Growth',
'GOAL 9: Industry, Innovation and Infrastructure',
'GOAL 10: Reduced Inequality',
'GOAL 11: Sustainable Cities and Communities',
'GOAL 12: Responsible Consumption and Production',
'GOAL 13: Climate Action',
'GOAL 14: Life Below Water',
'GOAL 15: Life on Land',
'GOAL 16: Peace, Justice and Strong Institutions'
]
# dictionary with label as key and percentage as value
pred_dict = dict(zip(label_list, predictions))
# sort 'pred_dict' by value and index the highest at [0]
sorted_preds = sorted(pred_dict.items(), key=lambda x: x[1], reverse=True)
# Make dataframe for plotly bar chart
u, v = zip(*sorted_preds)
m = list(u)
n = list(v)
df2 = pd.DataFrame()
df2['SDG'] = m
df2['Likelihood'] = n
# plot graph of predictions
fig = px.bar(df2, x="Likelihood", y="SDG", orientation="h")
fig.update_layout(
# barmode='stack',
template='seaborn', font=dict(family="Arial", size=12, color="black"),
autosize=True,
#width=800,
#height=500,
xaxis_title="Likelihood of SDG",
yaxis_title="Sustainable development goals (SDG)",
# legend_title="Topics"
)
fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
fig.update_annotations(font_size=12) # this changes y_axis, x_axis and subplot title font sizes
# Make dataframe for plotly bar chart
#df2 = pd.DataFrame(sorted_preds, columns=['SDG', 'Likelihood'])
# Return the top prediction
top_prediction = sorted_preds[0]
# Return result
return {top_prediction[0]: round(top_prediction[1], 3)}, fig
# Create Gradio interface for single text
iface2 = gr.Interface(fn=predict_sdg,
inputs=gr.Textbox(lines=7, label="Paste or type text here"),
outputs=[gr.Label(label="Top SDG Predicted", show_label=True), gr.Plot(label="Likelihood of all SDG", show_label=True)],
title="Single Text Prediction",
article="**Note:** The quality of model predictions may depend on the quality of information provided."
)
# UPLOAD CSV
# Define the prediction function
def predict_sdg_from_csv(file, progress=gr.Progress()):
# Read the CSV file
df_docs = pd.read_csv(file)
text_list = df_docs["text_inputs"].tolist()
# SDG labels list
label_list = [
'GOAL 1: No Poverty',
'GOAL 2: Zero Hunger',
'GOAL 3: Good Health and Well-being',
'GOAL 4: Quality Education',
'GOAL 5: Gender Equality',
'GOAL 6: Clean Water and Sanitation',
'GOAL 7: Affordable and Clean Energy',
'GOAL 8: Decent Work and Economic Growth',
'GOAL 9: Industry, Innovation and Infrastructure',
'GOAL 10: Reduced Inequality',
'GOAL 11: Sustainable Cities and Communities',
'GOAL 12: Responsible Consumption and Production',
'GOAL 13: Climate Action',
'GOAL 14: Life Below Water',
'GOAL 15: Life on Land',
'GOAL 16: Peace, Justice and Strong Institutions'
]
# Lists for appending predictions
predicted_labels = []
prediction_score = []
# Preprocess text and make predictions
for text_input in progress.tqdm(text_list, desc="Analysing data"):
time.sleep(0.02) # Sleep to avoid rate limiting
cleaned_text = prep_text(text_input)
tokenized_text = tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=512, padding=True)
text_logits = model(**tokenized_text).logits
predictions = torch.softmax(text_logits, dim=1).tolist()[0]
pred_dict = dict(zip(label_list, predictions))
sorted_preds = sorted(pred_dict.items(), key=lambda g: g[1], reverse=True)
predicted_labels.append(sorted_preds[0][0])
prediction_score.append(sorted_preds[0][1])
# Append predictions to the DataFrame
df_docs['SDG_predicted'] = predicted_labels
df_docs['prediction_score'] = prediction_score
df_docs.to_csv('sdg_predictions.csv')
output_csv = gr.File(value='sdg_predictions.csv', visible=True)
# Create the histogram
fig = px.histogram(df_docs, y="SDG_predicted")
fig.update_layout(
template='seaborn',
font=dict(family="Arial", size=12, color="black"),
autosize=True,
#width=800,
#height=500,
xaxis_title="SDG counts",
yaxis_title="Sustainable development goals (SDG)",
)
fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
fig.update_annotations(font_size=12)
return fig, output_csv
# Define the input component
file_input = gr.File(label="Upload CSV file here", show_label=True, file_types=[".csv"])
# Create the Gradio interface
iface3 = gr.Interface(fn=predict_sdg_from_csv,
inputs= file_input,
outputs=[gr.Plot(label='Frequency of SDGs', show_label=True), gr.File(label='Download output CSV', show_label=True)],
title="Multi-text Prediction (CVS)",
description='**NOTE:** The column to be analysed must be titled ***text_inputs***')
demo = gr.TabbedInterface(interface_list = [iface1, iface2, iface3],
tab_names = ["General-App-Info", "Single-Text-Prediction", "Multi-Text-Prediction (CSV)"],
title = "Sustainble Development Goals (SDG) Text Classifier App",
theme = 'soft'
)
# Run the interface
demo.queue().launch() |