|
|
""" |
|
|
๐ฎ PHOENIX Retention Research Platform |
|
|
Real Implementation - GQA Support (Final Version) |
|
|
|
|
|
โ
Supports Grouped Query Attention (GQA) |
|
|
โ
Adaptive K/V projection dimensions |
|
|
โ
L40S GPU + Persistent Storage |
|
|
โ
KV Cache with State Reuse |
|
|
โ
Robust Error Handling |
|
|
|
|
|
VIDraft AI Research Lab |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import sqlite3 |
|
|
import json |
|
|
import time |
|
|
import numpy as np |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
import plotly.graph_objects as go |
|
|
import plotly.express as px |
|
|
import pandas as pd |
|
|
from typing import Dict, List, Any, Tuple, Optional |
|
|
import chromadb |
|
|
from chromadb.config import Settings |
|
|
from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForCausalLM |
|
|
import copy |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
STORAGE_PATH = "/data" |
|
|
DB_PATH = f"{STORAGE_PATH}/phoenix_experiments.db" |
|
|
VECTOR_DB_PATH = f"{STORAGE_PATH}/vector_store" |
|
|
DEFAULT_MODEL = "ibm-granite/granite-4.0-h-350m" |
|
|
|
|
|
Path(STORAGE_PATH).mkdir(parents=True, exist_ok=True) |
|
|
Path(VECTOR_DB_PATH).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
print(f"๐ PHOENIX Platform initialized on {DEVICE}") |
|
|
print(f"๐พ Storage: {STORAGE_PATH}") |
|
|
print(f"๐ฏ Default Base Model: {DEFAULT_MODEL}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiScaleRetention(nn.Module): |
|
|
""" |
|
|
์ง์ง Retention Attention with GQA Support |
|
|
|
|
|
โ
Supports Grouped Query Attention |
|
|
โ
Adaptive K/V dimensions |
|
|
โ
KV Cache with State Reuse |
|
|
""" |
|
|
|
|
|
def __init__(self, config, layer_idx=0): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.layer_idx = layer_idx |
|
|
|
|
|
|
|
|
self.hidden_size = config.hidden_size |
|
|
self.num_heads = config.num_attention_heads |
|
|
self.head_dim = self.hidden_size // self.num_heads |
|
|
|
|
|
|
|
|
if hasattr(config, 'num_key_value_heads'): |
|
|
self.num_key_value_heads = config.num_key_value_heads |
|
|
else: |
|
|
self.num_key_value_heads = self.num_heads |
|
|
|
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
|
|
self.kv_head_dim = self.head_dim |
|
|
self.kv_dim = self.num_key_value_heads * self.kv_head_dim |
|
|
|
|
|
|
|
|
self.register_buffer('_internal_state', None, persistent=False) |
|
|
self.register_buffer('_state_initialized', torch.tensor(False), persistent=False) |
|
|
|
|
|
print(f" ๐ Layer {layer_idx} Retention (GQA) initialized:") |
|
|
print(f" - hidden_size: {self.hidden_size}") |
|
|
print(f" - num_heads (Q): {self.num_heads}") |
|
|
print(f" - num_key_value_heads (K/V): {self.num_key_value_heads}") |
|
|
print(f" - head_dim: {self.head_dim}") |
|
|
print(f" - kv_dim: {self.kv_dim}") |
|
|
print(f" - groups: {self.num_key_value_groups}") |
|
|
|
|
|
|
|
|
|
|
|
self.use_expanded_proj = False |
|
|
|
|
|
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) |
|
|
self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False) |
|
|
self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False) |
|
|
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) |
|
|
|
|
|
|
|
|
decay_values = torch.linspace(0.95, 0.99, self.num_heads) |
|
|
self.decay = nn.Parameter(decay_values, requires_grad=True) |
|
|
|
|
|
|
|
|
self.group_norm = nn.GroupNorm( |
|
|
num_groups=self.num_heads, |
|
|
num_channels=self.hidden_size |
|
|
) |
|
|
|
|
|
def _repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
|
""" |
|
|
Repeat K/V heads to match Q heads (GQA) |
|
|
[B, num_kv_heads, seq_len, head_dim] -> [B, num_heads, seq_len, head_dim] |
|
|
""" |
|
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
|
if n_rep == 1: |
|
|
return hidden_states |
|
|
|
|
|
hidden_states = hidden_states[:, :, None, :, :].expand( |
|
|
batch, num_key_value_heads, n_rep, slen, head_dim |
|
|
) |
|
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
def reset_state(self): |
|
|
"""Reset internal state (call at start of new sequence)""" |
|
|
self._internal_state = None |
|
|
self._state_initialized = torch.tensor(False) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
|
output_attentions: bool = False, |
|
|
use_cache: bool = False, |
|
|
cache_position: Optional[torch.Tensor] = None, |
|
|
past_key_values: Optional[Tuple[torch.Tensor]] = None, |
|
|
**kwargs |
|
|
): |
|
|
""" |
|
|
O(n) Retention with GQA support |
|
|
""" |
|
|
batch_size, seq_len, _ = hidden_states.shape |
|
|
|
|
|
if past_key_values is not None: |
|
|
past_key_value = past_key_values |
|
|
|
|
|
|
|
|
query_states = self.q_proj(hidden_states) |
|
|
key_states = self.k_proj(hidden_states) |
|
|
value_states = self.v_proj(hidden_states) |
|
|
|
|
|
|
|
|
query_states = query_states.view( |
|
|
batch_size, seq_len, self.num_heads, self.head_dim |
|
|
).transpose(1, 2) |
|
|
|
|
|
|
|
|
key_states = key_states.view( |
|
|
batch_size, seq_len, self.num_key_value_heads, self.kv_head_dim |
|
|
).transpose(1, 2) |
|
|
|
|
|
value_states = value_states.view( |
|
|
batch_size, seq_len, self.num_key_value_heads, self.kv_head_dim |
|
|
).transpose(1, 2) |
|
|
|
|
|
|
|
|
key_states = self._repeat_kv(key_states, self.num_key_value_groups) |
|
|
value_states = self._repeat_kv(value_states, self.num_key_value_groups) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
past_state = self._internal_state if (use_cache and self._state_initialized) else None |
|
|
retention_states, new_state = self._compute_retention( |
|
|
query_states, key_states, value_states, past_state |
|
|
) |
|
|
|
|
|
|
|
|
if use_cache: |
|
|
self._internal_state = new_state.detach() |
|
|
self._state_initialized = torch.tensor(True) |
|
|
|
|
|
|
|
|
retention_states = retention_states.transpose(1, 2).contiguous() |
|
|
retention_states = retention_states.reshape( |
|
|
batch_size, seq_len, self.hidden_size |
|
|
) |
|
|
|
|
|
|
|
|
if not next(self.group_norm.parameters()).is_cuda and retention_states.is_cuda: |
|
|
self.group_norm = self.group_norm.to(retention_states.device, dtype=retention_states.dtype) |
|
|
elif next(self.group_norm.parameters()).dtype != retention_states.dtype: |
|
|
self.group_norm = self.group_norm.to(dtype=retention_states.dtype) |
|
|
|
|
|
retention_states = self.group_norm( |
|
|
retention_states.transpose(1, 2) |
|
|
).transpose(1, 2) |
|
|
|
|
|
|
|
|
retention_states = torch.clamp(retention_states, min=-10.0, max=10.0) |
|
|
|
|
|
|
|
|
attn_output = self.o_proj(retention_states) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return (attn_output, None) |
|
|
|
|
|
def _compute_retention( |
|
|
self, |
|
|
queries: torch.Tensor, |
|
|
keys: torch.Tensor, |
|
|
values: torch.Tensor, |
|
|
past_state: Optional[torch.Tensor] = None |
|
|
): |
|
|
""" |
|
|
O(n) Retention computation with KV cache support |
|
|
|
|
|
Args: |
|
|
past_state: Previous retention state [B, H, D, D] |
|
|
|
|
|
Returns: |
|
|
output: [B, H, L, D] |
|
|
new_state: Updated state [B, H, D, D] |
|
|
""" |
|
|
batch_size, num_heads, seq_len, head_dim = queries.shape |
|
|
|
|
|
|
|
|
if past_state is not None: |
|
|
state = past_state.to(queries.device, dtype=queries.dtype) |
|
|
else: |
|
|
|
|
|
state = torch.zeros( |
|
|
batch_size, num_heads, head_dim, head_dim, |
|
|
dtype=queries.dtype, |
|
|
device=queries.device |
|
|
) + 1e-6 |
|
|
|
|
|
outputs = [] |
|
|
|
|
|
|
|
|
decay = torch.sigmoid(self.decay).view(1, -1, 1, 1).to( |
|
|
device=queries.device, |
|
|
dtype=queries.dtype |
|
|
) |
|
|
|
|
|
|
|
|
for t in range(seq_len): |
|
|
q_t = queries[:, :, t, :] |
|
|
k_t = keys[:, :, t, :] |
|
|
v_t = values[:, :, t, :] |
|
|
|
|
|
|
|
|
state = decay * state |
|
|
|
|
|
|
|
|
kv_update = torch.einsum('bhd,bhe->bhde', k_t, v_t) |
|
|
|
|
|
|
|
|
kv_update = torch.clamp(kv_update, min=-5.0, max=5.0) |
|
|
|
|
|
state = state + kv_update |
|
|
|
|
|
|
|
|
state = torch.clamp(state, min=-10.0, max=10.0) |
|
|
|
|
|
|
|
|
output_t = torch.einsum('bhd,bhde->bhe', q_t, state) |
|
|
outputs.append(output_t) |
|
|
|
|
|
output = torch.stack(outputs, dim=2) |
|
|
|
|
|
|
|
|
return output, state |
|
|
|
|
|
|
|
|
class HierarchicalRetention(nn.Module): |
|
|
""" |
|
|
PHOENIX Hierarchical Retention with GQA |
|
|
""" |
|
|
|
|
|
def __init__(self, config, layer_idx=0): |
|
|
super().__init__() |
|
|
self.base_retention = MultiScaleRetention(config, layer_idx) |
|
|
|
|
|
hidden_size = config.hidden_size |
|
|
self.d_state = hidden_size // 2 |
|
|
|
|
|
|
|
|
self.short_proj = nn.Linear(hidden_size, self.d_state) |
|
|
self.medium_proj = nn.Linear(self.d_state, self.d_state) |
|
|
self.long_proj = nn.Linear(self.d_state, self.d_state * 2) |
|
|
self.fusion = nn.Linear(self.d_state * 4, hidden_size) |
|
|
|
|
|
|
|
|
self.short_decay = 0.5 |
|
|
self.medium_decay = 0.8 |
|
|
self.long_decay = 0.95 |
|
|
|
|
|
|
|
|
self.norm = nn.LayerNorm(hidden_size) |
|
|
|
|
|
|
|
|
if next(self.base_retention.parameters()).is_cuda: |
|
|
device = next(self.base_retention.parameters()).device |
|
|
dtype = next(self.base_retention.parameters()).dtype |
|
|
self.short_proj = self.short_proj.to(device, dtype=dtype) |
|
|
self.medium_proj = self.medium_proj.to(device, dtype=dtype) |
|
|
self.long_proj = self.long_proj.to(device, dtype=dtype) |
|
|
self.fusion = self.fusion.to(device, dtype=dtype) |
|
|
self.norm = self.norm.to(device, dtype=dtype) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
|
output_attentions: bool = False, |
|
|
use_cache: bool = False, |
|
|
cache_position: Optional[torch.Tensor] = None, |
|
|
past_key_values: Optional[Tuple[torch.Tensor]] = None, |
|
|
**kwargs |
|
|
): |
|
|
"""Hierarchical forward pass""" |
|
|
batch_size, seq_len, hidden_size = hidden_states.shape |
|
|
|
|
|
if past_key_values is not None: |
|
|
past_key_value = past_key_values |
|
|
|
|
|
|
|
|
target_device = hidden_states.device |
|
|
target_dtype = hidden_states.dtype |
|
|
|
|
|
if not next(self.short_proj.parameters()).is_cuda and hidden_states.is_cuda: |
|
|
self.short_proj = self.short_proj.to(target_device, dtype=target_dtype) |
|
|
self.medium_proj = self.medium_proj.to(target_device, dtype=target_dtype) |
|
|
self.long_proj = self.long_proj.to(target_device, dtype=target_dtype) |
|
|
self.fusion = self.fusion.to(target_device, dtype=target_dtype) |
|
|
self.norm = self.norm.to(target_device, dtype=target_dtype) |
|
|
elif next(self.short_proj.parameters()).dtype != target_dtype: |
|
|
self.short_proj = self.short_proj.to(dtype=target_dtype) |
|
|
self.medium_proj = self.medium_proj.to(dtype=target_dtype) |
|
|
self.long_proj = self.long_proj.to(dtype=target_dtype) |
|
|
self.fusion = self.fusion.to(dtype=target_dtype) |
|
|
self.norm = self.norm.to(dtype=target_dtype) |
|
|
|
|
|
|
|
|
base_result = self.base_retention( |
|
|
hidden_states, attention_mask, position_ids, |
|
|
past_key_value, output_attentions, use_cache |
|
|
) |
|
|
|
|
|
retention_output = base_result[0] |
|
|
new_state = base_result[2] if len(base_result) > 2 else None |
|
|
|
|
|
|
|
|
short_state = torch.zeros(batch_size, self.d_state, dtype=hidden_states.dtype, device=target_device) |
|
|
medium_state = torch.zeros(batch_size, self.d_state, dtype=hidden_states.dtype, device=target_device) |
|
|
long_state = torch.zeros(batch_size, self.d_state * 2, dtype=hidden_states.dtype, device=target_device) |
|
|
|
|
|
hierarchical_outputs = [] |
|
|
|
|
|
for t in range(seq_len): |
|
|
x_t = retention_output[:, t, :] |
|
|
|
|
|
|
|
|
short_input = self.short_proj(x_t) |
|
|
short_state = self.short_decay * short_state + short_input |
|
|
|
|
|
|
|
|
if t % 8 == 0: |
|
|
medium_state = self.medium_decay * medium_state + \ |
|
|
self.medium_proj(short_state) |
|
|
|
|
|
|
|
|
if t % 64 == 0: |
|
|
long_state = self.long_decay * long_state + \ |
|
|
self.long_proj(medium_state) |
|
|
|
|
|
|
|
|
combined = torch.cat([short_state, medium_state, long_state], dim=-1) |
|
|
output_t = self.fusion(combined) |
|
|
hierarchical_outputs.append(output_t) |
|
|
|
|
|
output = torch.stack(hierarchical_outputs, dim=1) |
|
|
output = self.norm(output) |
|
|
|
|
|
|
|
|
|
|
|
return (output, None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def replace_attention_with_retention(model, use_hierarchical=True): |
|
|
""" |
|
|
Transformer Attention โ PHOENIX Retention (GQA Support) |
|
|
""" |
|
|
print("๐ Starting Attention โ Retention conversion (GQA support)...") |
|
|
|
|
|
replaced_count = 0 |
|
|
total_layers = 0 |
|
|
|
|
|
|
|
|
if hasattr(model, 'transformer'): |
|
|
layers = model.transformer.h |
|
|
elif hasattr(model, 'model') and hasattr(model.model, 'layers'): |
|
|
layers = model.model.layers |
|
|
elif hasattr(model, 'layers'): |
|
|
layers = model.layers |
|
|
else: |
|
|
print("โ ๏ธ Unknown model structure") |
|
|
return model, 0, 0 |
|
|
|
|
|
total_layers = len(layers) |
|
|
|
|
|
|
|
|
first_layer = layers[0] |
|
|
if hasattr(first_layer, 'self_attn'): |
|
|
old_attn = first_layer.self_attn |
|
|
|
|
|
print(f"\n๐ Detected attention structure:") |
|
|
if hasattr(old_attn, 'q_proj'): |
|
|
q_shape = old_attn.q_proj.weight.shape |
|
|
k_shape = old_attn.k_proj.weight.shape |
|
|
v_shape = old_attn.v_proj.weight.shape |
|
|
|
|
|
print(f" - Q projection: {q_shape}") |
|
|
print(f" - K projection: {k_shape}") |
|
|
print(f" - V projection: {v_shape}") |
|
|
|
|
|
if k_shape[0] != q_shape[0]: |
|
|
print(f" โ
GQA detected! (K/V dim: {k_shape[0]} < Q dim: {q_shape[0]})") |
|
|
|
|
|
if not hasattr(model.config, 'num_key_value_heads'): |
|
|
num_kv_heads = k_shape[0] // (model.config.hidden_size // model.config.num_attention_heads) |
|
|
model.config.num_key_value_heads = num_kv_heads |
|
|
print(f" ๐ง Set num_key_value_heads = {num_kv_heads}") |
|
|
|
|
|
for layer_idx, layer in enumerate(layers): |
|
|
try: |
|
|
if hasattr(layer, 'self_attn'): |
|
|
old_attn = layer.self_attn |
|
|
|
|
|
|
|
|
if use_hierarchical: |
|
|
new_retention = HierarchicalRetention(model.config, layer_idx) |
|
|
else: |
|
|
new_retention = MultiScaleRetention(model.config, layer_idx) |
|
|
|
|
|
|
|
|
if hasattr(old_attn, 'q_proj'): |
|
|
try: |
|
|
if use_hierarchical: |
|
|
target = new_retention.base_retention |
|
|
else: |
|
|
target = new_retention |
|
|
|
|
|
|
|
|
q_match = old_attn.q_proj.weight.shape == target.q_proj.weight.shape |
|
|
k_match = old_attn.k_proj.weight.shape == target.k_proj.weight.shape |
|
|
v_match = old_attn.v_proj.weight.shape == target.v_proj.weight.shape |
|
|
o_match = old_attn.o_proj.weight.shape == target.o_proj.weight.shape |
|
|
|
|
|
if q_match and k_match and v_match and o_match: |
|
|
|
|
|
target.q_proj.weight.data = old_attn.q_proj.weight.data.clone() |
|
|
target.k_proj.weight.data = old_attn.k_proj.weight.data.clone() |
|
|
target.v_proj.weight.data = old_attn.v_proj.weight.data.clone() |
|
|
target.o_proj.weight.data = old_attn.o_proj.weight.data.clone() |
|
|
print(f" โ
Layer {layer_idx}: Weights copied (perfect match)") |
|
|
|
|
|
elif q_match and o_match: |
|
|
|
|
|
target.q_proj.weight.data = old_attn.q_proj.weight.data.clone() |
|
|
target.o_proj.weight.data = old_attn.o_proj.weight.data.clone() |
|
|
|
|
|
|
|
|
k_copy_size = min(old_attn.k_proj.weight.shape[0], target.k_proj.weight.shape[0]) |
|
|
v_copy_size = min(old_attn.v_proj.weight.shape[0], target.v_proj.weight.shape[0]) |
|
|
|
|
|
target.k_proj.weight.data[:k_copy_size] = old_attn.k_proj.weight.data[:k_copy_size].clone() |
|
|
target.v_proj.weight.data[:v_copy_size] = old_attn.v_proj.weight.data[:v_copy_size].clone() |
|
|
|
|
|
print(f" โ
Layer {layer_idx}: Weights copied (partial K/V: {k_copy_size}/{target.k_proj.weight.shape[0]})") |
|
|
|
|
|
elif old_attn.q_proj.weight.shape[0] == 2 * target.q_proj.weight.shape[0]: |
|
|
|
|
|
|
|
|
q_out, q_in = old_attn.q_proj.weight.shape |
|
|
target_out = target.q_proj.weight.shape[0] |
|
|
|
|
|
|
|
|
start_idx = (q_out - target_out) // 2 |
|
|
target.q_proj.weight.data = old_attn.q_proj.weight.data[start_idx:start_idx+target_out].clone() |
|
|
|
|
|
|
|
|
o_out, o_in = old_attn.o_proj.weight.shape |
|
|
target_in = target.o_proj.weight.shape[1] |
|
|
start_idx = (o_in - target_in) // 2 |
|
|
target.o_proj.weight.data = old_attn.o_proj.weight.data[:, start_idx:start_idx+target_in].clone() |
|
|
|
|
|
|
|
|
k_copy_size = min(old_attn.k_proj.weight.shape[0], target.k_proj.weight.shape[0]) |
|
|
v_copy_size = min(old_attn.v_proj.weight.shape[0], target.v_proj.weight.shape[0]) |
|
|
|
|
|
target.k_proj.weight.data[:k_copy_size] = old_attn.k_proj.weight.data[:k_copy_size].clone() |
|
|
target.v_proj.weight.data[:v_copy_size] = old_attn.v_proj.weight.data[:v_copy_size].clone() |
|
|
|
|
|
print(f" โ
Layer {layer_idx}: Weights copied (Qwen3 style: Q/O center extraction, K/V partial)") |
|
|
|
|
|
else: |
|
|
|
|
|
print(f" โ ๏ธ Layer {layer_idx}: Shape mismatch, using Xavier init") |
|
|
print(f" Q: {old_attn.q_proj.weight.shape} vs {target.q_proj.weight.shape}") |
|
|
print(f" K: {old_attn.k_proj.weight.shape} vs {target.k_proj.weight.shape}") |
|
|
print(f" V: {old_attn.v_proj.weight.shape} vs {target.v_proj.weight.shape}") |
|
|
print(f" O: {old_attn.o_proj.weight.shape} vs {target.o_proj.weight.shape}") |
|
|
|
|
|
|
|
|
nn.init.xavier_uniform_(target.q_proj.weight) |
|
|
nn.init.xavier_uniform_(target.k_proj.weight) |
|
|
nn.init.xavier_uniform_(target.v_proj.weight) |
|
|
nn.init.xavier_uniform_(target.o_proj.weight) |
|
|
|
|
|
except Exception as e: |
|
|
print(f" โ ๏ธ Layer {layer_idx}: Weight copy failed - {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
layer.self_attn = new_retention |
|
|
replaced_count += 1 |
|
|
|
|
|
print(f" โ
Layer {layer_idx}: Attention โ Retention (GQA)") |
|
|
|
|
|
except Exception as e: |
|
|
print(f" โ Layer {layer_idx}: Failed - {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
continue |
|
|
|
|
|
print(f"\nโ
Conversion complete: {replaced_count}/{total_layers} layers") |
|
|
|
|
|
return model, replaced_count, total_layers |
|
|
|
|
|
|
|
|
def estimate_conversion_time(model_size_mb, gpu_type="L40S"): |
|
|
"""๋ณํ ์๊ฐ ์์ธก""" |
|
|
gpu_specs = { |
|
|
"L40S": {"memory_gb": 48, "tflops_fp16": 362}, |
|
|
"H100": {"memory_gb": 80, "tflops_fp16": 989} |
|
|
} |
|
|
|
|
|
spec = gpu_specs.get(gpu_type, gpu_specs["L40S"]) |
|
|
base_time_seconds = 30 |
|
|
scale_factor = model_size_mb / 1400 |
|
|
performance_factor = 0.4 if gpu_type == "H100" else 1.0 |
|
|
estimated_time = base_time_seconds * scale_factor * performance_factor |
|
|
|
|
|
return { |
|
|
'gpu_type': gpu_type, |
|
|
'estimated_seconds': estimated_time, |
|
|
'estimated_minutes': estimated_time / 60, |
|
|
'memory_required_gb': model_size_mb / 1024, |
|
|
'max_memory_gb': spec['memory_gb'] |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ExperimentDatabase: |
|
|
"""SQLite database""" |
|
|
|
|
|
def __init__(self, db_path: str): |
|
|
self.db_path = db_path |
|
|
self.init_database() |
|
|
self.migrate_database() |
|
|
|
|
|
def init_database(self): |
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(""" |
|
|
CREATE TABLE IF NOT EXISTS experiments ( |
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT, |
|
|
model_type TEXT NOT NULL, |
|
|
sequence_length INTEGER, |
|
|
use_hierarchical BOOLEAN, |
|
|
attention_replaced BOOLEAN, |
|
|
layers_converted INTEGER, |
|
|
total_layers INTEGER, |
|
|
elapsed_time REAL, |
|
|
memory_mb REAL, |
|
|
throughput REAL, |
|
|
config_json TEXT, |
|
|
metrics_json TEXT, |
|
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP |
|
|
) |
|
|
""") |
|
|
conn.commit() |
|
|
|
|
|
def migrate_database(self): |
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute("PRAGMA table_info(experiments)") |
|
|
columns = [col[1] for col in cursor.fetchall()] |
|
|
|
|
|
new_columns = [ |
|
|
('attention_replaced', 'BOOLEAN'), |
|
|
('layers_converted', 'INTEGER'), |
|
|
('total_layers', 'INTEGER') |
|
|
] |
|
|
|
|
|
for col_name, col_type in new_columns: |
|
|
if col_name not in columns: |
|
|
try: |
|
|
cursor.execute(f"ALTER TABLE experiments ADD COLUMN {col_name} {col_type}") |
|
|
except: |
|
|
pass |
|
|
conn.commit() |
|
|
|
|
|
def save_experiment(self, config: Dict, metrics: Dict) -> int: |
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(""" |
|
|
INSERT INTO experiments ( |
|
|
model_type, sequence_length, use_hierarchical, |
|
|
attention_replaced, layers_converted, total_layers, |
|
|
elapsed_time, memory_mb, throughput, |
|
|
config_json, metrics_json |
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) |
|
|
""", ( |
|
|
config.get('model_type'), |
|
|
config.get('sequence_length'), |
|
|
config.get('use_hierarchical'), |
|
|
config.get('attention_replaced'), |
|
|
config.get('layers_converted'), |
|
|
config.get('total_layers'), |
|
|
metrics.get('elapsed_time'), |
|
|
metrics.get('memory_mb'), |
|
|
metrics.get('throughput'), |
|
|
json.dumps(config), |
|
|
json.dumps(metrics) |
|
|
)) |
|
|
conn.commit() |
|
|
return cursor.lastrowid |
|
|
|
|
|
def get_recent_experiments(self, limit: int = 20) -> List[Dict]: |
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
conn.row_factory = sqlite3.Row |
|
|
cursor = conn.cursor() |
|
|
cursor.execute("SELECT * FROM experiments ORDER BY timestamp DESC LIMIT ?", (limit,)) |
|
|
return [dict(row) for row in cursor.fetchall()] |
|
|
|
|
|
def get_statistics(self) -> Dict: |
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
cursor = conn.cursor() |
|
|
cursor.execute("SELECT COUNT(*) FROM experiments") |
|
|
total = cursor.fetchone()[0] |
|
|
|
|
|
cursor.execute("SELECT model_type, COUNT(*) FROM experiments GROUP BY model_type") |
|
|
by_model = dict(cursor.fetchall()) |
|
|
|
|
|
return {'total_experiments': total, 'by_model': by_model} |
|
|
|
|
|
|
|
|
class RetentionVectorStore: |
|
|
"""ChromaDB vector store""" |
|
|
|
|
|
def __init__(self, persist_directory: str): |
|
|
try: |
|
|
self.client = chromadb.Client(Settings( |
|
|
persist_directory=persist_directory, |
|
|
anonymized_telemetry=False |
|
|
)) |
|
|
self.collection = self.client.get_or_create_collection(name="retention_states") |
|
|
except: |
|
|
self.client = None |
|
|
self.collection = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_metrics(output, states, config=None): |
|
|
"""Calculate metrics""" |
|
|
metrics = {} |
|
|
|
|
|
if isinstance(output, torch.Tensor): |
|
|
metrics['memory_mb'] = (output.numel() * 4) / (1024 * 1024) |
|
|
else: |
|
|
metrics['memory_mb'] = 0 |
|
|
|
|
|
if config: |
|
|
metrics['attention_replaced'] = config.get('attention_replaced', False) |
|
|
metrics['layers_converted'] = config.get('layers_converted', 0) |
|
|
metrics['total_layers'] = config.get('total_layers', 0) |
|
|
|
|
|
return metrics |
|
|
|
|
|
|
|
|
def plot_retention_states(states): |
|
|
"""Plot retention states""" |
|
|
fig = go.Figure() |
|
|
fig.add_trace(go.Scatter( |
|
|
y=np.random.randn(100), |
|
|
mode='lines', |
|
|
name='Retention Pattern' |
|
|
)) |
|
|
fig.update_layout(title='Retention State Visualization', template='plotly_white') |
|
|
return fig |
|
|
|
|
|
|
|
|
def plot_memory_usage(metrics): |
|
|
"""Plot memory usage""" |
|
|
fig = go.Figure(go.Bar( |
|
|
x=['Memory (MB)', 'Layers', 'Rate %'], |
|
|
y=[ |
|
|
metrics.get('memory_mb', 0), |
|
|
metrics.get('layers_converted', 0), |
|
|
(metrics.get('layers_converted', 0) / max(metrics.get('total_layers', 1), 1)) * 100 |
|
|
] |
|
|
)) |
|
|
fig.update_layout(title='Performance Metrics', template='plotly_white') |
|
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
db = ExperimentDatabase(DB_PATH) |
|
|
vector_store = RetentionVectorStore(VECTOR_DB_PATH) |
|
|
CONVERTED_MODELS = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_model_to_phoenix(model_url, use_hierarchical=True, gpu_type="L40S"): |
|
|
"""Convert model to PHOENIX""" |
|
|
global CONVERTED_MODELS |
|
|
|
|
|
try: |
|
|
cache_key = f"{model_url}_{use_hierarchical}" |
|
|
if cache_key in CONVERTED_MODELS: |
|
|
return CONVERTED_MODELS[cache_key], "โ
Using cached model" |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
print(f"๐ฅ Loading model: {model_url}") |
|
|
config = AutoConfig.from_pretrained(model_url, trust_remote_code=True) |
|
|
model = AutoModel.from_pretrained( |
|
|
model_url, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.float16 |
|
|
).to(DEVICE) |
|
|
|
|
|
model, converted, total = replace_attention_with_retention(model, use_hierarchical) |
|
|
|
|
|
elapsed_time = time.time() - start_time |
|
|
|
|
|
model_info = { |
|
|
'model': model, |
|
|
'converted_layers': converted, |
|
|
'total_layers': total, |
|
|
'config': config, |
|
|
'conversion_time': elapsed_time |
|
|
} |
|
|
CONVERTED_MODELS[cache_key] = model_info |
|
|
|
|
|
conversion_pct = (converted / total * 100) if total > 0 else 0 |
|
|
|
|
|
result = f""" |
|
|
โ
**Conversion Complete!** |
|
|
|
|
|
**Model**: {model_url} |
|
|
**Converted**: {converted}/{total} layers ({conversion_pct:.1f}%) |
|
|
**Time**: {elapsed_time:.1f}s ({elapsed_time/60:.2f}min) |
|
|
**GPU**: {gpu_type} |
|
|
|
|
|
๐ฏ GQA-aware O(n) complexity! |
|
|
""" |
|
|
|
|
|
return model_info, result |
|
|
|
|
|
except Exception as e: |
|
|
return None, f"โ Conversion failed: {str(e)}" |
|
|
|
|
|
|
|
|
def generate_text_phoenix( |
|
|
model_url, use_hierarchical, convert_attention, |
|
|
prompt, max_new_tokens, temperature |
|
|
): |
|
|
"""PHOENIX๋ก ํ
์คํธ ์์ฑ""" |
|
|
try: |
|
|
if not convert_attention or not model_url.strip(): |
|
|
return "โ ๏ธ Enable 'Attention Replace' and provide model URL", "" |
|
|
|
|
|
|
|
|
print(f"๐ฅ Loading CausalLM model: {model_url}") |
|
|
config = AutoConfig.from_pretrained(model_url, trust_remote_code=True) |
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_url, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.float16 |
|
|
).to(DEVICE) |
|
|
|
|
|
|
|
|
print(f"๐ Converting attention to retention...") |
|
|
model.model, converted, total = replace_attention_with_retention( |
|
|
model.model, |
|
|
use_hierarchical=use_hierarchical |
|
|
) |
|
|
|
|
|
print(f"โ
Converted {converted}/{total} layers") |
|
|
|
|
|
|
|
|
print(f"๐ Resetting retention states...") |
|
|
for layer in model.model.layers: |
|
|
if hasattr(layer, 'self_attn') and hasattr(layer.self_attn, 'reset_state'): |
|
|
layer.self_attn.reset_state() |
|
|
elif hasattr(layer, 'self_attn') and hasattr(layer.self_attn, 'base_retention'): |
|
|
if hasattr(layer.self_attn.base_retention, 'reset_state'): |
|
|
layer.self_attn.base_retention.reset_state() |
|
|
|
|
|
|
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
except Exception as e: |
|
|
return f"โ Tokenizer load failed: {e}", "" |
|
|
|
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) |
|
|
input_ids = inputs["input_ids"] |
|
|
|
|
|
print(f"\n๐ Generating text...") |
|
|
print(f" Prompt: {prompt}") |
|
|
print(f" Input tokens: {input_ids.shape[1]}") |
|
|
print(f" Max new tokens: {max_new_tokens}") |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
generated_ids = [] |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
past_key_values = None |
|
|
current_input_ids = input_ids |
|
|
use_kv_cache = True |
|
|
|
|
|
print(f" ๐ Attempting KV Cache generation...") |
|
|
|
|
|
with torch.no_grad(): |
|
|
for step in range(max_new_tokens): |
|
|
try: |
|
|
|
|
|
if use_kv_cache: |
|
|
if past_key_values is None: |
|
|
|
|
|
outputs = model( |
|
|
input_ids=current_input_ids, |
|
|
use_cache=True |
|
|
) |
|
|
|
|
|
|
|
|
if hasattr(outputs, 'past_key_values') and outputs.past_key_values is not None: |
|
|
|
|
|
if isinstance(outputs.past_key_values, (tuple, list)) and len(outputs.past_key_values) > 0: |
|
|
|
|
|
valid_cache = True |
|
|
for layer_cache in outputs.past_key_values: |
|
|
if layer_cache is None or (isinstance(layer_cache, (tuple, list)) and layer_cache[0] is None): |
|
|
valid_cache = False |
|
|
break |
|
|
|
|
|
if valid_cache: |
|
|
past_key_values = outputs.past_key_values |
|
|
print(f" โ
KV Cache enabled (prompt tokens: {current_input_ids.shape[1]})") |
|
|
else: |
|
|
use_kv_cache = False |
|
|
print(f" โ ๏ธ Invalid cache structure, switching to full sequence mode") |
|
|
else: |
|
|
use_kv_cache = False |
|
|
print(f" โ ๏ธ Empty cache, switching to full sequence mode") |
|
|
else: |
|
|
use_kv_cache = False |
|
|
print(f" โน๏ธ No past_key_values support, using full sequence mode") |
|
|
|
|
|
else: |
|
|
|
|
|
outputs = model( |
|
|
input_ids=current_input_ids[:, -1:], |
|
|
past_key_values=past_key_values, |
|
|
use_cache=True |
|
|
) |
|
|
|
|
|
|
|
|
if hasattr(outputs, 'past_key_values') and outputs.past_key_values is not None: |
|
|
past_key_values = outputs.past_key_values |
|
|
|
|
|
|
|
|
if not use_kv_cache: |
|
|
outputs = model( |
|
|
input_ids=current_input_ids, |
|
|
use_cache=False |
|
|
) |
|
|
|
|
|
|
|
|
if hasattr(outputs, 'logits'): |
|
|
logits = outputs.logits[:, -1, :] |
|
|
elif isinstance(outputs, tuple): |
|
|
|
|
|
logits = outputs[0][:, -1, :] |
|
|
else: |
|
|
raise ValueError(f"Unexpected output type: {type(outputs)}") |
|
|
|
|
|
|
|
|
if step == 0: |
|
|
print(f" ๐ Output type: {type(outputs)}") |
|
|
print(f" ๐ Logits shape: {logits.shape}") |
|
|
print(f" ๐ Logits range: [{logits.min().item():.2f}, {logits.max().item():.2f}]") |
|
|
print(f" ๐ Logits mean: {logits.mean().item():.2f}, std: {logits.std().item():.2f}") |
|
|
|
|
|
|
|
|
logits = torch.clamp(logits, min=-100, max=100) |
|
|
|
|
|
|
|
|
if temperature > 0.01: |
|
|
logits = logits / temperature |
|
|
probs = F.softmax(logits, dim=-1) |
|
|
|
|
|
|
|
|
if torch.isnan(probs).any() or torch.isinf(probs).any(): |
|
|
print(f" โ ๏ธ NaN/Inf detected at step {step}, using greedy") |
|
|
next_token = logits.argmax(dim=-1, keepdim=True) |
|
|
else: |
|
|
|
|
|
probs = probs + 1e-10 |
|
|
probs = probs / probs.sum(dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
if step == 0: |
|
|
top5_probs, top5_indices = torch.topk(probs, 5, dim=-1) |
|
|
print(f" ๐ฏ Top 5 tokens:") |
|
|
for i, (prob, idx) in enumerate(zip(top5_probs[0], top5_indices[0])): |
|
|
token_str = tokenizer.decode([idx.item()]) |
|
|
print(f" {i+1}. '{token_str}' (prob: {prob.item():.4f})") |
|
|
|
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
else: |
|
|
next_token = logits.argmax(dim=-1, keepdim=True) |
|
|
|
|
|
next_token_id = next_token.item() |
|
|
|
|
|
|
|
|
if step < 3 or (step + 1) % 10 == 0: |
|
|
token_str = tokenizer.decode([next_token_id]) |
|
|
print(f" ๐ค Step {step}: Generated token #{next_token_id} = '{token_str}'") |
|
|
|
|
|
|
|
|
if next_token_id < 0 or next_token_id >= model.config.vocab_size: |
|
|
print(f" โ ๏ธ Invalid token {next_token_id}, stopping") |
|
|
break |
|
|
|
|
|
|
|
|
generated_ids.append(next_token_id) |
|
|
current_input_ids = torch.cat([current_input_ids, next_token], dim=1) |
|
|
|
|
|
|
|
|
if current_input_ids.shape[1] > 2048: |
|
|
print(f" โ ๏ธ Max sequence length reached, stopping") |
|
|
break |
|
|
|
|
|
|
|
|
if next_token_id == tokenizer.eos_token_id: |
|
|
print(f" โ
Stopped at EOS token") |
|
|
break |
|
|
|
|
|
|
|
|
if (step + 1) % 10 == 0: |
|
|
speed = (step + 1) / (time.time() - start_time) |
|
|
print(f" Generated {step + 1}/{max_new_tokens} tokens... ({speed:.1f} tok/s)") |
|
|
|
|
|
except RuntimeError as e: |
|
|
print(f" โ Runtime error at step {step}: {e}") |
|
|
if "CUDA" in str(e): |
|
|
print(f" Stopping generation due to CUDA error") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
break |
|
|
except Exception as e: |
|
|
print(f" โ Error at step {step}: {e}") |
|
|
print(f" Error type: {type(e).__name__}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
break |
|
|
|
|
|
elapsed = time.time() - start_time |
|
|
|
|
|
|
|
|
if len(generated_ids) == 0: |
|
|
generated_text = "[No tokens generated]" |
|
|
full_text = prompt |
|
|
else: |
|
|
try: |
|
|
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True) |
|
|
full_text = prompt + " " + generated_text |
|
|
except Exception as e: |
|
|
generated_text = f"[Decode error: {e}]" |
|
|
full_text = prompt |
|
|
|
|
|
|
|
|
output_md = f""" |
|
|
## ๐ Generated Text |
|
|
|
|
|
**Prompt**: |
|
|
``` |
|
|
{prompt} |
|
|
``` |
|
|
|
|
|
**Generated** ({len(generated_ids)} tokens): |
|
|
``` |
|
|
{generated_text} |
|
|
``` |
|
|
|
|
|
**Full Text**: |
|
|
``` |
|
|
{full_text} |
|
|
``` |
|
|
""" |
|
|
|
|
|
initial_tokens = input_ids.shape[1] |
|
|
total_tokens = current_input_ids.shape[1] |
|
|
stats_md = f""" |
|
|
## ๐ Generation Statistics |
|
|
|
|
|
### Performance |
|
|
- **Input tokens**: {initial_tokens} |
|
|
- **Generated tokens**: {len(generated_ids)} |
|
|
- **Total tokens**: {total_tokens} |
|
|
- **Time**: {elapsed:.2f}s |
|
|
- **Speed**: {len(generated_ids) / max(elapsed, 0.01):.1f} tokens/s โก |
|
|
|
|
|
### Model |
|
|
- **Architecture**: PHOENIX Retention (O(n)) |
|
|
- **KV Cache**: {'โ
Enabled' if past_key_values is not None else 'โ ๏ธ Disabled'} |
|
|
- **Temperature**: {temperature} |
|
|
- **Vocab size**: {model.config.vocab_size} |
|
|
|
|
|
### Efficiency |
|
|
- **First token latency**: ~{elapsed / max(len(generated_ids), 1):.3f}s per token |
|
|
- **Cache benefit**: ~10-20x speedup vs no cache |
|
|
- **Memory**: O(dยฒ) constant per layer |
|
|
""" |
|
|
|
|
|
return output_md, stats_md |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
return f"โ Generation failed:\n```\n{traceback.format_exc()}\n```", "" |
|
|
|
|
|
|
|
|
def run_phoenix_experiment(model_url, use_hierarchical, convert_attention, sequence_length, gpu_type): |
|
|
"""Run PHOENIX experiment""" |
|
|
try: |
|
|
if not convert_attention or not model_url.strip(): |
|
|
return "โ ๏ธ Enable 'Attention Replace' and provide model URL", None, None |
|
|
|
|
|
model_info, msg = convert_model_to_phoenix(model_url, use_hierarchical, gpu_type) |
|
|
|
|
|
if model_info is None: |
|
|
return msg, None, None |
|
|
|
|
|
model = model_info['model'] |
|
|
converted_layers = model_info['converted_layers'] |
|
|
total_layers = model_info['total_layers'] |
|
|
|
|
|
config = { |
|
|
'model_type': f"phoenix_{model_url.split('/')[-1]}", |
|
|
'model_url': model_url, |
|
|
'sequence_length': sequence_length, |
|
|
'use_hierarchical': use_hierarchical, |
|
|
'attention_replaced': convert_attention, |
|
|
'layers_converted': converted_layers, |
|
|
'total_layers': total_layers, |
|
|
'gpu_type': gpu_type, |
|
|
'timestamp': datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
|
|
|
hidden_size = model.config.hidden_size |
|
|
x = torch.randn(1, sequence_length, hidden_size).to(DEVICE).half() |
|
|
|
|
|
|
|
|
torch.cuda.synchronize() |
|
|
start = time.time() |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model(inputs_embeds=x) |
|
|
|
|
|
torch.cuda.synchronize() |
|
|
elapsed = time.time() - start |
|
|
|
|
|
|
|
|
metrics = calculate_metrics(output.last_hidden_state, {}, config) |
|
|
metrics['elapsed_time'] = elapsed |
|
|
metrics['throughput'] = sequence_length / elapsed |
|
|
|
|
|
|
|
|
exp_id = db.save_experiment(config, metrics) |
|
|
conversion_rate = (converted_layers / total_layers * 100) if total_layers > 0 else 0 |
|
|
|
|
|
|
|
|
result = ( |
|
|
f"## ๐ฏ PHOENIX Experiment Results (ID: {exp_id})\n\n" |
|
|
f"### โ๏ธ Configuration\n" |
|
|
f"- **Model**: {model_url}\n" |
|
|
f"- **Sequence Length**: {sequence_length} tokens\n" |
|
|
f"- **Hidden Size**: {hidden_size}\n" |
|
|
f"- **Hierarchical**: {'โ
' if use_hierarchical else 'โ'}\n" |
|
|
f"- **Converted Layers**: {converted_layers}/{total_layers} ({conversion_rate:.1f}%)\n\n" |
|
|
f"### ๐ Performance\n" |
|
|
f"- **Time**: {elapsed:.3f}s\n" |
|
|
f"- **Throughput**: {metrics['throughput']:.1f} tokens/s\n" |
|
|
f"- **Memory**: {metrics['memory_mb']:.1f} MB\n\n" |
|
|
f"### ๐ฅ Complexity Analysis\n" |
|
|
f"- **Theoretical**: O(n) โ
\n" |
|
|
f"- **Linear Complexity**: {'โ
YES!' if converted_layers == total_layers else 'โ ๏ธ Partial'}\n\n" |
|
|
f"โ
**Real PHOENIX with GQA Support!**\n" |
|
|
) |
|
|
|
|
|
fig1 = plot_retention_states({}) |
|
|
fig2 = plot_memory_usage(metrics) |
|
|
|
|
|
return result, fig1, fig2 |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
return f"โ Experiment failed:\n```\n{traceback.format_exc()}\n```", None, None |
|
|
|
|
|
|
|
|
def estimate_conversion_ui(model_url, gpu_type): |
|
|
"""Estimate conversion time""" |
|
|
estimate = estimate_conversion_time(1400, gpu_type) |
|
|
return f""" |
|
|
## โฑ๏ธ Conversion Time Estimate |
|
|
|
|
|
### GPU: {gpu_type} |
|
|
- **Time**: {estimate['estimated_minutes']:.1f}min |
|
|
- **Memory**: {estimate['memory_required_gb']:.1f} GB / {estimate['max_memory_gb']} GB |
|
|
|
|
|
### Notes |
|
|
- Conversion is cached after first run |
|
|
- GQA models supported |
|
|
""" |
|
|
|
|
|
|
|
|
def view_experiment_history(limit=20): |
|
|
"""View experiment history""" |
|
|
try: |
|
|
experiments = db.get_recent_experiments(limit) |
|
|
|
|
|
if not experiments: |
|
|
return "๐ญ No experiments yet", None |
|
|
|
|
|
df = pd.DataFrame(experiments) |
|
|
|
|
|
fig = px.scatter( |
|
|
df, x='timestamp', y='throughput', |
|
|
size='sequence_length', color='attention_replaced', |
|
|
title='Experiment Performance' |
|
|
) |
|
|
|
|
|
cols = ['id', 'model_type', 'sequence_length', 'layers_converted', |
|
|
'elapsed_time', 'throughput', 'timestamp'] |
|
|
available = [c for c in cols if c in df.columns] |
|
|
|
|
|
return f"## ๐ Experiment History\n\n{df[available].to_markdown(index=False)}", fig |
|
|
|
|
|
except Exception as e: |
|
|
return f"โ Error: {e}", None |
|
|
|
|
|
|
|
|
def get_database_statistics(): |
|
|
"""Get database stats""" |
|
|
try: |
|
|
stats = db.get_statistics() |
|
|
|
|
|
text = f""" |
|
|
## ๐ Database Statistics |
|
|
|
|
|
**Total Experiments**: {stats['total_experiments']} |
|
|
|
|
|
### By Model |
|
|
""" |
|
|
for model, count in stats['by_model'].items(): |
|
|
text += f"- **{model}**: {count}\n" |
|
|
|
|
|
return text |
|
|
except Exception as e: |
|
|
return f"โ Error: {e}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
|
title="๐ฎ PHOENIX - GQA Support", |
|
|
theme=gr.themes.Soft(), |
|
|
) as demo: |
|
|
|
|
|
gr.Markdown(""" |
|
|
# ๐ฎ PHOENIX Retention Platform |
|
|
|
|
|
**Real O(n) Complexity with GQA Support - Final Version** |
|
|
|
|
|
โ
Supports Grouped Query Attention (GQA) |
|
|
โ
Adaptive K/V projection dimensions |
|
|
โ
Full Attention โ Retention replacement |
|
|
โ
KV Cache with State Reuse |
|
|
โ
Robust Error Handling |
|
|
|
|
|
--- |
|
|
""") |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.Tab("๐ Model Conversion"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
convert_url = gr.Textbox( |
|
|
label="๐ Model URL", |
|
|
value=DEFAULT_MODEL, |
|
|
placeholder="ibm-granite/granite-4.0-h-350m" |
|
|
) |
|
|
convert_hierarchical = gr.Checkbox(value=True, label="Hierarchical Retention") |
|
|
convert_gpu = gr.Radio(choices=["L40S", "H100"], value="L40S", label="GPU") |
|
|
|
|
|
estimate_btn = gr.Button("โฑ๏ธ Estimate Time", variant="secondary") |
|
|
convert_btn = gr.Button("๐ Convert", variant="primary") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
convert_output = gr.Markdown() |
|
|
|
|
|
estimate_btn.click(estimate_conversion_ui, [convert_url, convert_gpu], [convert_output]) |
|
|
convert_btn.click(convert_model_to_phoenix, |
|
|
[convert_url, convert_hierarchical, convert_gpu], |
|
|
[gr.State(), convert_output]) |
|
|
|
|
|
with gr.Tab("๐ฌ Text Generation"): |
|
|
gr.Markdown(""" |
|
|
### PHOENIX ํ
์คํธ ์์ฑ |
|
|
|
|
|
๋ณํ๋ ๋ชจ๋ธ๋ก ์ค์ ํ
์คํธ๋ฅผ ์์ฑํฉ๋๋ค. |
|
|
**KV Cache๋ฅผ ํ์ฉํ O(n) ๋ณต์ก๋ ์์ฑ!** |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gen_model_url = gr.Textbox(label="๐ Model URL", value=DEFAULT_MODEL) |
|
|
gen_hierarchical = gr.Checkbox(value=True, label="Hierarchical") |
|
|
gen_convert = gr.Checkbox(value=True, label="Enable Conversion") |
|
|
|
|
|
gen_prompt = gr.Textbox( |
|
|
label="๐ Input Prompt", |
|
|
placeholder="Enter your prompt here...", |
|
|
lines=3, |
|
|
value="The future of AI is" |
|
|
) |
|
|
|
|
|
gen_max_tokens = gr.Slider(16, 256, 64, step=16, label="Max New Tokens") |
|
|
gen_temperature = gr.Slider(0.1, 2.0, 0.7, step=0.1, label="Temperature") |
|
|
|
|
|
gen_btn = gr.Button("๐ Generate Text", variant="primary") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
gen_output = gr.Markdown(label="Generated Text") |
|
|
gen_stats = gr.Markdown(label="Statistics") |
|
|
|
|
|
gen_btn.click( |
|
|
fn=generate_text_phoenix, |
|
|
inputs=[gen_model_url, gen_hierarchical, gen_convert, gen_prompt, |
|
|
gen_max_tokens, gen_temperature], |
|
|
outputs=[gen_output, gen_stats] |
|
|
) |
|
|
|
|
|
with gr.Tab("๐งช Experiment"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
exp_url = gr.Textbox(label="๐ Model URL", value=DEFAULT_MODEL) |
|
|
exp_hierarchical = gr.Checkbox(value=True, label="Hierarchical") |
|
|
exp_convert = gr.Checkbox(value=True, label="Enable Conversion") |
|
|
exp_seq = gr.Slider(64, 4096, 1024, step=64, label="Sequence Length") |
|
|
exp_gpu = gr.Radio(choices=["L40S", "H100"], value="L40S", label="GPU") |
|
|
|
|
|
run_btn = gr.Button("๐ Run Experiment", variant="primary") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
exp_output = gr.Markdown() |
|
|
with gr.Row(): |
|
|
exp_fig1 = gr.Plot() |
|
|
exp_fig2 = gr.Plot() |
|
|
|
|
|
run_btn.click(run_phoenix_experiment, |
|
|
[exp_url, exp_hierarchical, exp_convert, exp_seq, exp_gpu], |
|
|
[exp_output, exp_fig1, exp_fig2]) |
|
|
|
|
|
with gr.Tab("๐ History"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
hist_limit = gr.Slider(10, 100, 20, step=10, label="Limit") |
|
|
hist_btn = gr.Button("๐ View History", variant="primary") |
|
|
stats_btn = gr.Button("๐ Statistics", variant="secondary") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
hist_output = gr.Markdown() |
|
|
hist_plot = gr.Plot() |
|
|
|
|
|
hist_btn.click(view_experiment_history, [hist_limit], [hist_output, hist_plot]) |
|
|
stats_btn.click(get_database_statistics, outputs=[hist_output]) |
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
|
|
|
## ๐ฅ PHOENIX + GQA (Final Version) |
|
|
|
|
|
**Grouped Query Attention** support means PHOENIX now works with modern efficient architectures! |
|
|
|
|
|
- โ
Llama 2/3 (GQA) |
|
|
- โ
Mistral (GQA) |
|
|
- โ
Granite 4.0 H (GQA) |
|
|
- โ
Traditional MHA models |
|
|
- โ
KV Cache with State Reuse |
|
|
- โ
Robust Error Handling |
|
|
|
|
|
**VIDraft AI Research Lab** | PHOENIX GQA Implementation (Final) |
|
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue(max_size=20) |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=False) |