wendys-llc commited on
Commit
5ae3ade
·
verified ·
1 Parent(s): 7bf8f2e

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +4 -55
model.py CHANGED
@@ -1,10 +1,6 @@
1
- from transformers import PreTrainedModel, PretrainedConfig, ImageProcessingMixin
2
- import torch
3
  import torch.nn as nn
4
  from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
5
- from torchvision import transforms
6
- from PIL import Image
7
- import numpy as np
8
 
9
  class CheckboxConfig(PretrainedConfig):
10
  model_type = "checkbox-classifier"
@@ -13,44 +9,6 @@ class CheckboxConfig(PretrainedConfig):
13
  super().__init__(num_labels=num_labels, **kwargs)
14
  self.dropout_rate = dropout_rate
15
 
16
- class CheckboxImageProcessor(ImageProcessingMixin):
17
- """Simple image processor for checkbox classifier"""
18
-
19
- def __init__(self, **kwargs):
20
- super().__init__(**kwargs)
21
- self.size = {"height": 128, "width": 128}
22
- self.image_mean = [0.485, 0.456, 0.406]
23
- self.image_std = [0.229, 0.224, 0.225]
24
-
25
- self.transform = transforms.Compose([
26
- transforms.Resize((self.size["height"], self.size["width"])),
27
- transforms.ToTensor(),
28
- transforms.Normalize(mean=self.image_mean, std=self.image_std)
29
- ])
30
-
31
- def preprocess(self, images, **kwargs):
32
- """Preprocess images for model input"""
33
- if not isinstance(images, list):
34
- images = [images]
35
-
36
- processed = []
37
- for image in images:
38
- if isinstance(image, str):
39
- image = Image.open(image).convert('RGB')
40
- elif isinstance(image, np.ndarray):
41
- image = Image.fromarray(image).convert('RGB')
42
- elif not isinstance(image, Image.Image):
43
- raise ValueError(f"Unsupported image type: {type(image)}")
44
-
45
- processed.append(self.transform(image))
46
-
47
- # Stack into batch
48
- pixel_values = torch.stack(processed)
49
- return {"pixel_values": pixel_values}
50
-
51
- def __call__(self, images, **kwargs):
52
- return self.preprocess(images, **kwargs)
53
-
54
  class CheckboxClassifier(PreTrainedModel):
55
  config_class = CheckboxConfig
56
 
@@ -58,7 +16,7 @@ class CheckboxClassifier(PreTrainedModel):
58
  super().__init__(config)
59
  self.num_labels = config.num_labels
60
 
61
- self.backbone = efficientnet_v2_s(weights=None) # Don't load pretrained weights here
62
  num_features = self.backbone.classifier[1].in_features
63
 
64
  self.backbone.classifier = nn.Sequential(
@@ -74,15 +32,6 @@ class CheckboxClassifier(PreTrainedModel):
74
  nn.Linear(256, config.num_labels)
75
  )
76
 
77
- def forward(self, pixel_values, labels=None):
78
  outputs = self.backbone(pixel_values)
79
-
80
- loss = None
81
- if labels is not None:
82
- loss_fct = nn.CrossEntropyLoss()
83
- loss = loss_fct(outputs, labels)
84
-
85
- return {
86
- "loss": loss,
87
- "logits": outputs,
88
- }
 
1
+ from transformers import PreTrainedModel, PretrainedConfig
 
2
  import torch.nn as nn
3
  from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
 
 
 
4
 
5
  class CheckboxConfig(PretrainedConfig):
6
  model_type = "checkbox-classifier"
 
9
  super().__init__(num_labels=num_labels, **kwargs)
10
  self.dropout_rate = dropout_rate
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  class CheckboxClassifier(PreTrainedModel):
13
  config_class = CheckboxConfig
14
 
 
16
  super().__init__(config)
17
  self.num_labels = config.num_labels
18
 
19
+ self.backbone = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
20
  num_features = self.backbone.classifier[1].in_features
21
 
22
  self.backbone.classifier = nn.Sequential(
 
32
  nn.Linear(256, config.num_labels)
33
  )
34
 
35
+ def forward(self, pixel_values):
36
  outputs = self.backbone(pixel_values)
37
+ return {"logits": outputs}