File size: 4,460 Bytes
17f6c62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33b6e89
 
17f6c62
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from transformers import ViTImageProcessor, ViTForImageClassification
import gradio as gr
from datasets import load_dataset
import torch
import random
import numpy as np
import pandas as pd



def get_predictions(image):
    inputs = processor(image, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    # Get top n predictions
    top_indices = logits[0].argsort(dim=-1, descending=True)
    probabilities = torch.softmax(logits, dim=-1)[0, top_indices]
    labels = [model.config.id2label[idx.item()] for idx in top_indices]

    predictions = {}
    for i, label in enumerate(labels):
        predictions[label] = probabilities[i]

    return predictions

data = load_dataset("marcelomoreno26/geoguessr",split="test")
model_name = "marcelomoreno26/vit-base-patch-16-384-geoguessr"

processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)


length = len(data)
countries = []

with open("countries.txt", "r") as file:
    for line in file:
        countries.append(line.strip())


def get_result(selection):
    global correct_country
    global model_prediction
    global filtered_predictions
    if selection == correct_country and correct_country == model_prediction:
        result = "It's a draw!"
    elif selection == correct_country:
        result = "Congratulations! You won!"
    elif correct_country == model_prediction:
        result = "Sorry, you lost. The AI guessed it right!"
    else:
        result = "Sorry, you both lost."

    total_prob = sum([(float(value)) for value in filtered_predictions.values()])
    prob_per_country = [(key,np.round(float(value)/total_prob,3)*100) for key,value in filtered_predictions.items()]
    df = pd.DataFrame(prob_per_country,columns=["Country","Model Confidence (%)"]).sort_values(by="Model Confidence (%)",ascending=False)
    ai_confidence = f"The AI's guess was {model_prediction}\n\nAI's Results:\n"+ df.to_markdown(index=False)



    return f"The correct country was: {correct_country}\n{result}", ai_confidence


def load():
    global filtered_predictions
    # Randomly select an image
    i = random.randint(0, len(data) - 1)
    image = data[i]['image']
    correct_country = data[i]['label']

    # Randomly sample 4 countries as options
    options = [country for country in random.sample(countries, 4) if country != correct_country]
    options.append(correct_country)
    random.shuffle(options)

    # Get model predictions
    predictions = get_predictions(image)
    filtered_predictions = {country: predictions[country] for country in options}
    model_prediction = max(filtered_predictions, key=filtered_predictions.get)

    return image, options, correct_country, model_prediction


def reload():
    global correct_country
    global model_prediction
    global filtered_predictions
    # Randomly select an image
    i = random.randint(0, len(data) - 1)
    image = data[i]['image']
    correct_country = data[i]['label']

    # Randomly sample 4 countries as options
    options = [country for country in random.sample(countries, 4) if country != correct_country]
    options.append(correct_country)
    random.shuffle(options)

    # Get model predictions
    predictions = get_predictions(image)
    filtered_predictions = {country: predictions[country] for country in options}
    model_prediction = max(filtered_predictions, key=filtered_predictions.get)


    return gr.Image(image), gr.Radio(choices=options, label ="Select the country:"),  "", ""



with gr.Blocks() as demo:
    
    image, options, correct_country, model_prediction = load()

    gr.Markdown("# GeoGuessr - Can You Beat the AI?")
    gr.Markdown("Try to guess the country in the image. Can you beat the AI?")
    gr.Markdown("## Instructions:")
    gr.Markdown("\n1. First to Start playing press **Get New Image** at the bottom (the server needs to refresh from the cache and previous user)\n2. Select the country where you think the image was taken.\n3. Review the results.\n4. Play again by clicking **Get New Image**")
    img = gr.Image(image)
    radio = gr.Radio(choices=options, label ="Select the country:")
    ai_pred = gr.Markdown()
    text = gr.Text(label="Result")
    radio.select(fn=get_result, inputs=radio, outputs=[text,ai_pred])

    btn = gr.Button(value="Get New Image")
    btn.click(reload, None,outputs=[img,radio,text,ai_pred])

demo.launch()