Subh775 commited on
Commit
cc91aff
·
verified ·
1 Parent(s): 6971ec7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import models, transforms
3
+ from PIL import Image
4
+ import gradio as gr
5
+
6
+ ## Define the device
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ # Load the trained model
10
+ def load_model():
11
+ model = models.resnet50(pretrained=False)
12
+ num_classes = 4 # Update based on your rice disease classes
13
+ model.fc = torch.nn.Sequential(
14
+ torch.nn.Linear(model.fc.in_features, 256),
15
+ torch.nn.ReLU(),
16
+ torch.nn.Linear(256, num_classes)
17
+ )
18
+ model.load_state_dict(torch.load(r"/kaggle/input/rice_epoch8/pytorch/default/1/best_model_epoch_8.pth", map_location=device), strict=False)
19
+ model = model.to(device)
20
+ model.eval()
21
+ return model
22
+
23
+ # Define preprocessing steps
24
+ transform = transforms.Compose([
25
+ transforms.Resize((224, 224)),
26
+ transforms.ToTensor(),
27
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
28
+ ])
29
+
30
+ # Prediction function
31
+ def predict(image):
32
+ # Ensure image is in RGB
33
+ image = image.convert("RGB")
34
+
35
+ input_tensor = transform(image).unsqueeze(0).to(device)
36
+
37
+ # Perform inference
38
+ with torch.no_grad():
39
+ outputs = model(input_tensor)
40
+ _, predicted_class = torch.max(outputs, 1)
41
+
42
+ # Map predicted class index to actual labels
43
+ class_names = ["Brown Spot", "Healthy", "Leaf Blast", "Neck Blast"]
44
+ predicted_label = class_names[predicted_class.item()]
45
+
46
+ # Calculate confidence scores
47
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
48
+ confidence = probabilities[predicted_class.item()].item()
49
+
50
+ return f"Predicted Disease: {predicted_label}\nConfidence: {confidence*100:.2f}%"
51
+
52
+ # Load the model globally
53
+ model = load_model()
54
+
55
+ # Create Gradio interface
56
+ def launch_interface():
57
+ # Create a Gradio interface
58
+ iface = gr.Interface(
59
+ theme="Subh775/orchid_candy",
60
+ fn=predict,
61
+ inputs=gr.Image(type="pil", label="Upload Rice Leaf Image"),
62
+ outputs=gr.Textbox(label="Prediction Results"),
63
+ title="Rice Disease Classification",
64
+ description="Upload a rice leaf image to detect disease type",
65
+ examples=[
66
+ ["https://doa.gov.lk/wp-content/uploads/2020/06/brownspot3-1024x683.jpg"],
67
+ ["https://arkansascrops.uada.edu/posts/crops/rice/images/Fig%206%20Rice%20leaf%20blast%20coalesced%20lesions.png"],
68
+ ["https://th.bing.com/th/id/OIP._5ejX_5Z-M0cO5c2QUmPlwHaE7?w=280&h=187&c=7&r=0&o=5&dpr=1.1&pid=1.7"],
69
+ ["https://www.weknowrice.com/wp-content/uploads/2022/11/how-to-grow-rice.jpeg"],
70
+ ],
71
+ allow_flagging="never"
72
+ )
73
+
74
+ return iface
75
+
76
+ # Launch the interface
77
+ if __name__ == "__main__":
78
+ interface = launch_interface()
79
+ interface.launch(share=True)