ombhojane commited on
Commit
2cd69df
·
verified ·
1 Parent(s): f97382d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -70
app.py CHANGED
@@ -1,112 +1,73 @@
1
  import streamlit as st
2
  from PIL import Image
3
- import torch
4
- from torchvision import transforms
5
- import torchvision.models as models
6
- import requests
7
- import os
8
- import json
9
 
10
  def load_model():
11
- """Load a pre-trained MobileNetV2 model for plant disease classification."""
12
- # Create model directory if it doesn't exist
13
- if not os.path.exists("model"):
14
- os.makedirs("model")
15
-
16
- # Download model if not already downloaded
17
- model_path = "model/plant_disease_model.pth"
18
- if not os.path.exists(model_path):
19
- st.info("Downloading model for the first time. This might take a moment...")
20
- # Replace this URL with the actual model URL if you have one
21
- # For now, we'll use a pre-trained MobileNetV2 and fine-tune it
22
- model = models.mobilenet_v2(pretrained=True)
23
- num_classes = 38 # Example number of plant disease classes
24
- model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, num_classes)
25
- torch.save(model.state_dict(), model_path)
26
-
27
- # Load the model
28
- model = models.mobilenet_v2(pretrained=False)
29
- num_classes = 38 # Same number as above
30
- model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, num_classes)
31
- model.load_state_dict(torch.load(model_path))
32
- model.eval()
33
-
34
  return model
35
 
36
- def get_class_names():
37
- """Get the class names for the plant disease model."""
38
- # Define a mapping of class indices to disease names
39
- # This is a placeholder - replace with your actual class mapping
40
- class_names = {
41
- 0: "Apple___Apple_scab",
42
- 1: "Apple___Black_rot",
43
- # ... add all your classes here
44
- 37: "Tomato___healthy"
45
- }
46
- return class_names
47
-
48
  def predict_disease(image_file):
49
- """Predicts the disease of a plant from an image using PyTorch.
50
 
51
  Args:
52
  image_file: The uploaded image file.
53
 
54
  Returns:
55
- A string representing the predicted disease.
56
  """
57
  try:
58
  # Load the model
59
  model = load_model()
60
 
61
- # Define image transformations
62
- transform = transforms.Compose([
63
- transforms.Resize(256),
64
- transforms.CenterCrop(224),
65
- transforms.ToTensor(),
66
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
67
- ])
68
-
69
  # Process the image
70
- image = Image.open(image_file).convert("RGB")
71
- image_tensor = transform(image).unsqueeze(0)
 
72
 
73
  # Make prediction
74
- with torch.no_grad():
75
- outputs = model(image_tensor)
76
- _, predicted = torch.max(outputs, 1)
77
- predicted_idx = predicted.item()
78
 
79
- # Get class names
80
- class_names = get_class_names()
81
- predicted_label = class_names[predicted_idx]
 
82
 
83
- return predicted_label
84
  except Exception as e:
85
  return f"Error: {str(e)}"
86
 
87
  def main():
88
  """Creates the Streamlit app."""
89
- st.title("Plant Disease Detection App")
 
90
 
91
  # Upload an image
92
- image_file = st.file_uploader("Upload an image of a plant", type=["jpg", "jpeg", "png"])
93
 
94
- # Predict the disease
95
  if image_file is not None:
96
  # Display the image
97
  image = Image.open(image_file)
98
- st.image(image, caption="Uploaded Plant Image", use_column_width=True)
99
 
100
  # Add a prediction button
101
- if st.button("Detect Disease"):
102
  with st.spinner("Analyzing image..."):
103
- disease = predict_disease(image_file)
104
 
105
  # Display the prediction
106
- if disease.startswith("Error"):
107
- st.error(disease)
108
  else:
109
- st.success(f"Predicted disease: {disease}")
110
 
111
  if __name__ == "__main__":
112
  main()
 
1
  import streamlit as st
2
  from PIL import Image
3
+ import tensorflow as tf
4
+ import numpy as np
 
 
 
 
5
 
6
  def load_model():
7
+ """Load a pre-trained TensorFlow model for image classification."""
8
+ # Use a TensorFlow Hub model or a local TensorFlow model
9
+ model = tf.keras.applications.MobileNetV2(
10
+ input_shape=(224, 224, 3),
11
+ include_top=True,
12
+ weights="imagenet"
13
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  return model
15
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def predict_disease(image_file):
17
+ """Predicts the class of an image using TensorFlow.
18
 
19
  Args:
20
  image_file: The uploaded image file.
21
 
22
  Returns:
23
+ A string representing the predicted class.
24
  """
25
  try:
26
  # Load the model
27
  model = load_model()
28
 
 
 
 
 
 
 
 
 
29
  # Process the image
30
+ image = Image.open(image_file).convert("RGB").resize((224, 224))
31
+ image_array = np.array(image) / 255.0
32
+ image_array = np.expand_dims(image_array, axis=0)
33
 
34
  # Make prediction
35
+ predictions = model.predict(image_array)
36
+ predicted_class = np.argmax(predictions[0])
 
 
37
 
38
+ # Get the class label from ImageNet (as an example)
39
+ # In a real app, you'd map this to plant diseases
40
+ from tensorflow.keras.applications.mobilenet_v2 import decode_predictions
41
+ _, label, confidence = decode_predictions(predictions, top=1)[0][0]
42
 
43
+ return f"{label} (confidence: {confidence:.2f})"
44
  except Exception as e:
45
  return f"Error: {str(e)}"
46
 
47
  def main():
48
  """Creates the Streamlit app."""
49
+ st.title("Image Classification App")
50
+ st.caption("Note: This is using a general ImageNet model, not a plant disease model")
51
 
52
  # Upload an image
53
+ image_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
54
 
55
+ # Predict the class
56
  if image_file is not None:
57
  # Display the image
58
  image = Image.open(image_file)
59
+ st.image(image, caption="Uploaded Image", use_column_width=True)
60
 
61
  # Add a prediction button
62
+ if st.button("Classify Image"):
63
  with st.spinner("Analyzing image..."):
64
+ result = predict_disease(image_file)
65
 
66
  # Display the prediction
67
+ if result.startswith("Error"):
68
+ st.error(result)
69
  else:
70
+ st.success(f"Prediction: {result}")
71
 
72
  if __name__ == "__main__":
73
  main()