File size: 2,642 Bytes
9c8dfb8
de2aabe
 
f173f8e
de2aabe
9c8dfb8
 
 
 
 
 
 
 
 
 
 
 
 
 
de2aabe
9c8dfb8
 
 
de2aabe
9c8dfb8
 
 
 
 
 
 
 
 
 
de2aabe
9c8dfb8
de2aabe
9c8dfb8
 
 
 
 
f173f8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de2aabe
9c8dfb8
 
 
f173f8e
 
 
 
 
 
 
 
 
 
 
 
 
de2aabe
f173f8e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import torch
from PIL import Image
import gradio as gr

def load_model_from_hub(repo_id: str):
    """
    Load model from Hugging Face Hub
    
    Args:
        repo_id: The repository ID (e.g., 'username/model-name')
    Returns:
        model: The loaded model
        processor: The feature extractor/processor
    """
    # Load model and processor from Hub
    model = AutoModelForImageClassification.from_pretrained(repo_id)
    processor = AutoFeatureExtractor.from_pretrained(repo_id)
    return model, processor

def predict(image_path: str, model, processor):
    """
    Make prediction using the loaded model
    
    Args:
        image_path: Path to input image
        model: Loaded model
        processor: Feature extractor/processor
    Returns:
        prediction: Model prediction
    """
    # Load and preprocess image
    image = Image.open(image_path)
    inputs = processor(images=image, return_tensors="pt")
    
    # Make prediction
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = outputs.logits.softmax(-1)
    
    return predictions

def predict_image(image):
    """
    Gradio interface function for prediction
    
    Args:
        image: Image uploaded through Gradio interface
    Returns:
        str: Prediction result with confidence score
    """
    # Convert from numpy array to PIL Image
    if not isinstance(image, Image.Image):
        image = Image.fromarray(image)
    
    # Process image and get prediction
    inputs = processor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = outputs.logits.softmax(-1)
    
    # Get the top prediction
    pred_scores = predictions[0].tolist()
    top_pred_idx = max(range(len(pred_scores)), key=pred_scores.__getitem__)
    confidence = pred_scores[top_pred_idx]
    
    # Get class label
    if hasattr(model.config, 'id2label'):
        label = model.config.id2label[top_pred_idx]
    else:
        label = f"Class {top_pred_idx}"
    
    return f"{label} (Confidence: {confidence:.2%})"

# Load model at startup
model, processor = load_model_from_hub("srtangirala/resnet50-exp")

# Create Gradio interface
iface = gr.Interface(
    fn=predict_image,
    inputs=gr.Image(),
    outputs=gr.Text(),
    title="Image Classification",
    description="Upload an image to classify it!",
    examples=[
        # You can add example images here
        # ["path/to/example1.jpg"],
        # ["path/to/example2.jpg"]
    ]
)

if __name__ == "__main__":
    iface.launch()