Upload folder using huggingface_hub
Browse files- README.md +38 -0
- config.json +6 -0
- labels.json +1 -0
- preprocessor.py +52 -0
- svm_vgg_model.pkl +3 -0
README.md
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
tags:
|
3 |
+
- svm
|
4 |
+
- computer-vision
|
5 |
+
- image-classification
|
6 |
+
---
|
7 |
+
|
8 |
+
# SVM-VGG Image Classifier
|
9 |
+
|
10 |
+
A hybrid SVM model using VGG-based CNN features for image classification of 28 chart/scientific diagram types.
|
11 |
+
|
12 |
+
## Usage
|
13 |
+
|
14 |
+
```python
|
15 |
+
from huggingface_hub import hf_hub_download
|
16 |
+
import joblib
|
17 |
+
import torch
|
18 |
+
from torchvision.models import vgg16
|
19 |
+
from PIL import Image
|
20 |
+
import numpy as np
|
21 |
+
|
22 |
+
# Load model
|
23 |
+
model_path = hf_hub_download(repo_id="ApostolosK/svm-vgg-model", filename="model.pkl")
|
24 |
+
svm_model = joblib.load(model_path)
|
25 |
+
|
26 |
+
# Load label names
|
27 |
+
labels = [...] # Load from labels.json
|
28 |
+
|
29 |
+
vgg_model = vgg16(pretrained=True)
|
30 |
+
fc_cnn_model = vgg_model.classifier[:-2]
|
31 |
+
def extract_features(image_path):
|
32 |
+
return combined_features
|
33 |
+
|
34 |
+
# Make prediction
|
35 |
+
image = Image.open("your-image.png")
|
36 |
+
features = extract_features(image)
|
37 |
+
prediction = svm_model.predict([features])[0]
|
38 |
+
print(labels[prediction])
|
config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_type": "svm",
|
3 |
+
"feature_extractor": "vgg-based-cnn",
|
4 |
+
"input_size": 224,
|
5 |
+
"num_classes": 28
|
6 |
+
}
|
labels.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
["3D objects", "Algorithm", "Area chart", "Bar plots", "Block diagram", "Box plot", "Bubble Chart","Confusion matrix", "Contour plot", "Flow chart", "Geographic map", "Graph plots", "Heat map", "Histogram", "Mask", "Medical images", "Natural images", "Pareto charts", "Pie chart", "Polar plot", "Radar chart", "Scatter plot", "Sketches", "Surface plot", "Tables", "Tree Diagram", "Vector plot", "Venn Diagram"]
|
preprocessor.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from torchvision import transforms
|
4 |
+
from torchvision.models import vgg16
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
# Initialize VGG model (matches your training setup)
|
8 |
+
class FeatureExtractor:
|
9 |
+
def __init__(self):
|
10 |
+
self.vgg_model = vgg16(pretrained=True)
|
11 |
+
self.vgg_model.eval() # Set to evaluation mode
|
12 |
+
|
13 |
+
# For FC-CNN features (classifier-based)
|
14 |
+
self.fc_extractor = torch.nn.Sequential(
|
15 |
+
*list(self.vgg_model.classifier.children())[:-2] # Remove last 2 layers
|
16 |
+
)
|
17 |
+
|
18 |
+
# Standard VGG preprocessing
|
19 |
+
self.preprocess = transforms.Compose([
|
20 |
+
transforms.Resize((224, 224)),
|
21 |
+
transforms.ToTensor(),
|
22 |
+
transforms.Normalize(
|
23 |
+
mean=[0.485, 0.456, 0.406],
|
24 |
+
std=[0.229, 0.224, 0.225]
|
25 |
+
)
|
26 |
+
])
|
27 |
+
|
28 |
+
def extract_fc_cnn_features(self, image_path):
|
29 |
+
"""Extract fully-connected layer features"""
|
30 |
+
image = Image.open(image_path).convert('RGB')
|
31 |
+
image_tensor = self.preprocess(image).unsqueeze(0)
|
32 |
+
|
33 |
+
with torch.no_grad():
|
34 |
+
features = self.fc_extractor(image_tensor)
|
35 |
+
|
36 |
+
return features.squeeze().numpy().flatten()
|
37 |
+
|
38 |
+
def extract_fv_cnn_features(self, image_path):
|
39 |
+
"""Extract convolutional layer features"""
|
40 |
+
image = Image.open(image_path).convert('RGB')
|
41 |
+
image_tensor = self.preprocess(image).unsqueeze(0)
|
42 |
+
|
43 |
+
with torch.no_grad():
|
44 |
+
conv_features = self.vgg_model.features(image_tensor)
|
45 |
+
|
46 |
+
return conv_features.squeeze().numpy().flatten()
|
47 |
+
|
48 |
+
def extract_combined_features(self, image_path):
|
49 |
+
"""Combine both feature types"""
|
50 |
+
fc = self.extract_fc_cnn_features(image_path)
|
51 |
+
fv = self.extract_fv_cnn_features(image_path)
|
52 |
+
return np.concatenate((fc, fv))
|
svm_vgg_model.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a27b4b58469a03cc418cbee34f8186dfe650c1952974df7f224a8c7e6dc7df39
|
3 |
+
size 8091210000
|