phoenix / app.py
seawolf2357's picture
Update app.py
ef87883 verified
raw
history blame
50.2 kB
"""
๐Ÿ”ฎ PHOENIX Retention Research Platform - FINAL INTEGRATED VERSION
Zero-shot Model Burning + Optional Fine-tuning
โœ… Zero-shot Conversion (No Dataset Required)
โœ… Optional Fine-tuning (Dataset-based)
โœ… GQA Support
โœ… HuggingFace Hub Integration
โœ… Comprehensive Evaluation
VIDraft AI Research Lab
"""
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import sqlite3
import json
import time
import numpy as np
from datetime import datetime
from pathlib import Path
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
from typing import Dict, List, Any, Tuple, Optional
import chromadb
from chromadb.config import Settings
from transformers import (
AutoModel, AutoTokenizer, AutoConfig, AutoModelForCausalLM,
get_cosine_schedule_with_warmup, TrainingArguments, Trainer
)
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from accelerate import Accelerator
from tqdm import tqdm
import copy
import shutil
# =====================================================
# ์ „์—ญ ์„ค์ •
# =====================================================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
STORAGE_PATH = "/data"
DB_PATH = f"{STORAGE_PATH}/phoenix_experiments.db"
VECTOR_DB_PATH = f"{STORAGE_PATH}/vector_store"
MODELS_PATH = f"{STORAGE_PATH}/phoenix_models"
DEFAULT_MODEL = "ibm-granite/granite-4.0-h-350m"
Path(STORAGE_PATH).mkdir(parents=True, exist_ok=True)
Path(VECTOR_DB_PATH).mkdir(parents=True, exist_ok=True)
Path(MODELS_PATH).mkdir(parents=True, exist_ok=True)
print(f"๐Ÿš€ PHOENIX Platform initialized on {DEVICE}")
print(f"๐Ÿ’พ Storage: {STORAGE_PATH}")
print(f"๐ŸŽฏ Default Base Model: {DEFAULT_MODEL}")
# =====================================================
# PHOENIX Retention with GQA Support
# =====================================================
class MultiScaleRetention(nn.Module):
"""์ง„์งœ Retention Attention with GQA Support"""
def __init__(self, config, layer_idx=0):
super().__init__()
self.config = config
self.layer_idx = layer_idx
# Q dimensions
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
# K/V dimensions (GQA)
if hasattr(config, 'num_key_value_heads'):
self.num_key_value_heads = config.num_key_value_heads
else:
self.num_key_value_heads = self.num_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.kv_head_dim = self.head_dim
self.kv_dim = self.num_key_value_heads * self.kv_head_dim
# Internal state storage for KV cache simulation
self.register_buffer('_internal_state', None, persistent=False)
self.register_buffer('_state_initialized', torch.tensor(False), persistent=False)
# Projections with correct dimensions
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
# Retention parameters
decay_values = torch.linspace(0.95, 0.99, self.num_heads)
self.decay = nn.Parameter(decay_values, requires_grad=True)
# Group norm
self.group_norm = nn.GroupNorm(
num_groups=self.num_heads,
num_channels=self.hidden_size
)
def _repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""Repeat K/V heads to match Q heads (GQA)"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def reset_state(self):
"""Reset internal state"""
self._internal_state = None
self._state_initialized = torch.tensor(False)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.Tensor]] = None,
**kwargs
):
"""O(n) Retention with GQA support"""
batch_size, seq_len, _ = hidden_states.shape
if past_key_values is not None:
past_key_value = past_key_values
# Q, K, V projections
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Reshape
query_states = query_states.view(
batch_size, seq_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
batch_size, seq_len, self.num_key_value_heads, self.kv_head_dim
).transpose(1, 2)
value_states = value_states.view(
batch_size, seq_len, self.num_key_value_heads, self.kv_head_dim
).transpose(1, 2)
# Repeat K/V to match Q heads (GQA)
key_states = self._repeat_kv(key_states, self.num_key_value_groups)
value_states = self._repeat_kv(value_states, self.num_key_value_groups)
# Retention computation
past_state = self._internal_state if (use_cache and self._state_initialized) else None
retention_states, new_state = self._compute_retention(
query_states, key_states, value_states, past_state
)
# Store state internally
if use_cache:
self._internal_state = new_state.detach()
self._state_initialized = torch.tensor(True)
# Reshape back
retention_states = retention_states.transpose(1, 2).contiguous()
retention_states = retention_states.reshape(
batch_size, seq_len, self.hidden_size
)
# Group norm
if not next(self.group_norm.parameters()).is_cuda and retention_states.is_cuda:
self.group_norm = self.group_norm.to(retention_states.device, dtype=retention_states.dtype)
elif next(self.group_norm.parameters()).dtype != retention_states.dtype:
self.group_norm = self.group_norm.to(dtype=retention_states.dtype)
retention_states = self.group_norm(
retention_states.transpose(1, 2)
).transpose(1, 2)
retention_states = torch.clamp(retention_states, min=-10.0, max=10.0)
# Output projection
attn_output = self.o_proj(retention_states)
return (attn_output, None)
def _compute_retention(
self,
queries: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
past_state: Optional[torch.Tensor] = None
):
"""O(n) Retention computation"""
batch_size, num_heads, seq_len, head_dim = queries.shape
if past_state is not None:
state = past_state.to(queries.device, dtype=queries.dtype)
else:
state = torch.zeros(
batch_size, num_heads, head_dim, head_dim,
dtype=queries.dtype,
device=queries.device
) + 1e-6
outputs = []
decay = torch.sigmoid(self.decay).view(1, -1, 1, 1).to(
device=queries.device,
dtype=queries.dtype
)
for t in range(seq_len):
q_t = queries[:, :, t, :]
k_t = keys[:, :, t, :]
v_t = values[:, :, t, :]
state = decay * state
kv_update = torch.einsum('bhd,bhe->bhde', k_t, v_t)
kv_update = torch.clamp(kv_update, min=-5.0, max=5.0)
state = state + kv_update
state = torch.clamp(state, min=-10.0, max=10.0)
output_t = torch.einsum('bhd,bhde->bhe', q_t, state)
outputs.append(output_t)
output = torch.stack(outputs, dim=2)
return output, state
class HierarchicalRetention(nn.Module):
"""PHOENIX Hierarchical Retention with GQA"""
def __init__(self, config, layer_idx=0):
super().__init__()
self.base_retention = MultiScaleRetention(config, layer_idx)
hidden_size = config.hidden_size
self.d_state = hidden_size // 2
self.short_proj = nn.Linear(hidden_size, self.d_state)
self.medium_proj = nn.Linear(self.d_state, self.d_state)
self.long_proj = nn.Linear(self.d_state, self.d_state * 2)
self.fusion = nn.Linear(self.d_state * 4, hidden_size)
self.short_decay = 0.5
self.medium_decay = 0.8
self.long_decay = 0.95
self.norm = nn.LayerNorm(hidden_size)
if next(self.base_retention.parameters()).is_cuda:
device = next(self.base_retention.parameters()).device
dtype = next(self.base_retention.parameters()).dtype
self.short_proj = self.short_proj.to(device, dtype=dtype)
self.medium_proj = self.medium_proj.to(device, dtype=dtype)
self.long_proj = self.long_proj.to(device, dtype=dtype)
self.fusion = self.fusion.to(device, dtype=dtype)
self.norm = self.norm.to(device, dtype=dtype)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.Tensor]] = None,
**kwargs
):
"""Hierarchical forward pass"""
batch_size, seq_len, hidden_size = hidden_states.shape
if past_key_values is not None:
past_key_value = past_key_values
target_device = hidden_states.device
target_dtype = hidden_states.dtype
if not next(self.short_proj.parameters()).is_cuda and hidden_states.is_cuda:
self.short_proj = self.short_proj.to(target_device, dtype=target_dtype)
self.medium_proj = self.medium_proj.to(target_device, dtype=target_dtype)
self.long_proj = self.long_proj.to(target_device, dtype=target_dtype)
self.fusion = self.fusion.to(target_device, dtype=target_dtype)
self.norm = self.norm.to(target_device, dtype=target_dtype)
elif next(self.short_proj.parameters()).dtype != target_dtype:
self.short_proj = self.short_proj.to(dtype=target_dtype)
self.medium_proj = self.medium_proj.to(dtype=target_dtype)
self.long_proj = self.long_proj.to(dtype=target_dtype)
self.fusion = self.fusion.to(dtype=target_dtype)
self.norm = self.norm.to(dtype=target_dtype)
base_result = self.base_retention(
hidden_states, attention_mask, position_ids,
past_key_value, output_attentions, use_cache
)
retention_output = base_result[0]
# Hierarchical states
short_state = torch.zeros(batch_size, self.d_state, dtype=hidden_states.dtype, device=target_device)
medium_state = torch.zeros(batch_size, self.d_state, dtype=hidden_states.dtype, device=target_device)
long_state = torch.zeros(batch_size, self.d_state * 2, dtype=hidden_states.dtype, device=target_device)
hierarchical_outputs = []
for t in range(seq_len):
x_t = retention_output[:, t, :]
short_input = self.short_proj(x_t)
short_state = self.short_decay * short_state + short_input
if t % 8 == 0:
medium_state = self.medium_decay * medium_state + \
self.medium_proj(short_state)
if t % 64 == 0:
long_state = self.long_decay * long_state + \
self.long_proj(medium_state)
combined = torch.cat([short_state, medium_state, long_state], dim=-1)
output_t = self.fusion(combined)
hierarchical_outputs.append(output_t)
output = torch.stack(hierarchical_outputs, dim=1)
output = self.norm(output)
return (output, None)
# =====================================================
# ๋ชจ๋ธ ๋ณ€ํ™˜ ํ•จ์ˆ˜
# =====================================================
def replace_attention_with_retention(model, use_hierarchical=True):
"""Transformer Attention โ†’ PHOENIX Retention (GQA Support)"""
print("๐Ÿ”„ Starting Attention โ†’ Retention conversion (GQA support)...")
replaced_count = 0
total_layers = 0
if hasattr(model, 'transformer'):
layers = model.transformer.h
elif hasattr(model, 'model') and hasattr(model.model, 'layers'):
layers = model.model.layers
elif hasattr(model, 'layers'):
layers = model.layers
else:
print("โš ๏ธ Unknown model structure")
return model, 0, 0
total_layers = len(layers)
# Check first layer for GQA
first_layer = layers[0]
if hasattr(first_layer, 'self_attn'):
old_attn = first_layer.self_attn
if hasattr(old_attn, 'q_proj'):
q_shape = old_attn.q_proj.weight.shape
k_shape = old_attn.k_proj.weight.shape
if k_shape[0] != q_shape[0]:
print(f" โœ… GQA detected! (K/V dim: {k_shape[0]} < Q dim: {q_shape[0]})")
if not hasattr(model.config, 'num_key_value_heads'):
num_kv_heads = k_shape[0] // (model.config.hidden_size // model.config.num_attention_heads)
model.config.num_key_value_heads = num_kv_heads
for layer_idx, layer in enumerate(layers):
try:
if hasattr(layer, 'self_attn'):
old_attn = layer.self_attn
if use_hierarchical:
new_retention = HierarchicalRetention(model.config, layer_idx)
else:
new_retention = MultiScaleRetention(model.config, layer_idx)
# Copy weights
if hasattr(old_attn, 'q_proj'):
try:
if use_hierarchical:
target = new_retention.base_retention
else:
target = new_retention
q_match = old_attn.q_proj.weight.shape == target.q_proj.weight.shape
k_match = old_attn.k_proj.weight.shape == target.k_proj.weight.shape
v_match = old_attn.v_proj.weight.shape == target.v_proj.weight.shape
o_match = old_attn.o_proj.weight.shape == target.o_proj.weight.shape
if q_match and k_match and v_match and o_match:
target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
target.k_proj.weight.data = old_attn.k_proj.weight.data.clone()
target.v_proj.weight.data = old_attn.v_proj.weight.data.clone()
target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
print(f" โœ… Layer {layer_idx}: Perfect match")
elif q_match and o_match:
target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
k_copy_size = min(old_attn.k_proj.weight.shape[0], target.k_proj.weight.shape[0])
v_copy_size = min(old_attn.v_proj.weight.shape[0], target.v_proj.weight.shape[0])
target.k_proj.weight.data[:k_copy_size] = old_attn.k_proj.weight.data[:k_copy_size].clone()
target.v_proj.weight.data[:v_copy_size] = old_attn.v_proj.weight.data[:v_copy_size].clone()
print(f" โœ… Layer {layer_idx}: Partial (GQA)")
else:
nn.init.xavier_uniform_(target.q_proj.weight)
nn.init.xavier_uniform_(target.k_proj.weight)
nn.init.xavier_uniform_(target.v_proj.weight)
nn.init.xavier_uniform_(target.o_proj.weight)
print(f" โš ๏ธ Layer {layer_idx}: Xavier init")
except Exception as e:
print(f" โš ๏ธ Layer {layer_idx}: Weight copy failed - {e}")
layer.self_attn = new_retention
replaced_count += 1
except Exception as e:
print(f" โŒ Layer {layer_idx}: Failed - {e}")
continue
print(f"\nโœ… Conversion complete: {replaced_count}/{total_layers} layers")
return model, replaced_count, total_layers
# =====================================================
# ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค
# =====================================================
class ExperimentDatabase:
"""SQLite database"""
def __init__(self, db_path: str):
self.db_path = db_path
self.init_database()
def init_database(self):
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS experiments (
id INTEGER PRIMARY KEY AUTOINCREMENT,
model_type TEXT NOT NULL,
sequence_length INTEGER,
use_hierarchical BOOLEAN,
attention_replaced BOOLEAN,
layers_converted INTEGER,
total_layers INTEGER,
elapsed_time REAL,
memory_mb REAL,
throughput REAL,
config_json TEXT,
metrics_json TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
""")
# Burning history table
cursor.execute("""
CREATE TABLE IF NOT EXISTS burning_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
model_url TEXT NOT NULL,
output_path TEXT NOT NULL,
use_hierarchical BOOLEAN,
dataset_used BOOLEAN,
conversion_rate REAL,
training_steps INTEGER,
final_loss REAL,
evaluation_score REAL,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
""")
conn.commit()
def save_experiment(self, config: Dict, metrics: Dict) -> int:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO experiments (
model_type, sequence_length, use_hierarchical,
attention_replaced, layers_converted, total_layers,
elapsed_time, memory_mb, throughput,
config_json, metrics_json
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
config.get('model_type'),
config.get('sequence_length'),
config.get('use_hierarchical'),
config.get('attention_replaced'),
config.get('layers_converted'),
config.get('total_layers'),
metrics.get('elapsed_time'),
metrics.get('memory_mb'),
metrics.get('throughput'),
json.dumps(config),
json.dumps(metrics)
))
conn.commit()
return cursor.lastrowid
def save_burning(self, burning_info: Dict) -> int:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT INTO burning_history (
model_url, output_path, use_hierarchical,
dataset_used, conversion_rate, training_steps,
final_loss, evaluation_score
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", (
burning_info.get('model_url'),
burning_info.get('output_path'),
burning_info.get('use_hierarchical'),
burning_info.get('dataset_used'),
burning_info.get('conversion_rate'),
burning_info.get('training_steps', 0),
burning_info.get('final_loss'),
burning_info.get('evaluation_score'),
))
conn.commit()
return cursor.lastrowid
def get_burning_history(self, limit: int = 20) -> List[Dict]:
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute("SELECT * FROM burning_history ORDER BY timestamp DESC LIMIT ?", (limit,))
return [dict(row) for row in cursor.fetchall()]
# =====================================================
# ๋ชจ๋ธ ๋ฒ„๋‹ (Zero-shot + Optional Fine-tuning)
# =====================================================
def evaluate_model_quality(model, tokenizer, test_prompts=None):
"""
๊ฐ„๋‹จํ•œ ๋ชจ๋ธ ํ’ˆ์งˆ ํ‰๊ฐ€
Returns:
score: 0.0 ~ 1.0 (๋†’์„์ˆ˜๋ก ์ข‹์Œ)
"""
if test_prompts is None:
test_prompts = [
"The capital of France is",
"In machine learning, overfitting means",
"2 + 2 =",
]
model.eval()
scores = []
with torch.no_grad():
for prompt in test_prompts:
try:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=20,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
# ๊ฐ„๋‹จํ•œ ํ’ˆ์งˆ ์ฒดํฌ
score = 0.0
if len(generated) > len(prompt): # ๋ญ”๊ฐ€ ์ƒ์„ฑ๋จ
score += 0.3
if not any(char in generated[len(prompt):] for char in ['๏ฟฝ', '[UNK]']): # ๊นจ์ง„ ๋ฌธ์ž ์—†์Œ
score += 0.3
if len(generated.split()) > len(prompt.split()) + 2: # ์˜๋ฏธ์žˆ๋Š” ๋‹จ์–ด ์ƒ์„ฑ
score += 0.4
scores.append(score)
except Exception as e:
print(f" โš ๏ธ Evaluation error for '{prompt}': {e}")
scores.append(0.0)
return sum(scores) / len(scores) if scores else 0.0
def burn_model_zero_shot(
model_url: str,
output_dir: str,
use_hierarchical: bool = True,
test_prompts: List[str] = None,
):
"""
Zero-shot Model Burning (๋ฐ์ดํ„ฐ์…‹ ๋ถˆํ•„์š”)
1. ๋ชจ๋ธ ๋กœ๋“œ
2. Attention โ†’ Retention ๋ณ€ํ™˜
3. ํ’ˆ์งˆ ํ‰๊ฐ€
4. ์ €์žฅ
Returns:
status, model_path, metrics
"""
print("="*80)
print("๐Ÿ”ฅ PHOENIX Zero-shot Model Burning")
print("="*80)
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
try:
# 1. Load model
print(f"\n๐Ÿ“ฅ Loading model: {model_url}")
start_time = time.time()
config = AutoConfig.from_pretrained(model_url, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_url,
trust_remote_code=True,
torch_dtype=torch.float16,
).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
load_time = time.time() - start_time
print(f"โœ… Loaded in {load_time:.1f}s")
# 2. Convert
print(f"\n๐Ÿ”„ Converting Attention โ†’ Retention...")
convert_start = time.time()
model.model, converted, total = replace_attention_with_retention(
model.model,
use_hierarchical=use_hierarchical
)
convert_time = time.time() - convert_start
conversion_rate = converted / total if total > 0 else 0
print(f"โœ… Converted {converted}/{total} layers ({conversion_rate*100:.1f}%) in {convert_time:.1f}s")
# 3. Evaluate
print(f"\n๐Ÿ“Š Evaluating model quality...")
eval_start = time.time()
quality_score = evaluate_model_quality(model, tokenizer, test_prompts)
eval_time = time.time() - eval_start
print(f"โœ… Quality Score: {quality_score:.2f}/1.00 (in {eval_time:.1f}s)")
# 4. Save
print(f"\n๐Ÿ’พ Saving PHOENIX model...")
save_start = time.time()
model.save_pretrained(output_path)
tokenizer.save_pretrained(output_path)
# Save metadata
metadata = {
'phoenix_version': '1.0.0',
'original_model': model_url,
'use_hierarchical': use_hierarchical,
'conversion_rate': conversion_rate,
'layers_converted': converted,
'total_layers': total,
'quality_score': quality_score,
'burning_type': 'zero_shot',
'timestamp': datetime.now().isoformat(),
}
with open(output_path / 'phoenix_metadata.json', 'w') as f:
json.dump(metadata, f, indent=2)
save_time = time.time() - save_start
print(f"โœ… Saved to {output_path} in {save_time:.1f}s")
# Total time
total_time = time.time() - start_time
result = {
'status': 'success',
'model_path': str(output_path),
'conversion_rate': conversion_rate,
'quality_score': quality_score,
'total_time': total_time,
'load_time': load_time,
'convert_time': convert_time,
'eval_time': eval_time,
'save_time': save_time,
}
print(f"\n{'='*80}")
print(f"โœ… Zero-shot Burning Complete!")
print(f" Total Time: {total_time:.1f}s")
print(f" Model Path: {output_path}")
print(f" Quality: {quality_score:.2f}/1.00")
print(f"{'='*80}\n")
return result
except Exception as e:
import traceback
error_msg = traceback.format_exc()
print(f"\nโŒ Zero-shot burning failed:\n{error_msg}")
return {
'status': 'failed',
'error': str(e),
'traceback': error_msg
}
def burn_model_with_finetuning(
model_url: str,
output_dir: str,
dataset_path: str,
use_hierarchical: bool = True,
num_epochs: int = 1,
batch_size: int = 4,
learning_rate: float = 5e-5,
max_steps: int = 100,
):
"""
Fine-tuning Model Burning (๋ฐ์ดํ„ฐ์…‹ ๊ธฐ๋ฐ˜)
1. ๋ชจ๋ธ ๋กœ๋“œ & ๋ณ€ํ™˜
2. ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
3. Fine-tuning
4. ํ‰๊ฐ€ & ์ €์žฅ
Returns:
status, model_path, metrics
"""
print("="*80)
print("๐Ÿ”ฅ PHOENIX Fine-tuning Model Burning")
print("="*80)
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
try:
# 1. Load & Convert
print(f"\n๐Ÿ“ฅ Loading model: {model_url}")
config = AutoConfig.from_pretrained(model_url, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_url,
trust_remote_code=True,
torch_dtype=torch.float16,
).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f"\n๐Ÿ”„ Converting...")
model.model, converted, total = replace_attention_with_retention(
model.model,
use_hierarchical=use_hierarchical
)
conversion_rate = converted / total if total > 0 else 0
print(f"โœ… Converted {converted}/{total} layers")
# 2. Load dataset
print(f"\n๐Ÿ“Š Loading dataset: {dataset_path}")
if dataset_path.endswith('.txt'):
with open(dataset_path, 'r', encoding='utf-8') as f:
texts = [line.strip() for line in f if line.strip()]
# Simple tokenization
def tokenize_fn(text):
return tokenizer(
text,
truncation=True,
max_length=512,
padding='max_length',
return_tensors='pt'
)
tokenized_data = [tokenize_fn(text) for text in texts[:1000]] # Limit to 1000
else:
# Try loading as HF dataset
from datasets import load_dataset
dataset = load_dataset('text', data_files=dataset_path)
def tokenize_function(examples):
return tokenizer(
examples['text'],
truncation=True,
max_length=512,
padding='max_length',
)
dataset = dataset.map(tokenize_function, batched=True)
tokenized_data = dataset['train']
print(f"โœ… Loaded {len(tokenized_data)} samples")
# 3. Quick fine-tuning
print(f"\n๐Ÿš€ Starting fine-tuning...")
print(f" Epochs: {num_epochs}")
print(f" Batch Size: {batch_size}")
print(f" Max Steps: {max_steps}")
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
step = 0
total_loss = 0.0
for epoch in range(num_epochs):
for i in range(0, len(tokenized_data), batch_size):
if step >= max_steps:
break
batch = tokenized_data[i:i+batch_size]
# Simple batch processing
if isinstance(batch, list):
input_ids = torch.stack([item['input_ids'].squeeze() for item in batch]).to(DEVICE)
attention_mask = torch.stack([item['attention_mask'].squeeze() for item in batch]).to(DEVICE)
else:
input_ids = torch.tensor(batch['input_ids']).to(DEVICE)
attention_mask = torch.tensor(batch['attention_mask']).to(DEVICE)
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
step += 1
if step % 10 == 0:
avg_loss = total_loss / step
print(f" Step {step}/{max_steps} - Loss: {avg_loss:.4f}")
final_loss = total_loss / step if step > 0 else 0.0
print(f"โœ… Training complete - Final Loss: {final_loss:.4f}")
# 4. Evaluate & Save
print(f"\n๐Ÿ“Š Evaluating...")
model.eval()
quality_score = evaluate_model_quality(model, tokenizer)
print(f"โœ… Quality Score: {quality_score:.2f}/1.00")
print(f"\n๐Ÿ’พ Saving model...")
model.save_pretrained(output_path)
tokenizer.save_pretrained(output_path)
metadata = {
'phoenix_version': '1.0.0',
'original_model': model_url,
'use_hierarchical': use_hierarchical,
'conversion_rate': conversion_rate,
'quality_score': quality_score,
'burning_type': 'fine_tuning',
'training_steps': step,
'final_loss': final_loss,
'dataset': dataset_path,
'timestamp': datetime.now().isoformat(),
}
with open(output_path / 'phoenix_metadata.json', 'w') as f:
json.dump(metadata, f, indent=2)
print(f"โœ… Saved to {output_path}")
result = {
'status': 'success',
'model_path': str(output_path),
'conversion_rate': conversion_rate,
'quality_score': quality_score,
'training_steps': step,
'final_loss': final_loss,
}
print(f"\n{'='*80}")
print(f"โœ… Fine-tuning Burning Complete!")
print(f"{'='*80}\n")
return result
except Exception as e:
import traceback
error_msg = traceback.format_exc()
print(f"\nโŒ Fine-tuning burning failed:\n{error_msg}")
return {
'status': 'failed',
'error': str(e),
'traceback': error_msg
}
# =====================================================
# Gradio UI Functions
# =====================================================
def convert_model_to_phoenix(model_url, use_hierarchical=True, gpu_type="L40S"):
"""Convert model to PHOENIX (๊ธฐ์กด ํ•จ์ˆ˜ ์œ ์ง€)"""
try:
start_time = time.time()
print(f"๐Ÿ“ฅ Loading model: {model_url}")
config = AutoConfig.from_pretrained(model_url, trust_remote_code=True)
model = AutoModel.from_pretrained(
model_url,
trust_remote_code=True,
torch_dtype=torch.float16
).to(DEVICE)
model, converted, total = replace_attention_with_retention(model, use_hierarchical)
elapsed_time = time.time() - start_time
conversion_pct = (converted / total * 100) if total > 0 else 0
result = f"""
โœ… **Conversion Complete!**
**Model**: {model_url}
**Converted**: {converted}/{total} layers ({conversion_pct:.1f}%)
**Time**: {elapsed_time:.1f}s
**GPU**: {gpu_type}
๐ŸŽฏ GQA-aware O(n) complexity!
"""
return result
except Exception as e:
return f"โŒ Conversion failed: {str(e)}"
def generate_text_phoenix(
model_url, use_hierarchical, convert_attention,
prompt, max_new_tokens, temperature
):
"""PHOENIX ํ…์ŠคํŠธ ์ƒ์„ฑ (๊ธฐ์กด ํ•จ์ˆ˜ - ๊ฐ„์†Œํ™”)"""
try:
if not convert_attention or not model_url.strip():
return "โš ๏ธ Enable 'Attention Replace' and provide model URL", ""
print(f"๐Ÿ“ฅ Loading model: {model_url}")
model = AutoModelForCausalLM.from_pretrained(
model_url,
trust_remote_code=True,
torch_dtype=torch.float16
).to(DEVICE)
print(f"๐Ÿ”„ Converting...")
model.model, converted, total = replace_attention_with_retention(
model.model,
use_hierarchical=use_hierarchical
)
tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
print(f"๐Ÿš€ Generating...")
start_time = time.time()
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=temperature > 0.01,
pad_token_id=tokenizer.eos_token_id,
)
elapsed = time.time() - start_time
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
output_md = f"""
## ๐Ÿ“ Generated Text
```
{generated}
```
"""
stats_md = f"""
## ๐Ÿ“Š Statistics
- **Time**: {elapsed:.2f}s
- **Converted**: {converted}/{total} layers
- **Tokens/s**: {max_new_tokens/elapsed:.1f}
"""
return output_md, stats_md
except Exception as e:
import traceback
return f"โŒ Error:\n```\n{traceback.format_exc()}\n```", ""
def burn_phoenix_model_ui(
model_url,
use_hierarchical,
dataset_path,
output_name,
use_finetuning,
num_epochs,
batch_size,
learning_rate,
max_steps,
):
"""
Gradio UI์šฉ ๋ชจ๋ธ ๋ฒ„๋‹ ํ•จ์ˆ˜
"""
try:
if not model_url.strip():
return "โš ๏ธ Model URL required", None
if not output_name.strip():
output_name = f"phoenix_{model_url.split('/')[-1]}_{int(time.time())}"
output_dir = f"{MODELS_PATH}/{output_name}"
# Dataset check
has_dataset = dataset_path and dataset_path.strip() and Path(dataset_path).exists()
if use_finetuning and not has_dataset:
return "โš ๏ธ Fine-tuning requires dataset path", None
# Choose burning method
if use_finetuning and has_dataset:
result = burn_model_with_finetuning(
model_url=model_url,
output_dir=output_dir,
dataset_path=dataset_path,
use_hierarchical=use_hierarchical,
num_epochs=num_epochs,
batch_size=batch_size,
learning_rate=learning_rate,
max_steps=max_steps,
)
else:
result = burn_model_zero_shot(
model_url=model_url,
output_dir=output_dir,
use_hierarchical=use_hierarchical,
)
if result['status'] == 'success':
# Save to database
burning_info = {
'model_url': model_url,
'output_path': result['model_path'],
'use_hierarchical': use_hierarchical,
'dataset_used': has_dataset,
'conversion_rate': result.get('conversion_rate', 0.0),
'training_steps': result.get('training_steps', 0),
'final_loss': result.get('final_loss'),
'evaluation_score': result.get('quality_score', 0.0),
}
db.save_burning(burning_info)
# Format output
output_md = f"""
# ๐Ÿ”ฅ Model Burning Complete!
## ๐Ÿ“ฆ Model Information
- **Original**: {model_url}
- **Output**: `{result['model_path']}`
- **Type**: {'Fine-tuning' if has_dataset else 'Zero-shot'}
## ๐Ÿ“Š Metrics
- **Conversion Rate**: {result['conversion_rate']*100:.1f}%
- **Quality Score**: {result.get('quality_score', 0.0):.2f}/1.00
"""
if 'training_steps' in result:
output_md += f"""
## ๐Ÿš€ Training
- **Steps**: {result['training_steps']}
- **Final Loss**: {result.get('final_loss', 0.0):.4f}
"""
output_md += f"""
## โฑ๏ธ Time Breakdown
- **Total**: {result.get('total_time', 0):.1f}s
"""
if 'load_time' in result:
output_md += f"- **Load**: {result['load_time']:.1f}s\n"
output_md += f"- **Convert**: {result['convert_time']:.1f}s\n"
output_md += f"- **Evaluate**: {result['eval_time']:.1f}s\n"
output_md += f"- **Save**: {result['save_time']:.1f}s\n"
output_md += f"""
## ๐ŸŽฏ Usage
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("{result['model_path']}")
tokenizer = AutoTokenizer.from_pretrained("{result['model_path']}")
inputs = tokenizer("Your prompt", return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0]))
```
โœ… **PHOENIX Model Ready!**
"""
# Create simple plot
fig = go.Figure()
fig.add_trace(go.Bar(
x=['Conversion', 'Quality'],
y=[result['conversion_rate'], result.get('quality_score', 0.0)],
text=[f"{result['conversion_rate']*100:.1f}%", f"{result.get('quality_score', 0.0):.2f}"],
textposition='auto',
))
fig.update_layout(
title="Burning Metrics",
yaxis_range=[0, 1],
template='plotly_white'
)
return output_md, fig
else:
return f"โŒ Burning failed:\n```\n{result.get('error', 'Unknown error')}\n```", None
except Exception as e:
import traceback
return f"โŒ Error:\n```\n{traceback.format_exc()}\n```", None
def view_burning_history():
"""View burning history"""
try:
history = db.get_burning_history(limit=20)
if not history:
return "๐Ÿ“ญ No burning history yet", None
df = pd.DataFrame(history)
fig = px.scatter(
df,
x='timestamp',
y='evaluation_score',
size='conversion_rate',
color='dataset_used',
hover_data=['model_url', 'output_path'],
title='Burning History'
)
cols = ['id', 'model_url', 'output_path', 'conversion_rate',
'evaluation_score', 'training_steps', 'timestamp']
available = [c for c in cols if c in df.columns]
return f"## ๐Ÿ“Š Burning History\n\n{df[available].to_markdown(index=False)}", fig
except Exception as e:
return f"โŒ Error: {e}", None
# ์ „์—ญ ์ดˆ๊ธฐํ™”
db = ExperimentDatabase(DB_PATH)
CONVERTED_MODELS = {}
# =====================================================
# Gradio UI
# =====================================================
with gr.Blocks(
title="๐Ÿ”ฎ PHOENIX - Model Burning Platform",
theme=gr.themes.Soft(),
) as demo:
gr.Markdown("""
# ๐Ÿ”ฎ PHOENIX Retention Platform
**Zero-shot Model Burning + Optional Fine-tuning**
โœ… Zero-shot Conversion (๋ฐ์ดํ„ฐ์…‹ ๋ถˆํ•„์š”!)
โœ… Optional Fine-tuning (๋ฐ์ดํ„ฐ์…‹ ๊ธฐ๋ฐ˜)
โœ… GQA Support
โœ… O(n) Complexity
---
""")
with gr.Tabs():
with gr.Tab("๐Ÿ”„ Quick Convert"):
gr.Markdown("""
### ๋น ๋ฅธ ๋ณ€ํ™˜ ํ…Œ์ŠคํŠธ
๋ชจ๋ธ์„ ๋กœ๋“œํ•˜๊ณ  Attention โ†’ Retention ๋ณ€ํ™˜๋งŒ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค. (์ €์žฅ ์•ˆ ํ•จ)
""")
with gr.Row():
with gr.Column(scale=1):
convert_url = gr.Textbox(
label="๐Ÿ”— Model URL",
value=DEFAULT_MODEL,
placeholder="ibm-granite/granite-4.0-h-350m"
)
convert_hierarchical = gr.Checkbox(value=True, label="Hierarchical Retention")
convert_gpu = gr.Radio(choices=["L40S", "H100"], value="L40S", label="GPU")
convert_btn = gr.Button("๐Ÿ”„ Convert", variant="primary")
with gr.Column(scale=2):
convert_output = gr.Markdown()
convert_btn.click(
convert_model_to_phoenix,
[convert_url, convert_hierarchical, convert_gpu],
[convert_output]
)
with gr.Tab("๐Ÿ”ฅ Model Burning"):
gr.Markdown("""
### ๐Ÿ”ฅ PHOENIX Model Burning
**๋ชจ๋ธ์„ ๋ณ€ํ™˜ํ•˜๊ณ  ์ €์žฅํ•ฉ๋‹ˆ๋‹ค!**
- **Zero-shot**: ๋ฐ์ดํ„ฐ์…‹ ์—†์ด ๋ณ€ํ™˜๋งŒ ์ˆ˜ํ–‰ (๋น ๋ฆ„!)
- **Fine-tuning**: ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ์ถ”๊ฐ€ ํ•™์Šต (์„ฑ๋Šฅ ํ–ฅ์ƒ)
""")
with gr.Row():
with gr.Column(scale=1):
burn_model_url = gr.Textbox(
label="๐Ÿ”— Model URL",
value=DEFAULT_MODEL,
placeholder="ibm-granite/granite-4.0-h-350m"
)
burn_hierarchical = gr.Checkbox(value=True, label="Hierarchical Retention")
burn_output_name = gr.Textbox(
label="๐Ÿ’พ Output Name",
placeholder="phoenix_my_model (auto-generated if empty)"
)
gr.Markdown("---")
gr.Markdown("### ๐Ÿ“Š Dataset (Optional)")
burn_dataset = gr.Textbox(
label="๐Ÿ“ Dataset Path (Optional)",
placeholder="/path/to/dataset.txt (leave empty for zero-shot)",
value=""
)
burn_use_finetuning = gr.Checkbox(
value=False,
label="๐Ÿš€ Enable Fine-tuning (requires dataset)"
)
with gr.Accordion("โš™๏ธ Fine-tuning Config", open=False):
burn_epochs = gr.Slider(1, 5, 1, step=1, label="Epochs")
burn_batch = gr.Slider(1, 16, 4, step=1, label="Batch Size")
burn_lr = gr.Number(value=5e-5, label="Learning Rate")
burn_max_steps = gr.Slider(10, 500, 100, step=10, label="Max Steps")
burn_btn = gr.Button("๐Ÿ”ฅ Burn Model", variant="primary", size="lg")
with gr.Column(scale=2):
burn_output = gr.Markdown()
burn_plot = gr.Plot()
burn_btn.click(
burn_phoenix_model_ui,
[
burn_model_url,
burn_hierarchical,
burn_dataset,
burn_output_name,
burn_use_finetuning,
burn_epochs,
burn_batch,
burn_lr,
burn_max_steps,
],
[burn_output, burn_plot]
)
with gr.Tab("๐Ÿ’ฌ Text Generation"):
gr.Markdown("""
### PHOENIX ํ…์ŠคํŠธ ์ƒ์„ฑ
๋ณ€ํ™˜๋œ ๋ชจ๋ธ๋กœ ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
""")
with gr.Row():
with gr.Column(scale=1):
gen_model_url = gr.Textbox(label="๐Ÿ”— Model URL", value=DEFAULT_MODEL)
gen_hierarchical = gr.Checkbox(value=True, label="Hierarchical")
gen_convert = gr.Checkbox(value=True, label="Enable Conversion")
gen_prompt = gr.Textbox(
label="๐Ÿ“ Prompt",
lines=3,
value="The future of AI is"
)
gen_max_tokens = gr.Slider(16, 256, 64, step=16, label="Max Tokens")
gen_temperature = gr.Slider(0.1, 2.0, 0.7, step=0.1, label="Temperature")
gen_btn = gr.Button("๐Ÿš€ Generate", variant="primary")
with gr.Column(scale=2):
gen_output = gr.Markdown()
gen_stats = gr.Markdown()
gen_btn.click(
generate_text_phoenix,
[gen_model_url, gen_hierarchical, gen_convert, gen_prompt,
gen_max_tokens, gen_temperature],
[gen_output, gen_stats]
)
with gr.Tab("๐Ÿ“Š Burning History"):
gr.Markdown("""
### ๐Ÿ“Š Model Burning History
์ €์žฅ๋œ ๋ชจ๋ธ ๋ฒ„๋‹ ๊ธฐ๋ก์„ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค.
""")
with gr.Row():
with gr.Column(scale=1):
hist_btn = gr.Button("๐Ÿ“Š Load History", variant="primary")
with gr.Column(scale=2):
hist_output = gr.Markdown()
hist_plot = gr.Plot()
hist_btn.click(view_burning_history, outputs=[hist_output, hist_plot])
gr.Markdown("""
---
## ๐Ÿ”ฅ PHOENIX Model Burning
### Zero-shot (๋ฐ์ดํ„ฐ์…‹ ๋ถˆํ•„์š”!)
1. ๋ชจ๋ธ URL ์ž…๋ ฅ
2. "Burn Model" ํด๋ฆญ
3. ์™„๋ฃŒ! โ†’ `/data/phoenix_models/` ์— ์ €์žฅ
### Fine-tuning (์„ ํƒ์‚ฌํ•ญ)
1. Dataset Path ์ž…๋ ฅ
2. "Enable Fine-tuning" ์ฒดํฌ
3. "Burn Model" ํด๋ฆญ
**VIDraft AI Research Lab** | PHOENIX v1.0
""")
if __name__ == "__main__":
demo.queue(max_size=20)
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)