VOIDER commited on
Commit
d4c04db
·
verified ·
1 Parent(s): 9d90ffe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -22
app.py CHANGED
@@ -9,14 +9,10 @@ import os
9
  from collections import OrderedDict
10
 
11
  # --- 1. Define the Full Model Architecture ---
12
- # This class's structure is matched to the keys in the downloaded state_dict.
13
  class AestheticScorerModel(nn.Module):
14
  def __init__(self, clip_model_id, embedding_dim):
15
  super().__init__()
16
- # Use the base CLIPVisionModel, which provides the correct pooler_output dimension
17
  self.backbone = CLIPVisionModel.from_pretrained(clip_model_id)
18
-
19
- # Define the seven prediction heads
20
  self.aesthetic_head = nn.Sequential(nn.Linear(embedding_dim, 1))
21
  self.quality_head = nn.Sequential(nn.Linear(embedding_dim, 1))
22
  self.composition_head = nn.Sequential(nn.Linear(embedding_dim, 1))
@@ -26,11 +22,8 @@ class AestheticScorerModel(nn.Module):
26
  self.content_head = nn.Sequential(nn.Linear(embedding_dim, 1))
27
 
28
  def forward(self, pixel_values):
29
- # Use the un-projected 'pooler_output' which has dimension 768
30
  outputs = self.backbone(pixel_values=pixel_values)
31
  embedding = outputs.pooler_output
32
-
33
- # Calculate and concatenate scores from each head
34
  scores = torch.cat([
35
  self.aesthetic_head(embedding), self.quality_head(embedding),
36
  self.composition_head(embedding), self.light_head(embedding),
@@ -44,14 +37,13 @@ print("Loading model and processor...")
44
  CACHE_DIR = "hf_cache"
45
  MODEL_REPO_ID = "rsinema/aesthetic-scorer"
46
  CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
47
- EMBEDDING_DIM = 768 # This is the hidden_size of the base model
48
 
49
  processor = CLIPProcessor.from_pretrained(CLIP_MODEL_ID, cache_dir=CACHE_DIR)
50
  model = AestheticScorerModel(clip_model_id=CLIP_MODEL_ID, embedding_dim=EMBEDDING_DIM)
51
  model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="model.pt", cache_dir=CACHE_DIR)
52
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
53
 
54
- # Key renaming logic to align saved weights with the model's structure
55
  corrected_state_dict = OrderedDict()
56
  for key, value in state_dict.items():
57
  if key.startswith('backbone.'):
@@ -64,15 +56,11 @@ model.load_state_dict(corrected_state_dict, strict=False)
64
  model.eval()
65
  print("Model and processor loaded successfully.")
66
 
67
- # Define the aesthetic categories in the correct order for table headers
68
  AESTHETIC_CATEGORIES = ["Overall", "Quality", "Composition", "Lighting", "Color", "Depth of Field", "Content"]
69
  TABLE_HEADERS = ["Preview", "File"] + AESTHETIC_CATEGORIES
70
 
71
  # --- 3. Core Processing Function ---
72
  def score_images(files):
73
- """
74
- Processes uploaded images, scores them, and returns a DataFrame with image previews.
75
- """
76
  if not files:
77
  return pd.DataFrame(columns=TABLE_HEADERS)
78
 
@@ -86,8 +74,9 @@ def score_images(files):
86
  inputs = processor(images=image, return_tensors="pt")
87
  scores = model(**inputs)[0]
88
 
89
- # Create a dictionary for the current image's scores, including the preview path
90
- image_scores = {"Preview": file_path, "File": filename}
 
91
  for category, score in zip(AESTHETIC_CATEGORIES, scores):
92
  image_scores[category] = f"{score.item():.2f} / 5"
93
  results_list.append(image_scores)
@@ -97,11 +86,9 @@ def score_images(files):
97
  error_row = {"Preview": None, "File": filename, **{cat: "Processing Error" for cat in AESTHETIC_CATEGORIES}}
98
  results_list.append(error_row)
99
 
100
- # Create DataFrame with specified column order
101
  return pd.DataFrame(results_list, columns=TABLE_HEADERS)
102
 
103
  def clear_all():
104
- """Returns None to clear all specified components."""
105
  return None, None
106
 
107
  # --- 4. Gradio Interface using Blocks for Layout Control ---
@@ -110,7 +97,6 @@ with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}")
110
  gr.Markdown("Upload one or more images to compare their aesthetic scores across seven distinct categories. This application uses the **rsinema/aesthetic-scorer** model.")
111
 
112
  with gr.Column():
113
- # Input section
114
  file_uploader = gr.File(
115
  label="Upload Images",
116
  file_count="multiple",
@@ -120,16 +106,15 @@ with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}")
120
  clear_button = gr.Button("Clear")
121
  submit_button = gr.Button("Submit", variant="primary")
122
 
