File size: 3,077 Bytes
2bf3c18
781cf08
 
2bf3c18
 
644173a
03c454c
2bf3c18
276b619
781cf08
276b619
781cf08
644173a
2bc793a
644173a
2bc793a
644173a
 
2bc793a
644173a
2bc793a
 
 
 
644173a
 
2bc793a
 
 
 
 
 
 
 
644173a
 
781cf08
644173a
2bc793a
 
644173a
 
2bc793a
781cf08
2bc793a
 
 
 
 
 
644173a
781cf08
 
 
 
2bf3c18
644173a
781cf08
2bc793a
03c454c
2bc793a
 
 
781cf08
03c454c
2bc793a
2bf3c18
03c454c
2bc793a
 
 
2bf3c18
781cf08
fb90d20
09d93d9
 
 
 
 
fb90d20
 
03c454c
781cf08
 
 
 
03c454c
 
 
644173a
781cf08
 
fb90d20
781cf08
644173a
781cf08
 
 
644173a
781cf08
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
from transformers import pipeline
from datasets import load_dataset
import requests
import traceback
import json
import os

dataset = load_dataset("SaladSlayer00/twin_matcher_data")

image_classifier = pipeline("image-classification", model="SaladSlayer00/twin_matcher_beta")

def format_info(info_json):

        
    info_data = json.loads(info_json)

        
    formatted_info = "<table style='border-collapse: collapse; width: 80%; margin: 20px;'>"
        
    formatted_info += "<tr style='background-color: #f2f2f2;'>"
    for key in info_data[0].keys():
        formatted_info += f"<th style='border: 1px solid #dddddd; text-align: left; padding: 8px;'><b>{key.capitalize()}</b></th>"
    formatted_info += "</tr>"

        
    for entry in info_data:
        formatted_info += "<tr>"
        for value in entry.values():
            formatted_info += f"<td style='border: 1px solid #dddddd; text-align: left; padding: 8px;'>{value}</td>"
        formatted_info += "</tr>"
    formatted_info += "</table>"
    return formatted_info



def fetch_info(celebrity_label):
        
    parts = celebrity_label.split("_")
    formatted_label = " ".join([part.capitalize() for part in parts])

        
    api_url = f'https://api.api-ninjas.com/v1/celebrity?name={formatted_label}'

    token = os.getenv('TOKEN')
    response = requests.get(api_url, headers={'X-Api-Key': token})
    if response.status_code == 200:
        return format_info(response.text)
    else:
        return "A shining star for sure."

def fetch_images_for_label(label):
    label_data = dataset['train'].filter(lambda example: example['label'] == label)
    images = [example['image'] for example in label_data]
    return images


def predict_and_fetch_images(input_image):

        # Use the image classifier pipeline
    predictions = image_classifier(input_image)
    top_prediction = max(predictions, key=lambda x: x['score'])
    label, score = top_prediction['label'], top_prediction['score']

        # Fetch images for the predicted label
    images = fetch_images_for_label(label)

        # Fetch information for the predicted label
    info = fetch_info(label)

    return label, score, images, info, "No Error"


example_images = [
    "images/megan_fox.png",
    "images/chris_evans.png",
    "images/millie_bobby_brown.png",
    "images/alvaro_morte.png",
    "images/amber_heard.png"
]
        
# Gradio interface
iface = gr.Interface(
    fn=predict_and_fetch_images,
    inputs=gr.Image(type="pil", label="Upload or Take a Snapshot"),
    outputs=[
        "text",  # Predicted label
        "number",  # Prediction score
        gr.Gallery(label="Lookalike Images"),  # Slideshow component for images
        "html",  # Info/Description as HTML
        gr.Textbox(type="text", label="Feedback", placeholder="Provide feedback here")  # Feedback textbox
    ],
    examples=example_images,
    live=True,
    title="Celebrity Lookalike Predictor",
    description="Take a snapshot or upload an image to see which celebrity you look like!"
)


iface.launch()