seawolf2357 commited on
Commit
e6bcdb0
ยท
verified ยท
1 Parent(s): b49c3a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +222 -79
app.py CHANGED
@@ -3,6 +3,7 @@
3
  Complete Integration - Single File
4
 
5
  L40S GPU + Persistent Storage (SQLite + ChromaDB)
 
6
  VIDraft AI Research Lab
7
  """
8
 
@@ -23,6 +24,7 @@ from typing import Dict, List, Any, Tuple, Optional
23
  import chromadb
24
  from chromadb.config import Settings
25
  from einops import rearrange, repeat
 
26
 
27
  # =====================================================
28
  # ์ „์—ญ ์„ค์ •
@@ -32,6 +34,7 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
32
  STORAGE_PATH = "/data" # HF Spaces ์˜๊ตฌ ์Šคํ† ๋ฆฌ์ง€
33
  DB_PATH = f"{STORAGE_PATH}/phoenix_experiments.db"
34
  VECTOR_DB_PATH = f"{STORAGE_PATH}/vector_store"
 
35
 
36
  # ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
37
  Path(STORAGE_PATH).mkdir(parents=True, exist_ok=True)
@@ -39,6 +42,7 @@ Path(VECTOR_DB_PATH).mkdir(parents=True, exist_ok=True)
39
 
40
  print(f"๐Ÿš€ PHOENIX Platform initialized on {DEVICE}")
41
  print(f"๐Ÿ’พ Storage: {STORAGE_PATH}")
 
42
 
43
  # =====================================================
44
  # ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ๊ด€๋ฆฌ ํด๋ž˜์Šค
@@ -61,6 +65,7 @@ class ExperimentDatabase:
61
  CREATE TABLE IF NOT EXISTS experiments (
62
  id INTEGER PRIMARY KEY AUTOINCREMENT,
63
  model_type TEXT NOT NULL,
 
64
  sequence_length INTEGER,
65
  power_mode TEXT,
66
  compression_level REAL,
@@ -87,6 +92,11 @@ class ExperimentDatabase:
87
  ON experiments(timestamp DESC)
88
  """)
89
 
 
 
 
 
 
90
  conn.commit()
91
  print("โœ… Database initialized")
92
 
@@ -97,13 +107,14 @@ class ExperimentDatabase:
97
 