123
- # Output section is now below the inputs
124
  output_df = gr.Dataframe(
125
  headers=TABLE_HEADERS,
126
  label="Aesthetic Scores Comparison",
127
  interactive=False,
128
- # Set column widths to make the Preview column larger
129
- column_widths=[20, 30] + [10] * len(AESTHETIC_CATEGORIES)
 
130
  )
131
 
132
- # Event listeners
133
  submit_button.click(fn=score_images, inputs=file_uploader, outputs=output_df)
134
  clear_button.click(fn=clear_all, inputs=None, outputs=[file_uploader, output_df])
135
 
 
9
  from collections import OrderedDict
10
 
11
  # --- 1. Define the Full Model Architecture ---
 
12
  class AestheticScorerModel(nn.Module):
13
  def __init__(self, clip_model_id, embedding_dim):
14
  super().__init__()
 
15
  self.backbone = CLIPVisionModel.from_pretrained(clip_model_id)
 
 
16
  self.aesthetic_head = nn.Sequential(nn.Linear(embedding_dim, 1))
17
  self.quality_head = nn.Sequential(nn.Linear(embedding_dim, 1))
18
  self.composition_head = nn.Sequential(nn.Linear(embedding_dim, 1))
 
22
  self.content_head = nn.Sequential(nn.Linear(embedding_dim, 1))
23
 
24
  def forward(self, pixel_values):
 
25
  outputs = self.backbone(pixel_values=pixel_values)
26
  embedding = outputs.pooler_output
 
 
27
  scores = torch.cat([
28
  self.aesthetic_head(embedding), self.quality_head(embedding),
29
  self.composition_head(embedding), self.light_head(embedding),
 
37
  CACHE_DIR = "hf_cache"
38
  MODEL_REPO_ID = "rsinema/aesthetic-scorer"
39
  CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
40
+ EMBEDDING_DIM = 768
41
 
42
  processor = CLIPProcessor.from_pretrained(CLIP_MODEL_ID, cache_dir=CACHE_DIR)
43
  model = AestheticScorerModel(clip_model_id=CLIP_MODEL_ID, embedding_dim=EMBEDDING_DIM)
44
  model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename="model.pt", cache_dir=CACHE_DIR)
45
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))
46
 
 
47
  corrected_state_dict = OrderedDict()
48
  for key, value in state_dict.items():
49
  if key.startswith('backbone.'):
 
56
  model.eval()
57
  print("Model and processor loaded successfully.")
58
 
 
59
  AESTHETIC_CATEGORIES = ["Overall", "Quality", "Composition", "Lighting", "Color", "Depth of Field", "Content"]
60
  TABLE_HEADERS = ["Preview", "File"] + AESTHETIC_CATEGORIES
61
 
62
  # --- 3. Core Processing Function ---
63
  def score_images(files):
 
 
 
64
  if not files:
65
  return pd.DataFrame(columns=TABLE_HEADERS)
66
 
 
74
  inputs = processor(images=image, return_tensors="pt")
75
  scores = model(**inputs)[0]
76
 
77
+ # --- THIS IS THE CORRECTED LINE ---
78
+ # Provide the image data as a (filepath, alt_text) tuple for rendering.
79
+ image_scores = {"Preview": (file_path, filename), "File": filename}
80
  for category, score in zip(AESTHETIC_CATEGORIES, scores):
81
  image_scores[category] = f"{score.item():.2f} / 5"
82
  results_list.append(image_scores)
 
86
  error_row = {"Preview": None, "File": filename, **{cat: "Processing Error" for cat in AESTHETIC_CATEGORIES}}
87
  results_list.append(error_row)
88
 
 
89
  return pd.DataFrame(results_list, columns=TABLE_HEADERS)
90
 
91
  def clear_all():
 
92
  return None, None
93
 
94
  # --- 4. Gradio Interface using Blocks for Layout Control ---
 
97
  gr.Markdown("Upload one or more images to compare their aesthetic scores across seven distinct categories. This application uses the **rsinema/aesthetic-scorer** model.")
98
 
99
  with gr.Column():
 
100
  file_uploader = gr.File(
101
  label="Upload Images",
102
  file_count="multiple",
 
106
  clear_button = gr.Button("Clear")
107
  submit_button = gr.Button("Submit", variant="primary")
108
 
 
109
  output_df = gr.Dataframe(
110
  headers=TABLE_HEADERS,
111
  label="Aesthetic Scores Comparison",
112
  interactive=False,
113
+ # Specify the data type for the Preview column as 'image'
114
+ datatype=["image", "str"] + ["str"] * len(AESTHETIC_CATEGORIES),
115
+ column_widths=[15, 35] + [10] * len(AESTHETIC_CATEGORIES)
116
  )
117
 
 
118
  submit_button.click(fn=score_images, inputs=file_uploader, outputs=output_df)
119
  clear_button.click(fn=clear_all, inputs=None, outputs=[file_uploader, output_df])
120