stefanodangelo commited on
Commit
e8ec216
·
verified ·
1 Parent(s): c952138

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +38 -2
README.md CHANGED
@@ -64,7 +64,42 @@ Users should carefully evaluate the model on their specific data to ensure compa
64
  Use the following code snippet to get started:
65
 
66
  ```python
67
- Coming soon
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  ```
69
 
70
  ## Training Details
@@ -152,8 +187,9 @@ The model uses a Swin Transformer-based architecture adapted for binary image cl
152
 
153
  - Windows 11
154
  - Python 3.11
 
155
  - HuggingFace Transformers Library
156
- - PyTorch
157
 
158
  ## Citation
159
 
 
64
  Use the following code snippet to get started:
65
 
66
  ```python
67
+ from huggingface_hub import snapshot_download
68
+ from transformers import AutoImageProcessor, AutoConfig, SwinForImageClassification
69
+ from torchvision import transforms
70
+ import torch
71
+
72
+ # Download the model
73
+ snapshot_download(repo_id='stefanodangelo/chartdet', local_dir='./', allow_patterns='*.pt')
74
+
75
+ # Load the model and processor
76
+ processor = AutoImageProcessor.from_pretrained("microsoft/swin-large-patch4-window7-224")
77
+ config = AutoConfig.from_pretrained("microsoft/swin-large-patch4-window7-224", num_labels=2)
78
+ model = SwinForImageClassification(config)
79
+
80
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
+ model.load_state_dict(torch.load("models/chart_detection_model.pt", map_location=torch.device(device)))
82
+
83
+ # Define transformations
84
+ transform = transforms.Compose([
85
+ transforms.Resize((224, 224)),
86
+ transforms.ToTensor(),
87
+ transforms.Normalize(mean=processor.image_mean, std=processor.image_std)
88
+ ])
89
+
90
+ id2class = {0: "Picture", 1: "Chart"}
91
+
92
+ # Prepare an image
93
+ from PIL import Image
94
+ image = Image.open("YOUR_IMG_PATH")
95
+ image_tensor = transform(image).unsqueeze(0).to(device) # Add batch dimension and move to device
96
+
97
+ # Make prediction
98
+ with torch.no_grad():
99
+ outputs = model(image_tensor).logits # Get logits from the model
100
+ predicted_class = torch.argmax(outputs, dim=1).item() # Get the class index
101
+
102
+ print(f"Predicted class: {id2class[predicted_class]}")
103
  ```
104
 
105
  ## Training Details
 
187
 
188
  - Windows 11
189
  - Python 3.11
190
+ - CUDA 11.7
191
  - HuggingFace Transformers Library
192
+ - PyTorch & TorchVision
193
 
194
  ## Citation
195