98
  cursor.execute("""
99
  INSERT INTO experiments (
100
- model_type, sequence_length, power_mode,
101
  compression_level, use_hierarchical, elapsed_time,
102
  memory_mb, throughput, avg_retention, compression_ratio,
103
  config_json, metrics_json
104
- ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
105
  """, (
106
  config.get('model_type'),
 
107
  config.get('sequence_length'),
108
  config.get('power_mode'),
109
  config.get('compression_level'),
@@ -160,9 +171,18 @@ class ExperimentDatabase:
160
  """)
161
  by_model = dict(cursor.fetchall())
162
 
 
 
 
 
 
 
 
 
163
  return {
164
  'total_experiments': total,
165
- 'by_model': by_model
 
166
  }
167
 
168
  class RetentionVectorStore:
@@ -182,7 +202,6 @@ class RetentionVectorStore:
182
 
183
  def add_retention_state(self, experiment_id: int, states: Dict, metadata: Dict):
184
  """Retention state ์ €์žฅ"""
185
- # State๋ฅผ ๋ฒกํ„ฐ๋กœ ๋ณ€ํ™˜
186
  state_vector = self._states_to_vector(states)
187
 
188
  self.collection.add(
@@ -223,7 +242,6 @@ class RetentionVectorStore:
223
  vectors.append(value.mean().item())
224
  vectors.append(value.std().item())
225
 
226
- # ๊ณ ์ • ํฌ๊ธฐ๋กœ ํŒจ๋”ฉ/์ž๋ฅด๊ธฐ
227
  target_size = 128
228
  if len(vectors) < target_size:
229
  vectors.extend([0.0] * (target_size - len(vectors)))
@@ -234,7 +252,6 @@ class RetentionVectorStore:
234
 
235
  def _text_to_vector(self, text: str) -> np.ndarray:
236
  """ํ…์ŠคํŠธ๋ฅผ ๋ฒกํ„ฐ๋กœ ๋ณ€ํ™˜ (๊ฐ„๋‹จํ•œ ํ•ด์‹œ ๊ธฐ๋ฐ˜)"""
237
- # ์‹ค์ œ๋กœ๋Š” sentence-transformers ์‚ฌ์šฉ ๊ถŒ์žฅ
238
  hash_val = hash(text) % (2**31)
239
  np.random.seed(hash_val)
240
  return np.random.randn(128)
@@ -346,24 +363,60 @@ class DynamicPowerRetention(nn.Module):
346
  class PHOENIXRetention(nn.Module):
347
  """PHOENIX Retention ํ†ตํ•ฉ ๋ชจ๋ธ"""
348
 
349
- def __init__(self, d_model=512, d_state=256, num_layers=12, device='cuda'):
350
  super().__init__()
351
  self.d_model = d_model
352
  self.d_state = d_state
353
  self.num_layers = num_layers
354
  self.device = device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
  # Core components
357
- self.hierarchical = HierarchicalRetention(d_model, d_state)
358
  self.compressor = AdaptiveCompression(d_state)
359
- self.power_adapter = DynamicPowerRetention(d_model)
360
 
361
  # Layer norm
362
- self.norm = nn.LayerNorm(d_model)
 
 
 
 
363
 
364
  self.to(device)
365
 
366
  def forward(self, x, return_states=True):
 
 
 
 
 
 
 
 
 
 
 
367
  # Hierarchical retention
368
  h_out, states = self.hierarchical(x)
369
 
@@ -383,64 +436,93 @@ class PHOENIXRetention(nn.Module):
383
  'medium_state': states['medium_state'],
384
  'long_state': states['long_state'],
385
  'compression_ratio': compression_ratio,
386
- 'dynamic_power': power
 
387
  }
388
  return output
389
 
390
- class BrumbyRetention(nn.Module):
391
- """Brumby ๋ฒ ์ด์Šค๋ผ์ธ"""
392
 
393
- def __init__(self, d_model=512, d_state=256, power=2, device='cuda'):
394
  super().__init__()
395
  self.d_model = d_model
396
  self.d_state = d_state
397
- self.power = power
398
  self.device = device
399
-
400
- self.proj_q = nn.Linear(d_model, d_state)
401
- self.proj_k = nn.Linear(d_model, d_state)
402
- self.proj_v = nn.Linear(d_model, d_state)
403
- self.proj_out = nn.Linear(d_state, d_model)
 
 
 
 
 
 
 
 
 
 
 
 
404
 
405
  self.to(device)
406
 
407
  def forward(self, x, return_states=True):
408
- batch_size, seq_len, _ = x.shape
409
-
410
- Q = self.proj_q(x)
411
- K = self.proj_k(x)
412
- V = self.proj_v(x)
413
-
414
- # Simple retention (simplified)
415
- state = torch.zeros(batch_size, self.d_state).to(x.device)
416
- outputs = []
417
-
418
- for t in range(seq_len):
419
- state = 0.9 * state + V[:, t, :] @ K[:, t, :].T
420
- output_t = state @ Q[:, t, :].unsqueeze(-1)
421
- outputs.append(output_t.squeeze(-1))
422
-
423
- outputs = torch.stack(outputs, dim=1)
424
- outputs = self.proj_out(outputs)
425
-
426
- if return_states:
427
- return outputs, {
428
- 'state': state,
429
- 'power': self.power
430
- }
431
- return outputs
432
 
433
  # =====================================================
434
  # ์œ ํ‹ธ๋ฆฌํ‹ฐ ํ•จ์ˆ˜๋“ค
435
  # =====================================================
436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
  def calculate_metrics(output, states):
438
  """๋ฉ”ํŠธ๋ฆญ ๊ณ„์‚ฐ"""
439
  metrics = {}
440
 
441
  # ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ (๋Œ€๋žต์ )
442
  total_params = sum(p.numel() for p in [output] if isinstance(p, torch.Tensor))
443
- metrics['memory_mb'] = (total_params * 4) / (1024 * 1024) # float32 = 4 bytes
444
 
445
  # Retention ๋น„์œจ
446
  if 'short_state' in states:
@@ -509,8 +591,8 @@ def plot_memory_usage(metrics):
509
  x=['Memory (MB)', 'State Size', 'Compression Ratio'],
510
  y=[
511
  metrics.get('memory_mb', 0),
512
- metrics.get('state_size', 0) / 10, # Scale down
513
- metrics.get('compression_ratio', 0) * 100 # Percentage
514
  ],
515
  marker_color=['lightblue', 'lightgreen', 'lightyellow']
516
  ))
@@ -527,7 +609,6 @@ def plot_performance_comparison(df):
527
  """์„ฑ๋Šฅ ๋น„๊ต ์‹œ๊ฐํ™”"""
528
  fig = go.Figure()
529
 
530
- # ์†๋„ ๋น„๊ต
531
  fig.add_trace(go.Bar(
532
  name='Execution Time (s)',
533
  x=df['model'],
@@ -535,7 +616,6 @@ def plot_performance_comparison(df):
535
  marker_color='indianred'
536
  ))
537
 
538
- # ์ฒ˜๋ฆฌ๋Ÿ‰ ๋น„๊ต
539
  fig.add_trace(go.Bar(
540
  name='Throughput (tokens/s)',
541
  x=df['model'],
@@ -563,33 +643,38 @@ def plot_performance_comparison(df):
563
  # ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
564
  # =====================================================
565
 
566
- def initialize_models():
567
- """๋ชจ๋ธ๋“ค ์ดˆ๊ธฐํ™”"""
568
  models = {}
569
 
570
  try:
571
- models['phoenix_small'] = PHOENIXRetention(
 
572
  d_model=512,
573
  d_state=256,
574
  num_layers=12,
575
- device=DEVICE
 
576
  )
577
 
578
- models['phoenix_medium'] = PHOENIXRetention(
579
- d_model=1024,
580
- d_state=512,
581
- num_layers=24,
582
- device=DEVICE
 
 
583
  )
584
 
585
- models['brumby_baseline'] = BrumbyRetention(
 
586
  d_model=512,
587
  d_state=256,
588
- power=2,
589
- device=DEVICE
590
  )
591
 
592
- print("โœ… Models initialized successfully")
593
  return models
594
 
595
  except Exception as e:
@@ -599,28 +684,36 @@ def initialize_models():
599
  # ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ๋ฐ ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
600
  db = ExperimentDatabase(DB_PATH)
601
  vector_store = RetentionVectorStore(VECTOR_DB_PATH)
602
- MODELS = initialize_models()
603
 
604
  # =====================================================
605
  # Gradio ์ธํ„ฐํŽ˜์ด์Šค ํ•จ์ˆ˜๋“ค
606
  # =====================================================
607
 
608
  def run_retention_experiment(
609
- model_type, input_text, sequence_length,
610
  power_mode, compression_level, use_hierarchical
611
  ):
612
  """PHOENIX Retention ์‹คํ—˜ ์‹คํ–‰"""
613
  try:
614
  start_time = time.time()
615
 
616
- if model_type not in MODELS:
617
- return "โŒ ๋ชจ๋ธ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.", None, None
618
-
619
- model = MODELS[model_type]
 
 
 
 
 
 
 
620
 
621
  # ์‹คํ—˜ ์„ค์ •
622
  config = {
623
- 'model_type': model_type,
 
624
  'sequence_length': sequence_length,
625
  'power_mode': power_mode,
626
  'compression_level': compression_level,
@@ -649,15 +742,18 @@ def run_retention_experiment(
649
  vector_store.add_retention_state(experiment_id, states, config)
650
 
651
  # ๊ฒฐ๊ณผ ํ…์ŠคํŠธ
 
 
652
  result_text = f"""
653
  ## ๐ŸŽฏ ์‹คํ—˜ ๊ฒฐ๊ณผ (ID: {experiment_id})
654
 
655
  ### โš™๏ธ ์„ค์ •
656
- - **๋ชจ๋ธ**: {model_type}
657
- - **์‹œํ€€์Šค ๊ธธ์ด**: {sequence_length} ํ† ํฐ
658
  - **Power ๋ชจ๋“œ**: {power_mode}
659
  - **์••์ถ• ๋ ˆ๋ฒจ**: {compression_level}
660
  - **๊ณ„์ธต์  ์‚ฌ์šฉ**: {"โœ…" if use_hierarchical else "โŒ"}
 
661
 
662
  ### ๐Ÿ“Š ์„ฑ๋Šฅ ๋ฉ”ํŠธ๋ฆญ
663
  - **์‹คํ–‰ ์‹œ๊ฐ„**: {elapsed_time:.3f}์ดˆ
@@ -682,11 +778,12 @@ def run_retention_experiment(
682
  except Exception as e:
683
  return f"โŒ ์‹คํ—˜ ์‹คํŒจ: {str(e)}", None, None
684
 
685
- def compare_retention_methods(input_text, sequence_length, benchmark_tasks):
686
  """๋ชจ๋ธ ๋น„๊ต"""
687
  try:
688
  results = []
689
 
 
690
  for model_name, model in MODELS.items():
691
  start_time = time.time()
692
 
@@ -705,6 +802,26 @@ def compare_retention_methods(input_text, sequence_length, benchmark_tasks):
705
  'throughput': sequence_length / elapsed_time
706
  })
707
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
708
  df = pd.DataFrame(results)
709
  fig = plot_performance_comparison(df)
710
 
@@ -744,6 +861,7 @@ def search_experiments(query, top_k=10):
744
  search_text += f"""
745
  ### {i}. ์‹คํ—˜ #{exp_id} (์œ ์‚ฌ๋„: {score:.3f})
746
  - **๋ชจ๋ธ**: {metadata.get('model_type', 'N/A')}
 
747
  - **์‹œํ€€์Šค ๊ธธ์ด**: {metadata.get('sequence_length', 'N/A')}
748
  - **์‹œ๊ฐ„**: {metadata.get('timestamp', 'N/A')}
749
  ---
@@ -764,7 +882,6 @@ def view_experiment_history(limit=20):
764
 
765
  df = pd.DataFrame(experiments)
766
 
767
- # ์‹œ๊ฐ„๋ณ„ ์„ฑ๋Šฅ ์ถ”์ด
768
  fig = px.line(
769
  df,
770
  x='timestamp',
@@ -776,7 +893,7 @@ def view_experiment_history(limit=20):
776
  history_text = f"""
777
  ## ๐Ÿ“Š ์‹คํ—˜ ์ด๋ ฅ ({len(df)}๊ฐœ)
778
 
779
- {df[['id', 'model_type', 'sequence_length', 'elapsed_time', 'throughput', 'timestamp']].to_markdown(index=False)}
780
  """
781
 
782
  return history_text, fig
@@ -800,6 +917,10 @@ def get_database_statistics():
800
  for model, count in stats['by_model'].items():
801
  stats_text += f"- **{model}**: {count}๊ฐœ\n"
802
 
 
 
 
 
803
  return stats_text
804
 
805
  except Exception as e:
@@ -819,7 +940,8 @@ with gr.Blocks(
819
 
820
  **Post-Hierarchical Optimized Efficient Neural Infinite-conteXt**
821
 
822
- Brumby๋ฅผ ๋›ฐ์–ด๋„˜๋Š” ์ฐจ์„ธ๋Œ€ Attention-Free ์•„ํ‚คํ…์ฒ˜ ์—ฐ๊ตฌ ํ”Œ๋žซํผ
 
823
 
824
  ---
825
  """)
