""" color extraction and palette generation """ import json import base64 from pathlib import Path from typing import Dict, Any, Tuple from src.state.poster_state import PosterState from utils.langgraph_utils import LangGraphAgent, extract_json, load_prompt from utils.src.logging_utils import log_agent_info, log_agent_success, log_agent_error, log_agent_warning from src.config.poster_config import load_config class ColorAgent: """extracts theme colors and generates color schemes""" def __init__(self): self.name = "color_agent" self.logo_extraction_prompt = load_prompt("config/prompts/extract_theme_from_logo.txt") self.figure_color_prompt = load_prompt("config/prompts/extract_color_from_figure.txt") self.config = load_config() self.color_config = self.config["colors"] def __call__(self, state: PosterState) -> PosterState: log_agent_info(self.name, "starting color analysis") try: aff_logo_path = state.get("aff_logo_path") if aff_logo_path and Path(aff_logo_path).exists(): log_agent_info(self.name, "extracting theme from affiliation logo") theme_color = self._extract_theme_from_logo(aff_logo_path, state) else: log_agent_info(self.name, "no logo found, using visual fallback") theme_color = self._extract_theme_from_visuals(state) color_scheme = self._generate_color_scheme(theme_color) color_scheme = self._add_contrast_color(color_scheme) state["color_scheme"] = color_scheme state["current_agent"] = self.name self._save_color_scheme(state) log_agent_success(self.name, f"theme: {theme_color}, {len(color_scheme)} colors") except Exception as e: log_agent_error(self.name, f"failed: {e}") state["errors"].append(f"{self.name}: {e}") return state def _extract_theme_from_logo(self, logo_path: str, state: PosterState) -> str: """extract theme color from affiliation logo using vision LLM""" log_agent_info(self.name, f"analyzing affiliation logo: {Path(logo_path).name}") try: # encode logo image with open(logo_path, "rb") as f: img_data = base64.b64encode(f.read()).decode() agent = LangGraphAgent( "color extraction specialist for academic institutions", state["vision_model"] ) messages = [ {"type": "text", "text": self.logo_extraction_prompt}, {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_data}"}} ] response = agent.step(json.dumps(messages)) result = extract_json(response.content) # add token usage state["tokens"].add_vision(response.input_tokens, response.output_tokens) extracted_color = result.get("extracted_color", load_config()["colors"]["fallback_theme"]) suitability_score = result.get("suitability_score", 0) log_agent_info(self.name, f"logo analysis: {result.get('color_name', 'unknown')} (score: {suitability_score})") if result.get("adjustment_made") != "none": log_agent_info(self.name, f"color adjusted: {result.get('adjustment_made')}") return extracted_color except Exception as e: log_agent_warning(self.name, f"logo extraction failed: {e}, using fallback") return self._extract_theme_from_visuals(state) def _extract_theme_from_visuals(self, state: PosterState) -> str: """fallback: extract theme from key visuals""" classified = state.get("classified_visuals", {}) key_visual = classified.get("key_visual") if not key_visual: log_agent_warning(self.name, "no key visual found, using default navy color") return load_config()["colors"]["fallback_theme"] # get path to key visual images = state.get("images", {}) visual_path = None if key_visual.startswith("figure_"): fig_id = key_visual.replace("figure_", "") if fig_id in images: visual_path = images[fig_id].get("path") if not visual_path or not Path(visual_path).exists(): log_agent_warning(self.name, "key visual path not found, using default navy color") return load_config()["colors"]["fallback_theme"] # analyze figure to extract prominent color try: theme_color = self._analyze_figure_for_color(visual_path, state) return theme_color except Exception as e: log_agent_warning(self.name, f"visual color extraction failed: {e}, using default navy color") return load_config()["colors"]["fallback_theme"] def _analyze_figure_for_color(self, image_path: str, state: PosterState) -> str: """analyze figure to extract theme color""" log_agent_info(self.name, "analyzing figure for color extraction") # encode image with open(image_path, "rb") as f: img_data = base64.b64encode(f.read()).decode() agent = LangGraphAgent( "color extraction expert for academic poster design", state["vision_model"] ) prompt = self.figure_color_prompt messages = [ {"type": "text", "text": prompt}, {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_data}"}} ] response = agent.step(json.dumps(messages)) result = extract_json(response.content) # add token usage state["tokens"].add_vision(response.input_tokens, response.output_tokens) return result.get("theme_color", load_config()["colors"]["fallback_theme"]) def _generate_color_scheme(self, theme_color: str) -> Dict[str, str]: # hex to rgb hex_color = theme_color.lstrip('#') r = int(hex_color[0:2], 16) g = int(hex_color[2:4], 16) b = int(hex_color[4:6], 16) # generate monochromatic variations # mono_light: medium/high saturation + brighter variant mono_light = self._generate_enhanced_light_variant(r, g, b) # mono_dark: medium saturation + darker variant mono_dark = self._generate_enhanced_dark_variant(r, g, b) return { "theme": theme_color, "mono_light": mono_light, "mono_dark": mono_dark, "text": self.color_config["constants"]["black_text"], "text_on_theme": self._get_contrast_text_color(theme_color) } def _add_contrast_color(self, color_scheme: Dict[str, str]) -> Dict[str, str]: """add contrast color for keyword highlighting""" theme_color = color_scheme["theme"] hex_color = theme_color.lstrip('#') r = int(hex_color[0:2], 16) g = int(hex_color[2:4], 16) b = int(hex_color[4:6], 16) comp_r, comp_g, comp_b = self._generate_complementary_color(r, g, b) contrast_color = self._reduce_saturation_brightness(comp_r, comp_g, comp_b) color_scheme["contrast"] = contrast_color return color_scheme def _generate_enhanced_light_variant(self, r: int, g: int, b: int) -> str: """generate light background color""" h, s, v = self._rgb_to_hsv(r, g, b) light_s = self.color_config["mono_light"]["saturation"] light_v = self.color_config["mono_light"]["brightness"] new_r, new_g, new_b = self._hsv_to_rgb(h, light_s, light_v) return f"#{int(new_r):02x}{int(new_g):02x}{int(new_b):02x}" def _generate_enhanced_dark_variant(self, r: int, g: int, b: int) -> str: """generate darker variant""" color_params = self.config["colors"]["saturation_adjustments"] bounds = self.config["colors"]["hsv_bounds"] h, s, v = self._rgb_to_hsv(r, g, b) s_range = self.color_config["mono_dark"]["saturation_range"] enhanced_s = min(1.0, max(s_range[0], self.color_config["mono_dark"]["saturation_default"])) enhanced_v = max(bounds["brightness_min"], v - color_params["dark_decrease"]) new_r, new_g, new_b = self._hsv_to_rgb(h, enhanced_s, enhanced_v) return f"#{int(new_r):02x}{int(new_g):02x}{int(new_b):02x}" def _generate_complementary_color(self, r: int, g: int, b: int) -> Tuple[int, int, int]: """generate complementary color""" h, s, v = self._rgb_to_hsv(r, g, b) comp_h = (h + self.color_config["complementary"]["hue_offset"]) % 1.0 comp_r, comp_g, comp_b = self._hsv_to_rgb(comp_h, s, v) return int(comp_r), int(comp_g), int(comp_b) def _reduce_saturation_brightness(self, r: int, g: int, b: int) -> str: """optimize contrast color for readability""" h, s, v = self._rgb_to_hsv(r, g, b) font_s = self.color_config["contrast_color"]["saturation"] font_v = self.color_config["contrast_color"]["brightness_start"] max_brightness = self.color_config["contrast_color"]["brightness_max"] step = self.color_config["contrast_color"]["brightness_step"] required_ratio = self.color_config["contrast_color"]["wcag_contrast_ratio"] white_rgb = self.color_config["constants"]["white_rgb"] while font_v < max_brightness: test_r, test_g, test_b = self._hsv_to_rgb(h, font_s, font_v) if self._calculate_contrast_ratio(test_r, test_g, test_b, *white_rgb) >= required_ratio: break font_v += step final_r, final_g, final_b = self._hsv_to_rgb(h, font_s, font_v) contrast_color = f"#{int(final_r):02x}{int(final_g):02x}{int(final_b):02x}" return contrast_color def _calculate_contrast_ratio(self, r1: int, g1: int, b1: int, r2: int, g2: int, b2: int) -> float: """calculate WCAG contrast ratio between two colors""" l1 = self._get_relative_luminance(r1, g1, b1) l2 = self._get_relative_luminance(r2, g2, b2) if l1 < l2: l1, l2 = l2, l1 return (l1 + self.color_config["srgb"]["gamma_offset"]) / (l2 + self.color_config["srgb"]["gamma_offset"]) def _get_relative_luminance(self, r: int, g: int, b: int) -> float: """calculate relative luminance for WCAG contrast calculations""" max_rgb = self.color_config["constants"]["max_rgb"] r_norm = r / max_rgb g_norm = g / max_rgb b_norm = b / max_rgb def gamma_correct(c): threshold = self.color_config["srgb"]["gamma_threshold"] if c <= threshold: return c / self.color_config["constants"]["gamma_linear_divisor"] else: offset = self.color_config["srgb"]["gamma_offset"] divisor = self.color_config["srgb"]["gamma_divisor"] exponent = self.color_config["srgb"]["gamma_exponent"] return pow((c + offset) / divisor, exponent) r_linear = gamma_correct(r_norm) g_linear = gamma_correct(g_norm) b_linear = gamma_correct(b_norm) weights = self.color_config["luminance_weights"] return weights["red"] * r_linear + weights["green"] * g_linear + weights["blue"] * b_linear def _rgb_to_hsv(self, r: int, g: int, b: int) -> Tuple[float, float, float]: """convert rgb to hsv""" max_rgb = self.color_config["constants"]["max_rgb"] r, g, b = r/max_rgb, g/max_rgb, b/max_rgb max_val = max(r, g, b) min_val = min(r, g, b) diff = max_val - min_val if diff == 0: h = 0 elif max_val == r: h = (60 * ((g - b) / diff) + 360) % 360 elif max_val == g: h = (60 * ((b - r) / diff) + 120) % 360 else: h = (60 * ((r - g) / diff) + 240) % 360 s = 0 if max_val == 0 else diff / max_val v = max_val return h/360.0, s, v def _hsv_to_rgb(self, h: float, s: float, v: float) -> Tuple[float, float, float]: h = h * 360 # convert back to degrees c = v * s x = c * (1 - abs((h / 60) % 2 - 1)) m = v - c if 0 <= h < 60: r, g, b = c, x, 0 elif 60 <= h < 120: r, g, b = x, c, 0 elif 120 <= h < 180: r, g, b = 0, c, x elif 180 <= h < 240: r, g, b = 0, x, c elif 240 <= h < 300: r, g, b = x, 0, c else: r, g, b = c, 0, x max_rgb = self.color_config["constants"]["max_rgb"] return (r + m) * max_rgb, (g + m) * max_rgb, (b + m) * max_rgb def _get_contrast_text_color(self, bg_color: str) -> str: """determine appropriate text color for given background""" hex_color = bg_color.lstrip('#') r = int(hex_color[0:2], 16) g = int(hex_color[2:4], 16) b = int(hex_color[4:6], 16) constants = self.color_config["constants"] brightness = (r * constants["red_weight"] + g * constants["green_weight"] + b * constants["blue_weight"]) / constants["brightness_divisor"] return (constants["white_text"] if brightness < constants["brightness_threshold"] else constants["black_text"]) def _save_color_scheme(self, state: PosterState): """save color scheme to json file""" output_dir = Path(state["output_dir"]) / "content" output_dir.mkdir(parents=True, exist_ok=True) with open(output_dir / "color_scheme.json", "w", encoding='utf-8') as f: json.dump(state.get("color_scheme", {}), f, indent=2) def color_agent_node(state: PosterState) -> Dict[str, Any]: result = ColorAgent()(state) return { **state, "color_scheme": result["color_scheme"], "tokens": result["tokens"], "current_agent": result["current_agent"], "errors": result["errors"] }