File size: 8,482 Bytes
0e903f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import torch
import logging
import torch.nn.functional as F
from PIL import Image
from transformers import AutoTokenizer, AutoModel, Swinv2Model
from torchvision import transforms
from src.model.model import MisinformationDetectionModel

logger = logging.getLogger(__name__)


class MisinformationPredictor:
    def __init__(
        self,
        model_path,
        device="cuda" if torch.cuda.is_available() else "cpu",
        embed_dim=256,
        num_heads=8,
        dropout=0.1,
        hidden_dim=64,
        num_classes=3,
        mlp_ratio=4.0,
        text_input_dim=384,
        image_input_dim=1024,
        fused_attn=False,
        text_encoder="microsoft/deberta-v3-xsmall",
    ):
        """
        Initialize the predictor with a trained model and required encoders.

        Args:
            model_path: Path to the saved model checkpoint
            text_encoder: Name/path of the text encoder model
            device: Device to run inference on
            Other args: Model architecture parameters
        """
        self.device = torch.device(device)

        # Initialize tokenizer and encoders
        logger.info("Loading encoders...")
        self.tokenizer = AutoTokenizer.from_pretrained(text_encoder)
        self.text_encoder = AutoModel.from_pretrained(text_encoder).to(self.device)
        self.image_encoder = Swinv2Model.from_pretrained(
            "microsoft/swinv2-base-patch4-window8-256"
        ).to(self.device)

        # Set encoders to eval mode
        self.text_encoder.eval()
        self.image_encoder.eval()

        # Initialize model
        self.model = MisinformationDetectionModel(
            text_input_dim=text_input_dim,
            image_input_dim=image_input_dim,
            embed_dim=embed_dim,
            num_heads=num_heads,
            dropout=dropout,
            hidden_dim=hidden_dim,
            num_classes=num_classes,
            mlp_ratio=mlp_ratio,
            fused_attn=fused_attn,
        ).to(self.device)

        # Load model weights
        logger.info(f"Loading model from {model_path}")
        checkpoint = torch.load(model_path, map_location=self.device)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.model.eval()

        # Image preprocessing
        self.image_transform = transforms.Compose(
            [
                transforms.Resize((256, 256)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

        # Class mapping
        self.idx_to_label = {0: "support", 1: "not_enough_information", 2: "refute"}

    def process_image(self, image_path):
        """Process image from path to tensor."""
        try:
            image = Image.open(image_path).convert("RGB")
            image = self.image_transform(image).unsqueeze(0)  # Add batch dimension
            return image.to(self.device)
        except Exception as e:
            logger.error(f"Error processing image {image_path}: {e}")
            return None

    @torch.no_grad()
    def evaluate(
        self, claim_text, claim_image_path, evidence_text, evidence_image_path
    ):
        """
        Evaluate a single claim-evidence pair.

        Args:
            claim_text (str): The claim text
            claim_image_path (str): Path to the claim image
            evidence_text (str): The evidence text
            evidence_image_path (str): Path to the evidence image

        Returns:
            dict: Dictionary containing predictions from all modality combinations
        """
        try:
            # Process text inputs
            claim_text_inputs = self.tokenizer(
                claim_text,
                truncation=True,
                padding="max_length",
                max_length=512,
                return_tensors="pt",
            ).to(self.device)

            evidence_text_inputs = self.tokenizer(
                evidence_text,
                truncation=True,
                padding="max_length",
                max_length=512,
                return_tensors="pt",
            ).to(self.device)

            # Get text embeddings
            claim_text_embeds = self.text_encoder(**claim_text_inputs).last_hidden_state
            evidence_text_embeds = self.text_encoder(
                **evidence_text_inputs
            ).last_hidden_state

            # Process image inputs
            claim_image = self.process_image(claim_image_path)
            evidence_image = self.process_image(evidence_image_path)

            # Process claim image
            if claim_image is not None:
                claim_image_embeds = self.image_encoder(claim_image).last_hidden_state
            else:
                logger.warning(
                    "Claim image processing failed, setting embedding to None"
                )
                claim_image_embeds = None

            # Process evidence image
            if evidence_image is not None:
                evidence_image_embeds = self.image_encoder(
                    evidence_image
                ).last_hidden_state
            else:
                logger.warning(
                    "Evidence image processing failed, setting embedding to None"
                )
                evidence_image_embeds = None

            # Get model predictions
            (y_t_t, y_t_i), (y_i_t, y_i_i) = self.model(
                X_t=claim_text_embeds,
                X_i=claim_image_embeds,
                E_t=evidence_text_embeds,
                E_i=evidence_image_embeds,
            )

            # Process predictions with confidence scores
            predictions = {}

            def process_output(output, path_name):
                if output is not None:
                    probs = F.softmax(output, dim=-1)
                    pred_idx = probs.argmax(dim=-1).item()
                    confidence = probs[0][pred_idx].item()
                    return {
                        "label": self.idx_to_label[pred_idx],
                        "confidence": confidence,
                        "probabilities": {
                            self.idx_to_label[i]: p.item()
                            for i, p in enumerate(probs[0])
                        },
                    }
                return None

            predictions["text_text"] = process_output(y_t_t, "text_text")
            predictions["text_image"] = process_output(y_t_i, "text_image")
            predictions["image_text"] = process_output(y_i_t, "image_text")
            predictions["image_image"] = process_output(y_i_i, "image_image")

            return {
                path: pred["label"] if pred else None
                for path, pred in predictions.items()
            }

        except Exception as e:
            logger.error(f"Error during evaluation: {e}")
            return None


if __name__ == "__main__":
    # Example usage
    logging.basicConfig(level=logging.INFO)

    predictor = MisinformationPredictor(model_path="ckpts/model.pt", device="cpu")

    # Example prediction
    predictions = predictor.evaluate(
        claim_text="Musician Kodak Black was shot outside of a nightclub in Florida in December 2016.",
        claim_image_path="./data/raw/factify/extracted/images/test/0_claim.jpg",
        evidence_text="On 26 December 2016, the web site Gummy Post published an article claiming \
                        that musician Kodak Black was shot outside a nightclub in Florida. \
                        This article is a hoax. While Gummy Post cited a 'police report', no records exist \
                        of any shooting involving Kodak Black (real name Dieuson Octave) in Florida during December 2016. \
                        Additionally, the video Gummy Post shared as evidence showed an unrelated crime scene.",
        evidence_image_path="./data/raw/factify/extracted/images/test/0_evidence.jpg",
    )

    print(predictions)
    # Print predictions
    # if predictions:
    #     print("\nPredictions:")
    #     for path, pred in predictions.items():
    #         if pred:
    #             print(f"\n{path}:")
    #             print(f"  Label: {pred['label']}")
    #             print(f"  Confidence: {pred['confidence']:.4f}")
    #             print("  Probabilities:")
    #             for label, prob in pred["probabilities"].items():
    #                 print(f"    {label}: {prob:.4f}")