@@ -832,8 +954,15 @@ with gr.Blocks(
832
  with gr.Column(scale=1):
833
  model_select = gr.Dropdown(
834
  choices=list(MODELS.keys()),
835
- value='phoenix_small',
836
- label="๋ชจ๋ธ ์„ ํƒ"
 
 
 
 
 
 
 
837
  )
838
 
839
  input_text = gr.Textbox(
@@ -875,7 +1004,7 @@ with gr.Blocks(
875
 
876
  run_btn.click(
877
  fn=run_retention_experiment,
878
- inputs=[model_select, input_text, sequence_length,
879
  power_mode, compression_level, use_hierarchical],
880
  outputs=[result_output, states_plot, memory_plot]
881
  )
@@ -884,6 +1013,12 @@ with gr.Blocks(
884
  with gr.Tab("โš”๏ธ ๋ชจ๋ธ ๋น„๊ต"):
885
  with gr.Row():
886
  with gr.Column(scale=1):
 
 
 
 
 
 
887
  compare_text = gr.Textbox(
888
  label="๋น„๊ต ํ…์ŠคํŠธ",
889
  lines=5,
@@ -909,7 +1044,7 @@ with gr.Blocks(
909
 
910
  compare_btn.click(
911
  fn=compare_retention_methods,
912
- inputs=[compare_text, compare_length, benchmark_tasks],
913
  outputs=[compare_result, compare_plot]
914
  )
915
 
@@ -967,6 +1102,14 @@ with gr.Blocks(
967
  2. **์ ์‘์  ์••์ถ•** - ์ค‘์š”๋„ ๊ธฐ๋ฐ˜ ๋™์  ์••์ถ•
968
  3. **๋™์  Power** - ์ž…๋ ฅ ๋”ฐ๋ผ ์ž๋™ ์ตœ์ ํ™”
969
  4. **๋ณ‘๋ ฌ ๊ฒฝ๋กœ** - ๋‹ค์ค‘ ์ „๋žต ๋™์‹œ ์šด์˜
 
 
 
 
 
 
 
 
970
 
971
  **VIDraft AI Research Lab** | L40S GPU + Persistent Storage
972
  """)
 
3
  Complete Integration - Single File
4
 
5
  L40S GPU + Persistent Storage (SQLite + ChromaDB)
6
+ Base Model: IBM Granite 4.0 H 350M
7
  VIDraft AI Research Lab
8
  """
9
 
 
24
  import chromadb
25
  from chromadb.config import Settings
26
  from einops import rearrange, repeat
27
+ from transformers import AutoModel, AutoTokenizer, AutoConfig
28
 
29
  # =====================================================
30
  # ์ „์—ญ ์„ค์ •
 
34
  STORAGE_PATH = "/data" # HF Spaces ์˜๊ตฌ ์Šคํ† ๋ฆฌ์ง€
35
  DB_PATH = f"{STORAGE_PATH}/phoenix_experiments.db"
36
  VECTOR_DB_PATH = f"{STORAGE_PATH}/vector_store"
37
+ DEFAULT_MODEL = "ibm-granite/granite-4.0-h-350m"
38
 
39
  # ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
40
  Path(STORAGE_PATH).mkdir(parents=True, exist_ok=True)
 
42
 
43
  print(f"๐Ÿš€ PHOENIX Platform initialized on {DEVICE}")
44
  print(f"๐Ÿ’พ Storage: {STORAGE_PATH}")
45
+ print(f"๐ŸŽฏ Default Base Model: {DEFAULT_MODEL}")
46
 
47
  # =====================================================
48
  # ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ๊ด€๋ฆฌ ํด๋ž˜์Šค
 
65
  CREATE TABLE IF NOT EXISTS experiments (
66
  id INTEGER PRIMARY KEY AUTOINCREMENT,
67
  model_type TEXT NOT NULL,
68
+ base_model_url TEXT,
69
  sequence_length INTEGER,
70
  power_mode TEXT,
71
  compression_level REAL,
 
92
  ON experiments(timestamp DESC)
93
  """)
94
 
95
+ cursor.execute("""
96
+ CREATE INDEX IF NOT EXISTS idx_base_model
97
+ ON experiments(base_model_url)
98
+ """)
99
+
100
  conn.commit()
101
  print("โœ… Database initialized")
102
 
 
107
 
108
  cursor.execute("""
109
  INSERT INTO experiments (
110
+ model_type, base_model_url, sequence_length, power_mode,
111
  compression_level, use_hierarchical, elapsed_time,
112
  memory_mb, throughput, avg_retention, compression_ratio,
113
  config_json, metrics_json
114
+ ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
115
  """, (
116
  config.get('model_type'),
117
+ config.get('base_model_url'),
118
  config.get('sequence_length'),
119
  config.get('power_mode'),
120
  config.get('compression_level'),
 
171
  """)
172
  by_model = dict(cursor.fetchall())
173
 
174
+ cursor.execute("""
175
+ SELECT base_model_url, COUNT(*) as count
176
+ FROM experiments
177
+ WHERE base_model_url IS NOT NULL
178
+ GROUP BY base_model_url
179
+ """)
180
+ by_base_model = dict(cursor.fetchall())
181
+
182
  return {
183
  'total_experiments': total,
184
+ 'by_model': by_model,
185
+ 'by_base_model': by_base_model
186
  }
187
 
188
  class RetentionVectorStore:
 
202
 
203
  def add_retention_state(self, experiment_id: int, states: Dict, metadata: Dict):
204
  """Retention state ์ €์žฅ"""
 
205
  state_vector = self._states_to_vector(states)
206
 
207
  self.collection.add(
 
242
  vectors.append(value.mean().item())
243
  vectors.append(value.std().item())
244
 
 
245
  target_size = 128
246
  if len(vectors) < target_size:
247
  vectors.extend([0.0] * (target_size - len(vectors)))
 
252
 
253
  def _text_to_vector(self, text: str) -> np.ndarray:
254
  """ํ…์ŠคํŠธ๋ฅผ ๋ฒกํ„ฐ๋กœ ๋ณ€ํ™˜ (๊ฐ„๋‹จํ•œ ํ•ด์‹œ ๊ธฐ๋ฐ˜)"""
 
255
  hash_val = hash(text) % (2**31)
256
  np.random.seed(hash_val)
257
  return np.random.randn(128)
 
363
  class PHOENIXRetention(nn.Module):
364
  """PHOENIX Retention ํ†ตํ•ฉ ๋ชจ๋ธ"""
365
 
366
+ def __init__(self, d_model=512, d_state=256, num_layers=12, device='cuda', base_model_url=None):
367
  super().__init__()
368
  self.d_model = d_model
369
  self.d_state = d_state
370
  self.num_layers = num_layers
371
  self.device = device
372
+ self.base_model_url = base_model_url
373
+
374
+ # Base model ๋กœ๋“œ (์„ ํƒ์ )
375
+ self.base_model = None
376
+ if base_model_url:
377
+ try:
378
+ print(f"๐Ÿ“ฅ Loading base model: {base_model_url}")
379
+ self.base_model = AutoModel.from_pretrained(
380
+ base_model_url,
381
+ trust_remote_code=True
382
+ ).to(device)
383
+
384
+ # Base model์˜ hidden size ๊ฐ€์ ธ์˜ค๊ธฐ
385
+ if hasattr(self.base_model.config, 'hidden_size'):
386
+ self.d_model = self.base_model.config.hidden_size
387
+
388
+ print(f"โœ… Base model loaded: {base_model_url}")
389
+ print(f"๐Ÿ“ Model dimension: {self.d_model}")
390
+ except Exception as e:
391
+ print(f"โš ๏ธ Base model loading failed: {e}")
392
+ print(f" Continuing with default architecture...")
393
 
394
  # Core components
395
+ self.hierarchical = HierarchicalRetention(self.d_model, d_state)
396
  self.compressor = AdaptiveCompression(d_state)
397
+ self.power_adapter = DynamicPowerRetention(self.d_model)
398
 
399
  # Layer norm
400
+ self.norm = nn.LayerNorm(self.d_model)
401
+
402
+ # Projection (base model๊ณผ ์—ฐ๊ฒฐ)
403
+ if self.base_model:
404
+ self.base_projection = nn.Linear(self.d_model, self.d_model)
405
 
406
  self.to(device)
407
 
408
  def forward(self, x, return_states=True):
409
+ # Base model ํ†ต๊ณผ (์žˆ๋Š” ๊ฒฝ์šฐ)
410
+ if self.base_model is not None:
411
+ with torch.no_grad():
412
+ base_output = self.base_model(
413
+ inputs_embeds=x,
414
+ output_hidden_states=True
415
+ )
416
+ # ๋งˆ์ง€๋ง‰ hidden state ์‚ฌ์šฉ
417
+ x = base_output.hidden_states[-1]
418
+ x = self.base_projection(x)
419
+
420
  # Hierarchical retention
421
  h_out, states = self.hierarchical(x)
422
 
 
436
  'medium_state': states['medium_state'],
437
  'long_state': states['long_state'],
438
  'compression_ratio': compression_ratio,
439
+ 'dynamic_power': power,
440
+ 'base_model_used': self.base_model is not None
441
  }
442
  return output
443
 
444
+ class TransformerBaseline(nn.Module):
445
+ """Transformer ๋ฒ ์ด์Šค๋ผ์ธ"""
446
 
447
+ def __init__(self, d_model=512, d_state=256, device='cuda', base_model_url=None):
448
  super().__init__()
449
  self.d_model = d_model
450
  self.d_state = d_state
 
451
  self.device = device
452
+ self.base_model_url = base_model_url
453
+
454
+ # Base model ๋กœ๋“œ
455
+ self.base_model = None
456
+ if base_model_url:
457
+ try:
458
+ self.base_model = AutoModel.from_pretrained(
459
+ base_model_url,
460
+ trust_remote_code=True
461
+ ).to(device)
462
+
463
+ if hasattr(self.base_model.config, 'hidden_size'):
464
+ self.d_model = self.base_model.config.hidden_size
465
+
466
+ print(f"โœ… Transformer baseline loaded: {base_model_url}")
467
+ except Exception as e:
468
+ print(f"โš ๏ธ Transformer baseline loading failed: {e}")
469
 
470
  self.to(device)
471
 
472
  def forward(self, x, return_states=True):
473
+ if self.base_model is not None:
474
+ output = self.base_model(
475
+ inputs_embeds=x,
476
+ output_hidden_states=True
477
+ )
478
+ last_hidden = output.hidden_states[-1]
479
+
480
+ if return_states:
481
+ return last_hidden, {
482
+ 'state': last_hidden[:, -1, :],
483
+ 'base_model_used': True
484
+ }
485
+ return last_hidden
486
+ else:
487
+ # Fallback: simple identity
488
+ if return_states:
489
+ return x, {'state': x[:, -1, :], 'base_model_used': False}
490
+ return x
 
 
 
 
 
 
491
 
492
  # =====================================================
493
  # ์œ ํ‹ธ๋ฆฌํ‹ฐ ํ•จ์ˆ˜๋“ค
494
  # =====================================================
495
 
496
+ def load_custom_model(model_url: str, model_type: str = "phoenix"):
497
+ """์‚ฌ์šฉ์ž ์ง€์ • ๋ชจ๋ธ ๋กœ๋“œ"""
498
+ try:
499
+ if model_type == "phoenix":
500
+ model = PHOENIXRetention(
501
+ d_model=512,
502
+ d_state=256,
503
+ num_layers=12,
504
+ device=DEVICE,
505
+ base_model_url=model_url if model_url.strip() else None
506
+ )
507
+ else: # transformer
508
+ model = TransformerBaseline(
509
+ d_model=512,
510
+ d_state=256,
511
+ device=DEVICE,
512
+ base_model_url=model_url if model_url.strip() else None
513
+ )
514
+
515
+ return model, None
516
+ except Exception as e:
517
+ return None, str(e)
518
+
519
  def calculate_metrics(output, states):
520
  """๋ฉ”ํŠธ๋ฆญ ๊ณ„์‚ฐ"""
521
  metrics = {}
522
 
523
  # ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ (๋Œ€๋žต์ )
524
  total_params = sum(p.numel() for p in [output] if isinstance(p, torch.Tensor))
525
+ metrics['memory_mb'] = (total_params * 4) / (1024 * 1024)
526
 
527
  # Retention ๋น„์œจ
528
  if 'short_state' in states:
 
591
  x=['Memory (MB)', 'State Size', 'Compression Ratio'],
592
  y=[
593
  metrics.get('memory_mb', 0),
594
+ metrics.get('state_size', 0) / 10,
595
+ metrics.get('compression_ratio', 0) * 100
596
  ],
597
  marker_color=['lightblue', 'lightgreen', 'lightyellow']
598
  ))
 
609
  """์„ฑ๋Šฅ ๋น„๊ต ์‹œ๊ฐํ™”"""
610
  fig = go.Figure()
611
 
 
612
  fig.add_trace(go.Bar(
613
  name='Execution Time (s)',
614
  x=df['model'],
 
616
  marker_color='indianred'
617
  ))
618
 
 
619
  fig.add_trace(go.Bar(
620
  name='Throughput (tokens/s)',
621
  x=df['model'],
 
643
  # ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
644
  # =====================================================
645
 
646
+ def initialize_default_models():
647
+ """๊ธฐ๋ณธ ๋ชจ๋ธ๋“ค ์ดˆ๊ธฐํ™”"""
648
  models = {}
649
 
650
  try:
651
+ # PHOENIX with Granite
652
+ models['phoenix_granite'] = PHOENIXRetention(
653
  d_model=512,
654
  d_state=256,
655
  num_layers=12,
656
+ device=DEVICE,
657
+ base_model_url=DEFAULT_MODEL
658
  )
659
 
660
+ # PHOENIX without base
661
+ models['phoenix_standalone'] = PHOENIXRetention(
662
+ d_model=512,
663
+ d_state=256,
664
+ num_layers=12,
665
+ device=DEVICE,
666
+ base_model_url=None
667
  )
668
 
669
+ # Transformer baseline
670
+ models['transformer_granite'] = TransformerBaseline(
671
  d_model=512,
672
  d_state=256,
673
+ device=DEVICE,
674
+ base_model_url=DEFAULT_MODEL
675
  )
676
 
677
+ print("โœ… Default models initialized")
678
  return models
679
 
680
  except Exception as e:
 
684
  # ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ๋ฐ ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
685
  db = ExperimentDatabase(DB_PATH)
686
  vector_store = RetentionVectorStore(VECTOR_DB_PATH)
687
+ MODELS = initialize_default_models()
688
 
689
  # =====================================================
690
  # Gradio ์ธํ„ฐํŽ˜์ด์Šค ํ•จ์ˆ˜๋“ค
691
  # =====================================================
692
 
693
  def run_retention_experiment(
694
+ model_type, custom_model_url, input_text, sequence_length,
695
  power_mode, compression_level, use_hierarchical
696
  ):
697
  """PHOENIX Retention ์‹คํ—˜ ์‹คํ–‰"""
698
  try:
699
  start_time = time.time()
700
 
701
+ # ์ปค์Šคํ…€ ๋ชจ๋ธ URL์ด ์žˆ์œผ๋ฉด ๋กœ๋“œ
702
+ if custom_model_url and custom_model_url.strip():
703
+ model, error = load_custom_model(custom_model_url, "phoenix")
704
+ if error:
705
+ return f"โŒ ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {error}", None, None
706
+ model_name = f"phoenix_custom_{custom_model_url.split('/')[-1]}"
707
+ else:
708
+ if model_type not in MODELS:
709
+ return "โŒ ๋ชจ๋ธ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.", None, None
710
+ model = MODELS[model_type]
711
+ model_name = model_type
712
 
713
  # ์‹คํ—˜ ์„ค์ •
714
  config = {
715
+ 'model_type': model_name,
716
+ 'base_model_url': custom_model_url if custom_model_url else model.base_model_url,
717
  'sequence_length': sequence_length,
718
  'power_mode': power_mode,
719
  'compression_level': compression_level,
 
742
  vector_store.add_retention_state(experiment_id, states, config)
743
 
744
  # ๊ฒฐ๊ณผ ํ…์ŠคํŠธ
745
+ base_model_info = f"**Base Model**: {config['base_model_url']}\n" if config.get('base_model_url') else ""
746
+
747
  result_text = f"""
748
  ## ๐ŸŽฏ ์‹คํ—˜ ๊ฒฐ๊ณผ (ID: {experiment_id})
749
 
750
  ### โš™๏ธ ์„ค์ •
751
+ - **๋ชจ๋ธ**: {model_name}
752
+ {base_model_info}- **์‹œํ€€์Šค ๊ธธ์ด**: {sequence_length} ํ† ํฐ
753
  - **Power ๋ชจ๋“œ**: {power_mode}
754
  - **์••์ถ• ๋ ˆ๋ฒจ**: {compression_level}
755
  - **๊ณ„์ธต์  ์‚ฌ์šฉ**: {"โœ…" if use_hierarchical else "โŒ"}
756
+ - **Base Model ์‚ฌ์šฉ**: {"โœ…" if states.get('base_model_used') else "โŒ"}
757
 
758
  ### ๐Ÿ“Š ์„ฑ๋Šฅ ๋ฉ”ํŠธ๋ฆญ
759
  - **์‹คํ–‰ ์‹œ๊ฐ„**: {elapsed_time:.3f}์ดˆ
 
778
  except Exception as e:
779
  return f"โŒ ์‹คํ—˜ ์‹คํŒจ: {str(e)}", None, None
780
 
781
+ def compare_retention_methods(custom_model_url, input_text, sequence_length, benchmark_tasks):
782
  """๋ชจ๋ธ ๋น„๊ต"""
783
  try:
784
  results = []
785
 
786
+ # ๊ธฐ๋ณธ ๋ชจ๋ธ๋“ค ํ…Œ์ŠคํŠธ
787
  for model_name, model in MODELS.items():
788
  start_time = time.time()
789
 
 
802
  'throughput': sequence_length / elapsed_time
803
  })
804
 
805
+ # ์ปค์Šคํ…€ ๋ชจ๋ธ ํ…Œ์ŠคํŠธ
806
+ if custom_model_url and custom_model_url.strip():
807
+ custom_model, error = load_custom_model(custom_model_url, "phoenix")
808
+ if not error:
809
+ start_time = time.time()
810
+ x = torch.randn(1, sequence_length, custom_model.d_model).to(DEVICE)
811
+
812
+ with torch.no_grad():
813
+ output, states = custom_model(x, return_states=True)
814
+
815
+ elapsed_time = time.time() - start_time
816
+ metrics = calculate_metrics(output, states)
817
+
818
+ results.append({
819
+ 'model': f"custom_{custom_model_url.split('/')[-1]}",
820
+ 'time': elapsed_time,
821
+ 'memory': metrics.get('memory_mb', 0),
822
+ 'throughput': sequence_length / elapsed_time
823
+ })
824
+
825
  df = pd.DataFrame(results)
826
  fig = plot_performance_comparison(df)
827
 
 
861
  search_text += f"""
862
  ### {i}. ์‹คํ—˜ #{exp_id} (์œ ์‚ฌ๋„: {score:.3f})
863
  - **๋ชจ๋ธ**: {metadata.get('model_type', 'N/A')}
864
+ - **Base Model**: {metadata.get('base_model_url', 'N/A')}
865
  - **์‹œํ€€์Šค ๊ธธ์ด**: {metadata.get('sequence_length', 'N/A')}
866
  - **์‹œ๊ฐ„**: {metadata.get('timestamp', 'N/A')}
867
  ---
 
882
 
883
  df = pd.DataFrame(experiments)
884
 
 
885
  fig = px.line(
886
  df,
887
  x='timestamp',
 
893
  history_text = f"""
894
  ## ๐Ÿ“Š ์‹คํ—˜ ์ด๋ ฅ ({len(df)}๊ฐœ)
895
 
896
+ {df[['id', 'model_type', 'base_model_url', 'sequence_length', 'elapsed_time', 'throughput', 'timestamp']].to_markdown(index=False)}
897
  """
898
 
899
  return history_text, fig
 
917
  for model, count in stats['by_model'].items():
918
  stats_text += f"- **{model}**: {count}๊ฐœ\n"
919
 
920
+ stats_text += "\n### Base Model๋ณ„ ์‹คํ—˜ ์ˆ˜\n"
921
+ for base_model, count in stats['by_base_model'].items():
922
+ stats_text += f"- **{base_model}**: {count}๊ฐœ\n"
923
+
924
  return stats_text
925
 
926
  except Exception as e:
 
940
 
941
  **Post-Hierarchical Optimized Efficient Neural Infinite-conteXt**
942
 
943
+ ์ฐจ์„ธ๋Œ€ Attention-Free ์•„ํ‚คํ…์ฒ˜ ์—ฐ๊ตฌ ํ”Œ๋žซํผ
944
+ Base Model: **IBM Granite 4.0 H 350M** (๋˜๋Š” ์‚ฌ์šฉ์ž ์ง€์ • ๋ชจ๋ธ)
945
 
946
  ---
947
  """)
 
954
  with gr.Column(scale=1):
955
  model_select = gr.Dropdown(
956
  choices=list(MODELS.keys()),
957
+ value='phoenix_granite',
958
+ label="๊ธฐ๋ณธ ๋ชจ๋ธ ์„ ํƒ"
959
+ )
960
+
961
+ custom_model_url = gr.Textbox(
962
+ label="๐Ÿ”— ์ปค์Šคํ…€ Base Model URL (์„ ํƒ์‚ฌํ•ญ)",
963
+ placeholder="์˜ˆ: ibm-granite/granite-4.0-h-350m ๋˜๋Š” meta-llama/Llama-3.2-1B",
964
+ value="",
965
+ info="Hugging Face ๋ชจ๋ธ URL์„ ์ž…๋ ฅํ•˜๋ฉด ํ•ด๋‹น ๋ชจ๋ธ์„ base๋กœ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค"
966
  )
967
 
968
  input_text = gr.Textbox(
 
1004
 
1005
  run_btn.click(
1006
  fn=run_retention_experiment,
1007
+ inputs=[model_select, custom_model_url, input_text, sequence_length,
1008
  power_mode, compression_level, use_hierarchical],
1009
  outputs=[result_output, states_plot, memory_plot]
1010
  )
 
1013
  with gr.Tab("โš”๏ธ ๋ชจ๋ธ ๋น„๊ต"):
1014
  with gr.Row():
1015
  with gr.Column(scale=1):
1016
+ compare_custom_url = gr.Textbox(
1017
+ label="๐Ÿ”— ์ถ”๊ฐ€ ๋น„๊ต ๋ชจ๋ธ URL (์„ ํƒ์‚ฌํ•ญ)",
1018
+ placeholder="์˜ˆ: microsoft/phi-2",
1019
+ value=""
1020
+ )
1021
+
1022
  compare_text = gr.Textbox(
1023
  label="๋น„๊ต ํ…์ŠคํŠธ",
1024
  lines=5,
 
1044
 
1045
  compare_btn.click(
1046
  fn=compare_retention_methods,
1047
+ inputs=[compare_custom_url, compare_text, compare_length, benchmark_tasks],
1048
  outputs=[compare_result, compare_plot]
1049
  )
1050
 
 
1102
  2. **์ ์‘์  ์••์ถ•** - ์ค‘์š”๋„ ๊ธฐ๋ฐ˜ ๋™์  ์••์ถ•
1103
  3. **๋™์  Power** - ์ž…๋ ฅ ๋”ฐ๋ผ ์ž๋™ ์ตœ์ ํ™”
1104
  4. **๋ณ‘๋ ฌ ๊ฒฝ๋กœ** - ๋‹ค์ค‘ ์ „๋žต ๋™์‹œ ์šด์˜
1105
+ 5. **์ปค์Šคํ…€ Base** - ๋ชจ๋“  HF ๋ชจ๋ธ ์ง€์›
1106
+
1107
+ ### ๐Ÿ“š ์ถ”์ฒœ Base Models
1108
+ - `ibm-granite/granite-4.0-h-350m` (๊ธฐ๋ณธ)
1109
+ - `meta-llama/Llama-3.2-1B`
1110
+ - `microsoft/phi-2`
1111
+ - `Qwen/Qwen2.5-0.5B`
1112
+ - `google/gemma-2-2b`
1113
 
1114
  **VIDraft AI Research Lab** | L40S GPU + Persistent Storage
1115
  """)