""" ๐Ÿ”ฎ PHOENIX Retention Research Platform - FINAL INTEGRATED VERSION Zero-shot Model Burning + Optional Fine-tuning โœ… Zero-shot Conversion (No Dataset Required) โœ… Optional Fine-tuning (Dataset-based) โœ… GQA Support โœ… HuggingFace Hub Integration โœ… Comprehensive Evaluation VIDraft AI Research Lab """ import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import sqlite3 import json import time import numpy as np from datetime import datetime from pathlib import Path import plotly.graph_objects as go import plotly.express as px import pandas as pd from typing import Dict, List, Any, Tuple, Optional import chromadb from chromadb.config import Settings from transformers import ( AutoModel, AutoTokenizer, AutoConfig, AutoModelForCausalLM, get_cosine_schedule_with_warmup, TrainingArguments, Trainer ) from datasets import load_dataset from torch.utils.data import Dataset, DataLoader from accelerate import Accelerator from tqdm import tqdm import copy import shutil # ===================================================== # ์ „์—ญ ์„ค์ • # ===================================================== DEVICE = "cuda" if torch.cuda.is_available() else "cpu" STORAGE_PATH = "/data" DB_PATH = f"{STORAGE_PATH}/phoenix_experiments.db" VECTOR_DB_PATH = f"{STORAGE_PATH}/vector_store" MODELS_PATH = f"{STORAGE_PATH}/phoenix_models" DEFAULT_MODEL = "ibm-granite/granite-4.0-h-350m" Path(STORAGE_PATH).mkdir(parents=True, exist_ok=True) Path(VECTOR_DB_PATH).mkdir(parents=True, exist_ok=True) Path(MODELS_PATH).mkdir(parents=True, exist_ok=True) print(f"๐Ÿš€ PHOENIX Platform initialized on {DEVICE}") print(f"๐Ÿ’พ Storage: {STORAGE_PATH}") print(f"๐ŸŽฏ Default Base Model: {DEFAULT_MODEL}") # ===================================================== # PHOENIX Retention with GQA Support # ===================================================== class MultiScaleRetention(nn.Module): """์ง„์งœ Retention Attention with GQA Support""" def __init__(self, config, layer_idx=0): super().__init__() self.config = config self.layer_idx = layer_idx # Q dimensions self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads # K/V dimensions (GQA) if hasattr(config, 'num_key_value_heads'): self.num_key_value_heads = config.num_key_value_heads else: self.num_key_value_heads = self.num_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.kv_head_dim = self.head_dim self.kv_dim = self.num_key_value_heads * self.kv_head_dim # Internal state storage for KV cache simulation self.register_buffer('_internal_state', None, persistent=False) self.register_buffer('_state_initialized', torch.tensor(False), persistent=False) # Projections with correct dimensions self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False) self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) # Retention parameters decay_values = torch.linspace(0.95, 0.99, self.num_heads) self.decay = nn.Parameter(decay_values, requires_grad=True) # Group norm self.group_norm = nn.GroupNorm( num_groups=self.num_heads, num_channels=self.hidden_size ) def _repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """Repeat K/V heads to match Q heads (GQA)""" batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_key_value_heads, n_rep, slen, head_dim ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def reset_state(self): """Reset internal state""" self._internal_state = None self._state_initialized = torch.tensor(False) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[torch.Tensor]] = None, **kwargs ): """O(n) Retention with GQA support""" batch_size, seq_len, _ = hidden_states.shape if past_key_values is not None: past_key_value = past_key_values # Q, K, V projections query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) # Reshape query_states = query_states.view( batch_size, seq_len, self.num_heads, self.head_dim ).transpose(1, 2) key_states = key_states.view( batch_size, seq_len, self.num_key_value_heads, self.kv_head_dim ).transpose(1, 2) value_states = value_states.view( batch_size, seq_len, self.num_key_value_heads, self.kv_head_dim ).transpose(1, 2) # Repeat K/V to match Q heads (GQA) key_states = self._repeat_kv(key_states, self.num_key_value_groups) value_states = self._repeat_kv(value_states, self.num_key_value_groups) # Retention computation past_state = self._internal_state if (use_cache and self._state_initialized) else None retention_states, new_state = self._compute_retention( query_states, key_states, value_states, past_state ) # Store state internally if use_cache: self._internal_state = new_state.detach() self._state_initialized = torch.tensor(True) # Reshape back retention_states = retention_states.transpose(1, 2).contiguous() retention_states = retention_states.reshape( batch_size, seq_len, self.hidden_size ) # Group norm if not next(self.group_norm.parameters()).is_cuda and retention_states.is_cuda: self.group_norm = self.group_norm.to(retention_states.device, dtype=retention_states.dtype) elif next(self.group_norm.parameters()).dtype != retention_states.dtype: self.group_norm = self.group_norm.to(dtype=retention_states.dtype) retention_states = self.group_norm( retention_states.transpose(1, 2) ).transpose(1, 2) retention_states = torch.clamp(retention_states, min=-10.0, max=10.0) # Output projection attn_output = self.o_proj(retention_states) return (attn_output, None) def _compute_retention( self, queries: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, past_state: Optional[torch.Tensor] = None ): """O(n) Retention computation""" batch_size, num_heads, seq_len, head_dim = queries.shape if past_state is not None: state = past_state.to(queries.device, dtype=queries.dtype) else: state = torch.zeros( batch_size, num_heads, head_dim, head_dim, dtype=queries.dtype, device=queries.device ) + 1e-6 outputs = [] decay = torch.sigmoid(self.decay).view(1, -1, 1, 1).to( device=queries.device, dtype=queries.dtype ) for t in range(seq_len): q_t = queries[:, :, t, :] k_t = keys[:, :, t, :] v_t = values[:, :, t, :] state = decay * state kv_update = torch.einsum('bhd,bhe->bhde', k_t, v_t) kv_update = torch.clamp(kv_update, min=-5.0, max=5.0) state = state + kv_update state = torch.clamp(state, min=-10.0, max=10.0) output_t = torch.einsum('bhd,bhde->bhe', q_t, state) outputs.append(output_t) output = torch.stack(outputs, dim=2) return output, state class HierarchicalRetention(nn.Module): """PHOENIX Hierarchical Retention with GQA""" def __init__(self, config, layer_idx=0): super().__init__() self.base_retention = MultiScaleRetention(config, layer_idx) hidden_size = config.hidden_size self.d_state = hidden_size // 2 self.short_proj = nn.Linear(hidden_size, self.d_state) self.medium_proj = nn.Linear(self.d_state, self.d_state) self.long_proj = nn.Linear(self.d_state, self.d_state * 2) self.fusion = nn.Linear(self.d_state * 4, hidden_size) self.short_decay = 0.5 self.medium_decay = 0.8 self.long_decay = 0.95 self.norm = nn.LayerNorm(hidden_size) if next(self.base_retention.parameters()).is_cuda: device = next(self.base_retention.parameters()).device dtype = next(self.base_retention.parameters()).dtype self.short_proj = self.short_proj.to(device, dtype=dtype) self.medium_proj = self.medium_proj.to(device, dtype=dtype) self.long_proj = self.long_proj.to(device, dtype=dtype) self.fusion = self.fusion.to(device, dtype=dtype) self.norm = self.norm.to(device, dtype=dtype) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[torch.Tensor]] = None, **kwargs ): """Hierarchical forward pass""" batch_size, seq_len, hidden_size = hidden_states.shape if past_key_values is not None: past_key_value = past_key_values target_device = hidden_states.device target_dtype = hidden_states.dtype if not next(self.short_proj.parameters()).is_cuda and hidden_states.is_cuda: self.short_proj = self.short_proj.to(target_device, dtype=target_dtype) self.medium_proj = self.medium_proj.to(target_device, dtype=target_dtype) self.long_proj = self.long_proj.to(target_device, dtype=target_dtype) self.fusion = self.fusion.to(target_device, dtype=target_dtype) self.norm = self.norm.to(target_device, dtype=target_dtype) elif next(self.short_proj.parameters()).dtype != target_dtype: self.short_proj = self.short_proj.to(dtype=target_dtype) self.medium_proj = self.medium_proj.to(dtype=target_dtype) self.long_proj = self.long_proj.to(dtype=target_dtype) self.fusion = self.fusion.to(dtype=target_dtype) self.norm = self.norm.to(dtype=target_dtype) base_result = self.base_retention( hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache ) retention_output = base_result[0] # Hierarchical states short_state = torch.zeros(batch_size, self.d_state, dtype=hidden_states.dtype, device=target_device) medium_state = torch.zeros(batch_size, self.d_state, dtype=hidden_states.dtype, device=target_device) long_state = torch.zeros(batch_size, self.d_state * 2, dtype=hidden_states.dtype, device=target_device) hierarchical_outputs = [] for t in range(seq_len): x_t = retention_output[:, t, :] short_input = self.short_proj(x_t) short_state = self.short_decay * short_state + short_input if t % 8 == 0: medium_state = self.medium_decay * medium_state + \ self.medium_proj(short_state) if t % 64 == 0: long_state = self.long_decay * long_state + \ self.long_proj(medium_state) combined = torch.cat([short_state, medium_state, long_state], dim=-1) output_t = self.fusion(combined) hierarchical_outputs.append(output_t) output = torch.stack(hierarchical_outputs, dim=1) output = self.norm(output) return (output, None) # ===================================================== # ๋ชจ๋ธ ๋ณ€ํ™˜ ํ•จ์ˆ˜ # ===================================================== def replace_attention_with_retention(model, use_hierarchical=True): """Transformer Attention โ†’ PHOENIX Retention (GQA Support)""" print("๐Ÿ”„ Starting Attention โ†’ Retention conversion (GQA support)...") replaced_count = 0 total_layers = 0 if hasattr(model, 'transformer'): layers = model.transformer.h elif hasattr(model, 'model') and hasattr(model.model, 'layers'): layers = model.model.layers elif hasattr(model, 'layers'): layers = model.layers else: print("โš ๏ธ Unknown model structure") return model, 0, 0 total_layers = len(layers) # Check first layer for GQA first_layer = layers[0] if hasattr(first_layer, 'self_attn'): old_attn = first_layer.self_attn if hasattr(old_attn, 'q_proj'): q_shape = old_attn.q_proj.weight.shape k_shape = old_attn.k_proj.weight.shape if k_shape[0] != q_shape[0]: print(f" โœ… GQA detected! (K/V dim: {k_shape[0]} < Q dim: {q_shape[0]})") if not hasattr(model.config, 'num_key_value_heads'): num_kv_heads = k_shape[0] // (model.config.hidden_size // model.config.num_attention_heads) model.config.num_key_value_heads = num_kv_heads for layer_idx, layer in enumerate(layers): try: if hasattr(layer, 'self_attn'): old_attn = layer.self_attn if use_hierarchical: new_retention = HierarchicalRetention(model.config, layer_idx) else: new_retention = MultiScaleRetention(model.config, layer_idx) # Copy weights if hasattr(old_attn, 'q_proj'): try: if use_hierarchical: target = new_retention.base_retention else: target = new_retention q_match = old_attn.q_proj.weight.shape == target.q_proj.weight.shape k_match = old_attn.k_proj.weight.shape == target.k_proj.weight.shape v_match = old_attn.v_proj.weight.shape == target.v_proj.weight.shape o_match = old_attn.o_proj.weight.shape == target.o_proj.weight.shape if q_match and k_match and v_match and o_match: target.q_proj.weight.data = old_attn.q_proj.weight.data.clone() target.k_proj.weight.data = old_attn.k_proj.weight.data.clone() target.v_proj.weight.data = old_attn.v_proj.weight.data.clone() target.o_proj.weight.data = old_attn.o_proj.weight.data.clone() print(f" โœ… Layer {layer_idx}: Perfect match") elif q_match and o_match: target.q_proj.weight.data = old_attn.q_proj.weight.data.clone() target.o_proj.weight.data = old_attn.o_proj.weight.data.clone() k_copy_size = min(old_attn.k_proj.weight.shape[0], target.k_proj.weight.shape[0]) v_copy_size = min(old_attn.v_proj.weight.shape[0], target.v_proj.weight.shape[0]) target.k_proj.weight.data[:k_copy_size] = old_attn.k_proj.weight.data[:k_copy_size].clone() target.v_proj.weight.data[:v_copy_size] = old_attn.v_proj.weight.data[:v_copy_size].clone() print(f" โœ… Layer {layer_idx}: Partial (GQA)") else: nn.init.xavier_uniform_(target.q_proj.weight) nn.init.xavier_uniform_(target.k_proj.weight) nn.init.xavier_uniform_(target.v_proj.weight) nn.init.xavier_uniform_(target.o_proj.weight) print(f" โš ๏ธ Layer {layer_idx}: Xavier init") except Exception as e: print(f" โš ๏ธ Layer {layer_idx}: Weight copy failed - {e}") layer.self_attn = new_retention replaced_count += 1 except Exception as e: print(f" โŒ Layer {layer_idx}: Failed - {e}") continue print(f"\nโœ… Conversion complete: {replaced_count}/{total_layers} layers") return model, replaced_count, total_layers # ===================================================== # ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค # ===================================================== class ExperimentDatabase: """SQLite database""" def __init__(self, db_path: str): self.db_path = db_path self.init_database() def init_database(self): with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() cursor.execute(""" CREATE TABLE IF NOT EXISTS experiments ( id INTEGER PRIMARY KEY AUTOINCREMENT, model_type TEXT NOT NULL, sequence_length INTEGER, use_hierarchical BOOLEAN, attention_replaced BOOLEAN, layers_converted INTEGER, total_layers INTEGER, elapsed_time REAL, memory_mb REAL, throughput REAL, config_json TEXT, metrics_json TEXT, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP ) """) # Burning history table cursor.execute(""" CREATE TABLE IF NOT EXISTS burning_history ( id INTEGER PRIMARY KEY AUTOINCREMENT, model_url TEXT NOT NULL, output_path TEXT NOT NULL, use_hierarchical BOOLEAN, dataset_used BOOLEAN, conversion_rate REAL, training_steps INTEGER, final_loss REAL, evaluation_score REAL, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP ) """) conn.commit() def save_experiment(self, config: Dict, metrics: Dict) -> int: with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() cursor.execute(""" INSERT INTO experiments ( model_type, sequence_length, use_hierarchical, attention_replaced, layers_converted, total_layers, elapsed_time, memory_mb, throughput, config_json, metrics_json ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( config.get('model_type'), config.get('sequence_length'), config.get('use_hierarchical'), config.get('attention_replaced'), config.get('layers_converted'), config.get('total_layers'), metrics.get('elapsed_time'), metrics.get('memory_mb'), metrics.get('throughput'), json.dumps(config), json.dumps(metrics) )) conn.commit() return cursor.lastrowid def save_burning(self, burning_info: Dict) -> int: with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() cursor.execute(""" INSERT INTO burning_history ( model_url, output_path, use_hierarchical, dataset_used, conversion_rate, training_steps, final_loss, evaluation_score ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, ( burning_info.get('model_url'), burning_info.get('output_path'), burning_info.get('use_hierarchical'), burning_info.get('dataset_used'), burning_info.get('conversion_rate'), burning_info.get('training_steps', 0), burning_info.get('final_loss'), burning_info.get('evaluation_score'), )) conn.commit() return cursor.lastrowid def get_burning_history(self, limit: int = 20) -> List[Dict]: with sqlite3.connect(self.db_path) as conn: conn.row_factory = sqlite3.Row cursor = conn.cursor() cursor.execute("SELECT * FROM burning_history ORDER BY timestamp DESC LIMIT ?", (limit,)) return [dict(row) for row in cursor.fetchall()] # ===================================================== # ๋ชจ๋ธ ๋ฒ„๋‹ (Zero-shot + Optional Fine-tuning) # ===================================================== def evaluate_model_quality(model, tokenizer, test_prompts=None): """ ๊ฐ„๋‹จํ•œ ๋ชจ๋ธ ํ’ˆ์งˆ ํ‰๊ฐ€ Returns: score: 0.0 ~ 1.0 (๋†’์„์ˆ˜๋ก ์ข‹์Œ) """ if test_prompts is None: test_prompts = [ "The capital of France is", "In machine learning, overfitting means", "2 + 2 =", ] model.eval() scores = [] with torch.no_grad(): for prompt in test_prompts: try: inputs = tokenizer(prompt, return_tensors="pt").to(model.device) outputs = model.generate( **inputs, max_new_tokens=20, do_sample=False, pad_token_id=tokenizer.eos_token_id, ) generated = tokenizer.decode(outputs[0], skip_special_tokens=True) # ๊ฐ„๋‹จํ•œ ํ’ˆ์งˆ ์ฒดํฌ score = 0.0 if len(generated) > len(prompt): # ๋ญ”๊ฐ€ ์ƒ์„ฑ๋จ score += 0.3 if not any(char in generated[len(prompt):] for char in ['๏ฟฝ', '[UNK]']): # ๊นจ์ง„ ๋ฌธ์ž ์—†์Œ score += 0.3 if len(generated.split()) > len(prompt.split()) + 2: # ์˜๋ฏธ์žˆ๋Š” ๋‹จ์–ด ์ƒ์„ฑ score += 0.4 scores.append(score) except Exception as e: print(f" โš ๏ธ Evaluation error for '{prompt}': {e}") scores.append(0.0) return sum(scores) / len(scores) if scores else 0.0 def burn_model_zero_shot( model_url: str, output_dir: str, use_hierarchical: bool = True, test_prompts: List[str] = None, ): """ Zero-shot Model Burning (๋ฐ์ดํ„ฐ์…‹ ๋ถˆํ•„์š”) 1. ๋ชจ๋ธ ๋กœ๋“œ 2. Attention โ†’ Retention ๋ณ€ํ™˜ 3. ํ’ˆ์งˆ ํ‰๊ฐ€ 4. ์ €์žฅ Returns: status, model_path, metrics """ print("="*80) print("๐Ÿ”ฅ PHOENIX Zero-shot Model Burning") print("="*80) output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) try: # 1. Load model print(f"\n๐Ÿ“ฅ Loading model: {model_url}") start_time = time.time() config = AutoConfig.from_pretrained(model_url, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_url, trust_remote_code=True, torch_dtype=torch.float16, ).to(DEVICE) tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token load_time = time.time() - start_time print(f"โœ… Loaded in {load_time:.1f}s") # 2. Convert print(f"\n๐Ÿ”„ Converting Attention โ†’ Retention...") convert_start = time.time() model.model, converted, total = replace_attention_with_retention( model.model, use_hierarchical=use_hierarchical ) convert_time = time.time() - convert_start conversion_rate = converted / total if total > 0 else 0 print(f"โœ… Converted {converted}/{total} layers ({conversion_rate*100:.1f}%) in {convert_time:.1f}s") # 3. Evaluate print(f"\n๐Ÿ“Š Evaluating model quality...") eval_start = time.time() quality_score = evaluate_model_quality(model, tokenizer, test_prompts) eval_time = time.time() - eval_start print(f"โœ… Quality Score: {quality_score:.2f}/1.00 (in {eval_time:.1f}s)") # 4. Save print(f"\n๐Ÿ’พ Saving PHOENIX model...") save_start = time.time() model.save_pretrained(output_path) tokenizer.save_pretrained(output_path) # Save metadata metadata = { 'phoenix_version': '1.0.0', 'original_model': model_url, 'use_hierarchical': use_hierarchical, 'conversion_rate': conversion_rate, 'layers_converted': converted, 'total_layers': total, 'quality_score': quality_score, 'burning_type': 'zero_shot', 'timestamp': datetime.now().isoformat(), } with open(output_path / 'phoenix_metadata.json', 'w') as f: json.dump(metadata, f, indent=2) save_time = time.time() - save_start print(f"โœ… Saved to {output_path} in {save_time:.1f}s") # Total time total_time = time.time() - start_time result = { 'status': 'success', 'model_path': str(output_path), 'conversion_rate': conversion_rate, 'quality_score': quality_score, 'total_time': total_time, 'load_time': load_time, 'convert_time': convert_time, 'eval_time': eval_time, 'save_time': save_time, } print(f"\n{'='*80}") print(f"โœ… Zero-shot Burning Complete!") print(f" Total Time: {total_time:.1f}s") print(f" Model Path: {output_path}") print(f" Quality: {quality_score:.2f}/1.00") print(f"{'='*80}\n") return result except Exception as e: import traceback error_msg = traceback.format_exc() print(f"\nโŒ Zero-shot burning failed:\n{error_msg}") return { 'status': 'failed', 'error': str(e), 'traceback': error_msg } def burn_model_with_finetuning( model_url: str, output_dir: str, dataset_path: str, use_hierarchical: bool = True, num_epochs: int = 1, batch_size: int = 4, learning_rate: float = 5e-5, max_steps: int = 100, ): """ Fine-tuning Model Burning (๋ฐ์ดํ„ฐ์…‹ ๊ธฐ๋ฐ˜) 1. ๋ชจ๋ธ ๋กœ๋“œ & ๋ณ€ํ™˜ 2. ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ 3. Fine-tuning 4. ํ‰๊ฐ€ & ์ €์žฅ Returns: status, model_path, metrics """ print("="*80) print("๐Ÿ”ฅ PHOENIX Fine-tuning Model Burning") print("="*80) output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) try: # 1. Load & Convert print(f"\n๐Ÿ“ฅ Loading model: {model_url}") config = AutoConfig.from_pretrained(model_url, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_url, trust_remote_code=True, torch_dtype=torch.float16, ).to(DEVICE) tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print(f"\n๐Ÿ”„ Converting...") model.model, converted, total = replace_attention_with_retention( model.model, use_hierarchical=use_hierarchical ) conversion_rate = converted / total if total > 0 else 0 print(f"โœ… Converted {converted}/{total} layers") # 2. Load dataset print(f"\n๐Ÿ“Š Loading dataset: {dataset_path}") if dataset_path.endswith('.txt'): with open(dataset_path, 'r', encoding='utf-8') as f: texts = [line.strip() for line in f if line.strip()] # Simple tokenization def tokenize_fn(text): return tokenizer( text, truncation=True, max_length=512, padding='max_length', return_tensors='pt' ) tokenized_data = [tokenize_fn(text) for text in texts[:1000]] # Limit to 1000 else: # Try loading as HF dataset from datasets import load_dataset dataset = load_dataset('text', data_files=dataset_path) def tokenize_function(examples): return tokenizer( examples['text'], truncation=True, max_length=512, padding='max_length', ) dataset = dataset.map(tokenize_function, batched=True) tokenized_data = dataset['train'] print(f"โœ… Loaded {len(tokenized_data)} samples") # 3. Quick fine-tuning print(f"\n๐Ÿš€ Starting fine-tuning...") print(f" Epochs: {num_epochs}") print(f" Batch Size: {batch_size}") print(f" Max Steps: {max_steps}") model.train() optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) step = 0 total_loss = 0.0 for epoch in range(num_epochs): for i in range(0, len(tokenized_data), batch_size): if step >= max_steps: break batch = tokenized_data[i:i+batch_size] # Simple batch processing if isinstance(batch, list): input_ids = torch.stack([item['input_ids'].squeeze() for item in batch]).to(DEVICE) attention_mask = torch.stack([item['attention_mask'].squeeze() for item in batch]).to(DEVICE) else: input_ids = torch.tensor(batch['input_ids']).to(DEVICE) attention_mask = torch.tensor(batch['attention_mask']).to(DEVICE) outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids) loss = outputs.loss loss.backward() optimizer.step() optimizer.zero_grad() total_loss += loss.item() step += 1 if step % 10 == 0: avg_loss = total_loss / step print(f" Step {step}/{max_steps} - Loss: {avg_loss:.4f}") final_loss = total_loss / step if step > 0 else 0.0 print(f"โœ… Training complete - Final Loss: {final_loss:.4f}") # 4. Evaluate & Save print(f"\n๐Ÿ“Š Evaluating...") model.eval() quality_score = evaluate_model_quality(model, tokenizer) print(f"โœ… Quality Score: {quality_score:.2f}/1.00") print(f"\n๐Ÿ’พ Saving model...") model.save_pretrained(output_path) tokenizer.save_pretrained(output_path) metadata = { 'phoenix_version': '1.0.0', 'original_model': model_url, 'use_hierarchical': use_hierarchical, 'conversion_rate': conversion_rate, 'quality_score': quality_score, 'burning_type': 'fine_tuning', 'training_steps': step, 'final_loss': final_loss, 'dataset': dataset_path, 'timestamp': datetime.now().isoformat(), } with open(output_path / 'phoenix_metadata.json', 'w') as f: json.dump(metadata, f, indent=2) print(f"โœ… Saved to {output_path}") result = { 'status': 'success', 'model_path': str(output_path), 'conversion_rate': conversion_rate, 'quality_score': quality_score, 'training_steps': step, 'final_loss': final_loss, } print(f"\n{'='*80}") print(f"โœ… Fine-tuning Burning Complete!") print(f"{'='*80}\n") return result except Exception as e: import traceback error_msg = traceback.format_exc() print(f"\nโŒ Fine-tuning burning failed:\n{error_msg}") return { 'status': 'failed', 'error': str(e), 'traceback': error_msg } # ===================================================== # Gradio UI Functions # ===================================================== def convert_model_to_phoenix(model_url, use_hierarchical=True, gpu_type="L40S"): """Convert model to PHOENIX (๊ธฐ์กด ํ•จ์ˆ˜ ์œ ์ง€)""" try: start_time = time.time() print(f"๐Ÿ“ฅ Loading model: {model_url}") config = AutoConfig.from_pretrained(model_url, trust_remote_code=True) model = AutoModel.from_pretrained( model_url, trust_remote_code=True, torch_dtype=torch.float16 ).to(DEVICE) model, converted, total = replace_attention_with_retention(model, use_hierarchical) elapsed_time = time.time() - start_time conversion_pct = (converted / total * 100) if total > 0 else 0 result = f""" โœ… **Conversion Complete!** **Model**: {model_url} **Converted**: {converted}/{total} layers ({conversion_pct:.1f}%) **Time**: {elapsed_time:.1f}s **GPU**: {gpu_type} ๐ŸŽฏ GQA-aware O(n) complexity! """ return result except Exception as e: return f"โŒ Conversion failed: {str(e)}" def generate_text_phoenix( model_url, use_hierarchical, convert_attention, prompt, max_new_tokens, temperature ): """PHOENIX ํ…์ŠคํŠธ ์ƒ์„ฑ (๊ธฐ์กด ํ•จ์ˆ˜ - ๊ฐ„์†Œํ™”)""" try: if not convert_attention or not model_url.strip(): return "โš ๏ธ Enable 'Attention Replace' and provide model URL", "" print(f"๐Ÿ“ฅ Loading model: {model_url}") model = AutoModelForCausalLM.from_pretrained( model_url, trust_remote_code=True, torch_dtype=torch.float16 ).to(DEVICE) print(f"๐Ÿ”„ Converting...") model.model, converted, total = replace_attention_with_retention( model.model, use_hierarchical=use_hierarchical ) tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) print(f"๐Ÿš€ Generating...") start_time = time.time() outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=temperature > 0.01, pad_token_id=tokenizer.eos_token_id, ) elapsed = time.time() - start_time generated = tokenizer.decode(outputs[0], skip_special_tokens=True) output_md = f""" ## ๐Ÿ“ Generated Text ``` {generated} ``` """ stats_md = f""" ## ๐Ÿ“Š Statistics - **Time**: {elapsed:.2f}s - **Converted**: {converted}/{total} layers - **Tokens/s**: {max_new_tokens/elapsed:.1f} """ return output_md, stats_md except Exception as e: import traceback return f"โŒ Error:\n```\n{traceback.format_exc()}\n```", "" def burn_phoenix_model_ui( model_url, use_hierarchical, dataset_path, output_name, use_finetuning, num_epochs, batch_size, learning_rate, max_steps, ): """ Gradio UI์šฉ ๋ชจ๋ธ ๋ฒ„๋‹ ํ•จ์ˆ˜ """ try: if not model_url.strip(): return "โš ๏ธ Model URL required", None if not output_name.strip(): output_name = f"phoenix_{model_url.split('/')[-1]}_{int(time.time())}" output_dir = f"{MODELS_PATH}/{output_name}" # Dataset check has_dataset = dataset_path and dataset_path.strip() and Path(dataset_path).exists() if use_finetuning and not has_dataset: return "โš ๏ธ Fine-tuning requires dataset path", None # Choose burning method if use_finetuning and has_dataset: result = burn_model_with_finetuning( model_url=model_url, output_dir=output_dir, dataset_path=dataset_path, use_hierarchical=use_hierarchical, num_epochs=num_epochs, batch_size=batch_size, learning_rate=learning_rate, max_steps=max_steps, ) else: result = burn_model_zero_shot( model_url=model_url, output_dir=output_dir, use_hierarchical=use_hierarchical, ) if result['status'] == 'success': # Save to database burning_info = { 'model_url': model_url, 'output_path': result['model_path'], 'use_hierarchical': use_hierarchical, 'dataset_used': has_dataset, 'conversion_rate': result.get('conversion_rate', 0.0), 'training_steps': result.get('training_steps', 0), 'final_loss': result.get('final_loss'), 'evaluation_score': result.get('quality_score', 0.0), } db.save_burning(burning_info) # Format output output_md = f""" # ๐Ÿ”ฅ Model Burning Complete! ## ๐Ÿ“ฆ Model Information - **Original**: {model_url} - **Output**: `{result['model_path']}` - **Type**: {'Fine-tuning' if has_dataset else 'Zero-shot'} ## ๐Ÿ“Š Metrics - **Conversion Rate**: {result['conversion_rate']*100:.1f}% - **Quality Score**: {result.get('quality_score', 0.0):.2f}/1.00 """ if 'training_steps' in result: output_md += f""" ## ๐Ÿš€ Training - **Steps**: {result['training_steps']} - **Final Loss**: {result.get('final_loss', 0.0):.4f} """ output_md += f""" ## โฑ๏ธ Time Breakdown - **Total**: {result.get('total_time', 0):.1f}s """ if 'load_time' in result: output_md += f"- **Load**: {result['load_time']:.1f}s\n" output_md += f"- **Convert**: {result['convert_time']:.1f}s\n" output_md += f"- **Evaluate**: {result['eval_time']:.1f}s\n" output_md += f"- **Save**: {result['save_time']:.1f}s\n" output_md += f""" ## ๐ŸŽฏ Usage ```python from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("{result['model_path']}") tokenizer = AutoTokenizer.from_pretrained("{result['model_path']}") inputs = tokenizer("Your prompt", return_tensors="pt") outputs = model.generate(**inputs, max_new_tokens=50) print(tokenizer.decode(outputs[0])) ``` โœ… **PHOENIX Model Ready!** """ # Create simple plot fig = go.Figure() fig.add_trace(go.Bar( x=['Conversion', 'Quality'], y=[result['conversion_rate'], result.get('quality_score', 0.0)], text=[f"{result['conversion_rate']*100:.1f}%", f"{result.get('quality_score', 0.0):.2f}"], textposition='auto', )) fig.update_layout( title="Burning Metrics", yaxis_range=[0, 1], template='plotly_white' ) return output_md, fig else: return f"โŒ Burning failed:\n```\n{result.get('error', 'Unknown error')}\n```", None except Exception as e: import traceback return f"โŒ Error:\n```\n{traceback.format_exc()}\n```", None def view_burning_history(): """View burning history""" try: history = db.get_burning_history(limit=20) if not history: return "๐Ÿ“ญ No burning history yet", None df = pd.DataFrame(history) fig = px.scatter( df, x='timestamp', y='evaluation_score', size='conversion_rate', color='dataset_used', hover_data=['model_url', 'output_path'], title='Burning History' ) cols = ['id', 'model_url', 'output_path', 'conversion_rate', 'evaluation_score', 'training_steps', 'timestamp'] available = [c for c in cols if c in df.columns] return f"## ๐Ÿ“Š Burning History\n\n{df[available].to_markdown(index=False)}", fig except Exception as e: return f"โŒ Error: {e}", None # ์ „์—ญ ์ดˆ๊ธฐํ™” db = ExperimentDatabase(DB_PATH) CONVERTED_MODELS = {} # ===================================================== # Gradio UI # ===================================================== with gr.Blocks( title="๐Ÿ”ฎ PHOENIX - Model Burning Platform", theme=gr.themes.Soft(), ) as demo: gr.Markdown(""" # ๐Ÿ”ฎ PHOENIX Retention Platform **Zero-shot Model Burning + Optional Fine-tuning** โœ… Zero-shot Conversion (๋ฐ์ดํ„ฐ์…‹ ๋ถˆํ•„์š”!) โœ… Optional Fine-tuning (๋ฐ์ดํ„ฐ์…‹ ๊ธฐ๋ฐ˜) โœ… GQA Support โœ… O(n) Complexity --- """) with gr.Tabs(): with gr.Tab("๐Ÿ”„ Quick Convert"): gr.Markdown(""" ### ๋น ๋ฅธ ๋ณ€ํ™˜ ํ…Œ์ŠคํŠธ ๋ชจ๋ธ์„ ๋กœ๋“œํ•˜๊ณ  Attention โ†’ Retention ๋ณ€ํ™˜๋งŒ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค. (์ €์žฅ ์•ˆ ํ•จ) """) with gr.Row(): with gr.Column(scale=1): convert_url = gr.Textbox( label="๐Ÿ”— Model URL", value=DEFAULT_MODEL, placeholder="ibm-granite/granite-4.0-h-350m" ) convert_hierarchical = gr.Checkbox(value=True, label="Hierarchical Retention") convert_gpu = gr.Radio(choices=["L40S", "H100"], value="L40S", label="GPU") convert_btn = gr.Button("๐Ÿ”„ Convert", variant="primary") with gr.Column(scale=2): convert_output = gr.Markdown() convert_btn.click( convert_model_to_phoenix, [convert_url, convert_hierarchical, convert_gpu], [convert_output] ) with gr.Tab("๐Ÿ”ฅ Model Burning"): gr.Markdown(""" ### ๐Ÿ”ฅ PHOENIX Model Burning **๋ชจ๋ธ์„ ๋ณ€ํ™˜ํ•˜๊ณ  ์ €์žฅํ•ฉ๋‹ˆ๋‹ค!** - **Zero-shot**: ๋ฐ์ดํ„ฐ์…‹ ์—†์ด ๋ณ€ํ™˜๋งŒ ์ˆ˜ํ–‰ (๋น ๋ฆ„!) - **Fine-tuning**: ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ์ถ”๊ฐ€ ํ•™์Šต (์„ฑ๋Šฅ ํ–ฅ์ƒ) """) with gr.Row(): with gr.Column(scale=1): burn_model_url = gr.Textbox( label="๐Ÿ”— Model URL", value=DEFAULT_MODEL, placeholder="ibm-granite/granite-4.0-h-350m" ) burn_hierarchical = gr.Checkbox(value=True, label="Hierarchical Retention") burn_output_name = gr.Textbox( label="๐Ÿ’พ Output Name", placeholder="phoenix_my_model (auto-generated if empty)" ) gr.Markdown("---") gr.Markdown("### ๐Ÿ“Š Dataset (Optional)") burn_dataset = gr.Textbox( label="๐Ÿ“ Dataset Path (Optional)", placeholder="/path/to/dataset.txt (leave empty for zero-shot)", value="" ) burn_use_finetuning = gr.Checkbox( value=False, label="๐Ÿš€ Enable Fine-tuning (requires dataset)" ) with gr.Accordion("โš™๏ธ Fine-tuning Config", open=False): burn_epochs = gr.Slider(1, 5, 1, step=1, label="Epochs") burn_batch = gr.Slider(1, 16, 4, step=1, label="Batch Size") burn_lr = gr.Number(value=5e-5, label="Learning Rate") burn_max_steps = gr.Slider(10, 500, 100, step=10, label="Max Steps") burn_btn = gr.Button("๐Ÿ”ฅ Burn Model", variant="primary", size="lg") with gr.Column(scale=2): burn_output = gr.Markdown() burn_plot = gr.Plot() burn_btn.click( burn_phoenix_model_ui, [ burn_model_url, burn_hierarchical, burn_dataset, burn_output_name, burn_use_finetuning, burn_epochs, burn_batch, burn_lr, burn_max_steps, ], [burn_output, burn_plot] ) with gr.Tab("๐Ÿ’ฌ Text Generation"): gr.Markdown(""" ### PHOENIX ํ…์ŠคํŠธ ์ƒ์„ฑ ๋ณ€ํ™˜๋œ ๋ชจ๋ธ๋กœ ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. """) with gr.Row(): with gr.Column(scale=1): gen_model_url = gr.Textbox(label="๐Ÿ”— Model URL", value=DEFAULT_MODEL) gen_hierarchical = gr.Checkbox(value=True, label="Hierarchical") gen_convert = gr.Checkbox(value=True, label="Enable Conversion") gen_prompt = gr.Textbox( label="๐Ÿ“ Prompt", lines=3, value="The future of AI is" ) gen_max_tokens = gr.Slider(16, 256, 64, step=16, label="Max Tokens") gen_temperature = gr.Slider(0.1, 2.0, 0.7, step=0.1, label="Temperature") gen_btn = gr.Button("๐Ÿš€ Generate", variant="primary") with gr.Column(scale=2): gen_output = gr.Markdown() gen_stats = gr.Markdown() gen_btn.click( generate_text_phoenix, [gen_model_url, gen_hierarchical, gen_convert, gen_prompt, gen_max_tokens, gen_temperature], [gen_output, gen_stats] ) with gr.Tab("๐Ÿ“Š Burning History"): gr.Markdown(""" ### ๐Ÿ“Š Model Burning History ์ €์žฅ๋œ ๋ชจ๋ธ ๋ฒ„๋‹ ๊ธฐ๋ก์„ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค. """) with gr.Row(): with gr.Column(scale=1): hist_btn = gr.Button("๐Ÿ“Š Load History", variant="primary") with gr.Column(scale=2): hist_output = gr.Markdown() hist_plot = gr.Plot() hist_btn.click(view_burning_history, outputs=[hist_output, hist_plot]) gr.Markdown(""" --- ## ๐Ÿ”ฅ PHOENIX Model Burning ### Zero-shot (๋ฐ์ดํ„ฐ์…‹ ๋ถˆํ•„์š”!) 1. ๋ชจ๋ธ URL ์ž…๋ ฅ 2. "Burn Model" ํด๋ฆญ 3. ์™„๋ฃŒ! โ†’ `/data/phoenix_models/` ์— ์ €์žฅ ### Fine-tuning (์„ ํƒ์‚ฌํ•ญ) 1. Dataset Path ์ž…๋ ฅ 2. "Enable Fine-tuning" ์ฒดํฌ 3. "Burn Model" ํด๋ฆญ **VIDraft AI Research Lab** | PHOENIX v1.0 """) if __name__ == "__main__": demo.queue(max_size=20) demo.launch(server_name="0.0.0.0", server_port=7860, share=False)