Spaces:
Running
Running
""" | |
color extraction and palette generation | |
""" | |
import json | |
import base64 | |
from pathlib import Path | |
from typing import Dict, Any, Tuple | |
from src.state.poster_state import PosterState | |
from utils.langgraph_utils import LangGraphAgent, extract_json, load_prompt | |
from utils.src.logging_utils import log_agent_info, log_agent_success, log_agent_error, log_agent_warning | |
from src.config.poster_config import load_config | |
class ColorAgent: | |
"""extracts theme colors and generates color schemes""" | |
def __init__(self): | |
self.name = "color_agent" | |
self.logo_extraction_prompt = load_prompt("config/prompts/extract_theme_from_logo.txt") | |
self.figure_color_prompt = load_prompt("config/prompts/extract_color_from_figure.txt") | |
self.config = load_config() | |
self.color_config = self.config["colors"] | |
def __call__(self, state: PosterState) -> PosterState: | |
log_agent_info(self.name, "starting color analysis") | |
try: | |
aff_logo_path = state.get("aff_logo_path") | |
if aff_logo_path and Path(aff_logo_path).exists(): | |
log_agent_info(self.name, "extracting theme from affiliation logo") | |
theme_color = self._extract_theme_from_logo(aff_logo_path, state) | |
else: | |
log_agent_info(self.name, "no logo found, using visual fallback") | |
theme_color = self._extract_theme_from_visuals(state) | |
color_scheme = self._generate_color_scheme(theme_color) | |
color_scheme = self._add_contrast_color(color_scheme) | |
state["color_scheme"] = color_scheme | |
state["current_agent"] = self.name | |
self._save_color_scheme(state) | |
log_agent_success(self.name, f"theme: {theme_color}, {len(color_scheme)} colors") | |
except Exception as e: | |
log_agent_error(self.name, f"failed: {e}") | |
state["errors"].append(f"{self.name}: {e}") | |
return state | |
def _extract_theme_from_logo(self, logo_path: str, state: PosterState) -> str: | |
"""extract theme color from affiliation logo using vision LLM""" | |
log_agent_info(self.name, f"analyzing affiliation logo: {Path(logo_path).name}") | |
try: | |
# encode logo image | |
with open(logo_path, "rb") as f: | |
img_data = base64.b64encode(f.read()).decode() | |
agent = LangGraphAgent( | |
"color extraction specialist for academic institutions", | |
state["vision_model"] | |
) | |
messages = [ | |
{"type": "text", "text": self.logo_extraction_prompt}, | |
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_data}"}} | |
] | |
response = agent.step(json.dumps(messages)) | |
result = extract_json(response.content) | |
# add token usage | |
state["tokens"].add_vision(response.input_tokens, response.output_tokens) | |
extracted_color = result.get("extracted_color", load_config()["colors"]["fallback_theme"]) | |
suitability_score = result.get("suitability_score", 0) | |
log_agent_info(self.name, f"logo analysis: {result.get('color_name', 'unknown')} (score: {suitability_score})") | |
if result.get("adjustment_made") != "none": | |
log_agent_info(self.name, f"color adjusted: {result.get('adjustment_made')}") | |
return extracted_color | |
except Exception as e: | |
log_agent_warning(self.name, f"logo extraction failed: {e}, using fallback") | |
return self._extract_theme_from_visuals(state) | |
def _extract_theme_from_visuals(self, state: PosterState) -> str: | |
"""fallback: extract theme from key visuals""" | |
classified = state.get("classified_visuals", {}) | |
key_visual = classified.get("key_visual") | |
if not key_visual: | |
log_agent_warning(self.name, "no key visual found, using default navy color") | |
return load_config()["colors"]["fallback_theme"] | |
# get path to key visual | |
images = state.get("images", {}) | |
visual_path = None | |
if key_visual.startswith("figure_"): | |
fig_id = key_visual.replace("figure_", "") | |
if fig_id in images: | |
visual_path = images[fig_id].get("path") | |
if not visual_path or not Path(visual_path).exists(): | |
log_agent_warning(self.name, "key visual path not found, using default navy color") | |
return load_config()["colors"]["fallback_theme"] | |
# analyze figure to extract prominent color | |
try: | |
theme_color = self._analyze_figure_for_color(visual_path, state) | |
return theme_color | |
except Exception as e: | |
log_agent_warning(self.name, f"visual color extraction failed: {e}, using default navy color") | |
return load_config()["colors"]["fallback_theme"] | |
def _analyze_figure_for_color(self, image_path: str, state: PosterState) -> str: | |
"""analyze figure to extract theme color""" | |
log_agent_info(self.name, "analyzing figure for color extraction") | |
# encode image | |
with open(image_path, "rb") as f: | |
img_data = base64.b64encode(f.read()).decode() | |
agent = LangGraphAgent( | |
"color extraction expert for academic poster design", | |
state["vision_model"] | |
) | |
prompt = self.figure_color_prompt | |
messages = [ | |
{"type": "text", "text": prompt}, | |
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_data}"}} | |
] | |
response = agent.step(json.dumps(messages)) | |
result = extract_json(response.content) | |
# add token usage | |
state["tokens"].add_vision(response.input_tokens, response.output_tokens) | |
return result.get("theme_color", load_config()["colors"]["fallback_theme"]) | |
def _generate_color_scheme(self, theme_color: str) -> Dict[str, str]: | |
# hex to rgb | |
hex_color = theme_color.lstrip('#') | |
r = int(hex_color[0:2], 16) | |
g = int(hex_color[2:4], 16) | |
b = int(hex_color[4:6], 16) | |
# generate monochromatic variations | |
# mono_light: medium/high saturation + brighter variant | |
mono_light = self._generate_enhanced_light_variant(r, g, b) | |
# mono_dark: medium saturation + darker variant | |
mono_dark = self._generate_enhanced_dark_variant(r, g, b) | |
return { | |
"theme": theme_color, | |
"mono_light": mono_light, | |
"mono_dark": mono_dark, | |
"text": self.color_config["constants"]["black_text"], | |
"text_on_theme": self._get_contrast_text_color(theme_color) | |
} | |
def _add_contrast_color(self, color_scheme: Dict[str, str]) -> Dict[str, str]: | |
"""add contrast color for keyword highlighting""" | |
theme_color = color_scheme["theme"] | |
hex_color = theme_color.lstrip('#') | |
r = int(hex_color[0:2], 16) | |
g = int(hex_color[2:4], 16) | |
b = int(hex_color[4:6], 16) | |
comp_r, comp_g, comp_b = self._generate_complementary_color(r, g, b) | |
contrast_color = self._reduce_saturation_brightness(comp_r, comp_g, comp_b) | |
color_scheme["contrast"] = contrast_color | |
return color_scheme | |
def _generate_enhanced_light_variant(self, r: int, g: int, b: int) -> str: | |
"""generate light background color""" | |
h, s, v = self._rgb_to_hsv(r, g, b) | |
light_s = self.color_config["mono_light"]["saturation"] | |
light_v = self.color_config["mono_light"]["brightness"] | |
new_r, new_g, new_b = self._hsv_to_rgb(h, light_s, light_v) | |
return f"#{int(new_r):02x}{int(new_g):02x}{int(new_b):02x}" | |
def _generate_enhanced_dark_variant(self, r: int, g: int, b: int) -> str: | |
"""generate darker variant""" | |
color_params = self.config["colors"]["saturation_adjustments"] | |
bounds = self.config["colors"]["hsv_bounds"] | |
h, s, v = self._rgb_to_hsv(r, g, b) | |
s_range = self.color_config["mono_dark"]["saturation_range"] | |
enhanced_s = min(1.0, max(s_range[0], self.color_config["mono_dark"]["saturation_default"])) | |
enhanced_v = max(bounds["brightness_min"], v - color_params["dark_decrease"]) | |
new_r, new_g, new_b = self._hsv_to_rgb(h, enhanced_s, enhanced_v) | |
return f"#{int(new_r):02x}{int(new_g):02x}{int(new_b):02x}" | |
def _generate_complementary_color(self, r: int, g: int, b: int) -> Tuple[int, int, int]: | |
"""generate complementary color""" | |
h, s, v = self._rgb_to_hsv(r, g, b) | |
comp_h = (h + self.color_config["complementary"]["hue_offset"]) % 1.0 | |
comp_r, comp_g, comp_b = self._hsv_to_rgb(comp_h, s, v) | |
return int(comp_r), int(comp_g), int(comp_b) | |
def _reduce_saturation_brightness(self, r: int, g: int, b: int) -> str: | |
"""optimize contrast color for readability""" | |
h, s, v = self._rgb_to_hsv(r, g, b) | |
font_s = self.color_config["contrast_color"]["saturation"] | |
font_v = self.color_config["contrast_color"]["brightness_start"] | |
max_brightness = self.color_config["contrast_color"]["brightness_max"] | |
step = self.color_config["contrast_color"]["brightness_step"] | |
required_ratio = self.color_config["contrast_color"]["wcag_contrast_ratio"] | |
white_rgb = self.color_config["constants"]["white_rgb"] | |
while font_v < max_brightness: | |
test_r, test_g, test_b = self._hsv_to_rgb(h, font_s, font_v) | |
if self._calculate_contrast_ratio(test_r, test_g, test_b, *white_rgb) >= required_ratio: | |
break | |
font_v += step | |
final_r, final_g, final_b = self._hsv_to_rgb(h, font_s, font_v) | |
contrast_color = f"#{int(final_r):02x}{int(final_g):02x}{int(final_b):02x}" | |
return contrast_color | |
def _calculate_contrast_ratio(self, r1: int, g1: int, b1: int, r2: int, g2: int, b2: int) -> float: | |
"""calculate WCAG contrast ratio between two colors""" | |
l1 = self._get_relative_luminance(r1, g1, b1) | |
l2 = self._get_relative_luminance(r2, g2, b2) | |
if l1 < l2: | |
l1, l2 = l2, l1 | |
return (l1 + self.color_config["srgb"]["gamma_offset"]) / (l2 + self.color_config["srgb"]["gamma_offset"]) | |
def _get_relative_luminance(self, r: int, g: int, b: int) -> float: | |
"""calculate relative luminance for WCAG contrast calculations""" | |
max_rgb = self.color_config["constants"]["max_rgb"] | |
r_norm = r / max_rgb | |
g_norm = g / max_rgb | |
b_norm = b / max_rgb | |
def gamma_correct(c): | |
threshold = self.color_config["srgb"]["gamma_threshold"] | |
if c <= threshold: | |
return c / self.color_config["constants"]["gamma_linear_divisor"] | |
else: | |
offset = self.color_config["srgb"]["gamma_offset"] | |
divisor = self.color_config["srgb"]["gamma_divisor"] | |
exponent = self.color_config["srgb"]["gamma_exponent"] | |
return pow((c + offset) / divisor, exponent) | |
r_linear = gamma_correct(r_norm) | |
g_linear = gamma_correct(g_norm) | |
b_linear = gamma_correct(b_norm) | |
weights = self.color_config["luminance_weights"] | |
return weights["red"] * r_linear + weights["green"] * g_linear + weights["blue"] * b_linear | |
def _rgb_to_hsv(self, r: int, g: int, b: int) -> Tuple[float, float, float]: | |
"""convert rgb to hsv""" | |
max_rgb = self.color_config["constants"]["max_rgb"] | |
r, g, b = r/max_rgb, g/max_rgb, b/max_rgb | |
max_val = max(r, g, b) | |
min_val = min(r, g, b) | |
diff = max_val - min_val | |
if diff == 0: | |
h = 0 | |
elif max_val == r: | |
h = (60 * ((g - b) / diff) + 360) % 360 | |
elif max_val == g: | |
h = (60 * ((b - r) / diff) + 120) % 360 | |
else: | |
h = (60 * ((r - g) / diff) + 240) % 360 | |
s = 0 if max_val == 0 else diff / max_val | |
v = max_val | |
return h/360.0, s, v | |
def _hsv_to_rgb(self, h: float, s: float, v: float) -> Tuple[float, float, float]: | |
h = h * 360 # convert back to degrees | |
c = v * s | |
x = c * (1 - abs((h / 60) % 2 - 1)) | |
m = v - c | |
if 0 <= h < 60: | |
r, g, b = c, x, 0 | |
elif 60 <= h < 120: | |
r, g, b = x, c, 0 | |
elif 120 <= h < 180: | |
r, g, b = 0, c, x | |
elif 180 <= h < 240: | |
r, g, b = 0, x, c | |
elif 240 <= h < 300: | |
r, g, b = x, 0, c | |
else: | |
r, g, b = c, 0, x | |
max_rgb = self.color_config["constants"]["max_rgb"] | |
return (r + m) * max_rgb, (g + m) * max_rgb, (b + m) * max_rgb | |
def _get_contrast_text_color(self, bg_color: str) -> str: | |
"""determine appropriate text color for given background""" | |
hex_color = bg_color.lstrip('#') | |
r = int(hex_color[0:2], 16) | |
g = int(hex_color[2:4], 16) | |
b = int(hex_color[4:6], 16) | |
constants = self.color_config["constants"] | |
brightness = (r * constants["red_weight"] + | |
g * constants["green_weight"] + | |
b * constants["blue_weight"]) / constants["brightness_divisor"] | |
return (constants["white_text"] if brightness < constants["brightness_threshold"] | |
else constants["black_text"]) | |
def _save_color_scheme(self, state: PosterState): | |
"""save color scheme to json file""" | |
output_dir = Path(state["output_dir"]) / "content" | |
output_dir.mkdir(parents=True, exist_ok=True) | |
with open(output_dir / "color_scheme.json", "w", encoding='utf-8') as f: | |
json.dump(state.get("color_scheme", {}), f, indent=2) | |
def color_agent_node(state: PosterState) -> Dict[str, Any]: | |
result = ColorAgent()(state) | |
return { | |
**state, | |
"color_scheme": result["color_scheme"], | |
"tokens": result["tokens"], | |
"current_agent": result["current_agent"], | |
"errors": result["errors"] | |
} |