seawolf2357 commited on
Commit
ef87883
ยท
verified ยท
1 Parent(s): ec5f981

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +724 -708
app.py CHANGED
@@ -1,12 +1,12 @@
1
  """
2
- ๐Ÿ”ฎ PHOENIX Retention Research Platform
3
- Real Implementation - GQA Support (Final Version)
4
 
5
- โœ… Supports Grouped Query Attention (GQA)
6
- โœ… Adaptive K/V projection dimensions
7
- โœ… L40S GPU + Persistent Storage
8
- โœ… KV Cache with State Reuse
9
- โœ… Robust Error Handling
10
 
11
  VIDraft AI Research Lab
12
  """
@@ -27,8 +27,16 @@ import pandas as pd
27
  from typing import Dict, List, Any, Tuple, Optional
28
  import chromadb
29
  from chromadb.config import Settings
30
- from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForCausalLM
 
 
 
 
 
 
 
31
  import copy
 
32
 
33
  # =====================================================
34
  # ์ „์—ญ ์„ค์ •
@@ -38,10 +46,12 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
38
  STORAGE_PATH = "/data"
39
  DB_PATH = f"{STORAGE_PATH}/phoenix_experiments.db"
40
  VECTOR_DB_PATH = f"{STORAGE_PATH}/vector_store"
 
41
  DEFAULT_MODEL = "ibm-granite/granite-4.0-h-350m"
42
 
43
  Path(STORAGE_PATH).mkdir(parents=True, exist_ok=True)
44
  Path(VECTOR_DB_PATH).mkdir(parents=True, exist_ok=True)
 
45
 
46
  print(f"๐Ÿš€ PHOENIX Platform initialized on {DEVICE}")
47
  print(f"๐Ÿ’พ Storage: {STORAGE_PATH}")
@@ -52,13 +62,7 @@ print(f"๐ŸŽฏ Default Base Model: {DEFAULT_MODEL}")
52
  # =====================================================
53
 
54
  class MultiScaleRetention(nn.Module):
55
- """
56
- ์ง„์งœ Retention Attention with GQA Support
57
-
58
- โœ… Supports Grouped Query Attention
59
- โœ… Adaptive K/V dimensions
60
- โœ… KV Cache with State Reuse
61
- """
62
 
63
  def __init__(self, config, layer_idx=0):
64
  super().__init__()
@@ -77,32 +81,21 @@ class MultiScaleRetention(nn.Module):
77
  self.num_key_value_heads = self.num_heads
78
 
79
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
80
- self.kv_head_dim = self.head_dim # Same as Q head_dim
81
  self.kv_dim = self.num_key_value_heads * self.kv_head_dim
82
 
83
- # โœ… Internal state storage for KV cache simulation
84
  self.register_buffer('_internal_state', None, persistent=False)
85
  self.register_buffer('_state_initialized', torch.tensor(False), persistent=False)
86
 
87
- print(f" ๐Ÿ“ Layer {layer_idx} Retention (GQA) initialized:")
88
- print(f" - hidden_size: {self.hidden_size}")
89
- print(f" - num_heads (Q): {self.num_heads}")
90
- print(f" - num_key_value_heads (K/V): {self.num_key_value_heads}")
91
- print(f" - head_dim: {self.head_dim}")
92
- print(f" - kv_dim: {self.kv_dim}")
93
- print(f" - groups: {self.num_key_value_groups}")
94
-
95
- # โœ… Projections with correct dimensions
96
- # Check if model uses expanded projections (like Qwen3)
97
- self.use_expanded_proj = False
98
-
99
  self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
100
- self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False) # GQA!
101
- self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False) # GQA!
102
  self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
103
 
104
  # Retention parameters
105
- decay_values = torch.linspace(0.95, 0.99, self.num_heads) # โœ… ๋” ๋†’์€ decay (์ •๋ณด ์œ ์ง€)
106
  self.decay = nn.Parameter(decay_values, requires_grad=True)
107
 
108
  # Group norm
@@ -112,10 +105,7 @@ class MultiScaleRetention(nn.Module):
112
  )
113
 
114
  def _repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
115
- """
116
- Repeat K/V heads to match Q heads (GQA)
117
- [B, num_kv_heads, seq_len, head_dim] -> [B, num_heads, seq_len, head_dim]
118
- """
119
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
120
  if n_rep == 1:
121
  return hidden_states
@@ -126,7 +116,7 @@ class MultiScaleRetention(nn.Module):
126
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
127
 
128
  def reset_state(self):
129
- """Reset internal state (call at start of new sequence)"""
130
  self._internal_state = None
131
  self._state_initialized = torch.tensor(False)
132
 
@@ -142,25 +132,22 @@ class MultiScaleRetention(nn.Module):
142
  past_key_values: Optional[Tuple[torch.Tensor]] = None,
143
  **kwargs
144
  ):
145
- """
146
- O(n) Retention with GQA support
147
- """
148
  batch_size, seq_len, _ = hidden_states.shape
149
 
150
  if past_key_values is not None:
151
  past_key_value = past_key_values
152
 
153
  # Q, K, V projections
154
- query_states = self.q_proj(hidden_states) # [B, L, hidden_size]
155
- key_states = self.k_proj(hidden_states) # [B, L, kv_dim]
156
- value_states = self.v_proj(hidden_states) # [B, L, kv_dim]
157
 
158
- # Reshape Q: [B, L, hidden_size] -> [B, num_heads, L, head_dim]
159
  query_states = query_states.view(
160
  batch_size, seq_len, self.num_heads, self.head_dim
161
  ).transpose(1, 2)
162
 
163
- # Reshape K/V: [B, L, kv_dim] -> [B, num_kv_heads, L, kv_head_dim]
164
  key_states = key_states.view(
165
  batch_size, seq_len, self.num_key_value_heads, self.kv_head_dim
166
  ).transpose(1, 2)
@@ -169,30 +156,28 @@ class MultiScaleRetention(nn.Module):
169
  batch_size, seq_len, self.num_key_value_heads, self.kv_head_dim
170
  ).transpose(1, 2)
171
 
172
- # โœ… Repeat K/V to match Q heads (GQA)
173
  key_states = self._repeat_kv(key_states, self.num_key_value_groups)
174
  value_states = self._repeat_kv(value_states, self.num_key_value_groups)
175
 
176
- # Now all have shape [B, num_heads, L, head_dim]
177
-
178
- # Retention computation with internal state
179
  past_state = self._internal_state if (use_cache and self._state_initialized) else None
180
  retention_states, new_state = self._compute_retention(
181
  query_states, key_states, value_states, past_state
182
  )
183
 
184
- # โœ… Store state internally for next iteration
185
  if use_cache:
186
  self._internal_state = new_state.detach()
187
  self._state_initialized = torch.tensor(True)
188
 
189
- # Reshape back: [B, num_heads, L, head_dim] -> [B, L, hidden_size]
190
  retention_states = retention_states.transpose(1, 2).contiguous()
191
  retention_states = retention_states.reshape(
192
  batch_size, seq_len, self.hidden_size
193
  )
194
 
195
- # โœ… Group norm - ensure it's on the correct device AND dtype
196
  if not next(self.group_norm.parameters()).is_cuda and retention_states.is_cuda:
197
  self.group_norm = self.group_norm.to(retention_states.device, dtype=retention_states.dtype)
198
  elif next(self.group_norm.parameters()).dtype != retention_states.dtype:
@@ -202,90 +187,60 @@ class MultiScaleRetention(nn.Module):
202
  retention_states.transpose(1, 2)
203
  ).transpose(1, 2)
204
 
205
- # โœ… Additional stabilization: clip extreme values
206
  retention_states = torch.clamp(retention_states, min=-10.0, max=10.0)
207
 
208
  # Output projection
209
  attn_output = self.o_proj(retention_states)
210
 
211
- # โœ… Return format for compatibility
212
- # Granite expects: (hidden_states, attn_weights)
213
- # We return: (output, None) - no past_key_values in return signature
214
- # State is stored internally but not returned
215
  return (attn_output, None)
216
 
217
  def _compute_retention(
218
  self,
219
- queries: torch.Tensor, # [B, H, L, D]
220
- keys: torch.Tensor, # [B, H, L, D]
221
- values: torch.Tensor, # [B, H, L, D]
222
  past_state: Optional[torch.Tensor] = None
223
  ):
224
- """
225
- O(n) Retention computation with KV cache support
226
-
227
- Args:
228
- past_state: Previous retention state [B, H, D, D]
229
-
230
- Returns:
231
- output: [B, H, L, D]
232
- new_state: Updated state [B, H, D, D]
233
- """
234
  batch_size, num_heads, seq_len, head_dim = queries.shape
235
 
236
- # โœ… State initialization with correct dtype and device
237
  if past_state is not None:
238
  state = past_state.to(queries.device, dtype=queries.dtype)
239
  else:
240
- # โœ… ์ž‘์€ ๊ฐ’์œผ๋กœ ์ดˆ๊ธฐํ™” (์™„์ „ํ•œ 0๋ณด๋‹ค ์•ˆ์ •์ )
241
  state = torch.zeros(
242
  batch_size, num_heads, head_dim, head_dim,
243
  dtype=queries.dtype,
244
  device=queries.device
245
- ) + 1e-6 # Small epsilon for stability
246
 
247
  outputs = []
248
 
249
- # โœ… Decay๋ฅผ ์ž…๋ ฅ๊ณผ ๊ฐ™์€ device/dtype์œผ๋กœ
250
  decay = torch.sigmoid(self.decay).view(1, -1, 1, 1).to(
251
  device=queries.device,
252
  dtype=queries.dtype
253
  )
254
 
255
- # Sequential processing (O(n))
256
  for t in range(seq_len):
257
- q_t = queries[:, :, t, :] # [B, H, D]
258
- k_t = keys[:, :, t, :] # [B, H, D]
259
- v_t = values[:, :, t, :] # [B, H, D]
260
 
261
- # Decay application
262
  state = decay * state
263
-
264
- # State update: S = decay * S + k @ v^T
265
  kv_update = torch.einsum('bhd,bhe->bhde', k_t, v_t)
266
-
267
- # โœ… Clip update to prevent explosion
268
  kv_update = torch.clamp(kv_update, min=-5.0, max=5.0)
269
-
270
  state = state + kv_update
271
-
272
- # โœ… Clip state to maintain stability
273
  state = torch.clamp(state, min=-10.0, max=10.0)
274
 
275
- # Output: q @ S
276
  output_t = torch.einsum('bhd,bhde->bhe', q_t, state)
277
  outputs.append(output_t)
278
 
279
- output = torch.stack(outputs, dim=2) # [B, H, L, D]
280
 
281
- # โœ… Return both output and updated state
282
  return output, state
283
 
284
 
285
  class HierarchicalRetention(nn.Module):
286
- """
287
- PHOENIX Hierarchical Retention with GQA
288
- """
289
 
290
  def __init__(self, config, layer_idx=0):
291
  super().__init__()
@@ -294,21 +249,17 @@ class HierarchicalRetention(nn.Module):
294
  hidden_size = config.hidden_size
295
  self.d_state = hidden_size // 2
296
 
297
- # 3-tier hierarchical states
298
  self.short_proj = nn.Linear(hidden_size, self.d_state)
299
  self.medium_proj = nn.Linear(self.d_state, self.d_state)
300
  self.long_proj = nn.Linear(self.d_state, self.d_state * 2)
301
  self.fusion = nn.Linear(self.d_state * 4, hidden_size)
302
 
303
- # Decay rates
304
  self.short_decay = 0.5
305
  self.medium_decay = 0.8
306
  self.long_decay = 0.95
307
 
308
- # Layer norm
309
  self.norm = nn.LayerNorm(hidden_size)
310
 
311
- # โœ… CRITICAL: Move all submodules to same device as base_retention
312
  if next(self.base_retention.parameters()).is_cuda:
313
  device = next(self.base_retention.parameters()).device
314
  dtype = next(self.base_retention.parameters()).dtype
@@ -336,7 +287,6 @@ class HierarchicalRetention(nn.Module):
336
  if past_key_values is not None:
337
  past_key_value = past_key_values
338
 
339
- # โœ… Ensure all submodules are on correct device AND dtype
340
  target_device = hidden_states.device
341
  target_dtype = hidden_states.dtype
342
 
@@ -353,14 +303,12 @@ class HierarchicalRetention(nn.Module):
353
  self.fusion = self.fusion.to(dtype=target_dtype)
354
  self.norm = self.norm.to(dtype=target_dtype)
355
 
356
- # โœ… Base Retention - now always returns 3 values
357
  base_result = self.base_retention(
358
  hidden_states, attention_mask, position_ids,
359
  past_key_value, output_attentions, use_cache
360
  )
361
 
362
  retention_output = base_result[0]
363
- new_state = base_result[2] if len(base_result) > 2 else None
364
 
365
  # Hierarchical states
366
  short_state = torch.zeros(batch_size, self.d_state, dtype=hidden_states.dtype, device=target_device)
@@ -372,21 +320,17 @@ class HierarchicalRetention(nn.Module):
372
  for t in range(seq_len):
373
  x_t = retention_output[:, t, :]
374
 
375
- # Short-term
376
  short_input = self.short_proj(x_t)
377
  short_state = self.short_decay * short_state + short_input
378
 
379
- # Medium-term (every 8 tokens)
380
  if t % 8 == 0:
