File size: 4,416 Bytes
021b464
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
Hugging Face's logo
Search models, datasets, users...

Spaces:

NemesisAlm
/
clip-satellite-demo 

like
1
App
Files
Community
clip-satellite-demo
/
app.py

NemesisAlm's picture
NemesisAlm
1st commit
0b0d380
raw

Copy download link
history
blame
contribute
delete

4.14 kB
import gradio as gr

import torch
from PIL import Image
from transformers import CLIPModel, CLIPProcessor

LIST_LABELS = ['agricultural land', 'airplane', 'baseball diamond', 'beach', 'buildings', 'chaparral', 'dense residential area', 'forest', 'freeway', 'golf course', 'harbor', 'intersection', 'medium residential area', 'mobilehome park', 'overpass', 'parking lot', 'river', 'runway', 'sparse residential area', 'storage tanks', 'tennis court']

CLIP_LABELS = [f"A satellite image of {label}" for label in LIST_LABELS]

MODEL_NAME = "NemesisAlm/clip-fine-tuned-satellite"

device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

fine_tuned_model = CLIPModel.from_pretrained(MODEL_NAME).to(device)
fine_tuned_processor = CLIPProcessor.from_pretrained(MODEL_NAME)


def classify(image_path, model_number):
    if model_number == "CLIP":
        processor = clip_processor
        model = clip_model
    else:
        processor = fine_tuned_processor
        model = fine_tuned_model
    image = Image.open(image_path).convert('RGB')
    inputs = processor(text=CLIP_LABELS, images=image, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image
    prediction = logits_per_image.softmax(dim=1)
    confidences = {LIST_LABELS[i]: float(prediction[0][i].item()) for i in range(len(LIST_LABELS))}
    return confidences

DESCRIPTION="""
<div style="font-family: Arial, sans-serif; line-height: 1.6;  margin: auto; text-align: center;">
    <h2 style="color: #333;">CLIP Fine-Tuned Satellite Model Demo</h2>
    <p>
        This space demonstrates the capabilities of a <strong>fine-tuned CLIP-based model</strong> 
        in classifying satellite images. The model has been specifically trained on the 
        <em>UC Merced</em> satellite image dataset.
    </p>
    <p>
        After just <strong>2 epochs of training</strong>, adjusting only 30% of the model parameters, 
        the model's accuracy in classifying satellite images has significantly improved, from an 
        initial accuracy of <strong>58.8%</strong> to <strong>96.9%</strong> on the test set.
    </p>
    <p>
        Explore this space to see its performance and compare it with the initial CLIP model.
    </p>
</div>
"""

FOOTER = """
<div style="margin-top:50px">
    Link to model: <a href='https://huggingface.co/NemesisAlm/clip-fine-tuned-satellite'>https://huggingface.co/NemesisAlm/clip-fine-tuned-satellite</a><br>
    Link to dataset: <a href='https://huggingface.co/datasets/blanchon/UC_Merced'>https://huggingface.co/datasets/blanchon/UC_Merced</a>
</div>
"""

with gr.Blocks(title="Satellite image classification", css="") as demo:
    logo = gr.HTML("<img src='file/logo_gradio.png' style='margin:auto'/>")
    description = gr.HTML(DESCRIPTION)
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type='filepath', label='Input image')
            submit_btn = gr.Button("Submit", variant="primary")
        with gr.Column():
            title_1 = gr.HTML("<h1 style='text-align:center'>Original CLIP Model</h1>")
            model_1 = gr.Textbox("CLIP", visible=False)
            output_labels_clip = gr.Label(num_top_classes=10, label="Top 10 classes")
        with gr.Column():
            title_2 = gr.HTML("<h1 style='text-align:center'>Fine-tuned Model</h1>")
            model_2 = gr.Textbox("Fine-tuned", visible=False)
            output_labels_finetuned = gr.Label(num_top_classes=10, label="Top 10 classes")
    examples = gr.Examples([["0.jpg"], ["1.jpg"], ["2.jpg"], ["3.jpg"]   ], input_image)
    footer = gr.HTML(FOOTER)
    submit_btn.click(fn=classify, inputs=[input_image, model_1], outputs=output_labels_clip).then( classify, inputs=[input_image, model_2], outputs=[output_labels_finetuned]  )


demo.queue()
demo.launch(server_name="0.0.0.0",favicon_path='favicon.ico', allowed_paths=["logo_gradio.png", "0.jpg", "1.jpg", "2.jpg", "3.jpg"])