File size: 5,828 Bytes
ca75f71
 
 
 
 
e0084d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd229ab
ca75f71
bd229ab
ca75f71
 
 
bd229ab
ca75f71
 
59e622f
ca75f71
e0084d4
c586725
ca75f71
 
 
c91906e
ca75f71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# classifier.py
import torch
import time
from model_loader import classifier_model
from paraphraser import paraphrase_comment
from metrics import compute_semantic_similarity, compute_empathy_score, compute_bias_score, compute_hallucination_score

def compute_reward_scores(original, paraphrased):
    """
    Compute all reward scores for a paraphrase.
    Returns a dictionary with empathy, toxicity, bias, hallucination, and overall reward.
    """
    try:
        # Get toxicity from classifier
        _, _, _, toxicity_score, bias_score, _, _, _, _, paraphrased_toxicity_score, paraphrased_bias_score, _, _ = classify_toxic_comment(paraphrased)
        toxicity = paraphrased_toxicity_score if paraphrased_toxicity_score is not None else 0.5

        # Compute other metrics
        empathy = compute_empathy_score(paraphrased) or 0.5
        bias = compute_bias_score(paraphrased) or 0.5
        hallucination = compute_hallucination_score(original, paraphrased) or 0.5

        # Overall reward: Weighted combination (adjust weights as needed)
        reward = (0.4 * empathy) - (0.2 * toxicity) - (0.2 * bias) - (0.2 * hallucination)
        reward = max(0.0, min(1.0, round(reward, 2)))

        return {
            "empathy": empathy,
            "toxicity": toxicity,
            "bias": bias,
            "hallucination": hallucination,
            "reward": reward
        }
    except Exception as e:
        print(f"Error computing reward scores: {str(e)}")
        return {
            "empathy": 0.5,
            "toxicity": 0.5,
            "bias": 0.5,
            "hallucination": 0.5,
            "reward": 0.5
        }

def classify_toxic_comment(comment):
    """
    Classify a comment as toxic or non-toxic using the fine-tuned XLM-RoBERTa model.
    If toxic, paraphrase the comment, re-evaluate, and compute essential metrics.
    Returns the prediction label, confidence, color, toxicity score, bias score, paraphrased comment (if applicable), and its metrics.
    """
    start_total = time.time()
    print("Starting classification...")

    if not comment.strip():
        return "Error: Please enter a comment.", None, None, None, None, None, None, None, None, None, None, None, None

    # Access the model and tokenizer
    model = classifier_model.model
    tokenizer = classifier_model.tokenizer

    # Tokenize the input comment
    start_classification = time.time()
    inputs = tokenizer(comment, return_tensors="pt", truncation=True, padding=True, max_length=512)

    # Run inference
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    # Get the predicted class (0 = non-toxic, 1 = toxic)
    predicted_class = torch.argmax(logits, dim=1).item()
    label = "Toxic" if predicted_class == 1 else "Non-Toxic"
    confidence = torch.softmax(logits, dim=1)[0][predicted_class].item()
    label_color = "red" if label == "Toxic" else "green"

    # Compute Toxicity Score (approximated as the probability of the toxic class)
    toxicity_score = torch.softmax(logits, dim=1)[0][1].item()
    toxicity_score = round(toxicity_score, 2)

    # Simulate Bias Score (placeholder)
    bias_score = 0.01 if label == "Non-Toxic" else 0.15
    bias_score = round(bias_score, 2)
    print(f"Classification took {time.time() - start_classification:.2f} seconds")

    # If the comment is toxic, paraphrase it and compute essential metrics
    paraphrased_comment = None
    paraphrased_prediction = None
    paraphrased_confidence = None
    paraphrased_color = None
    paraphrased_toxicity_score = None
    paraphrased_bias_score = None
    semantic_similarity = None
    empathy_score = None

    if label == "Toxic":
        # Paraphrase the comment
        start_paraphrase = time.time()
        paraphrased_comment = paraphrase_comment(comment)
        print(f"Paraphrasing took {time.time() - start_paraphrase:.2f} seconds")

        # Re-evaluate the paraphrased comment
        start_reclassification = time.time()
        paraphrased_inputs = tokenizer(paraphrased_comment, return_tensors="pt", truncation=True, padding=True, max_length=512)
        with torch.no_grad():
            paraphrased_outputs = model(**paraphrased_inputs)
            paraphrased_logits = paraphrased_outputs.logits

        paraphrased_predicted_class = torch.argmax(paraphrased_logits, dim=1).item()
        paraphrased_label = "Toxic" if paraphrased_predicted_class == 1 else "Non-Toxic"
        paraphrased_confidence = torch.softmax(paraphrased_logits, dim=1)[0][paraphrased_predicted_class].item()
        paraphrased_color = "red" if paraphrased_label == "Toxic" else "green"
        paraphrased_toxicity_score = torch.softmax(paraphrased_logits, dim=1)[0][1].item()
        paraphrased_toxicity_score = round(paraphrased_toxicity_score, 2)
        paraphrased_bias_score = 0.01 if paraphrased_label == "Non-Toxic" else 0.15  # Placeholder
        paraphrased_bias_score = round(paraphrased_bias_score, 2)
        print(f"Reclassification of paraphrased comment took {time.time() - start_reclassification:.2f} seconds")

        # Compute essential metrics
        start_metrics = time.time()
        semantic_similarity = compute_semantic_similarity(comment, paraphrased_comment)
        empathy_score = compute_empathy_score(paraphrased_comment)
        print(f"Metrics computation took {time.time() - start_metrics:.2f} seconds")

    print(f"Total processing time: {time.time() - start_total:.2f} seconds")

    return (
        f"Prediction: {label}", confidence, label_color, toxicity_score, bias_score,
        paraphrased_comment, f"Prediction: {paraphrased_label}" if paraphrased_comment else None,
        paraphrased_confidence, paraphrased_color, paraphrased_toxicity_score, paraphrased_bias_score,
        semantic_similarity, empathy_score
    )