381
  medium_state = self.medium_decay * medium_state + \
382
  self.medium_proj(short_state)
383
 
384
- # Long-term (every 64 tokens)
385
  if t % 64 == 0:
386
  long_state = self.long_decay * long_state + \
387
  self.long_proj(medium_state)
388
 
389
- # Fusion
390
  combined = torch.cat([short_state, medium_state, long_state], dim=-1)
391
  output_t = self.fusion(combined)
392
  hierarchical_outputs.append(output_t)
@@ -394,8 +338,6 @@ class HierarchicalRetention(nn.Module):
394
  output = torch.stack(hierarchical_outputs, dim=1)
395
  output = self.norm(output)
396
 
397
- # โœ… Return format for compatibility with Granite
398
- # Granite expects: (hidden_states, attn_weights)
399
  return (output, None)
400
 
401
 
@@ -404,15 +346,12 @@ class HierarchicalRetention(nn.Module):
404
  # =====================================================
405
 
406
  def replace_attention_with_retention(model, use_hierarchical=True):
407
- """
408
- Transformer Attention โ†’ PHOENIX Retention (GQA Support)
409
- """
410
  print("๐Ÿ”„ Starting Attention โ†’ Retention conversion (GQA support)...")
411
 
412
  replaced_count = 0
413
  total_layers = 0
414
 
415
- # Layer structure
416
  if hasattr(model, 'transformer'):
417
  layers = model.transformer.h
418
  elif hasattr(model, 'model') and hasattr(model.model, 'layers'):
@@ -425,35 +364,26 @@ def replace_attention_with_retention(model, use_hierarchical=True):
425
 
426
  total_layers = len(layers)
427
 
428
- # Check first layer for dimensions
429
  first_layer = layers[0]
430
  if hasattr(first_layer, 'self_attn'):
431
  old_attn = first_layer.self_attn
432
 
433
- print(f"\n๐Ÿ“ Detected attention structure:")
434
  if hasattr(old_attn, 'q_proj'):
435
  q_shape = old_attn.q_proj.weight.shape
436
  k_shape = old_attn.k_proj.weight.shape
437
- v_shape = old_attn.v_proj.weight.shape
438
-
439
- print(f" - Q projection: {q_shape}")
440
- print(f" - K projection: {k_shape}")
441
- print(f" - V projection: {v_shape}")
442
 
443
  if k_shape[0] != q_shape[0]:
444
  print(f" โœ… GQA detected! (K/V dim: {k_shape[0]} < Q dim: {q_shape[0]})")
445
- # Update config for GQA
446
  if not hasattr(model.config, 'num_key_value_heads'):
447
  num_kv_heads = k_shape[0] // (model.config.hidden_size // model.config.num_attention_heads)
448
  model.config.num_key_value_heads = num_kv_heads
449
- print(f" ๐Ÿ”ง Set num_key_value_heads = {num_kv_heads}")
450
 
451
  for layer_idx, layer in enumerate(layers):
452
  try:
453
  if hasattr(layer, 'self_attn'):
454
  old_attn = layer.self_attn
455
 
456
- # Create PHOENIX Retention
457
  if use_hierarchical:
458
  new_retention = HierarchicalRetention(model.config, layer_idx)
459
  else:
@@ -467,88 +397,45 @@ def replace_attention_with_retention(model, use_hierarchical=True):
467
  else:
468
  target = new_retention
469
 
470
- # โœ… Shape ํ™•์ธ ๋ฐ ๋ณต์‚ฌ
471
  q_match = old_attn.q_proj.weight.shape == target.q_proj.weight.shape
472
  k_match = old_attn.k_proj.weight.shape == target.k_proj.weight.shape
473
  v_match = old_attn.v_proj.weight.shape == target.v_proj.weight.shape
474
  o_match = old_attn.o_proj.weight.shape == target.o_proj.weight.shape
475
 
476
  if q_match and k_match and v_match and o_match:
477
- # ์™„๋ฒฝํ•œ ๋งค์นญ - ๊ทธ๋Œ€๋กœ ๋ณต์‚ฌ
478
  target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
479
  target.k_proj.weight.data = old_attn.k_proj.weight.data.clone()
480
  target.v_proj.weight.data = old_attn.v_proj.weight.data.clone()
481
  target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
482
- print(f" โœ… Layer {layer_idx}: Weights copied (perfect match)")
483
 
484
  elif q_match and o_match:
485
- # Q์™€ O๋Š” ๋งค์นญ - K/V๋Š” ๋ถ€๋ถ„ ๋ณต์‚ฌ
486
  target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
487
  target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
488
 
489
- # K/V๋Š” ๊ฐ€๋Šฅํ•œ ๋งŒํผ ๋ณต์‚ฌ (GQA์˜ ๊ฒฝ์šฐ ์ผ๋ถ€๋งŒ)
490
- k_copy_size = min(old_attn.k_proj.weight.shape[0], target.k_proj.weight.shape[0])
491
- v_copy_size = min(old_attn.v_proj.weight.shape[0], target.v_proj.weight.shape[0])
492
-
493
- target.k_proj.weight.data[:k_copy_size] = old_attn.k_proj.weight.data[:k_copy_size].clone()
494
- target.v_proj.weight.data[:v_copy_size] = old_attn.v_proj.weight.data[:v_copy_size].clone()
495
-
496
- print(f" โœ… Layer {layer_idx}: Weights copied (partial K/V: {k_copy_size}/{target.k_proj.weight.shape[0]})")
497
-
498
- elif old_attn.q_proj.weight.shape[0] == 2 * target.q_proj.weight.shape[0]:
499
- # Qwen3 ์Šคํƒ€์ผ: Q๊ฐ€ 2๋ฐฐ ํฌ๊ธฐ (ํ™•์žฅ๋œ projection)
500
- # ์ค‘์•™ ๋ถ€๋ถ„์„ ์ถ”์ถœ
501
- q_out, q_in = old_attn.q_proj.weight.shape
502
- target_out = target.q_proj.weight.shape[0]
503
-
504
- # Q์˜ ์ค‘์•™ ๋ถ€๋ถ„ ์ถ”์ถœ
505
- start_idx = (q_out - target_out) // 2
506
- target.q_proj.weight.data = old_attn.q_proj.weight.data[start_idx:start_idx+target_out].clone()
507
-
508
- # O์˜ ์ค‘์•™ ๋ถ€๋ถ„ ์ถ”์ถœ (transposed)
509
- o_out, o_in = old_attn.o_proj.weight.shape
510
- target_in = target.o_proj.weight.shape[1]
511
- start_idx = (o_in - target_in) // 2
512
- target.o_proj.weight.data = old_attn.o_proj.weight.data[:, start_idx:start_idx+target_in].clone()
513
-
514
- # K/V ๋ถ€๋ถ„ ๋ณต์‚ฌ
515
  k_copy_size = min(old_attn.k_proj.weight.shape[0], target.k_proj.weight.shape[0])
516
  v_copy_size = min(old_attn.v_proj.weight.shape[0], target.v_proj.weight.shape[0])
517
 
518
  target.k_proj.weight.data[:k_copy_size] = old_attn.k_proj.weight.data[:k_copy_size].clone()
519
  target.v_proj.weight.data[:v_copy_size] = old_attn.v_proj.weight.data[:v_copy_size].clone()
520
 
521
- print(f" โœ… Layer {layer_idx}: Weights copied (Qwen3 style: Q/O center extraction, K/V partial)")
522
 
523
  else:
524
- # Shape mismatch - Xavier ์ดˆ๊ธฐํ™”๋กœ ๋Œ€์ฒด
525
- print(f" โš ๏ธ Layer {layer_idx}: Shape mismatch, using Xavier init")
526
- print(f" Q: {old_attn.q_proj.weight.shape} vs {target.q_proj.weight.shape}")
527
- print(f" K: {old_attn.k_proj.weight.shape} vs {target.k_proj.weight.shape}")
528
- print(f" V: {old_attn.v_proj.weight.shape} vs {target.v_proj.weight.shape}")
529
- print(f" O: {old_attn.o_proj.weight.shape} vs {target.o_proj.weight.shape}")
530
-
531
- # โœ… Xavier initialization (better than random)
532
  nn.init.xavier_uniform_(target.q_proj.weight)
533
  nn.init.xavier_uniform_(target.k_proj.weight)
534
  nn.init.xavier_uniform_(target.v_proj.weight)
535
  nn.init.xavier_uniform_(target.o_proj.weight)
 
536
 
537
  except Exception as e:
538
  print(f" โš ๏ธ Layer {layer_idx}: Weight copy failed - {e}")
539
- import traceback
540
- traceback.print_exc()
541
 
542
- # Replace
543
  layer.self_attn = new_retention
544
  replaced_count += 1
545
 
546
- print(f" โœ… Layer {layer_idx}: Attention โ†’ Retention (GQA)")
547
-
548
  except Exception as e:
549
  print(f" โŒ Layer {layer_idx}: Failed - {e}")
550
- import traceback
551
- traceback.print_exc()
552
  continue
553
 
554
  print(f"\nโœ… Conversion complete: {replaced_count}/{total_layers} layers")
@@ -556,28 +443,6 @@ def replace_attention_with_retention(model, use_hierarchical=True):
556
  return model, replaced_count, total_layers
557
 
558
 
559
- def estimate_conversion_time(model_size_mb, gpu_type="L40S"):
560
- """๋ณ€ํ™˜ ์‹œ๊ฐ„ ์˜ˆ์ธก"""
561
- gpu_specs = {
562
- "L40S": {"memory_gb": 48, "tflops_fp16": 362},
563
- "H100": {"memory_gb": 80, "tflops_fp16": 989}
564
- }
565
-
566
- spec = gpu_specs.get(gpu_type, gpu_specs["L40S"])
567
- base_time_seconds = 30
568
- scale_factor = model_size_mb / 1400
569
- performance_factor = 0.4 if gpu_type == "H100" else 1.0
570
- estimated_time = base_time_seconds * scale_factor * performance_factor
571
-
572
- return {
573
- 'gpu_type': gpu_type,
574
- 'estimated_seconds': estimated_time,
575
- 'estimated_minutes': estimated_time / 60,
576
- 'memory_required_gb': model_size_mb / 1024,
577
- 'max_memory_gb': spec['memory_gb']
578
- }
579
-
580
-
581
  # =====================================================
582
  # ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค
583
  # =====================================================
@@ -588,7 +453,6 @@ class ExperimentDatabase:
588
  def __init__(self, db_path: str):
589
  self.db_path = db_path
590
  self.init_database()
591
- self.migrate_database()
592
 
593
  def init_database(self):
594
  with sqlite3.connect(self.db_path) as conn:
@@ -610,26 +474,22 @@ class ExperimentDatabase:
610
  timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
611
  )
