|
|
|
|
|
from model_loader import classifier_model, metrics_models |
|
|
import torch |
|
|
import numpy as np |
|
|
import time |
|
|
|
|
|
def softmax(logits): |
|
|
exp_logits = np.exp(logits - np.max(logits)) |
|
|
return exp_logits / exp_logits.sum() |
|
|
|
|
|
def compute_reward_scores(original, paraphrase): |
|
|
""" |
|
|
Compute reward scores for a paraphrased comment. |
|
|
Returns a dictionary with empathy, toxicity, bias, hallucination, and reward scores. |
|
|
""" |
|
|
try: |
|
|
start_time = time.time() |
|
|
print("Starting reward computation...") |
|
|
|
|
|
|
|
|
if not isinstance(paraphrase, str) or "Error: Unable to generate paraphrase" in paraphrase: |
|
|
print(f"Invalid paraphrase: {paraphrase}. Returning default scores.") |
|
|
return { |
|
|
"empathy": 0.0, |
|
|
"toxicity": 1.0, |
|
|
"bias": 1.0, |
|
|
"hallucination": 1.0, |
|
|
"reward": 0.0 |
|
|
} |
|
|
|
|
|
|
|
|
print("Starting classification...") |
|
|
inputs = classifier_model.tokenizer( |
|
|
paraphrase, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
padding=True, |
|
|
max_length=512 |
|
|
).to(classifier_model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = classifier_model.model(**inputs) |
|
|
logits = outputs.logits.cpu().numpy()[0] |
|
|
probs = softmax(logits) |
|
|
|
|
|
toxicity = probs[1] |
|
|
empathy = 1.0 - toxicity |
|
|
bias = probs[1] |
|
|
print(f"Classification took {time.time() - start_time:.2f} seconds") |
|
|
|
|
|
|
|
|
print("Computing semantic similarity...") |
|
|
sentence_bert = metrics_models.sentence_bert |
|
|
embeddings = sentence_bert.encode([original, paraphrase], convert_to_tensor=True) |
|
|
similarity = torch.cosine_similarity(embeddings[0], embeddings[1], dim=0).item() |
|
|
hallucination = 1.0 - similarity |
|
|
print(f"Semantic similarity computed: {similarity}") |
|
|
|
|
|
|
|
|
reward = 0.4 * empathy - 0.2 * toxicity - 0.2 * bias - 0.2 * hallucination |
|
|
reward = max(0.0, min(1.0, reward)) |
|
|
|
|
|
print(f"Total processing time: {time.time() - start_time:.2f} seconds") |
|
|
return { |
|
|
"empathy": empathy, |
|
|
"toxicity": toxicity, |
|
|
"bias": bias, |
|
|
"hallucination": hallucination, |
|
|
"reward": reward |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in reward computation: {str(e)}") |
|
|
return { |
|
|
"empathy": 0.0, |
|
|
"toxicity": 1.0, |
|
|
"bias": 1.0, |
|
|
"hallucination": 1.0, |
|
|
"reward": 0.0 |
|
|
} |