File size: 5,984 Bytes
894bc0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
from geoclip import GeoCLIP
from PIL import Image
import tempfile
from pathlib import Path
import gradio as gr
import spaces
from geopy.geocoders import Nominatim
from transformers import CLIPProcessor, CLIPModel
from torchvision import transforms
import reverse_geocoder as rg
from models.huggingface import Geolocalizer
import folium
import json
from geopy.exc import GeocoderTimedOut

if torch.cuda.is_available():
    geoclip_model = GeoCLIP().to("cuda")
else:
    geoclip_model = GeoCLIP()

geolocator = Nominatim(user_agent="predictGeolocforImage")

streetclip_model = CLIPModel.from_pretrained("geolocal/StreetCLIP")
streetclip_processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")
labels = ['Albania', 'Andorra', 'Argentina', 'Australia', 'Austria', 'Bangladesh', 'Belgium', 'Bermuda', 'Bhutan', 'Bolivia', 'Botswana', 'Brazil', 'Bulgaria', 'Cambodia', 'Canada', 'Chile', 'China', 'Colombia', 'Croatia', 'Czech Republic', 'Denmark', 'Dominican Republic', 'Ecuador', 'Estonia', 'Finland', 'France', 'Germany', 'Ghana', 'Greece', 'Greenland', 'Guam', 'Guatemala', 'Hungary', 'Iceland', 'India', 'Indonesia', 'Ireland', 'Israel', 'Italy', 'Japan', 'Jordan', 'Kenya', 'Kyrgyzstan', 'Laos', 'Latvia', 'Lesotho', 'Lithuania', 'Luxembourg', 'Macedonia', 'Madagascar', 'Malaysia', 'Malta', 'Mexico', 'Monaco', 'Mongolia', 'Montenegro', 'Netherlands', 'New Zealand', 'Nigeria', 'Norway', 'Pakistan', 'Palestine', 'Peru', 'Philippines', 'Poland', 'Portugal', 'Puerto Rico', 'Romania', 'Russia', 'Rwanda', 'Senegal', 'Serbia', 'Singapore', 'Slovakia', 'Slovenia', 'South Africa', 'South Korea', 'Spain', 'Sri Lanka', 'Swaziland', 'Sweden', 'Switzerland', 'Taiwan', 'Thailand', 'Tunisia', 'Turkey', 'Uganda', 'Ukraine', 'United Arab Emirates', 'United Kingdom', 'United States', 'Uruguay']

IMAGE_SIZE = (224, 224)
GEOLOC_MODEL_NAME = "osv5m/baseline"
geoloc_model = Geolocalizer.from_pretrained(GEOLOC_MODEL_NAME)
geoloc_model.eval()

def transform_image(image):
    transform = transforms.Compose([
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return transform(image).unsqueeze(0)

def create_map(lat, lon):
    m = folium.Map(location=[lat, lon], zoom_start=4)
    folium.Marker([lat, lon]).add_to(m)
    map_html = m._repr_html_()
    return map_html

def get_country_coordinates(country_name):
    try:
        location = geolocator.geocode(country_name, timeout=10)
        if location:
            return location.latitude, location.longitude
    except GeocoderTimedOut:
        return None
    return None

@spaces.GPU
def predict_geoclip(image):
    with tempfile.TemporaryDirectory() as tmp_dir:
        tmppath = Path(tmp_dir) / "tmp.jpg"
        image.save(str(tmppath))
        top_pred_gps, top_pred_prob = geoclip_model.predict(str(tmppath), top_k=50)
        
    predictions = []
    for i in range(1):
        lat, lon = top_pred_gps[i]
        probpercent = top_pred_prob[i] * 100
        location = geolocator.reverse((lat, lon), exactly_one=True)
        address = location.raw['address']
        city = address.get('city', '')
        country = address.get('country', '')
        prediction = f"Latitude: {lat:.6f}, Longitude: {lon:.6f} - Country: {country}"
        predictions.append(prediction)
        
    map_html = create_map(lat, lon)
    return "\n".join(predictions), map_html

@spaces.GPU
def classify_streetclip(image):
    inputs = streetclip_processor(text=labels, images=image, return_tensors="pt", padding=True)
    with torch.no_grad():
        outputs = streetclip_model(**inputs)
    logits_per_image = outputs.logits_per_image
    prediction = logits_per_image.softmax(dim=1)
    confidences = {labels[i]: float(prediction[0][i].item()) for i in range(len(labels))}
    
    sorted_confidences = sorted(confidences.items(), key=lambda item: item[1], reverse=True)
    top_label, top_confidence = sorted_confidences[0]
    coords = get_country_coordinates(top_label)
    map_html = create_map(*coords) if coords else "Map not available"
    return f"Country: {top_label}", map_html

def infer(image):
    try:
        img_tensor = transform_image(image)
        gps_radians = geoloc_model(img_tensor)
        gps_degrees = torch.rad2deg(gps_radians).squeeze(0).cpu().tolist()
        lat, lon = gps_degrees[0], gps_degrees[1]
        location_query = rg.search((lat, lon))[0]
        location_name = f"{location_query['name']}, {location_query['admin1']}, {location_query['cc']}"
        map_html = create_map(lat, lon)
        return f"Latitude: {lat:.6f}, Longitude: {lon:.6f} - Country: {location_query['admin1']} - {location_query['cc']}", map_html
    except Exception as e:
        return f"Failed to predict the location: {e}", None

geoclip_interface = gr.Interface(
    fn=predict_geoclip,
    inputs=gr.Image(type="pil", label="Upload Image", elem_id="geoclip_image_input"),
    outputs=[gr.Textbox(label="Prediction", elem_id="geoclip_output"), gr.HTML(label="Map", elem_id="geoclip_map_output")],
    title="GeoCLIP"
)

streetclip_interface = gr.Interface(
    fn=classify_streetclip,
    inputs=gr.Image(type="pil", label="Upload Image", elem_id="streetclip_image_input"),
    outputs=[gr.Textbox(label="Prediction", elem_id="streetclip_output"), gr.HTML(label="Map", elem_id="streetclip_map_output")],
    title="StreetCLIP"
)

osv5m_interface = gr.Interface(
    fn=infer,
    inputs=gr.Image(label="Upload Image", type="pil", elem_id="osv5m_image_input"),
    outputs=[gr.Textbox(label="Prediction", elem_id="result_text"), gr.HTML(label="Map", elem_id="map_output")],
    title="OSV-5M Baseline"
)

demo = gr.TabbedInterface([geoclip_interface, streetclip_interface, osv5m_interface], tab_names=["GeoCLIP", "StreetCLIP", "OSV-5M Baseline"])

demo.launch()