612
  """)
613
- conn.commit()
614
-
615
- def migrate_database(self):
616
- with sqlite3.connect(self.db_path) as conn:
617
- cursor = conn.cursor()
618
- cursor.execute("PRAGMA table_info(experiments)")
619
- columns = [col[1] for col in cursor.fetchall()]
620
 
621
- new_columns = [
622
- ('attention_replaced', 'BOOLEAN'),
623
- ('layers_converted', 'INTEGER'),
624
- ('total_layers', 'INTEGER')
625
- ]
626
-
627
- for col_name, col_type in new_columns:
628
- if col_name not in columns:
629
- try:
630
- cursor.execute(f"ALTER TABLE experiments ADD COLUMN {col_name} {col_type}")
631
- except:
632
- pass
 
 
 
633
  conn.commit()
634
 
635
  def save_experiment(self, config: Dict, metrics: Dict) -> int:
@@ -658,106 +518,406 @@ class ExperimentDatabase:
658
  conn.commit()
659
  return cursor.lastrowid
660
 
661
- def get_recent_experiments(self, limit: int = 20) -> List[Dict]:
662
  with sqlite3.connect(self.db_path) as conn:
663
- conn.row_factory = sqlite3.Row
664
  cursor = conn.cursor()
665
- cursor.execute("SELECT * FROM experiments ORDER BY timestamp DESC LIMIT ?", (limit,))
666
- return [dict(row) for row in cursor.fetchall()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
667
 
668
- def get_statistics(self) -> Dict:
669
  with sqlite3.connect(self.db_path) as conn:
 
670
  cursor = conn.cursor()
671
- cursor.execute("SELECT COUNT(*) FROM experiments")
672
- total = cursor.fetchone()[0]
673
-
674
- cursor.execute("SELECT model_type, COUNT(*) FROM experiments GROUP BY model_type")
675
- by_model = dict(cursor.fetchall())
676
-
677
- return {'total_experiments': total, 'by_model': by_model}
678
-
679
-
680
- class RetentionVectorStore:
681
- """ChromaDB vector store"""
682
-
683
- def __init__(self, persist_directory: str):
684
- try:
685
- self.client = chromadb.Client(Settings(
686
- persist_directory=persist_directory,
687
- anonymized_telemetry=False
688
- ))
689
- self.collection = self.client.get_or_create_collection(name="retention_states")
690
- except:
691
- self.client = None
692
- self.collection = None
693
 
694
 
695
  # =====================================================
696
- # ์œ ํ‹ธ๋ฆฌํ‹ฐ
697
  # =====================================================
698
 
699
- def calculate_metrics(output, states, config=None):
700
- """Calculate metrics"""
701
- metrics = {}
702
 
703
- if isinstance(output, torch.Tensor):
704
- metrics['memory_mb'] = (output.numel() * 4) / (1024 * 1024)
705
- else:
706
- metrics['memory_mb'] = 0
 
 
 
 
 
707
 
708
- if config:
709
- metrics['attention_replaced'] = config.get('attention_replaced', False)
710
- metrics['layers_converted'] = config.get('layers_converted', 0)
711
- metrics['total_layers'] = config.get('total_layers', 0)
712
 
713
- return metrics
714
-
715
-
716
- def plot_retention_states(states):
717
- """Plot retention states"""
718
- fig = go.Figure()
719
- fig.add_trace(go.Scatter(
720
- y=np.random.randn(100),
721
- mode='lines',
722
- name='Retention Pattern'
723
- ))
724
- fig.update_layout(title='Retention State Visualization', template='plotly_white')
725
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
726
 
727
 
728
- def plot_memory_usage(metrics):
729
- """Plot memory usage"""
730
- fig = go.Figure(go.Bar(
731
- x=['Memory (MB)', 'Layers', 'Rate %'],
732
- y=[
733
- metrics.get('memory_mb', 0),
734
- metrics.get('layers_converted', 0),
735
- (metrics.get('layers_converted', 0) / max(metrics.get('total_layers', 1), 1)) * 100
736
- ]
737
- ))
738
- fig.update_layout(title='Performance Metrics', template='plotly_white')
739
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
740
 
741
 
742
- # ์ „์—ญ ์ดˆ๊ธฐํ™”
743
- db = ExperimentDatabase(DB_PATH)
744
- vector_store = RetentionVectorStore(VECTOR_DB_PATH)
745
- CONVERTED_MODELS = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
746
 
747
 
748
  # =====================================================
749
- # Gradio Functions
750
  # =====================================================
751
 
752
  def convert_model_to_phoenix(model_url, use_hierarchical=True, gpu_type="L40S"):
753
- """Convert model to PHOENIX"""
754
- global CONVERTED_MODELS
755
-
756
  try:
757
- cache_key = f"{model_url}_{use_hierarchical}"
758
- if cache_key in CONVERTED_MODELS:
759
- return CONVERTED_MODELS[cache_key], "โœ… Using cached model"
760
-
761
  start_time = time.time()
762
 
763
  print(f"๐Ÿ“ฅ Loading model: {model_url}")
@@ -771,16 +931,6 @@ def convert_model_to_phoenix(model_url, use_hierarchical=True, gpu_type="L40S"):
771
  model, converted, total = replace_attention_with_retention(model, use_hierarchical)
772
 
773
  elapsed_time = time.time() - start_time
774
-
775
- model_info = {
776
- 'model': model,
777
- 'converted_layers': converted,
778
- 'total_layers': total,
779
- 'config': config,
780
- 'conversion_time': elapsed_time
781
- }
782
- CONVERTED_MODELS[cache_key] = model_info
783
-
784
  conversion_pct = (converted / total * 100) if total > 0 else 0
785
 
786
  result = f"""
@@ -788,435 +938,253 @@ def convert_model_to_phoenix(model_url, use_hierarchical=True, gpu_type="L40S"):
788
 
789
  **Model**: {model_url}
790
  **Converted**: {converted}/{total} layers ({conversion_pct:.1f}%)
791
- **Time**: {elapsed_time:.1f}s ({elapsed_time/60:.2f}min)
792
  **GPU**: {gpu_type}
793
 
794
  ๐ŸŽฏ GQA-aware O(n) complexity!
795
  """
796
 
797
- return model_info, result
798
 
799
  except Exception as e:
800
- return None, f"โŒ Conversion failed: {str(e)}"
801
 
802
 
803
  def generate_text_phoenix(
804
  model_url, use_hierarchical, convert_attention,
805
  prompt, max_new_tokens, temperature
806
  ):
807
- """PHOENIX๋กœ ํ…์ŠคํŠธ ์ƒ์„ฑ"""
808
  try:
809
  if not convert_attention or not model_url.strip():
810
  return "โš ๏ธ Enable 'Attention Replace' and provide model URL", ""
811
 
812
- # 1. โœ… CausalLM ๋ชจ๋ธ ๋กœ๋“œ (lm_head ํฌํ•จ)
813
- print(f"๐Ÿ“ฅ Loading CausalLM model: {model_url}")
814
- config = AutoConfig.from_pretrained(model_url, trust_remote_code=True)
815
-
816
- # Load full causal LM model
817
  model = AutoModelForCausalLM.from_pretrained(
818
  model_url,
819
  trust_remote_code=True,
820
  torch_dtype=torch.float16
821
  ).to(DEVICE)
822
 
823
- # 2. Attention โ†’ Retention ๋ณ€ํ™˜
824
- print(f"๐Ÿ”„ Converting attention to retention...")
825
  model.model, converted, total = replace_attention_with_retention(
826
- model.model, # Convert the base model, keep lm_head
827
  use_hierarchical=use_hierarchical
828
  )
829
 
830
- print(f"โœ… Converted {converted}/{total} layers")
831
-
832
- # โœ… Reset all retention states before generation
833
- print(f"๐Ÿ”„ Resetting retention states...")
834
- for layer in model.model.layers:
835
- if hasattr(layer, 'self_attn') and hasattr(layer.self_attn, 'reset_state'):
836
- layer.self_attn.reset_state()
837
- elif hasattr(layer, 'self_attn') and hasattr(layer.self_attn, 'base_retention'):
838
- if hasattr(layer.self_attn.base_retention, 'reset_state'):
839
- layer.self_attn.base_retention.reset_state()
840
 
841
- # 3. Tokenizer ๋กœ๋“œ
842
- try:
843
- tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
844
- if tokenizer.pad_token is None:
845
- tokenizer.pad_token = tokenizer.eos_token
846
- except Exception as e:
847
- return f"โŒ Tokenizer load failed: {e}", ""
848
-
849
- # 4. ์ž…๋ ฅ ํ† ํฌ๋‚˜์ด์ฆˆ
850
  inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
851
- input_ids = inputs["input_ids"]
852
-
853
- print(f"\n๐Ÿ“ Generating text...")
854
- print(f" Prompt: {prompt}")
855
- print(f" Input tokens: {input_ids.shape[1]}")
856
- print(f" Max new tokens: {max_new_tokens}")
857
 
858
- # 5. ์ƒ์„ฑ (โœ… KV Cache ์‹œ๋„, ์‹คํŒจ์‹œ Full Sequence)
859
  start_time = time.time()
860
- generated_ids = []
861
-
862
- model.eval() # โœ… Set to eval mode
863
-
864
- # โœ… KV Cache ์ดˆ๊ธฐํ™”
865
- past_key_values = None
866
- current_input_ids = input_ids
867
- use_kv_cache = True # KV Cache ์‚ฌ์šฉ ์‹œ๋„
868
-
869
- print(f" ๐Ÿš€ Attempting KV Cache generation...")
870
-
871
- with torch.no_grad():
872
- for step in range(max_new_tokens):
873
- try:
874
- # โœ… KV Cache ๋ชจ๋“œ ์‹œ๋„
875
- if use_kv_cache:
876
- if past_key_values is None:
877
- # ์ฒซ forward: ์ „์ฒด ํ”„๋กฌํ”„ํŠธ ์ฒ˜๋ฆฌ
878
- outputs = model(
879
- input_ids=current_input_ids,
880
- use_cache=True
881
- )
882
-
883
- # โœ… past_key_values ํ™•์ธ
884
- if hasattr(outputs, 'past_key_values') and outputs.past_key_values is not None:
885
- # KV Cache๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ
886
- if isinstance(outputs.past_key_values, (tuple, list)) and len(outputs.past_key_values) > 0:
887
- # ๊ฐ ๋ ˆ์ด์–ด์˜ state ํ™•์ธ
888
- valid_cache = True
889
- for layer_cache in outputs.past_key_values:
890
- if layer_cache is None or (isinstance(layer_cache, (tuple, list)) and layer_cache[0] is None):
891
- valid_cache = False
892
- break
893
-
894
- if valid_cache:
895
- past_key_values = outputs.past_key_values
896
- print(f" โœ… KV Cache enabled (prompt tokens: {current_input_ids.shape[1]})")
897
- else:
898
- use_kv_cache = False
899
- print(f" โš ๏ธ Invalid cache structure, switching to full sequence mode")
900
- else:
901
- use_kv_cache = False
902
- print(f" โš ๏ธ Empty cache, switching to full sequence mode")
903
- else:
904
- use_kv_cache = False
905
- print(f" โ„น๏ธ No past_key_values support, using full sequence mode")
906
-
907
- else:
908
- # ์ดํ›„ forward: ์ƒˆ ํ† ํฐ๋งŒ ์ฒ˜๋ฆฌ (โšก ๋น ๋ฆ„!)
909
- outputs = model(
910
- input_ids=current_input_ids[:, -1:], # โœ… ๋งˆ์ง€๋ง‰ ํ† ํฐ๋งŒ
911
- past_key_values=past_key_values, # โœ… ์ด์ „ state ์žฌ์‚ฌ์šฉ
912
- use_cache=True
913
- )
914
-
915
- # โœ… State ์—…๋ฐ์ดํŠธ
916
- if hasattr(outputs, 'past_key_values') and outputs.past_key_values is not None:
917
- past_key_values = outputs.past_key_values
918
-
919
- # โœ… Full Sequence ๋ชจ๋“œ (KV Cache ์—†์ด)
920
- if not use_kv_cache:
921
- outputs = model(
922
- input_ids=current_input_ids, # ์ „์ฒด ์‹œํ€€์Šค ์ฒ˜๋ฆฌ
923
- use_cache=False
924
- )
925
-
926
- # โœ… Get logits - handle different output formats
927
- if hasattr(outputs, 'logits'):
928
- logits = outputs.logits[:, -1, :] # [B, vocab_size]
929
- elif isinstance(outputs, tuple):
930
- # Some models return (logits, ) or (logits, hidden_states, ...)
931
- logits = outputs[0][:, -1, :]
932
- else:
933
- raise ValueError(f"Unexpected output type: {type(outputs)}")
934
-
935
- # โœ… ๋””๋ฒ„๊น…: logits ํ™•์ธ
936
- if step == 0:
937
- print(f" ๐Ÿ“Š Output type: {type(outputs)}")
938
- print(f" ๐Ÿ“Š Logits shape: {logits.shape}")
939
- print(f" ๐Ÿ“Š Logits range: [{logits.min().item():.2f}, {logits.max().item():.2f}]")
940
- print(f" ๐Ÿ“Š Logits mean: {logits.mean().item():.2f}, std: {logits.std().item():.2f}")
941
-
942
- # โœ… Clamp logits to prevent numerical issues
943
- logits = torch.clamp(logits, min=-100, max=100)
944
-
945
- # Temperature sampling
946
- if temperature > 0.01:
947
- logits = logits / temperature
948
- probs = F.softmax(logits, dim=-1)
949
-
950
- # โœ… Check for NaN/Inf
951
- if torch.isnan(probs).any() or torch.isinf(probs).any():
952
- print(f" โš ๏ธ NaN/Inf detected at step {step}, using greedy")
953
- next_token = logits.argmax(dim=-1, keepdim=True)
954
- else:
955
- # โœ… Add small epsilon to avoid zero probabilities
956
- probs = probs + 1e-10
957
- probs = probs / probs.sum(dim=-1, keepdim=True)
958
-
959
- # โœ… ๋””๋ฒ„๊น…: Top-5 tokens
960
- if step == 0:
961
- top5_probs, top5_indices = torch.topk(probs, 5, dim=-1)
962
- print(f" ๐ŸŽฏ Top 5 tokens:")
963
- for i, (prob, idx) in enumerate(zip(top5_probs[0], top5_indices[0])):
964
- token_str = tokenizer.decode([idx.item()])
965
- print(f" {i+1}. '{token_str}' (prob: {prob.item():.4f})")
966
-
967
- next_token = torch.multinomial(probs, num_samples=1)
968
- else:
969
- next_token = logits.argmax(dim=-1, keepdim=True)
970
-
971
- next_token_id = next_token.item()
972
-
973
- # โœ… ๋””๋ฒ„๊น…: ์ƒ์„ฑ๋œ ํ† ํฐ ์ •๋ณด
974
- if step < 3 or (step + 1) % 10 == 0:
975
- token_str = tokenizer.decode([next_token_id])
976
- print(f" ๐Ÿ”ค Step {step}: Generated token #{next_token_id} = '{token_str}'")
977
-
978
- # โœ… Validate token range
979
- if next_token_id < 0 or next_token_id >= model.config.vocab_size:
980
- print(f" โš ๏ธ Invalid token {next_token_id}, stopping")
981
- break
982
-
983
- # Append
984
- generated_ids.append(next_token_id)
985
- current_input_ids = torch.cat([current_input_ids, next_token], dim=1)
986
-
987
- # โœ… Limit max sequence length
988
- if current_input_ids.shape[1] > 2048:
989
- print(f" โš ๏ธ Max sequence length reached, stopping")
990
- break
991
-
992
- # Stop at EOS
993
- if next_token_id == tokenizer.eos_token_id:
994
- print(f" โœ… Stopped at EOS token")
995
- break
996
-
997
- # Progress
998
- if (step + 1) % 10 == 0:
999
- speed = (step + 1) / (time.time() - start_time)
1000
- print(f" Generated {step + 1}/{max_new_tokens} tokens... ({speed:.1f} tok/s)")
1001
-
1002
- except RuntimeError as e:
1003
- print(f" โŒ Runtime error at step {step}: {e}")
1004
- if "CUDA" in str(e):
1005
- print(f" Stopping generation due to CUDA error")
1006
- import traceback
1007
- traceback.print_exc()
1008
- break
1009
- except Exception as e:
1010
- print(f" โŒ Error at step {step}: {e}")
1011
- print(f" Error type: {type(e).__name__}")
1012
- import traceback
1013
- traceback.print_exc()
1014
- break
1015
 
1016
  elapsed = time.time() - start_time
1017
 
1018
- # 6. ๋””์ฝ”๋“œ
1019
- if len(generated_ids) == 0:
1020
- generated_text = "[No tokens generated]"
1021
- full_text = prompt
1022
- else:
1023
- try:
1024
- generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
1025
- full_text = prompt + " " + generated_text
1026
- except Exception as e:
1027
- generated_text = f"[Decode error: {e}]"
1028
- full_text = prompt
1029
 
1030
- # 7. ๊ฒฐ๊ณผ
1031
  output_md = f"""
