Spaces:
Running
Running
File size: 5,984 Bytes
894bc0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import torch
from geoclip import GeoCLIP
from PIL import Image
import tempfile
from pathlib import Path
import gradio as gr
import spaces
from geopy.geocoders import Nominatim
from transformers import CLIPProcessor, CLIPModel
from torchvision import transforms
import reverse_geocoder as rg
from models.huggingface import Geolocalizer
import folium
import json
from geopy.exc import GeocoderTimedOut
if torch.cuda.is_available():
geoclip_model = GeoCLIP().to("cuda")
else:
geoclip_model = GeoCLIP()
geolocator = Nominatim(user_agent="predictGeolocforImage")
streetclip_model = CLIPModel.from_pretrained("geolocal/StreetCLIP")
streetclip_processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")
labels = ['Albania', 'Andorra', 'Argentina', 'Australia', 'Austria', 'Bangladesh', 'Belgium', 'Bermuda', 'Bhutan', 'Bolivia', 'Botswana', 'Brazil', 'Bulgaria', 'Cambodia', 'Canada', 'Chile', 'China', 'Colombia', 'Croatia', 'Czech Republic', 'Denmark', 'Dominican Republic', 'Ecuador', 'Estonia', 'Finland', 'France', 'Germany', 'Ghana', 'Greece', 'Greenland', 'Guam', 'Guatemala', 'Hungary', 'Iceland', 'India', 'Indonesia', 'Ireland', 'Israel', 'Italy', 'Japan', 'Jordan', 'Kenya', 'Kyrgyzstan', 'Laos', 'Latvia', 'Lesotho', 'Lithuania', 'Luxembourg', 'Macedonia', 'Madagascar', 'Malaysia', 'Malta', 'Mexico', 'Monaco', 'Mongolia', 'Montenegro', 'Netherlands', 'New Zealand', 'Nigeria', 'Norway', 'Pakistan', 'Palestine', 'Peru', 'Philippines', 'Poland', 'Portugal', 'Puerto Rico', 'Romania', 'Russia', 'Rwanda', 'Senegal', 'Serbia', 'Singapore', 'Slovakia', 'Slovenia', 'South Africa', 'South Korea', 'Spain', 'Sri Lanka', 'Swaziland', 'Sweden', 'Switzerland', 'Taiwan', 'Thailand', 'Tunisia', 'Turkey', 'Uganda', 'Ukraine', 'United Arab Emirates', 'United Kingdom', 'United States', 'Uruguay']
IMAGE_SIZE = (224, 224)
GEOLOC_MODEL_NAME = "osv5m/baseline"
geoloc_model = Geolocalizer.from_pretrained(GEOLOC_MODEL_NAME)
geoloc_model.eval()
def transform_image(image):
transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return transform(image).unsqueeze(0)
def create_map(lat, lon):
m = folium.Map(location=[lat, lon], zoom_start=4)
folium.Marker([lat, lon]).add_to(m)
map_html = m._repr_html_()
return map_html
def get_country_coordinates(country_name):
try:
location = geolocator.geocode(country_name, timeout=10)
if location:
return location.latitude, location.longitude
except GeocoderTimedOut:
return None
return None
@spaces.GPU
def predict_geoclip(image):
with tempfile.TemporaryDirectory() as tmp_dir:
tmppath = Path(tmp_dir) / "tmp.jpg"
image.save(str(tmppath))
top_pred_gps, top_pred_prob = geoclip_model.predict(str(tmppath), top_k=50)
predictions = []
for i in range(1):
lat, lon = top_pred_gps[i]
probpercent = top_pred_prob[i] * 100
location = geolocator.reverse((lat, lon), exactly_one=True)
address = location.raw['address']
city = address.get('city', '')
country = address.get('country', '')
prediction = f"Latitude: {lat:.6f}, Longitude: {lon:.6f} - Country: {country}"
predictions.append(prediction)
map_html = create_map(lat, lon)
return "\n".join(predictions), map_html
@spaces.GPU
def classify_streetclip(image):
inputs = streetclip_processor(text=labels, images=image, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = streetclip_model(**inputs)
logits_per_image = outputs.logits_per_image
prediction = logits_per_image.softmax(dim=1)
confidences = {labels[i]: float(prediction[0][i].item()) for i in range(len(labels))}
sorted_confidences = sorted(confidences.items(), key=lambda item: item[1], reverse=True)
top_label, top_confidence = sorted_confidences[0]
coords = get_country_coordinates(top_label)
map_html = create_map(*coords) if coords else "Map not available"
return f"Country: {top_label}", map_html
def infer(image):
try:
img_tensor = transform_image(image)
gps_radians = geoloc_model(img_tensor)
gps_degrees = torch.rad2deg(gps_radians).squeeze(0).cpu().tolist()
lat, lon = gps_degrees[0], gps_degrees[1]
location_query = rg.search((lat, lon))[0]
location_name = f"{location_query['name']}, {location_query['admin1']}, {location_query['cc']}"
map_html = create_map(lat, lon)
return f"Latitude: {lat:.6f}, Longitude: {lon:.6f} - Country: {location_query['admin1']} - {location_query['cc']}", map_html
except Exception as e:
return f"Failed to predict the location: {e}", None
geoclip_interface = gr.Interface(
fn=predict_geoclip,
inputs=gr.Image(type="pil", label="Upload Image", elem_id="geoclip_image_input"),
outputs=[gr.Textbox(label="Prediction", elem_id="geoclip_output"), gr.HTML(label="Map", elem_id="geoclip_map_output")],
title="GeoCLIP"
)
streetclip_interface = gr.Interface(
fn=classify_streetclip,
inputs=gr.Image(type="pil", label="Upload Image", elem_id="streetclip_image_input"),
outputs=[gr.Textbox(label="Prediction", elem_id="streetclip_output"), gr.HTML(label="Map", elem_id="streetclip_map_output")],
title="StreetCLIP"
)
osv5m_interface = gr.Interface(
fn=infer,
inputs=gr.Image(label="Upload Image", type="pil", elem_id="osv5m_image_input"),
outputs=[gr.Textbox(label="Prediction", elem_id="result_text"), gr.HTML(label="Map", elem_id="map_output")],
title="OSV-5M Baseline"
)
demo = gr.TabbedInterface([geoclip_interface, streetclip_interface, osv5m_interface], tab_names=["GeoCLIP", "StreetCLIP", "OSV-5M Baseline"])
demo.launch()
|