Spaces:
Sleeping
Sleeping
import os | |
from pickle import load, dump | |
import streamlit as st | |
import numpy as np | |
from numpy.typing import NDArray | |
from sentence_transformers import SentenceTransformer | |
from langdetect import detect | |
# Initialize the model once to avoid repeated loading | |
model = SentenceTransformer("paraphrase-multilingual-MiniLM-L12-v2") | |
def get_embedding(documents: list[str]) -> NDArray[np.float32]: | |
""" | |
Generates embeddings for a list of documents using a pre-trained SentenceTransformer model. | |
Args: | |
documents (list[str]): A list of strings (documents) for which embeddings are to be generated. | |
Returns: | |
NDArray: A NumPy array containing the embeddings for each document. | |
""" | |
if isinstance(documents, str): | |
documents = [documents] | |
return model.encode(documents) | |
class Guardrail: | |
""" | |
A class to handle guardrail analysis based on query embeddings. | |
Attributes: | |
guardrail (Any): The guardrail model used for predictions. | |
""" | |
def __init__(self): | |
""" | |
Initializes the Guardrail class with a guardrail model instance. | |
""" | |
file_path = os.path.join("server","security","storage","guardrail_multi.pkl") | |
with open(file_path, "rb") as f: | |
self.guardrail = load(f) | |
def analyze_language(self, query:str) -> bool: | |
""" | |
Analyzes the given query to determine what language it is written in and whether it is english, french, german or spanish. | |
Args: | |
query (str): The input query to be analyzed. | |
Returns: | |
bool: Returns `False` if the query is not a supported language, `True` otherwise. | |
""" | |
det = detect(query) | |
return det in ["en","fr","de","es"] | |
def analyze_query(self, query: str) -> bool: | |
""" | |
Analyzes the given query to determine if it passes the guardrail check. | |
Args: | |
query (str): The input query to be analyzed. | |
Returns: | |
bool: Returns `False` if the query is flagged, `True` otherwise. | |
""" | |
embed_query = get_embedding(documents=[query]) | |
pred = self.guardrail.predict(embed_query.reshape(1, -1)) | |
return pred != 1 # Return True if pred is not 1, otherwise False | |
def incremental_learning(self, X_new, y_new): | |
""" | |
Allows to pursue the guardrail learning with new examples. | |
Args: | |
X_new (str) : string's prompt on which the guardrail is going to be partly fit for incremental training | |
y_new (int) : class label of the prompt | |
""" | |
# Extraction des caractéristiques | |
embedding = model.encode(X_new) | |
# Mise à jour incrémentale du modèle | |
self.guardrail.partial_fit(embedding, y_new, classes=[0, 1]) | |
with open(os.path.join("server","security","storage","guardrail_multi.pkl"), "wb") as f: | |
dump(self.guardrail, f) | |