1032
  ## ๐Ÿ“ Generated Text
1033
 
1034
- **Prompt**:
1035
  ```
1036
- {prompt}
1037
- ```
1038
-
1039
- **Generated** ({len(generated_ids)} tokens):
1040
- ```
1041
- {generated_text}
1042
- ```
1043
-
1044
- **Full Text**:
1045
- ```
1046
- {full_text}
1047
  ```
1048
  """
1049
 
1050
- initial_tokens = input_ids.shape[1]
1051
- total_tokens = current_input_ids.shape[1]
1052
  stats_md = f"""
1053
- ## ๐Ÿ“Š Generation Statistics
1054
 
1055
- ### Performance
1056
- - **Input tokens**: {initial_tokens}
1057
- - **Generated tokens**: {len(generated_ids)}
1058
- - **Total tokens**: {total_tokens}
1059
  - **Time**: {elapsed:.2f}s
1060
- - **Speed**: {len(generated_ids) / max(elapsed, 0.01):.1f} tokens/s โšก
1061
-
1062
- ### Model
1063
- - **Architecture**: PHOENIX Retention (O(n))
1064
- - **KV Cache**: {'โœ… Enabled' if past_key_values is not None else 'โš ๏ธ Disabled'}
1065
- - **Temperature**: {temperature}
1066
- - **Vocab size**: {model.config.vocab_size}
1067
-
1068
- ### Efficiency
1069
- - **First token latency**: ~{elapsed / max(len(generated_ids), 1):.3f}s per token
1070
- - **Cache benefit**: ~10-20x speedup vs no cache
1071
- - **Memory**: O(dยฒ) constant per layer
1072
  """
1073
 
1074
  return output_md, stats_md
1075
 
1076
  except Exception as e:
1077
  import traceback
1078
- return f"โŒ Generation failed:\n```\n{traceback.format_exc()}\n```", ""
1079
 
1080
 
1081
- def run_phoenix_experiment(model_url, use_hierarchical, convert_attention, sequence_length, gpu_type):
1082
- """Run PHOENIX experiment"""
 
 
 
 
 
 
 
 
 
 
 
 
1083
  try:
1084
- if not convert_attention or not model_url.strip():
1085
- return "โš ๏ธ Enable 'Attention Replace' and provide model URL", None, None
1086
-
1087
- model_info, msg = convert_model_to_phoenix(model_url, use_hierarchical, gpu_type)
1088
-
1089
- if model_info is None:
1090
- return msg, None, None
1091
-
1092
- model = model_info['model']
1093
- converted_layers = model_info['converted_layers']
1094
- total_layers = model_info['total_layers']
1095
-
1096
- config = {
1097
- 'model_type': f"phoenix_{model_url.split('/')[-1]}",
1098
- 'model_url': model_url,
1099
- 'sequence_length': sequence_length,
1100
- 'use_hierarchical': use_hierarchical,
1101
- 'attention_replaced': convert_attention,
1102
- 'layers_converted': converted_layers,
1103
- 'total_layers': total_layers,
1104
- 'gpu_type': gpu_type,
1105
- 'timestamp': datetime.now().isoformat()
1106
- }
1107
-
1108
- # Generate input
1109
- hidden_size = model.config.hidden_size
1110
- x = torch.randn(1, sequence_length, hidden_size).to(DEVICE).half()
1111
-
1112
- # Forward pass
1113
- torch.cuda.synchronize()
1114
- start = time.time()
1115
-
1116
- with torch.no_grad():
1117
- output = model(inputs_embeds=x)
1118
-
1119
- torch.cuda.synchronize()
1120
- elapsed = time.time() - start
1121
-
1122
- # Metrics
1123
- metrics = calculate_metrics(output.last_hidden_state, {}, config)
1124
- metrics['elapsed_time'] = elapsed
1125
- metrics['throughput'] = sequence_length / elapsed
1126
-
1127
- # Save
1128
- exp_id = db.save_experiment(config, metrics)
1129
- conversion_rate = (converted_layers / total_layers * 100) if total_layers > 0 else 0
1130
-
1131
- # Result text
1132
- result = (
1133
- f"## ๐ŸŽฏ PHOENIX Experiment Results (ID: {exp_id})\n\n"
1134
- f"### โš™๏ธ Configuration\n"
1135
- f"- **Model**: {model_url}\n"
1136
- f"- **Sequence Length**: {sequence_length} tokens\n"
1137
- f"- **Hidden Size**: {hidden_size}\n"
1138
- f"- **Hierarchical**: {'โœ…' if use_hierarchical else 'โŒ'}\n"
1139
- f"- **Converted Layers**: {converted_layers}/{total_layers} ({conversion_rate:.1f}%)\n\n"
1140
- f"### ๐Ÿ“Š Performance\n"
1141
- f"- **Time**: {elapsed:.3f}s\n"
1142
- f"- **Throughput**: {metrics['throughput']:.1f} tokens/s\n"
1143
- f"- **Memory**: {metrics['memory_mb']:.1f} MB\n\n"
1144
- f"### ๐Ÿ”ฅ Complexity Analysis\n"
1145
- f"- **Theoretical**: O(n) โœ…\n"
1146
- f"- **Linear Complexity**: {'โœ… YES!' if converted_layers == total_layers else 'โš ๏ธ Partial'}\n\n"
1147
- f"โœ… **Real PHOENIX with GQA Support!**\n"
1148
- )
1149
-
1150
- fig1 = plot_retention_states({})
1151
- fig2 = plot_memory_usage(metrics)
1152
-
1153
- return result, fig1, fig2
1154
 
1155
- except Exception as e:
1156
- import traceback
1157
- return f"โŒ Experiment failed:\n```\n{traceback.format_exc()}\n```", None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1158
 
 
 
 
 
1159
 
1160
- def estimate_conversion_ui(model_url, gpu_type):
1161
- """Estimate conversion time"""
1162
- estimate = estimate_conversion_time(1400, gpu_type)
1163
- return f"""
1164
- ## โฑ๏ธ Conversion Time Estimate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1165
 
1166
- ### GPU: {gpu_type}
1167
- - **Time**: {estimate['estimated_minutes']:.1f}min
1168
- - **Memory**: {estimate['memory_required_gb']:.1f} GB / {estimate['max_memory_gb']} GB
 
 
 
 
1169
 
1170
- ### Notes
1171
- - Conversion is cached after first run
1172
- - GQA models supported
1173
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1174
 
1175
 
1176
- def view_experiment_history(limit=20):
1177
- """View experiment history"""
1178
  try:
1179
- experiments = db.get_recent_experiments(limit)
1180
 
1181
- if not experiments:
1182
- return "๐Ÿ“ญ No experiments yet", None
1183
 
1184
- df = pd.DataFrame(experiments)
1185
 
1186
  fig = px.scatter(
1187
- df, x='timestamp', y='throughput',
1188
- size='sequence_length', color='attention_replaced',
1189
- title='Experiment Performance'
 
 
 
 
1190
  )
1191
 
1192
- cols = ['id', 'model_type', 'sequence_length', 'layers_converted',
1193
- 'elapsed_time', 'throughput', 'timestamp']
1194
  available = [c for c in cols if c in df.columns]
1195
 
1196
- return f"## ๐Ÿ“Š Experiment History\n\n{df[available].to_markdown(index=False)}", fig
1197
 
1198
  except Exception as e:
1199
  return f"โŒ Error: {e}", None
1200
 
1201
 
1202
- def get_database_statistics():
1203
- """Get database stats"""
1204
- try:
1205
- stats = db.get_statistics()
1206
-
1207
- text = f"""
1208
- ## ๐Ÿ“Š Database Statistics
1209
-
1210
- **Total Experiments**: {stats['total_experiments']}
1211
-
1212
- ### By Model
1213
- """
1214
- for model, count in stats['by_model'].items():
1215
- text += f"- **{model}**: {count}\n"
1216
-
1217
- return text
1218
- except Exception as e:
1219
- return f"โŒ Error: {e}"
1220
 
1221
 
1222
  # =====================================================
@@ -1224,26 +1192,30 @@ def get_database_statistics():
1224
  # =====================================================
1225
 
1226
  with gr.Blocks(
1227
- title="๐Ÿ”ฎ PHOENIX - GQA Support",
1228
  theme=gr.themes.Soft(),
1229
  ) as demo:
1230
 
1231
  gr.Markdown("""
1232
  # ๐Ÿ”ฎ PHOENIX Retention Platform
1233
 
1234
- **Real O(n) Complexity with GQA Support - Final Version**
1235
 
1236
- โœ… Supports Grouped Query Attention (GQA)
1237
- โœ… Adaptive K/V projection dimensions
1238
- โœ… Full Attention โ†’ Retention replacement
1239
- โœ… KV Cache with State Reuse
1240
- โœ… Robust Error Handling
1241
 
1242
  ---
1243
  """)
1244
 
1245
  with gr.Tabs():
1246
- with gr.Tab("๐Ÿ”„ Model Conversion"):
 
 
 
 
 
1247
  with gr.Row():
1248
  with gr.Column(scale=1):
