|
import streamlit as st |
|
import torch |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from transformers import AutoTokenizer, AutoModel |
|
import os |
|
from transformers import AutoTokenizer, AutoModelForMaskedLM |
|
from datetime import datetime |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("AshenR/AshenBERTo") |
|
modelBert = AutoModel.from_pretrained("AshenR/AshenBERTo",output_hidden_states=True) |
|
ishape = (768) |
|
def get_embeddings2(text, token_length, device='cuda'): |
|
import torch |
|
|
|
|
|
device = torch.device(device if torch.cuda.is_available() else 'cpu') |
|
print(f"Using device: {device}") |
|
|
|
|
|
tokens = tokenizer(text, max_length=token_length, padding='max_length', truncation=True, return_tensors='pt') |
|
|
|
|
|
input_ids = tokens.input_ids.to(device) |
|
attention_mask = tokens.attention_mask.to(device) |
|
|
|
|
|
modelBert.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
output = modelBert(input_ids, attention_mask=attention_mask).hidden_states[-1] |
|
|
|
|
|
mean_output = torch.mean(output, dim=1) |
|
|
|
|
|
return mean_output.cpu().detach().numpy() |
|
|
|
import tensorflow as tf |
|
from tensorflow.keras import layers, Model, callbacks |
|
from tensorflow.keras.layers import ( |
|
Dense, Dropout, BatchNormalization, Activation, Input, Bidirectional, LSTM, GlobalAveragePooling1D |
|
) |
|
from tensorflow.keras.regularizers import l2 |
|
import os |
|
|
|
class SiameseNetwork(Model): |
|
def __init__(self, inputShape, featExtractorConfig, lstm_units=128, dropout_rate=0.5, |
|
add_lstm=True, distance_metric="concat", regularization=0.01): |
|
super(SiameseNetwork, self).__init__() |
|
|
|
self.inputShape = inputShape |
|
self.featExtractorConfig = featExtractorConfig |
|
self.add_lstm = add_lstm |
|
self.lstm_units = lstm_units |
|
self.dropout_rate = dropout_rate |
|
self.distance_metric = distance_metric |
|
self.regularization = regularization |
|
|
|
|
|
inp_a = layers.Input(shape=inputShape, name="Input_A") |
|
inp_b = layers.Input(shape=inputShape, name="Input_B") |
|
|
|
|
|
self.feature_extractor = self.build_feature_extractor() |
|
|
|
|
|
feats_a = self.feature_extractor(inp_a) |
|
feats_b = self.feature_extractor(inp_b) |
|
|
|
|
|
if distance_metric == "concat": |
|
distance = layers.Concatenate()([feats_a, feats_b]) |
|
elif distance_metric == "euclidean": |
|
distance = layers.Lambda(lambda tensors: tf.norm(tensors[0] - tensors[1], axis=1, keepdims=True))([feats_a, feats_b]) |
|
elif distance_metric == "cosine": |
|
distance = layers.Lambda(lambda tensors: tf.keras.losses.cosine_similarity(tensors[0], tensors[1]))([feats_a, feats_b]) |
|
else: |
|
raise ValueError(f"Unsupported distance metric: {distance_metric}") |
|
|
|
|
|
outputs = layers.Dense(1, activation="sigmoid", name="Output")(distance) |
|
|
|
|
|
self.model = Model(inputs=[inp_a, inp_b], outputs=outputs) |
|
|
|
def build_feature_extractor(self): |
|
inputs = Input(shape=self.inputShape) |
|
x = inputs |
|
|
|
|
|
for n_units in self.featExtractorConfig: |
|
x = Dense(n_units, activation=None, kernel_regularizer=l2(self.regularization))(x) |
|
x = BatchNormalization()(x) |
|
x = Activation('relu')(x) |
|
x = Dropout(self.dropout_rate)(x) |
|
|
|
if self.add_lstm: |
|
|
|
x = layers.Reshape((-1, self.featExtractorConfig[-1]))(x) |
|
|
|
|
|
x = Bidirectional(LSTM(self.lstm_units, return_sequences=True, dropout=self.dropout_rate))(x) |
|
x = GlobalAveragePooling1D()(x) |
|
|
|
return Model(inputs, x, name="FeatureExtractor") |
|
|
|
def call(self, inputs): |
|
return self.model(inputs) |
|
|
|
def save_model(self, filepath): |
|
|
|
os.makedirs(filepath, exist_ok=True) |
|
|
|
|
|
self.model.save(os.path.join(filepath, 'siamese_model.keras')) |
|
|
|
|
|
config = { |
|
'inputShape': self.inputShape, |
|
'featExtractorConfig': self.featExtractorConfig, |
|
'lstm_units': self.lstm_units, |
|
'dropout_rate': self.dropout_rate, |
|
'add_lstm': self.add_lstm, |
|
'distance_metric': self.distance_metric, |
|
'regularization': self.regularization |
|
} |
|
|
|
|
|
import json |
|
with open(os.path.join(filepath, 'model_config.json'), 'w') as f: |
|
json.dump(config, f) |
|
|
|
print(f"Model saved successfully to {filepath}") |
|
|
|
@classmethod |
|
def load_model(cls, filepath): |
|
|
|
import json |
|
with open(os.path.join(filepath, 'model_config.json'), 'r') as f: |
|
config = json.load(f) |
|
|
|
|
|
siamese_net = cls(**config) |
|
|
|
|
|
siamese_net.model = tf.keras.models.load_model(os.path.join(filepath, 'siamese_model.keras')) |
|
|
|
print(f"Model loaded successfully from {filepath}") |
|
return siamese_net |
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
def predict(input_1, input_2): |
|
loaded_siamese_net = SiameseNetwork.load_model('./saved_siamese_model') |
|
|
|
x1_test = [] |
|
x2_test = [] |
|
|
|
embedding1 = get_embeddings2(input_1, token_length=100).reshape(ishape) |
|
embedding2 = get_embeddings2(input_2, token_length=100).reshape(ishape) |
|
x1_test.append(embedding1) |
|
x2_test.append(embedding2) |
|
|
|
x1_test = np.array(x1_test) |
|
x2_test = np.array(x2_test) |
|
pred = str(round(loaded_siamese_net.predict([x1_test,x2_test])[0][0]*100,2))+ " %" |
|
|
|
result = f"Predicted Similarity: {pred}" |
|
print(input_1,input_2,pred) |
|
|
|
return result |
|
|
|
|
|
import pymongo |
|
|
|
def connect_to_mongodb(connection_string): |
|
|
|
try: |
|
client = pymongo.MongoClient(connection_string) |
|
|
|
client.admin.command('ismaster') |
|
print("Successfully connected to MongoDB Atlas!") |
|
return client |
|
except pymongo.errors.ConnectionFailure as e: |
|
print(f"Could not connect to MongoDB Atlas: {e}") |
|
return None |
|
|
|
|
|
|
|
|
|
|
|
if "feedback_submitted" not in st.session_state: |
|
st.session_state["feedback_submitted"] = False |
|
if "show_feedback" not in st.session_state: |
|
st.session_state["show_feedback"] = False |
|
|
|
def save_feedback(input_1, input_2, prediction, rating, feedback): |
|
connection_string = "mongodb+srv://ashen8810:[email protected]/?retryWrites=true&w=majority&appName=Cluster0" |
|
client = connect_to_mongodb(connection_string) |
|
if client: |
|
try: |
|
collection = client["user_predictions"]["survey"] |
|
document = { |
|
"timestamp": datetime.now(), |
|
"input_1": str(input_1), |
|
"input_2": str(input_2), |
|
"prediction": str(prediction), |
|
"rating": rating, |
|
"feedback": str(feedback) |
|
} |
|
collection.insert_one(document) |
|
print("Data Saved to mongoDB") |
|
return True |
|
|
|
except Exception as e: |
|
st.error(f"Database Error: {e}") |
|
return False |
|
finally: |
|
client.close() |
|
|
|
|
|
|
|
if "feedback_submitted" not in st.session_state: |
|
st.session_state.feedback_submitted = False |
|
if "show_feedback" not in st.session_state: |
|
st.session_state.show_feedback = False |
|
|
|
|
|
st.title("Sinhala Short Sentence Similarity") |
|
st.write("Compare the similarity between two Sinhala sentences.") |
|
|
|
|
|
|
|
input_1 = st.text_input("First Sentence:") |
|
|
|
input_2 = st.text_input("Second Sentence:") |
|
|
|
|
|
if st.button("Compare Sentences", type="primary"): |
|
if input_1 and input_2: |
|
with st.spinner("Calculating similarity..."): |
|
result = predict(input_1, input_2) |
|
if result: |
|
st.success(result) |
|
st.session_state.show_feedback = True |
|
save_feedback(input_1, input_2, result, 0, "Null") |
|
else: |
|
st.warning("Please enter both sentences to compare.") |
|
|
|
|
|
if st.session_state.show_feedback and not st.session_state.feedback_submitted: |
|
st.subheader("Feedback") |
|
is_correct = st.radio("Is this similarity assessment correct?", ("Yes", "No")) |
|
|
|
if is_correct == "No": |
|
rating = st.slider("How accurate was the prediction?", 0.0, 1.0, 0.5, 0.1) |
|
feedback = st.text_area("Please provide detailed feedback:") |
|
|
|
if st.button("Submit Feedback"): |
|
if save_feedback(input_1, input_2, predict(input_1, input_2), rating, feedback): |
|
st.success("Thank you for your feedback!") |
|
st.session_state.feedback_submitted = True |
|
if st.button("Clear"): |
|
st.session_state.clear() |
|
st.session_state.input_1 = "" |
|
st.session_state.input_2 = "" |
|
|
|
|
|
|
|
st.markdown(""" |
|
<hr> |
|
<p style="text-align: center; color: gray; font-size: 0.8em;"> |
|
Developed by Ashen | Version 1.0 |
|
</p> |
|
""", unsafe_allow_html=True) |
|
|
|
|