| | import torch |
| | from torchvision import models, transforms |
| | from PIL import Image |
| | import json |
| |
|
| | |
| | class CustomResNet: |
| | def __init__(self, model_path, num_classes): |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self.model = models.resnet152(pretrained=False) |
| | self.model.fc = torch.nn.Linear(self.model.fc.in_features, num_classes) |
| | self.model.load_state_dict(torch.load(model_path, map_location=self.device)) |
| | self.model.to(self.device) |
| | self.model.eval() |
| |
|
| | |
| | self.preprocess = transforms.Compose([ |
| | transforms.Resize((224, 224)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| | ]) |
| |
|
| | def predict(self, image_bytes): |
| | |
| | image = Image.open(image_bytes).convert("RGB") |
| | tensor = self.preprocess(image).unsqueeze(0).to(self.device) |
| |
|
| | |
| | with torch.no_grad(): |
| | outputs = self.model(tensor) |
| | _, predicted = torch.max(outputs, 1) |
| |
|
| | return predicted.item() |
| |
|
| |
|
| | |
| | def load_model(): |
| | with open("config.json", "r") as f: |
| | config = json.load(f) |
| | return CustomResNet("trained_model.pth", config["num_labels"]) |
| |
|