1249
  convert_url = gr.Textbox(
@@ -1253,24 +1225,87 @@ with gr.Blocks(
1253
  )
1254
  convert_hierarchical = gr.Checkbox(value=True, label="Hierarchical Retention")
1255
  convert_gpu = gr.Radio(choices=["L40S", "H100"], value="L40S", label="GPU")
1256
-
1257
- estimate_btn = gr.Button("โฑ๏ธ Estimate Time", variant="secondary")
1258
  convert_btn = gr.Button("๐Ÿ”„ Convert", variant="primary")
1259
 
1260
  with gr.Column(scale=2):
1261
  convert_output = gr.Markdown()
1262
 
1263
- estimate_btn.click(estimate_conversion_ui, [convert_url, convert_gpu], [convert_output])
1264
- convert_btn.click(convert_model_to_phoenix,
1265
- [convert_url, convert_hierarchical, convert_gpu],
1266
- [gr.State(), convert_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1267
 
1268
  with gr.Tab("๐Ÿ’ฌ Text Generation"):
1269
  gr.Markdown("""
1270
  ### PHOENIX ํ…์ŠคํŠธ ์ƒ์„ฑ
1271
-
1272
- ๋ณ€ํ™˜๋œ ๋ชจ๋ธ๋กœ ์‹ค์ œ ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
1273
- **KV Cache๋ฅผ ํ™œ์šฉํ•œ O(n) ๋ณต์žก๋„ ์ƒ์„ฑ!**
1274
  """)
1275
 
1276
  with gr.Row():
@@ -1280,78 +1315,59 @@ with gr.Blocks(
1280
  gen_convert = gr.Checkbox(value=True, label="Enable Conversion")
1281
 
1282
  gen_prompt = gr.Textbox(
1283
- label="๐Ÿ“ Input Prompt",
1284
- placeholder="Enter your prompt here...",
1285
  lines=3,
1286
  value="The future of AI is"
1287
  )
1288
 
1289
- gen_max_tokens = gr.Slider(16, 256, 64, step=16, label="Max New Tokens")
1290
  gen_temperature = gr.Slider(0.1, 2.0, 0.7, step=0.1, label="Temperature")
1291
 
1292
- gen_btn = gr.Button("๐Ÿš€ Generate Text", variant="primary")
1293
 
1294
  with gr.Column(scale=2):
1295
- gen_output = gr.Markdown(label="Generated Text")
1296
- gen_stats = gr.Markdown(label="Statistics")
1297
 
1298
  gen_btn.click(
1299
- fn=generate_text_phoenix,
1300
- inputs=[gen_model_url, gen_hierarchical, gen_convert, gen_prompt,
1301
- gen_max_tokens, gen_temperature],
1302
- outputs=[gen_output, gen_stats]
1303
  )
1304
 
1305
- with gr.Tab("๐Ÿงช Experiment"):
1306
- with gr.Row():
1307
- with gr.Column(scale=1):
1308
- exp_url = gr.Textbox(label="๐Ÿ”— Model URL", value=DEFAULT_MODEL)
1309
- exp_hierarchical = gr.Checkbox(value=True, label="Hierarchical")
1310
- exp_convert = gr.Checkbox(value=True, label="Enable Conversion")
1311
- exp_seq = gr.Slider(64, 4096, 1024, step=64, label="Sequence Length")
1312
- exp_gpu = gr.Radio(choices=["L40S", "H100"], value="L40S", label="GPU")
1313
-
1314
- run_btn = gr.Button("๐Ÿš€ Run Experiment", variant="primary")
1315
-
1316
- with gr.Column(scale=2):
1317
- exp_output = gr.Markdown()
1318
- with gr.Row():
1319
- exp_fig1 = gr.Plot()
1320
- exp_fig2 = gr.Plot()
1321
 
1322
- run_btn.click(run_phoenix_experiment,
1323
- [exp_url, exp_hierarchical, exp_convert, exp_seq, exp_gpu],
1324
- [exp_output, exp_fig1, exp_fig2])
1325
-
1326
- with gr.Tab("๐Ÿ“Š History"):
1327
  with gr.Row():
1328
  with gr.Column(scale=1):
1329
- hist_limit = gr.Slider(10, 100, 20, step=10, label="Limit")
1330
- hist_btn = gr.Button("๐Ÿ“Š View History", variant="primary")
1331
- stats_btn = gr.Button("๐Ÿ“ˆ Statistics", variant="secondary")
1332
 
1333
  with gr.Column(scale=2):
1334
  hist_output = gr.Markdown()
1335
  hist_plot = gr.Plot()
1336
 
1337
- hist_btn.click(view_experiment_history, [hist_limit], [hist_output, hist_plot])
1338
- stats_btn.click(get_database_statistics, outputs=[hist_output])
1339
 
1340
  gr.Markdown("""
1341
  ---
1342
 
1343
- ## ๐Ÿ”ฅ PHOENIX + GQA (Final Version)
1344
 
1345
- **Grouped Query Attention** support means PHOENIX now works with modern efficient architectures!
 
 
 
1346
 
1347
- - โœ… Llama 2/3 (GQA)
1348
- - โœ… Mistral (GQA)
1349
- - โœ… Granite 4.0 H (GQA)
1350
- - โœ… Traditional MHA models
1351
- - โœ… KV Cache with State Reuse
1352
- - โœ… Robust Error Handling
1353
 
1354
- **VIDraft AI Research Lab** | PHOENIX GQA Implementation (Final)
1355
  """)
1356
 
1357
  if __name__ == "__main__":
 
1
  """
2
+ ๐Ÿ”ฎ PHOENIX Retention Research Platform - FINAL INTEGRATED VERSION
3
+ Zero-shot Model Burning + Optional Fine-tuning
4
 
5
+ โœ… Zero-shot Conversion (No Dataset Required)
6
+ โœ… Optional Fine-tuning (Dataset-based)
7
+ โœ… GQA Support
8
+ โœ… HuggingFace Hub Integration
9
+ โœ… Comprehensive Evaluation
10
 
11
  VIDraft AI Research Lab
12
  """
 
27
  from typing import Dict, List, Any, Tuple, Optional
28
  import chromadb
29
  from chromadb.config import Settings
30
+ from transformers import (
31
+ AutoModel, AutoTokenizer, AutoConfig, AutoModelForCausalLM,
32
+ get_cosine_schedule_with_warmup, TrainingArguments, Trainer
33
+ )
34
+ from datasets import load_dataset
35
+ from torch.utils.data import Dataset, DataLoader
36
+ from accelerate import Accelerator
37
+ from tqdm import tqdm
38
  import copy
39
+ import shutil
40
 
41
  # =====================================================
42
  # ์ „์—ญ ์„ค์ •
 
46
  STORAGE_PATH = "/data"
47
  DB_PATH = f"{STORAGE_PATH}/phoenix_experiments.db"
48
  VECTOR_DB_PATH = f"{STORAGE_PATH}/vector_store"
49
+ MODELS_PATH = f"{STORAGE_PATH}/phoenix_models"
50
  DEFAULT_MODEL = "ibm-granite/granite-4.0-h-350m"
51
 
52
  Path(STORAGE_PATH).mkdir(parents=True, exist_ok=True)
53
  Path(VECTOR_DB_PATH).mkdir(parents=True, exist_ok=True)
54
+ Path(MODELS_PATH).mkdir(parents=True, exist_ok=True)
55
 
56
  print(f"๐Ÿš€ PHOENIX Platform initialized on {DEVICE}")
57
  print(f"๐Ÿ’พ Storage: {STORAGE_PATH}")
 
62
  # =====================================================
63
 
64
  class MultiScaleRetention(nn.Module):
65
+ """์ง„์งœ Retention Attention with GQA Support"""
 
 
 
 
 
 
66
 
67
  def __init__(self, config, layer_idx=0):
68
  super().__init__()
 
81
  self.num_key_value_heads = self.num_heads
82
 
83
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
84
+ self.kv_head_dim = self.head_dim
85
  self.kv_dim = self.num_key_value_heads * self.kv_head_dim
86
 
87
+ # Internal state storage for KV cache simulation
88
  self.register_buffer('_internal_state', None, persistent=False)
89
  self.register_buffer('_state_initialized', torch.tensor(False), persistent=False)
90
 
91
+ # Projections with correct dimensions
 
 
 
 
 
 
 
 
 
 
 
92
  self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
93
+ self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
94
+ self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
95
  self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
96
 
97
  # Retention parameters
98
+ decay_values = torch.linspace(0.95, 0.99, self.num_heads)
99
  self.decay = nn.Parameter(decay_values, requires_grad=True)
100
 
101
  # Group norm
 
105
  )
106
 
107
  def _repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
108
+ """Repeat K/V heads to match Q heads (GQA)"""
 
 
 
109
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
110
  if n_rep == 1:
111
  return hidden_states
 
116
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
117
 
118
  def reset_state(self):
119
+ """Reset internal state"""
120
  self._internal_state = None
121
  self._state_initialized = torch.tensor(False)
122
 
 
132
  past_key_values: Optional[Tuple[torch.Tensor]] = None,
133
  **kwargs
134
  ):
135
+ """O(n) Retention with GQA support"""
 
 
136
  batch_size, seq_len, _ = hidden_states.shape
137
 
138
  if past_key_values is not None:
139
  past_key_value = past_key_values
140
 
141
  # Q, K, V projections
142
+ query_states = self.q_proj(hidden_states)
143
+ key_states = self.k_proj(hidden_states)
144
+ value_states = self.v_proj(hidden_states)
145
 
146
+ # Reshape
147
  query_states = query_states.view(
148
  batch_size, seq_len, self.num_heads, self.head_dim
149
  ).transpose(1, 2)
150
 
 
151
  key_states = key_states.view(
152
  batch_size, seq_len, self.num_key_value_heads, self.kv_head_dim
153
  ).transpose(1, 2)
 
156
  batch_size, seq_len, self.num_key_value_heads, self.kv_head_dim
157
  ).transpose(1, 2)
158
 
159
+ # Repeat K/V to match Q heads (GQA)
160
  key_states = self._repeat_kv(key_states, self.num_key_value_groups)
161
  value_states = self._repeat_kv(value_states, self.num_key_value_groups)
162
 
163
+ # Retention computation
 
 
164
  past_state = self._internal_state if (use_cache and self._state_initialized) else None
165
  retention_states, new_state = self._compute_retention(
166
  query_states, key_states, value_states, past_state
167
  )
168
 
169
+ # Store state internally
170
  if use_cache:
171
  self._internal_state = new_state.detach()
172
  self._state_initialized = torch.tensor(True)
173
 
174
+ # Reshape back
175
  retention_states = retention_states.transpose(1, 2).contiguous()
176
  retention_states = retention_states.reshape(
177
  batch_size, seq_len, self.hidden_size
178
  )
179
 
180
+ # Group norm
181
  if not next(self.group_norm.parameters()).is_cuda and retention_states.is_cuda:
182
  self.group_norm = self.group_norm.to(retention_states.device, dtype=retention_states.dtype)
183
  elif next(self.group_norm.parameters()).dtype != retention_states.dtype:
 
187
  retention_states.transpose(1, 2)
188
  ).transpose(1, 2)
189
 
 
190
  retention_states = torch.clamp(retention_states, min=-10.0, max=10.0)
191
 
192
  # Output projection
193
  attn_output = self.o_proj(retention_states)
194
 
 
 
 
 
195
  return (attn_output, None)
196
 
197
  def _compute_retention(
198
  self,
199
+ queries: torch.Tensor,
200
+ keys: torch.Tensor,
201
+ values: torch.Tensor,
202
  past_state: Optional[torch.Tensor] = None
203
  ):
204
+ """O(n) Retention computation"""
 
 
 
 
 
 
 
 
 
205
  batch_size, num_heads, seq_len, head_dim = queries.shape
206
 
 
207
  if past_state is not None:
208
  state = past_state.to(queries.device, dtype=queries.dtype)
209
  else:
 
210
  state = torch.zeros(
211
  batch_size, num_heads, head_dim, head_dim,
212
  dtype=queries.dtype,
213
  device=queries.device
214
+ ) + 1e-6
215
 
216
  outputs = []
217
 
 
218
  decay = torch.sigmoid(self.decay).view(1, -1, 1, 1).to(
219
  device=queries.device,
220
  dtype=queries.dtype
221
  )
222
 
 
223
  for t in range(seq_len):
224
+ q_t = queries[:, :, t, :]
225
+ k_t = keys[:, :, t, :]
226
+ v_t = values[:, :, t, :]
227
 
 
228
  state = decay * state
 
 
229
  kv_update = torch.einsum('bhd,bhe->bhde', k_t, v_t)
 
 
230
  kv_update = torch.clamp(kv_update, min=-5.0, max=5.0)
 
231
  state = state + kv_update
 
 
232
  state = torch.clamp(state, min=-10.0, max=10.0)
233
 
 
234
  output_t = torch.einsum('bhd,bhde->bhe', q_t, state)
235
  outputs.append(output_t)
236
 
237
+ output = torch.stack(outputs, dim=2)
238
 
 
239
  return output, state
240
 
241
 
242
  class HierarchicalRetention(nn.Module):
243
+ """PHOENIX Hierarchical Retention with GQA"""
 
 
244
 
245
  def __init__(self, config, layer_idx=0):
246
  super().__init__()
 
249
  hidden_size = config.hidden_size
250
  self.d_state = hidden_size // 2
251
 
 
252
  self.short_proj = nn.Linear(hidden_size, self.d_state)
253
  self.medium_proj = nn.Linear(self.d_state, self.d_state)
254
  self.long_proj = nn.Linear(self.d_state, self.d_state * 2)
255
  self.fusion = nn.Linear(self.d_state * 4, hidden_size)
256
 
 
257
  self.short_decay = 0.5
258
  self.medium_decay = 0.8
259
  self.long_decay = 0.95
260
 
 
261
  self.norm = nn.LayerNorm(hidden_size)
262
 
 
263
  if next(self.base_retention.parameters()).is_cuda:
264
  device = next(self.base_retention.parameters()).device
265
  dtype = next(self.base_retention.parameters()).dtype
 
287
  if past_key_values is not None:
288
  past_key_value = past_key_values
289
 
 
290
  target_device = hidden_states.device
291
  target_dtype = hidden_states.dtype
292
 
 
303
  self.fusion = self.fusion.to(dtype=target_dtype)
304
  self.norm = self.norm.to(dtype=target_dtype)
305
 
 
306
  base_result = self.base_retention(
307
  hidden_states, attention_mask, position_ids,
308
  past_key_value, output_attentions, use_cache
309
  )
310
 
311
  retention_output = base_result[0]
 
312
 
313
  # Hierarchical states
314
  short_state = torch.zeros(batch_size, self.d_state, dtype=hidden_states.dtype, device=target_device)
 
320
  for t in range(seq_len):
321
  x_t = retention_output[:, t, :]
322
 
 
323
  short_input = self.short_proj(x_t)
324
  short_state = self.short_decay * short_state + short_input
325
 
 
326
  if t % 8 == 0:
327
  medium_state = self.medium_decay * medium_state + \
328
  self.medium_proj(short_state)
329
 
 
330
  if t % 64 == 0:
331
  long_state = self.long_decay * long_state + \
332
  self.long_proj(medium_state)
333
 
 
334
  combined = torch.cat([short_state, medium_state, long_state], dim=-1)
335
  output_t = self.fusion(combined)
336
  hierarchical_outputs.append(output_t)
 
338
  output = torch.stack(hierarchical_outputs, dim=1)
339
  output = self.norm(output)
340
 
 
 
341
  return (output, None)
342
 
343
 
 
346
  # =====================================================
347
 
348
  def replace_attention_with_retention(model, use_hierarchical=True):
349
+ """Transformer Attention โ†’ PHOENIX Retention (GQA Support)"""
 
 
350
  print("๐Ÿ”„ Starting Attention โ†’ Retention conversion (GQA support)...")
351
 
352
  replaced_count = 0
353
  total_layers = 0
354
 
 
355
  if hasattr(model, 'transformer'):
356
  layers = model.transformer.h
357
  elif hasattr(model, 'model') and hasattr(model.model, 'layers'):
 
364
 
365
  total_layers = len(layers)
366
 
367
+ # Check first layer for GQA
368
  first_layer = layers[0]
369
  if hasattr(first_layer, 'self_attn'):
