Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -9,14 +9,10 @@ import os
|
|
9 |
from collections import OrderedDict
|
10 |
|
11 |
# --- 1. Define the Full Model Architecture ---
|
12 |
-
# This class's structure is matched to the keys in the downloaded state_dict.
|
13 |
class AestheticScorerModel(nn.Module):
|
14 |
def __init__(self, clip_model_id, embedding_dim):
|
15 |
super().__init__()
|
16 |
-
# Use the base CLIPVisionModel, which provides the correct pooler_output dimension
|
17 |
self.backbone = CLIPVisionModel.from_pretrained(clip_model_id)
|
18 |
-
|
19 |
-
# Define the seven prediction heads
|
20 |
self.aesthetic_head = nn.Sequential(nn.Linear(embedding_dim, 1))
|
21 |
self.quality_head = nn.Sequential(nn.Linear(embedding_dim, 1))
|
22 |
self.composition_head = nn.Sequential(nn.Linear(embedding_dim, 1))
|
@@ -26,11 +22,8 @@ class AestheticScorerModel(nn.Module):
|
|
26 |
self.content_head = nn.Sequential(nn.Linear(embedding_dim, 1))
|
27 |
|
28 |
def forward(self, pixel_values):
|
29 |
-
# Use the un-projected 'pooler_output' which has dimension 768
|
30 |
outputs = self.backbone(pixel_values=pixel_values)
|
31 |
embedding = outputs.pooler_output
|
32 |
-
|
33 |
-
# Calculate and concatenate scores from each head
|
34 |
scores = torch.cat([
|
35 |
self.aesthetic_head(embedding), self.quality_head(embedding),
|
36 |
self.composition_head(embedding), self.light_head(embedding),
|
@@ -44,14 +37,13 @@ print("Loading model and processor...")
|
|
44 |
CACHE_DIR = "hf_cache"
|
45 |
MODEL_REPO_ID = "rsinema/aesthetic-scorer"
|
46 |
CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
|
47 |
-
EMBEDDING_DIM = 768
|
48 |
|
49 |
processor = CLIPProcessor.from_pretrained(CLIP_MODEL_ID, cache_dir=CACHE_DIR)
|
50 |
model = AestheticScorerModel(clip_model_id=CLIP_MODEL_ID, embedding_dim=EMBEDDING_DIM)
|
51 |
model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="model.pt", cache_dir=CACHE_DIR)
|
52 |
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
|
53 |
|
54 |
-
# Key renaming logic to align saved weights with the model's structure
|
55 |
corrected_state_dict = OrderedDict()
|
56 |
for key, value in state_dict.items():
|
57 |
if key.startswith('backbone.'):
|
@@ -64,15 +56,11 @@ model.load_state_dict(corrected_state_dict, strict=False)
|
|
64 |
model.eval()
|
65 |
print("Model and processor loaded successfully.")
|
66 |
|
67 |
-
# Define the aesthetic categories in the correct order for table headers
|
68 |
AESTHETIC_CATEGORIES = ["Overall", "Quality", "Composition", "Lighting", "Color", "Depth of Field", "Content"]
|
69 |
TABLE_HEADERS = ["Preview", "File"] + AESTHETIC_CATEGORIES
|
70 |
|
71 |
# --- 3. Core Processing Function ---
|
72 |
def score_images(files):
|
73 |
-
"""
|
74 |
-
Processes uploaded images, scores them, and returns a DataFrame with image previews.
|
75 |
-
"""
|
76 |
if not files:
|
77 |
return pd.DataFrame(columns=TABLE_HEADERS)
|
78 |
|
@@ -86,8 +74,9 @@ def score_images(files):
|
|
86 |
inputs = processor(images=image, return_tensors="pt")
|
87 |
scores = model(**inputs)[0]
|
88 |
|
89 |
-
#
|
90 |
-
|
|
|
91 |
for category, score in zip(AESTHETIC_CATEGORIES, scores):
|
92 |
image_scores[category] = f"{score.item():.2f} / 5"
|
93 |
results_list.append(image_scores)
|
@@ -97,11 +86,9 @@ def score_images(files):
|
|
97 |
error_row = {"Preview": None, "File": filename, **{cat: "Processing Error" for cat in AESTHETIC_CATEGORIES}}
|
98 |
results_list.append(error_row)
|
99 |
|
100 |
-
# Create DataFrame with specified column order
|
101 |
return pd.DataFrame(results_list, columns=TABLE_HEADERS)
|
102 |
|
103 |
def clear_all():
|
104 |
-
"""Returns None to clear all specified components."""
|
105 |
return None, None
|
106 |
|
107 |
# --- 4. Gradio Interface using Blocks for Layout Control ---
|
@@ -110,7 +97,6 @@ with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}")
|
|
110 |
gr.Markdown("Upload one or more images to compare their aesthetic scores across seven distinct categories. This application uses the **rsinema/aesthetic-scorer** model.")
|
111 |
|
112 |
with gr.Column():
|
113 |
-
# Input section
|
114 |
file_uploader = gr.File(
|
115 |
label="Upload Images",
|
116 |
file_count="multiple",
|
@@ -120,16 +106,15 @@ with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}")
|
|
120 |
clear_button = gr.Button("Clear")
|
121 |
submit_button = gr.Button("Submit", variant="primary")
|
122 |
|
123 |
-
# Output section is now below the inputs
|
124 |
output_df = gr.Dataframe(
|
125 |
headers=TABLE_HEADERS,
|
126 |
label="Aesthetic Scores Comparison",
|
127 |
interactive=False,
|
128 |
-
#
|
129 |
-
|
|
|
130 |
)
|
131 |
|
132 |
-
# Event listeners
|
133 |
submit_button.click(fn=score_images, inputs=file_uploader, outputs=output_df)
|
134 |
clear_button.click(fn=clear_all, inputs=None, outputs=[file_uploader, output_df])
|
135 |
|
|
|
9 |
from collections import OrderedDict
|
10 |
|
11 |
# --- 1. Define the Full Model Architecture ---
|
|
|
12 |
class AestheticScorerModel(nn.Module):
|
13 |
def __init__(self, clip_model_id, embedding_dim):
|
14 |
super().__init__()
|
|
|
15 |
self.backbone = CLIPVisionModel.from_pretrained(clip_model_id)
|
|
|
|
|
16 |
self.aesthetic_head = nn.Sequential(nn.Linear(embedding_dim, 1))
|
17 |
self.quality_head = nn.Sequential(nn.Linear(embedding_dim, 1))
|
18 |
self.composition_head = nn.Sequential(nn.Linear(embedding_dim, 1))
|
|
|
22 |
self.content_head = nn.Sequential(nn.Linear(embedding_dim, 1))
|
23 |
|
24 |
def forward(self, pixel_values):
|
|
|
25 |
outputs = self.backbone(pixel_values=pixel_values)
|
26 |
embedding = outputs.pooler_output
|
|
|
|
|
27 |
scores = torch.cat([
|
28 |
self.aesthetic_head(embedding), self.quality_head(embedding),
|
29 |
self.composition_head(embedding), self.light_head(embedding),
|
|
|
37 |
CACHE_DIR = "hf_cache"
|
38 |
MODEL_REPO_ID = "rsinema/aesthetic-scorer"
|
39 |
CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
|
40 |
+
EMBEDDING_DIM = 768
|
41 |
|
42 |
processor = CLIPProcessor.from_pretrained(CLIP_MODEL_ID, cache_dir=CACHE_DIR)
|
43 |
model = AestheticScorerModel(clip_model_id=CLIP_MODEL_ID, embedding_dim=EMBEDDING_DIM)
|
44 |
model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="model.pt", cache_dir=CACHE_DIR)
|
45 |
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
|
46 |
|
|
|
47 |
corrected_state_dict = OrderedDict()
|
48 |
for key, value in state_dict.items():
|
49 |
if key.startswith('backbone.'):
|
|
|
56 |
model.eval()
|
57 |
print("Model and processor loaded successfully.")
|
58 |
|
|
|
59 |
AESTHETIC_CATEGORIES = ["Overall", "Quality", "Composition", "Lighting", "Color", "Depth of Field", "Content"]
|
60 |
TABLE_HEADERS = ["Preview", "File"] + AESTHETIC_CATEGORIES
|
61 |
|
62 |
# --- 3. Core Processing Function ---
|
63 |
def score_images(files):
|
|
|
|
|
|
|
64 |
if not files:
|
65 |
return pd.DataFrame(columns=TABLE_HEADERS)
|
66 |
|
|
|
74 |
inputs = processor(images=image, return_tensors="pt")
|
75 |
scores = model(**inputs)[0]
|
76 |
|
77 |
+
# --- THIS IS THE CORRECTED LINE ---
|
78 |
+
# Provide the image data as a (filepath, alt_text) tuple for rendering.
|
79 |
+
image_scores = {"Preview": (file_path, filename), "File": filename}
|
80 |
for category, score in zip(AESTHETIC_CATEGORIES, scores):
|
81 |
image_scores[category] = f"{score.item():.2f} / 5"
|
82 |
results_list.append(image_scores)
|
|
|
86 |
error_row = {"Preview": None, "File": filename, **{cat: "Processing Error" for cat in AESTHETIC_CATEGORIES}}
|
87 |
results_list.append(error_row)
|
88 |
|
|
|
89 |
return pd.DataFrame(results_list, columns=TABLE_HEADERS)
|
90 |
|
91 |
def clear_all():
|
|
|
92 |
return None, None
|
93 |
|
94 |
# --- 4. Gradio Interface using Blocks for Layout Control ---
|
|
|
97 |
gr.Markdown("Upload one or more images to compare their aesthetic scores across seven distinct categories. This application uses the **rsinema/aesthetic-scorer** model.")
|
98 |
|
99 |
with gr.Column():
|
|
|
100 |
file_uploader = gr.File(
|
101 |
label="Upload Images",
|
102 |
file_count="multiple",
|
|
|
106 |
clear_button = gr.Button("Clear")
|
107 |
submit_button = gr.Button("Submit", variant="primary")
|
108 |
|
|
|
109 |
output_df = gr.Dataframe(
|
110 |
headers=TABLE_HEADERS,
|
111 |
label="Aesthetic Scores Comparison",
|
112 |
interactive=False,
|
113 |
+
# Specify the data type for the Preview column as 'image'
|
114 |
+
datatype=["image", "str"] + ["str"] * len(AESTHETIC_CATEGORIES),
|
115 |
+
column_widths=[15, 35] + [10] * len(AESTHETIC_CATEGORIES)
|
116 |
)
|
117 |
|
|
|
118 |
submit_button.click(fn=score_images, inputs=file_uploader, outputs=output_df)
|
119 |
clear_button.click(fn=clear_all, inputs=None, outputs=[file_uploader, output_df])
|
120 |
|