K Sudhakar Reddy commited on
Commit
c11a16d
·
unverified ·
1 Parent(s): cc9c515

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -10
app.py CHANGED
@@ -1,16 +1,68 @@
1
- from transformers import pipeline
2
  import gradio as gr
 
 
 
 
3
 
 
 
 
 
 
 
 
 
4
 
5
- model = pipeline(
6
- "summarization",
7
- )
 
 
 
 
 
 
 
 
 
8
 
9
- def predict(prompt):
10
- summary = model(prompt)[0]["summary_text"]
11
- return summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # create an interface for the model
15
- with gr.Interface(predict, "textbox", "text") as interface:
16
- interface.launch()
 
 
1
  import gradio as gr
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ import os
6
 
7
+ class CatDogClassifier:
8
+ def __init__(self, model_path="model.pt"):
9
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+
11
+ # Load the traced model
12
+ self.model = torch.jit.load(model_path)
13
+ self.model = self.model.to(self.device)
14
+ self.model.eval()
15
 
16
+ # Define the same transforms used during training/testing
17
+ self.transform = transforms.Compose([
18
+ transforms.Resize((160, 160)),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize(
21
+ mean=[0.485, 0.456, 0.406],
22
+ std=[0.229, 0.224, 0.225]
23
+ )
24
+ ])
25
+
26
+ # Class labels
27
+ self.labels = ['Dog', 'Cat']
28
 
29
+ @torch.no_grad()
30
+ def predict(self, image):
31
+ if image is None:
32
+ return None
33
+
34
+ # Convert to PIL Image if needed
35
+ if not isinstance(image, Image.Image):
36
+ image = Image.fromarray(image).convert('RGB')
37
+
38
+ # Preprocess image
39
+ img_tensor = self.transform(image).unsqueeze(0).to(self.device)
40
+
41
+ # Get prediction
42
+ output = self.model(img_tensor)
43
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
44
+
45
+ # Create prediction dictionary
46
+ return {
47
+ self.labels[idx]: float(prob)
48
+ for idx, prob in enumerate(probabilities)
49
+ }
50
 
51
+ # Create classifier instance
52
+ classifier = CatDogClassifier()
53
+
54
+ # Create Gradio interface
55
+ demo = gr.Interface(
56
+ fn=classifier.predict,
57
+ inputs=gr.Image(),
58
+ outputs=gr.Label(num_top_classes=2),
59
+ title="Cat vs Dog Classifier",
60
+ description="Upload an image to classify whether it's a cat or a dog",
61
+ examples=[
62
+ ["examples/cat.jpg"],
63
+ ["examples/dog.jpg"]
64
+ ]
65
+ )
66
 
67
+ if __name__ == "__main__":
68
+ demo.launch()