370
  old_attn = first_layer.self_attn
371
 
 
372
  if hasattr(old_attn, 'q_proj'):
373
  q_shape = old_attn.q_proj.weight.shape
374
  k_shape = old_attn.k_proj.weight.shape
 
 
 
 
 
375
 
376
  if k_shape[0] != q_shape[0]:
377
  print(f" โœ… GQA detected! (K/V dim: {k_shape[0]} < Q dim: {q_shape[0]})")
 
378
  if not hasattr(model.config, 'num_key_value_heads'):
379
  num_kv_heads = k_shape[0] // (model.config.hidden_size // model.config.num_attention_heads)
380
  model.config.num_key_value_heads = num_kv_heads
 
381
 
382
  for layer_idx, layer in enumerate(layers):
383
  try:
384
  if hasattr(layer, 'self_attn'):
385
  old_attn = layer.self_attn
386
 
 
387
  if use_hierarchical:
388
  new_retention = HierarchicalRetention(model.config, layer_idx)
389
  else:
 
397
  else:
398
  target = new_retention
399
 
 
400
  q_match = old_attn.q_proj.weight.shape == target.q_proj.weight.shape
401
  k_match = old_attn.k_proj.weight.shape == target.k_proj.weight.shape
402
  v_match = old_attn.v_proj.weight.shape == target.v_proj.weight.shape
403
  o_match = old_attn.o_proj.weight.shape == target.o_proj.weight.shape
404
 
405
  if q_match and k_match and v_match and o_match:
 
406
  target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
407
  target.k_proj.weight.data = old_attn.k_proj.weight.data.clone()
408
  target.v_proj.weight.data = old_attn.v_proj.weight.data.clone()
409
  target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
410
+ print(f" โœ… Layer {layer_idx}: Perfect match")
411
 
412
  elif q_match and o_match:
 
413
  target.q_proj.weight.data = old_attn.q_proj.weight.data.clone()
414
  target.o_proj.weight.data = old_attn.o_proj.weight.data.clone()
415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
  k_copy_size = min(old_attn.k_proj.weight.shape[0], target.k_proj.weight.shape[0])
417
  v_copy_size = min(old_attn.v_proj.weight.shape[0], target.v_proj.weight.shape[0])
418
 
419
  target.k_proj.weight.data[:k_copy_size] = old_attn.k_proj.weight.data[:k_copy_size].clone()
420
  target.v_proj.weight.data[:v_copy_size] = old_attn.v_proj.weight.data[:v_copy_size].clone()
421
 
422
+ print(f" โœ… Layer {layer_idx}: Partial (GQA)")
423
 
424
  else:
 
 
 
 
 
 
 
 
425
  nn.init.xavier_uniform_(target.q_proj.weight)
426
  nn.init.xavier_uniform_(target.k_proj.weight)
427
  nn.init.xavier_uniform_(target.v_proj.weight)
428
  nn.init.xavier_uniform_(target.o_proj.weight)
429
+ print(f" โš ๏ธ Layer {layer_idx}: Xavier init")
430
 
431
  except Exception as e:
432
  print(f" โš ๏ธ Layer {layer_idx}: Weight copy failed - {e}")
 
 
433
 
 
434
  layer.self_attn = new_retention
435
  replaced_count += 1
436
 
 
 
437
  except Exception as e:
438
  print(f" โŒ Layer {layer_idx}: Failed - {e}")
 
 
439
  continue
440
 
441
  print(f"\nโœ… Conversion complete: {replaced_count}/{total_layers} layers")
 
443
  return model, replaced_count, total_layers
444
 
445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
  # =====================================================
447
  # ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค
448
  # =====================================================
 
453
  def __init__(self, db_path: str):
454
  self.db_path = db_path
455
  self.init_database()
 
456
 
457
  def init_database(self):
458
  with sqlite3.connect(self.db_path) as conn:
 
474
  timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
475
  )
476
  """)
 
 
 
 
 
 
 
477
 
478
+ # Burning history table
479
+ cursor.execute("""
480
+ CREATE TABLE IF NOT EXISTS burning_history (
481
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
482
+ model_url TEXT NOT NULL,
483
+ output_path TEXT NOT NULL,
484
+ use_hierarchical BOOLEAN,
485
+ dataset_used BOOLEAN,
486
+ conversion_rate REAL,
487
+ training_steps INTEGER,
488
+ final_loss REAL,
489
+ evaluation_score REAL,
490
+ timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
491
+ )
492
+ """)
493
  conn.commit()
494
 
495
  def save_experiment(self, config: Dict, metrics: Dict) -> int:
 
518
  conn.commit()
519
  return cursor.lastrowid
520
 
521
+ def save_burning(self, burning_info: Dict) -> int:
522
  with sqlite3.connect(self.db_path) as conn:
 
523
  cursor = conn.cursor()
524
+ cursor.execute("""
525
+ INSERT INTO burning_history (
526
+ model_url, output_path, use_hierarchical,
527
+ dataset_used, conversion_rate, training_steps,
528
+ final_loss, evaluation_score
529
+ ) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
530
+ """, (
531
+ burning_info.get('model_url'),
532
+ burning_info.get('output_path'),
533
+ burning_info.get('use_hierarchical'),
534
+ burning_info.get('dataset_used'),
535
+ burning_info.get('conversion_rate'),
536
+ burning_info.get('training_steps', 0),
537
+ burning_info.get('final_loss'),
538
+ burning_info.get('evaluation_score'),
539
+ ))
540
+ conn.commit()
541
+ return cursor.lastrowid
542
 
543
+ def get_burning_history(self, limit: int = 20) -> List[Dict]:
544
  with sqlite3.connect(self.db_path) as conn:
545
+ conn.row_factory = sqlite3.Row
546
  cursor = conn.cursor()
547
+ cursor.execute("SELECT * FROM burning_history ORDER BY timestamp DESC LIMIT ?", (limit,))
548
+ return [dict(row) for row in cursor.fetchall()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
 
550
 
551
  # =====================================================
552
+ # ๋ชจ๋ธ ๋ฒ„๋‹ (Zero-shot + Optional Fine-tuning)
553
  # =====================================================
554
 
555
+ def evaluate_model_quality(model, tokenizer, test_prompts=None):
556
+ """
557
+ ๊ฐ„๋‹จํ•œ ๋ชจ๋ธ ํ’ˆ์งˆ ํ‰๊ฐ€
558
 
559
+ Returns:
560
+ score: 0.0 ~ 1.0 (๋†’์„์ˆ˜๋ก ์ข‹์Œ)
561
+ """
562
+ if test_prompts is None:
563
+ test_prompts = [
564
+ "The capital of France is",
565
+ "In machine learning, overfitting means",
566
+ "2 + 2 =",
567
+ ]
568
 
569
+ model.eval()
570
+ scores = []
 
 
571
 
572
+ with torch.no_grad():
573
+ for prompt in test_prompts:
574
+ try:
575
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
576
+ outputs = model.generate(
577
+ **inputs,
578
+ max_new_tokens=20,
579
+ do_sample=False,
580
+ pad_token_id=tokenizer.eos_token_id,
581
+ )
582
+ generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
583
+
584
+ # ๊ฐ„๋‹จํ•œ ํ’ˆ์งˆ ์ฒดํฌ
585
+ score = 0.0
586
+ if len(generated) > len(prompt): # ๋ญ”๊ฐ€ ์ƒ์„ฑ๋จ
587
+ score += 0.3
588
+ if not any(char in generated[len(prompt):] for char in ['๏ฟฝ', '[UNK]']): # ๊นจ์ง„ ๋ฌธ์ž ์—†์Œ
589
+ score += 0.3
590
+ if len(generated.split()) > len(prompt.split()) + 2: # ์˜๋ฏธ์žˆ๋Š” ๋‹จ์–ด ์ƒ์„ฑ
591
+ score += 0.4
592
+
593
+ scores.append(score)
594
+ except Exception as e:
595
+ print(f" โš ๏ธ Evaluation error for '{prompt}': {e}")
596
+ scores.append(0.0)
597
+
598
+ return sum(scores) / len(scores) if scores else 0.0
599
 
600
 
601
+ def burn_model_zero_shot(
602
+ model_url: str,
603
+ output_dir: str,
604
+ use_hierarchical: bool = True,
605
+ test_prompts: List[str] = None,
606
+ ):
607
+ """
608
+ Zero-shot Model Burning (๋ฐ์ดํ„ฐ์…‹ ๋ถˆํ•„์š”)
609
+
610
+ 1. ๋ชจ๋ธ ๋กœ๋“œ
611
+ 2. Attention โ†’ Retention ๋ณ€ํ™˜
612
+ 3. ํ’ˆ์งˆ ํ‰๊ฐ€
613
+ 4. ์ €์žฅ
614
+
615
+ Returns:
616
+ status, model_path, metrics
617
+ """
618
+ print("="*80)
619
+ print("๐Ÿ”ฅ PHOENIX Zero-shot Model Burning")
620
+ print("="*80)
621
+
622
+ output_path = Path(output_dir)
623
+ output_path.mkdir(parents=True, exist_ok=True)
624
+
625
+ try:
626
+ # 1. Load model
627
+ print(f"\n๐Ÿ“ฅ Loading model: {model_url}")
628
+ start_time = time.time()
629
+
630
+ config = AutoConfig.from_pretrained(model_url, trust_remote_code=True)
631
+ model = AutoModelForCausalLM.from_pretrained(
632
+ model_url,
633
+ trust_remote_code=True,
634
+ torch_dtype=torch.float16,
635
+ ).to(DEVICE)
636
+
637
+ tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
638
+ if tokenizer.pad_token is None:
639
+ tokenizer.pad_token = tokenizer.eos_token
640
+
641
+ load_time = time.time() - start_time
642
+ print(f"โœ… Loaded in {load_time:.1f}s")
643
+
644
+ # 2. Convert
645
+ print(f"\n๐Ÿ”„ Converting Attention โ†’ Retention...")
646
+ convert_start = time.time()
647
+
648
+ model.model, converted, total = replace_attention_with_retention(
649
+ model.model,
650
+ use_hierarchical=use_hierarchical
651
+ )
652
+
653
+ convert_time = time.time() - convert_start
654
+ conversion_rate = converted / total if total > 0 else 0
655
+
656
+ print(f"โœ… Converted {converted}/{total} layers ({conversion_rate*100:.1f}%) in {convert_time:.1f}s")
657
+
658
+ # 3. Evaluate
659
+ print(f"\n๐Ÿ“Š Evaluating model quality...")
660
+ eval_start = time.time()
661
+
662
+ quality_score = evaluate_model_quality(model, tokenizer, test_prompts)
663
+
664
+ eval_time = time.time() - eval_start
665
+ print(f"โœ… Quality Score: {quality_score:.2f}/1.00 (in {eval_time:.1f}s)")
666
+
667
+ # 4. Save
668
+ print(f"\n๐Ÿ’พ Saving PHOENIX model...")
669
+ save_start = time.time()
670
+
671
+ model.save_pretrained(output_path)
672
+ tokenizer.save_pretrained(output_path)
673
+
674
+ # Save metadata
675
+ metadata = {
676
+ 'phoenix_version': '1.0.0',
677
+ 'original_model': model_url,
678
+ 'use_hierarchical': use_hierarchical,
679
+ 'conversion_rate': conversion_rate,
680
+ 'layers_converted': converted,
681
+ 'total_layers': total,
682
+ 'quality_score': quality_score,
683
+ 'burning_type': 'zero_shot',
684
+ 'timestamp': datetime.now().isoformat(),
685
+ }
686
+
687
+ with open(output_path / 'phoenix_metadata.json', 'w') as f:
688
+ json.dump(metadata, f, indent=2)
689
+
690
+ save_time = time.time() - save_start
691
+ print(f"โœ… Saved to {output_path} in {save_time:.1f}s")
692
+
693
+ # Total time
694
+ total_time = time.time() - start_time
695
+
696
+ result = {
697
+ 'status': 'success',
698
+ 'model_path': str(output_path),
699
+ 'conversion_rate': conversion_rate,
700
+ 'quality_score': quality_score,
701
+ 'total_time': total_time,
702
+ 'load_time': load_time,
703
+ 'convert_time': convert_time,
704
+ 'eval_time': eval_time,
705
+ 'save_time': save_time,
706
+ }
707
+
708
+ print(f"\n{'='*80}")
709
+ print(f"โœ… Zero-shot Burning Complete!")
710
+ print(f" Total Time: {total_time:.1f}s")
711
+ print(f" Model Path: {output_path}")
712
+ print(f" Quality: {quality_score:.2f}/1.00")
713
+ print(f"{'='*80}\n")
714
+
715
+ return result
716
+
717
+ except Exception as e:
718
+ import traceback
719
+ error_msg = traceback.format_exc()
720
+ print(f"\nโŒ Zero-shot burning failed:\n{error_msg}")
721
+ return {
722
+ 'status': 'failed',
723
+ 'error': str(e),
724
+ 'traceback': error_msg
725
+ }
726
 
727
 
728
+ def burn_model_with_finetuning(
729
+ model_url: str,
730
+ output_dir: str,
731
+ dataset_path: str,
732
+ use_hierarchical: bool = True,
733
+ num_epochs: int = 1,
734
+ batch_size: int = 4,
735
+ learning_rate: float = 5e-5,
736
+ max_steps: int = 100,
737
+ ):
738
+ """
739
+ Fine-tuning Model Burning (๋ฐ์ดํ„ฐ์…‹ ๊ธฐ๋ฐ˜)
740
+
741
+ 1. ๋ชจ๋ธ ๋กœ๋“œ & ๋ณ€ํ™˜
742
+ 2. ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
743
+ 3. Fine-tuning
744
+ 4. ํ‰๊ฐ€ & ์ €์žฅ
745
+
746
+ Returns:
747
+ status, model_path, metrics
748
+ """
749
+ print("="*80)
750
+ print("๐Ÿ”ฅ PHOENIX Fine-tuning Model Burning")
751
+ print("="*80)
752
+
753
+ output_path = Path(output_dir)
754
+ output_path.mkdir(parents=True, exist_ok=True)
755
+
756
+ try:
757
+ # 1. Load & Convert
758
+ print(f"\n๐Ÿ“ฅ Loading model: {model_url}")
759
+ config = AutoConfig.from_pretrained(model_url, trust_remote_code=True)
760
+ model = AutoModelForCausalLM.from_pretrained(
761
+ model_url,
762
+ trust_remote_code=True,
763
+ torch_dtype=torch.float16,
764
+ ).to(DEVICE)
765
+
766
+ tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
767
+ if tokenizer.pad_token is None:
768
+ tokenizer.pad_token = tokenizer.eos_token
769
+
770
+ print(f"\n๐Ÿ”„ Converting...")
771
+ model.model, converted, total = replace_attention_with_retention(
772
+ model.model,
773
+ use_hierarchical=use_hierarchical
774
+ )
775
+
776
+ conversion_rate = converted / total if total > 0 else 0
777
+ print(f"โœ… Converted {converted}/{total} layers")
778
+
779
+ # 2. Load dataset
780
+ print(f"\n๐Ÿ“Š Loading dataset: {dataset_path}")
781
+
782
+ if dataset_path.endswith('.txt'):
783
+ with open(dataset_path, 'r', encoding='utf-8') as f:
784
+ texts = [line.strip() for line in f if line.strip()]
785
+
786
+ # Simple tokenization
787
+ def tokenize_fn(text):
788
+ return tokenizer(
789
+ text,
790
+ truncation=True,
791
+ max_length=512,
792
+ padding='max_length',
793
+ return_tensors='pt'
794
+ )
795
+
796
+ tokenized_data = [tokenize_fn(text) for text in texts[:1000]] # Limit to 1000
797
+
798
+ else:
799
+ # Try loading as HF dataset
800
+ from datasets import load_dataset
801
+ dataset = load_dataset('text', data_files=dataset_path)
802
+
803
+ def tokenize_function(examples):
804
+ return tokenizer(
805
+ examples['text'],
806
+ truncation=True,
807
+ max_length=512,
808
+ padding='max_length',
809
+ )
810
+
811
+ dataset = dataset.map(tokenize_function, batched=True)
812
+ tokenized_data = dataset['train']
813
+
814
+ print(f"โœ… Loaded {len(tokenized_data)} samples")
815
+
816
+ # 3. Quick fine-tuning
817
+ print(f"\n๐Ÿš€ Starting fine-tuning...")
818
+ print(f" Epochs: {num_epochs}")
819
+ print(f" Batch Size: {batch_size}")
820
+ print(f" Max Steps: {max_steps}")
821
+
822
+ model.train()
823
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
824
+
825
+ step = 0
826
+ total_loss = 0.0
827
+
828
+ for epoch in range(num_epochs):
829
+ for i in range(0, len(tokenized_data), batch_size):
830
+ if step >= max_steps:
831
+ break
832
+
833
+ batch = tokenized_data[i:i+batch_size]
834
+
835
+ # Simple batch processing
836
+ if isinstance(batch, list):
837
+ input_ids = torch.stack([item['input_ids'].squeeze() for item in batch]).to(DEVICE)
838
+ attention_mask = torch.stack([item['attention_mask'].squeeze() for item in batch]).to(DEVICE)
839
+ else:
840
+ input_ids = torch.tensor(batch['input_ids']).to(DEVICE)
841
+ attention_mask = torch.tensor(batch['attention_mask']).to(DEVICE)
842
+
843
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
844
+ loss = outputs.loss
845
+
846
+ loss.backward()
847
+ optimizer.step()
848
+ optimizer.zero_grad()
849
+
850
+ total_loss += loss.item()
851
+ step += 1
852
+
853
+ if step % 10 == 0:
854
+ avg_loss = total_loss / step
855
+ print(f" Step {step}/{max_steps} - Loss: {avg_loss:.4f}")
856
+
857
+ final_loss = total_loss / step if step > 0 else 0.0
858
+ print(f"โœ… Training complete - Final Loss: {final_loss:.4f}")
859
+
860
+ # 4. Evaluate & Save
861
+ print(f"\n๐Ÿ“Š Evaluating...")
862
+ model.eval()
863
+ quality_score = evaluate_model_quality(model, tokenizer)
864
+ print(f"โœ… Quality Score: {quality_score:.2f}/1.00")
865
+
866
+ print(f"\n๐Ÿ’พ Saving model...")
867
+ model.save_pretrained(output_path)
868
+ tokenizer.save_pretrained(output_path)
869
+
870
+ metadata = {
871
+ 'phoenix_version': '1.0.0',
872
+ 'original_model': model_url,
873
+ 'use_hierarchical': use_hierarchical,
874
+ 'conversion_rate': conversion_rate,
875
+ 'quality_score': quality_score,
876
+ 'burning_type': 'fine_tuning',
877
+ 'training_steps': step,
878
+ 'final_loss': final_loss,
879
+ 'dataset': dataset_path,
880
+ 'timestamp': datetime.now().isoformat(),
881
+ }
882
+
883
+ with open(output_path / 'phoenix_metadata.json', 'w') as f:
884
+ json.dump(metadata, f, indent=2)
885
+
886
+ print(f"โœ… Saved to {output_path}")
887
+
888
+ result = {
889
+ 'status': 'success',
890
+ 'model_path': str(output_path),
891
+ 'conversion_rate': conversion_rate,
892
+ 'quality_score': quality_score,
893
+ 'training_steps': step,
894
+ 'final_loss': final_loss,
895
+ }
896
+
897
+ print(f"\n{'='*80}")
898
+ print(f"โœ… Fine-tuning Burning Complete!")
899
+ print(f"{'='*80}\n")
900
+
901
+ return result
902
+
903
+ except Exception as e:
904
+ import traceback
905
+ error_msg = traceback.format_exc()
906
+ print(f"\nโŒ Fine-tuning burning failed:\n{error_msg}")
907
+ return {
908
+ 'status': 'failed',
909
+ 'error': str(e),
910
+ 'traceback': error_msg
911
+ }
912
 
913
 
914
  # =====================================================
915
+ # Gradio UI Functions
916
  # =====================================================
917
 
918
  def convert_model_to_phoenix(model_url, use_hierarchical=True, gpu_type="L40S"):
919
+ """Convert model to PHOENIX (๊ธฐ์กด ํ•จ์ˆ˜ ์œ ์ง€)"""
 
 
920
  try:
 
 
 
 
921
  start_time = time.time()
922
 
923
  print(f"๐Ÿ“ฅ Loading model: {model_url}")
 
931
  model, converted, total = replace_attention_with_retention(model, use_hierarchical)
932
 
933
  elapsed_time = time.time() - start_time
 
 
 
 
 
 
 
 
 
 
934
  conversion_pct = (converted / total * 100) if total > 0 else 0
935
 
936
  result = f"""
 
938
 
939
  **Model**: {model_url}
940
  **Converted**: {converted}/{total} layers ({conversion_pct:.1f}%)
941
+ **Time**: {elapsed_time:.1f}s
942
  **GPU**: {gpu_type}
943
 
944
  ๐ŸŽฏ GQA-aware O(n) complexity!
945
  """
946
 
947
+ return result
948
 
949
  except Exception as e:
950
+ return f"โŒ Conversion failed: {str(e)}"
951
 
952
 
953
  def generate_text_phoenix(
954
  model_url, use_hierarchical, convert_attention,
955
  prompt, max_new_tokens, temperature
956
  ):
957
+ """PHOENIX ํ…์ŠคํŠธ ์ƒ์„ฑ (๊ธฐ์กด ํ•จ์ˆ˜ - ๊ฐ„์†Œํ™”)"""
958
  try:
959
  if not convert_attention or not model_url.strip():
960
  return "โš ๏ธ Enable 'Attention Replace' and provide model URL", ""
961
 
962
+ print(f"๐Ÿ“ฅ Loading model: {model_url}")
 
 
 
 
963
  model = AutoModelForCausalLM.from_pretrained(
964
  model_url,
965
  trust_remote_code=True,
966
  torch_dtype=torch.float16
967
  ).to(DEVICE)
968
 
969
+ print(f"๐Ÿ”„ Converting...")
 
970
  model.model, converted, total = replace_attention_with_retention(
971
+ model.model,
972
  use_hierarchical=use_hierarchical
973
  )
974
 
975
+ tokenizer = AutoTokenizer.from_pretrained(model_url, trust_remote_code=True)
976
+ if tokenizer.pad_token is None:
977
+ tokenizer.pad_token = tokenizer.eos_token
 
 
 
 
 
 
 
978
 
 
 
 
 
 
 
 
 
 
979
  inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
 
 
 
 
 
 
980
 
981
+ print(f"๐Ÿš€ Generating...")
982
  start_time = time.time()
983
+
984
+ outputs = model.generate(
985
+ **inputs,
986
+ max_new_tokens=max_new_tokens,
987
+ temperature=temperature,
988
+ do_sample=temperature > 0.01,
989
+ pad_token_id=tokenizer.eos_token_id,
990
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
991
 
992
  elapsed = time.time() - start_time
993
 
994
+ generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
995
 
 
996
  output_md = f"""
997
  ## ๐Ÿ“ Generated Text
998
 
 
999
  ```
1000
+ {generated}
 
 
 
 
 
 
 
 
 
 
1001
  ```
1002
  """
1003
 
 
 
1004
  stats_md = f"""
1005
+ ## ๐Ÿ“Š Statistics
1006
 
 
 
 
 
1007
  - **Time**: {elapsed:.2f}s
1008
+ - **Converted**: {converted}/{total} layers
1009
+ - **Tokens/s**: {max_new_tokens/elapsed:.1f}
 
 
 
 
 
 
 
 
 
 
1010
  """
1011
 
1012
  return output_md, stats_md
1013
 
1014
  except Exception as e:
1015
  import traceback
1016
+ return f"โŒ Error:\n```\n{traceback.format_exc()}\n```", ""
1017
 
1018
 
1019
+ def burn_phoenix_model_ui(
1020
+ model_url,
1021
+ use_hierarchical,
1022
+ dataset_path,
1023
+ output_name,
1024
+ use_finetuning,
1025
+ num_epochs,
1026
+ batch_size,
1027
+ learning_rate,
1028
+ max_steps,
1029
+ ):
1030
+ """
1031
+ Gradio UI์šฉ ๋ชจ๋ธ ๋ฒ„๋‹ ํ•จ์ˆ˜
1032
+ """
1033
  try:
1034
+ if not model_url.strip():
1035
+ return "โš ๏ธ Model URL required", None
1036
+
1037
+ if not output_name.strip():
1038
+ output_name = f"phoenix_{model_url.split('/')[-1]}_{int(time.time())}"
1039
+
1040
+ output_dir = f"{MODELS_PATH}/{output_name}"
1041
+
1042
+ # Dataset check
1043
+ has_dataset = dataset_path and dataset_path.strip() and Path(dataset_path).exists()
1044
+
1045
+ if use_finetuning and not has_dataset:
1046
+ return "โš ๏ธ Fine-tuning requires dataset path", None
1047
+
1048
+ # Choose burning method
1049
+ if use_finetuning and has_dataset:
1050
+ result = burn_model_with_finetuning(
1051
+ model_url=model_url,
1052
+ output_dir=output_dir,
1053
+ dataset_path=dataset_path,
1054
+ use_hierarchical=use_hierarchical,
1055
+ num_epochs=num_epochs,
1056
+ batch_size=batch_size,
1057
+ learning_rate=learning_rate,
1058
+ max_steps=max_steps,
1059
+ )
1060
+ else:
1061
+ result = burn_model_zero_shot(
1062
+ model_url=model_url,
1063
+ output_dir=output_dir,
1064
+ use_hierarchical=use_hierarchical,
1065
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1066
 
1067
+ if result['status'] == 'success':
1068
+ # Save to database
1069
+ burning_info = {
1070
+ 'model_url': model_url,
1071
+ 'output_path': result['model_path'],
1072
+ 'use_hierarchical': use_hierarchical,
1073
+ 'dataset_used': has_dataset,
1074
+ 'conversion_rate': result.get('conversion_rate', 0.0),
1075
+ 'training_steps': result.get('training_steps', 0),
1076
+ 'final_loss': result.get('final_loss'),
1077
+ 'evaluation_score': result.get('quality_score', 0.0),
1078
+ }
1079
+
1080
+ db.save_burning(burning_info)
1081
+
1082
+ # Format output
1083
+ output_md = f"""
1084
+ # ๐Ÿ”ฅ Model Burning Complete!
1085
 
1086
+ ## ๐Ÿ“ฆ Model Information
1087
+ - **Original**: {model_url}
1088
+ - **Output**: `{result['model_path']}`
1089
+ - **Type**: {'Fine-tuning' if has_dataset else 'Zero-shot'}
1090
 
1091
+ ## ๐Ÿ“Š Metrics
1092
+ - **Conversion Rate**: {result['conversion_rate']*100:.1f}%
1093
+ - **Quality Score**: {result.get('quality_score', 0.0):.2f}/1.00
1094
+ """
1095
+
1096
+ if 'training_steps' in result:
1097
+ output_md += f"""
1098
+ ## ๐Ÿš€ Training
1099
+ - **Steps**: {result['training_steps']}
1100
+ - **Final Loss**: {result.get('final_loss', 0.0):.4f}
1101
+ """
1102
+
1103
+ output_md += f"""
1104
+ ## โฑ๏ธ Time Breakdown
1105
+ - **Total**: {result.get('total_time', 0):.1f}s
1106
+ """
1107
+
1108
+ if 'load_time' in result:
1109
+ output_md += f"- **Load**: {result['load_time']:.1f}s\n"
1110
+ output_md += f"- **Convert**: {result['convert_time']:.1f}s\n"
1111
+ output_md += f"- **Evaluate**: {result['eval_time']:.1f}s\n"
1112
+ output_md += f"- **Save**: {result['save_time']:.1f}s\n"
1113
+
1114
+ output_md += f"""
1115
+ ## ๐ŸŽฏ Usage
1116
+
1117
+ ```python
1118
+ from transformers import AutoModelForCausalLM, AutoTokenizer
1119
 
1120
+ model = AutoModelForCausalLM.from_pretrained("{result['model_path']}")
1121
+ tokenizer = AutoTokenizer.from_pretrained("{result['model_path']}")
1122
+
1123
+ inputs = tokenizer("Your prompt", return_tensors="pt")
1124
+ outputs = model.generate(**inputs, max_new_tokens=50)
1125
+ print(tokenizer.decode(outputs[0]))
1126
+ ```
1127
 
1128
+ โœ… **PHOENIX Model Ready!**
 
 
1129
  """
1130
+
1131
+ # Create simple plot
1132
+ fig = go.Figure()
1133
+ fig.add_trace(go.Bar(
1134
+ x=['Conversion', 'Quality'],
1135
+ y=[result['conversion_rate'], result.get('quality_score', 0.0)],
1136
+ text=[f"{result['conversion_rate']*100:.1f}%", f"{result.get('quality_score', 0.0):.2f}"],
1137
+ textposition='auto',
1138
+ ))
1139
+ fig.update_layout(
1140
+ title="Burning Metrics",
1141
+ yaxis_range=[0, 1],
1142
+ template='plotly_white'
1143
+ )
1144
+
1145
+ return output_md, fig
1146
+
1147
+ else:
1148
+ return f"โŒ Burning failed:\n```\n{result.get('error', 'Unknown error')}\n```", None
1149
+
1150
+ except Exception as e:
1151
+ import traceback
1152
+ return f"โŒ Error:\n```\n{traceback.format_exc()}\n```", None
1153
 
1154
 
1155
+ def view_burning_history():
1156
+ """View burning history"""
1157
  try:
1158
+ history = db.get_burning_history(limit=20)
1159
 
1160
+ if not history:
1161
+ return "๐Ÿ“ญ No burning history yet", None
1162
 
1163
+ df = pd.DataFrame(history)
1164
 
1165
  fig = px.scatter(
1166
+ df,
1167
+ x='timestamp',
1168
+ y='evaluation_score',
1169
+ size='conversion_rate',
1170
+ color='dataset_used',
1171
+ hover_data=['model_url', 'output_path'],
1172
+ title='Burning History'
1173
  )
1174
 
1175
+ cols = ['id', 'model_url', 'output_path', 'conversion_rate',
1176
+ 'evaluation_score', 'training_steps', 'timestamp']
1177
  available = [c for c in cols if c in df.columns]
1178
 
1179
+ return f"## ๐Ÿ“Š Burning History\n\n{df[available].to_markdown(index=False)}", fig
1180
 
1181
  except Exception as e:
1182
  return f"โŒ Error: {e}", None
1183
 
1184
 
1185
+ # ์ „์—ญ ์ดˆ๊ธฐํ™”
1186
+ db = ExperimentDatabase(DB_PATH)
1187
+ CONVERTED_MODELS = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1188
 
1189
 
1190
  # =====================================================
 
1192
  # =====================================================
1193
 
1194
  with gr.Blocks(
1195
+ title="๐Ÿ”ฎ PHOENIX - Model Burning Platform",
1196
  theme=gr.themes.Soft(),
1197
  ) as demo:
1198
 
1199
  gr.Markdown("""
1200
  # ๐Ÿ”ฎ PHOENIX Retention Platform
1201
 
1202
+ **Zero-shot Model Burning + Optional Fine-tuning**
1203
 
1204
+ โœ… Zero-shot Conversion (๋ฐ์ดํ„ฐ์…‹ ๋ถˆํ•„์š”!)
1205
+ โœ… Optional Fine-tuning (๋ฐ์ดํ„ฐ์…‹ ๊ธฐ๋ฐ˜)
1206
+ โœ… GQA Support
1207
+ โœ… O(n) Complexity
 
1208
 
1209
  ---
1210
  """)
1211
 
1212
  with gr.Tabs():
1213
+ with gr.Tab("๐Ÿ”„ Quick Convert"):
1214
+ gr.Markdown("""
1215
+ ### ๋น ๋ฅธ ๋ณ€ํ™˜ ํ…Œ์ŠคํŠธ
1216
+ ๋ชจ๋ธ์„ ๋กœ๋“œํ•˜๊ณ  Attention โ†’ Retention ๋ณ€ํ™˜๋งŒ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค. (์ €์žฅ ์•ˆ ํ•จ)
1217
+ """)
1218
+
1219
  with gr.Row():
1220
  with gr.Column(scale=1):
1221
  convert_url = gr.Textbox(
 
1225
  )
1226
  convert_hierarchical = gr.Checkbox(value=True, label="Hierarchical Retention")
1227
  convert_gpu = gr.Radio(choices=["L40S", "H100"], value="L40S", label="GPU")
 
 
1228
  convert_btn = gr.Button("๐Ÿ”„ Convert", variant="primary")
1229
 
1230
  with gr.Column(scale=2):
1231
  convert_output = gr.Markdown()
1232
 
1233
+ convert_btn.click(
1234
+ convert_model_to_phoenix,
1235
+ [convert_url, convert_hierarchical, convert_gpu],
1236
+ [convert_output]
1237
+ )
1238
+
1239
+ with gr.Tab("๐Ÿ”ฅ Model Burning"):
1240
+ gr.Markdown("""
1241
+ ### ๐Ÿ”ฅ PHOENIX Model Burning
1242
+
1243
+ **๋ชจ๋ธ์„ ๋ณ€ํ™˜ํ•˜๊ณ  ์ €์žฅํ•ฉ๋‹ˆ๋‹ค!**
1244
+
1245
+ - **Zero-shot**: ๋ฐ์ดํ„ฐ์…‹ ์—†์ด ๋ณ€ํ™˜๋งŒ ์ˆ˜ํ–‰ (๋น ๋ฆ„!)
1246
+ - **Fine-tuning**: ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ์ถ”๊ฐ€ ํ•™์Šต (์„ฑ๋Šฅ ํ–ฅ์ƒ)
1247
+ """)
1248
+
1249
+ with gr.Row():
1250
+ with gr.Column(scale=1):
1251
+ burn_model_url = gr.Textbox(
1252
+ label="๐Ÿ”— Model URL",
1253
+ value=DEFAULT_MODEL,
1254
+ placeholder="ibm-granite/granite-4.0-h-350m"
1255
+ )
1256
+ burn_hierarchical = gr.Checkbox(value=True, label="Hierarchical Retention")
1257
+
1258
+ burn_output_name = gr.Textbox(
1259
+ label="๐Ÿ’พ Output Name",
1260
+ placeholder="phoenix_my_model (auto-generated if empty)"
1261
+ )
1262
+
1263
+ gr.Markdown("---")
1264
+ gr.Markdown("### ๐Ÿ“Š Dataset (Optional)")
1265
+
1266
+ burn_dataset = gr.Textbox(
1267
+ label="๐Ÿ“ Dataset Path (Optional)",
1268
+ placeholder="/path/to/dataset.txt (leave empty for zero-shot)",
1269
+ value=""
1270
+ )
1271
+
1272
+ burn_use_finetuning = gr.Checkbox(
1273
+ value=False,
1274
+ label="๐Ÿš€ Enable Fine-tuning (requires dataset)"
1275
+ )
1276
+
1277
+ with gr.Accordion("โš™๏ธ Fine-tuning Config", open=False):
1278
+ burn_epochs = gr.Slider(1, 5, 1, step=1, label="Epochs")
1279
+ burn_batch = gr.Slider(1, 16, 4, step=1, label="Batch Size")
1280
+ burn_lr = gr.Number(value=5e-5, label="Learning Rate")
1281
+ burn_max_steps = gr.Slider(10, 500, 100, step=10, label="Max Steps")
1282
+
1283
+ burn_btn = gr.Button("๐Ÿ”ฅ Burn Model", variant="primary", size="lg")
1284
+
1285
+ with gr.Column(scale=2):
1286
+ burn_output = gr.Markdown()
1287
+ burn_plot = gr.Plot()
1288
+
1289
+ burn_btn.click(
1290
+ burn_phoenix_model_ui,
1291
+ [
1292
+ burn_model_url,
1293
+ burn_hierarchical,
1294
+ burn_dataset,
1295
+ burn_output_name,
1296
+ burn_use_finetuning,
1297
+ burn_epochs,
1298
+ burn_batch,
1299
+ burn_lr,
1300
+ burn_max_steps,
1301
+ ],
1302
+ [burn_output, burn_plot]
1303
+ )
1304
 
1305
  with gr.Tab("๐Ÿ’ฌ Text Generation"):
1306
  gr.Markdown("""
1307
  ### PHOENIX ํ…์ŠคํŠธ ์ƒ์„ฑ
1308
+ ๋ณ€ํ™˜๋œ ๋ชจ๋ธ๋กœ ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
 
 
1309
  """)
1310
 
1311
  with gr.Row():
 
1315
  gen_convert = gr.Checkbox(value=True, label="Enable Conversion")
1316
 
1317
  gen_prompt = gr.Textbox(
1318
+ label="๐Ÿ“ Prompt",
 
1319
  lines=3,
1320
  value="The future of AI is"
1321
  )
1322
 
1323
+ gen_max_tokens = gr.Slider(16, 256, 64, step=16, label="Max Tokens")
1324
  gen_temperature = gr.Slider(0.1, 2.0, 0.7, step=0.1, label="Temperature")
1325
 
1326
+ gen_btn = gr.Button("๐Ÿš€ Generate", variant="primary")
1327
 
1328
  with gr.Column(scale=2):
1329
+ gen_output = gr.Markdown()
1330
+ gen_stats = gr.Markdown()
1331
 
1332
  gen_btn.click(
1333
+ generate_text_phoenix,
1334
+ [gen_model_url, gen_hierarchical, gen_convert, gen_prompt,
1335
+ gen_max_tokens, gen_temperature],
1336
+ [gen_output, gen_stats]
1337
  )
1338
 
1339
+ with gr.Tab("๐Ÿ“Š Burning History"):
1340
+ gr.Markdown("""
1341
+ ### ๐Ÿ“Š Model Burning History
1342
+ ์ €์žฅ๋œ ๋ชจ๋ธ ๋ฒ„๋‹ ๊ธฐ๋ก์„ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค.
1343
+ """)
 
 
 
 
 
 
 
 
 
 
 
1344
 
 
 
 
 
 
1345
  with gr.Row():
1346
  with gr.Column(scale=1):
1347
+ hist_btn = gr.Button("๐Ÿ“Š Load History", variant="primary")
 
 
1348
 
1349
  with gr.Column(scale=2):
1350
  hist_output = gr.Markdown()
1351
  hist_plot = gr.Plot()
1352
 
1353
+ hist_btn.click(view_burning_history, outputs=[hist_output, hist_plot])
 
1354
 
1355
  gr.Markdown("""
1356
  ---
1357
 
1358
+ ## ๐Ÿ”ฅ PHOENIX Model Burning
1359
 
1360
+ ### Zero-shot (๋ฐ์ดํ„ฐ์…‹ ๋ถˆํ•„์š”!)
1361
+ 1. ๋ชจ๋ธ URL ์ž…๋ ฅ
1362
+ 2. "Burn Model" ํด๋ฆญ
1363
+ 3. ์™„๋ฃŒ! โ†’ `/data/phoenix_models/` ์— ์ €์žฅ
1364
 
1365
+ ### Fine-tuning (์„ ํƒ์‚ฌํ•ญ)
1366
+ 1. Dataset Path ์ž…๋ ฅ
1367
+ 2. "Enable Fine-tuning" ์ฒดํฌ
1368
+ 3. "Burn Model" ํด๋ฆญ
 
 
1369
 
1370
+ **VIDraft AI Research Lab** | PHOENIX v1.0
1371
  """)
1372
 
1373
  if __name__ == "__main__":