anemll commited on
Commit
f32a712
·
verified ·
1 Parent(s): 86baea7

Fixed GIL issue

Browse files

race condition between CoreML and casual_mask update

Files changed (1) hide show
  1. chat.py +214 -313
chat.py CHANGED
@@ -26,10 +26,8 @@ DARK_BLUE = "\033[34m"
26
  LIGHT_GREEN = "\033[92m"
27
  RESET_COLOR = "\033[0m"
28
 
29
- # Add at the top with other constants
30
  WARMUP_TOKEN_LIMIT = 10 # Maximum tokens to generate during warmup
31
- THINKING_MODE = False
32
- THINKING_PROMPT = """You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem."""
33
 
34
  class TokenPrinter:
35
  """Handles background printing of generated tokens."""
@@ -42,12 +40,9 @@ class TokenPrinter:
42
  self.lock = threading.Lock()
43
  self.thinking = True # Track if we're still in thinking mode
44
  self.decoding_buffer = [] # Buffer for token IDs
45
- # Timing and stats tracking
46
  self.start_time = time.time()
47
  self.token_count = 0
48
- self.prefill_time = 0
49
- self.inference_time = 0
50
- self.context_pos = 0
51
  self.start()
52
 
53
  def start(self):
@@ -108,15 +103,15 @@ class TokenPrinter:
108
  self.thread.join(timeout=1.0)
109
  except Exception:
110
  pass
111
- print(RESET_COLOR) # Reset color at the end
 
 
 
 
 
 
112
  return self.buffer
113
 
114
- def set_timing(self, prefill_time, inference_time, context_pos):
115
- """Set timing information."""
116
- self.prefill_time = prefill_time
117
- self.inference_time = inference_time
118
- self.context_pos = context_pos
119
-
120
  def parse_model_path(path):
121
  """Parse model path and return full path with .mlmodelc or .mlpackage extension."""
122
  path = Path(path)
@@ -193,89 +188,6 @@ def load_model(path, function_name=None):
193
  print("\nTry using the .mlpackage version instead, or recompile the model.")
194
  raise
195
 
196
- def parse_args():
197
- parser = argparse.ArgumentParser(description='Full Chat with CoreML LLaMA with context window shifting, gil resolved (c) 2025 Anemll')
198
-
199
- # Add meta.yaml option
200
- parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
201
-
202
- # Add existing arguments
203
- parser.add_argument('--d', '--dir', type=str, default='.',
204
- help='Directory containing model files (default: current directory)')
205
- parser.add_argument('--embed', type=str, required=False,
206
- help='Path to embeddings model (relative to --dir)')
207
- parser.add_argument('--ffn', type=str, required=False,
208
- help='Path to FFN model (can be chunked, relative to --dir)')
209
- parser.add_argument('--lmhead', type=str, required=False,
210
- help='Path to LM head model (relative to --dir)')
211
- parser.add_argument('--tokenizer', type=str, required=False,
212
- help='Path to tokenizer')
213
-
214
- # Add new argument for auto-generation
215
- parser.add_argument('--prompt', type=str,
216
- help='If specified, run once with this prompt and exit')
217
-
218
- # Add no-warmup flag
219
- parser.add_argument('--nw', action='store_true',
220
- help='Skip warmup phase')
221
-
222
- # Model configuration
223
- parser.add_argument('--context-length', type=int,
224
- help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
225
- parser.add_argument('--batch-size', type=int,
226
- help='Batch size for prefill (default: 64)')
227
-
228
- args = parser.parse_args()
229
-
230
- # If meta.yaml is provided, load parameters from it
231
- if args.meta:
232
- try:
233
- with open(args.meta, 'r') as f:
234
- meta = yaml.safe_load(f)
235
- params = meta['model_info']['parameters']
236
-
237
- # Set model directory to meta.yaml directory if not specified
238
- if not args.d or args.d == '.':
239
- args.d = str(Path(args.meta).parent)
240
-
241
- # Build model paths based on parameters
242
- prefix = params.get('model_prefix', 'llama') # Default to 'llama' if not specified
243
- lut_ffn = f"_lut{params['lut_ffn']}" if params['lut_ffn'] != 'none' else ''
244
- lut_lmhead = f"_lut{params['lut_lmhead']}" if params['lut_lmhead'] != 'none' else ''
245
- num_chunks = int(params['num_chunks'])
246
-
247
- # Set model paths if not specified
248
- if not args.embed:
249
- args.embed = f'{prefix}_embeddings'
250
- if not args.lmhead:
251
- args.lmhead = f'{prefix}_lm_head{lut_lmhead}'
252
- if not args.ffn:
253
- args.ffn = f'{prefix}_FFN_PF{lut_ffn}_chunk_01of{num_chunks:02d}'
254
- if not args.tokenizer:
255
- args.tokenizer = args.d
256
-
257
- # Set other parameters if not overridden by command line
258
- if args.context_length is None:
259
- args.context_length = int(params['context_length'])
260
- if args.batch_size is None:
261
- args.batch_size = int(params['batch_size'])
262
- args.num_chunks = num_chunks
263
-
264
- print(f"\nLoaded parameters from {args.meta}:")
265
- print(f" Context Length: {args.context_length}")
266
- print(f" Batch Size: {args.batch_size}")
267
- print(f" Num Chunks: {args.num_chunks}")
268
- print(f" Models Directory: {args.d}")
269
- print(f" Embeddings: {args.embed}")
270
- print(f" LM Head: {args.lmhead}")
271
- print(f" FFN: {args.ffn}")
272
-
273
- except Exception as e:
274
- print(f"\nError loading meta.yaml: {str(e)}")
275
- sys.exit(1)
276
-
277
- return args
278
-
279
  def load_metadata(model,args):
280
  # Extract metadata and config parameters
281
  metadata = {}
@@ -474,74 +386,84 @@ def make_causal_mask(length, start):
474
  mask[:, :, col_indices <= (row_indices + start)] = 0
475
  return mask
476
 
477
- def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length, batch_size, state, causal_mask):
 
 
 
 
 
 
 
478
  """Run prefill on the input sequence."""
479
- #print(f"[DEBUG] Running prefill from 0 to {current_pos}")
 
 
 
480
 
481
  # Process in batches
482
  batch_pos = 0
483
- while batch_pos < current_pos:
484
- batch_end = min(batch_pos + batch_size, current_pos)
485
  current_batch_size = batch_end - batch_pos
486
 
487
- #print(f"[DEBUG] Prefill batch {batch_pos}-{batch_end} (size={current_batch_size})")
488
-
489
  # Get current batch
490
  batch_input = input_ids[:, batch_pos:batch_end]
491
 
492
- # Pad to full batch size
493
  batch_input = F.pad(
494
  batch_input,
495
  (0, batch_size - current_batch_size),
496
  value=0
497
  )
498
 
499
- # Generate position IDs for this batch
500
- position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
501
-
502
- # Use the pre-initialized causal mask and extract the batch portion
503
- batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
504
 
505
  # Run embeddings
506
  hidden_states = torch.from_numpy(
507
  embed_model.predict({'input_ids': batch_input.numpy()})['hidden_states']
508
  )
509
 
510
- # Run through FFN chunks
511
  for ffn_model in ffn_models:
512
  if isinstance(ffn_model, dict):
513
  inputs = {
514
- 'hidden_states': hidden_states.numpy(),
515
- 'position_ids': position_ids.numpy(),
516
- 'causal_mask': batch_causal_mask.numpy(),
517
- 'current_pos': np.array([batch_pos], dtype=np.int32)
518
  }
519
  output = ffn_model['prefill'].predict(inputs, state)
520
  hidden_states = torch.from_numpy(output['output_hidden_states'])
521
 
522
  batch_pos = batch_end
523
 
524
- return torch.tensor([current_pos], dtype=torch.int32)
525
 
526
- def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state, causal_mask, temperature=0.0):
527
  """Generate the next token."""
528
  # Get current token
529
- current_token = input_ids[:, pos-1:pos]
530
 
531
  # Run embeddings
532
  hidden_states = torch.from_numpy(
533
  embed_model.predict({'input_ids': current_token.numpy()})['hidden_states']
534
- )
535
 
536
  # Create masks
537
  update_mask = torch.zeros((1, 1, context_length, 1), dtype=torch.float16)
538
  update_mask[0, 0, pos-1, 0] = 1.0
539
- position_ids = torch.tensor([pos-1], dtype=torch.int32)
540
 
541
- # Use the pre-initialized causal mask and extract the single position portion
542
- single_causal_mask = causal_mask[:, :, pos-1:pos, :]
 
 
 
 
543
 
544
- # Run through FFN chunks
545
  for ffn_model in ffn_models:
546
  if isinstance(ffn_model, dict):
547
  inputs = {
@@ -554,19 +476,25 @@ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, c
554
  output = ffn_model['infer'].predict(inputs, state)
555
  hidden_states = torch.from_numpy(output['output_hidden_states'])
556
 
557
- # Run LM head and get next token
558
  lm_output = lmhead_model.predict({'hidden_states': hidden_states.numpy()})
 
 
559
 
 
560
  if 'logits1' in lm_output:
 
561
  logits_parts = []
562
  for i in range(1, 9):
563
  key = f'logits{i}'
564
  if key in lm_output:
565
  logits_parts.append(torch.from_numpy(lm_output[key]))
566
- logits = torch.cat(logits_parts, dim=-1)
567
  else:
 
568
  logits = torch.from_numpy(lm_output['output_logits'])
569
 
 
570
  if temperature > 0:
571
  logits = logits / temperature
572
  probs = F.softmax(logits[0, -1, :], dim=-1)
@@ -588,93 +516,36 @@ def create_unified_state(ffn_models, context_length):
588
  print("\nCreated unified transformer state")
589
  return state
590
 
591
- def initialize_causal_mask(context_length):
592
- """Initialize causal mask for transformer attention."""
593
- causal_mask = make_causal_mask(context_length, 0)
594
- causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
595
- print(f"\nInitialized causal mask for context length {context_length}")
596
- return causal_mask
597
-
598
- def get_user_input():
599
- """Get input from user, handling special key combinations."""
600
- global THINKING_MODE
601
- try:
602
- import termios
603
- import tty
604
- import sys
605
-
606
- def _getch():
607
- fd = sys.stdin.fileno()
608
- old_settings = termios.tcgetattr(fd)
609
- try:
610
- tty.setraw(sys.stdin.fileno())
611
- ch = sys.stdin.read(1)
612
- finally:
613
- termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
614
- return ch
615
-
616
- buffer = []
617
- while True:
618
- char = _getch()
619
-
620
- # Debug: print the character code
621
- print(f"\nKey pressed: {repr(char)} (hex: {hex(ord(char))})")
622
-
623
- # Check for Enter key
624
- if char == '\r' or char == '\n':
625
- print() # Move to next line
626
- input_text = ''.join(buffer)
627
- # Check if the command is /t
628
- if input_text == '/t':
629
- THINKING_MODE = not THINKING_MODE
630
- print(f"Thinking mode {'ON' if THINKING_MODE else 'OFF'}")
631
- buffer = [] # Clear buffer
632
- print(f"\n{LIGHT_GREEN}You{' (thinking)' if THINKING_MODE else ''}:{RESET_COLOR}", end=' ', flush=True)
633
- continue
634
- return input_text
635
-
636
- # Handle backspace
637
- if char == '\x7f': # backspace
638
- if buffer:
639
- buffer.pop()
640
- sys.stdout.write('\b \b') # Erase character
641
- sys.stdout.flush()
642
- continue
643
-
644
- # Handle Ctrl-C
645
- if char == '\x03': # Ctrl-C
646
- print("^C")
647
- raise KeyboardInterrupt
648
-
649
- # Print character and add to buffer
650
- sys.stdout.write(char)
651
- sys.stdout.flush()
652
- buffer.append(char)
653
-
654
- except ImportError:
655
- # Fallback for systems without termios
656
- return input("> ")
657
-
658
- def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, causal_mask, auto_prompt=None, warmup=False):
659
  """Interactive chat loop."""
660
- global THINKING_MODE
661
  context_length = metadata.get('context_length')
662
  batch_size = metadata.get('batch_size', 64)
663
 
664
  if not warmup:
665
  print(f"\nUsing context length: {context_length}")
666
  print("\nStarting chat session. Press Ctrl+D to exit.")
667
- print("Type your message and press Enter to chat. Use /t to toggle thinking mode.")
668
- print(f"Thinking mode is {'ON' if THINKING_MODE else 'OFF'}")
 
 
 
 
 
 
 
 
 
 
 
 
669
 
670
- # Keep track of conversation history
671
  conversation = []
672
 
673
  try:
674
  while True:
675
  try:
676
  if not warmup:
677
- print(f"\n{LIGHT_GREEN}You{' (thinking)' if THINKING_MODE else ''}:{RESET_COLOR}", end=' ', flush=True)
678
  if auto_prompt is not None:
679
  user_input = auto_prompt
680
  if not warmup:
@@ -685,69 +556,41 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
685
  if not warmup:
686
  print("\nExiting chat...")
687
  break
688
-
689
  if not user_input:
690
  continue
691
-
692
- # Handle /t command
693
- if user_input == "/t":
694
- THINKING_MODE = not THINKING_MODE
695
- print(f"Thinking mode {'ON' if THINKING_MODE else 'OFF'}")
696
- continue
697
-
698
- # Add user message to conversation
699
- conversation.append({"role": "user", "content": user_input})
700
 
701
- # Format using chat template with full history
702
- if THINKING_MODE:
703
- # Add thinking prompt to system message
704
- conversation_with_thinking = [{"role": "system", "content": THINKING_PROMPT}] + conversation
705
- base_input_ids = tokenizer.apply_chat_template(
706
- conversation_with_thinking,
707
  return_tensors="pt",
708
  add_generation_prompt=True
709
  ).to(torch.int32)
710
  else:
711
- base_input_ids = tokenizer.apply_chat_template(
712
- conversation,
 
 
713
  return_tensors="pt",
714
- add_generation_prompt=True
715
- ).to(torch.int32)
716
 
717
- # Check if we need to trim history
718
- while base_input_ids.size(1) > context_length - 100: # Leave room for response
719
- # Remove oldest message pair (user + assistant)
720
- if len(conversation) > 2:
721
- conversation = conversation[2:] # Remove oldest pair
722
- base_input_ids = tokenizer.apply_chat_template(
723
- conversation,
724
- return_tensors="pt",
725
- add_generation_prompt=True
726
- ).to(torch.int32)
727
- else:
728
- # If only current message remains and still too long, truncate
729
- base_input_ids = base_input_ids[:, -context_length//2:]
730
- break
731
-
732
- context_pos = base_input_ids.size(1)
733
-
734
- # Pad sequence to context_size
735
- input_ids = F.pad(
736
- base_input_ids,
737
- (0, context_length - context_pos),
738
- value=0
739
- )
740
 
741
  if not warmup:
742
  print(f"\n{LIGHT_BLUE}Assistant:{RESET_COLOR}", end=' ', flush=True)
743
 
744
- # Initialize token printer and collect response
745
  token_printer = TokenPrinter(tokenizer)
746
- response_tokens = []
747
- generation_start_time = time.time()
748
 
749
  try:
750
- # Run prefill on entire context
 
 
 
751
  current_pos = run_prefill(
752
  embed_model,
753
  ffn_models,
@@ -758,51 +601,20 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
758
  state,
759
  causal_mask
760
  )
761
- #print(f"\n[DEBUG] After initial prefill - current_pos: {current_pos}")
762
 
763
- # Generation loop
 
 
 
 
 
 
764
  pos = context_pos
765
- tokens_generated = 0
766
- inference_start = time.time() # Start inference timing
767
 
768
- while True:
769
- # Check if we need to shift window
770
- if pos >= context_length - 2:
771
- # Calculate shift to maintain full batches
772
- batch_size = metadata.get('batch_size', 64)
773
- # Calculate max batches that fit in context
774
- max_batches = context_length // batch_size
775
- desired_batches = max(1, max_batches - 2) # Leave room for new tokens
776
- new_size = min(desired_batches * batch_size, context_length - batch_size)
777
-
778
- # Create shifted input_ids
779
- tmp = torch.zeros((1, context_length), dtype=torch.int32)
780
- tmp[:,0:new_size] = input_ids[:,pos-new_size:pos]
781
- input_ids = tmp
782
-
783
- # Reset state and run prefill
784
- # keep the same state
785
- #state = create_unified_state(ffn_models, context_length)
786
- current_pos = run_prefill(
787
- embed_model,
788
- ffn_models,
789
- input_ids,
790
- new_size, # Prefill the entire shifted content
791
- context_length,
792
- batch_size,
793
- state,
794
- causal_mask
795
- )
796
-
797
- # Start generating from the next position
798
- pos = new_size # Don't back up, continue from where we left off
799
-
800
- #print(f"\n[DEBUG] After shift - next token will be at pos {pos}")
801
- #print(f"[DEBUG] Context before next token: {tokenizer.decode(input_ids[0, pos-40:pos])}")
802
-
803
- window_shifted = True
804
-
805
- # Generate next token
806
  next_token = generate_next_token(
807
  embed_model,
808
  ffn_models,
@@ -814,54 +626,143 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
814
  causal_mask
815
  )
816
 
817
- # Add token
818
- input_ids[0, pos] = next_token
 
 
 
 
 
 
 
 
819
  if not warmup:
820
  token_printer.add_token(next_token)
821
  token_printer.drain_buffer()
822
- response_tokens.append(next_token)
823
 
824
  pos += 1
825
  tokens_generated += 1
 
826
 
827
- # In warmup mode, limit tokens
828
  if warmup and tokens_generated >= WARMUP_TOKEN_LIMIT:
829
  break
830
-
831
  if next_token == tokenizer.eos_token_id:
832
  break
833
 
834
- inference_time = time.time() - inference_start # Calculate inference time
835
-
836
- # Add assistant response to conversation
837
- response_text = token_printer.stop()
838
- conversation.append({"role": "assistant", "content": response_text})
839
 
840
- # Print stats only if not in warmup
841
  if not warmup:
842
- total_time = time.time() - generation_start_time
843
- prefill_time = total_time - inference_time
844
- inference_tokens_per_sec = len(response_tokens) / inference_time if inference_time > 0 else 0
845
- prefill_ms = prefill_time * 1000
846
- prefill_tokens_per_sec = context_pos / prefill_time if prefill_time > 0 else 0
847
- print(f"{DARK_BLUE}{inference_tokens_per_sec:.1f} t/s, "
848
- f"TTFT: {prefill_ms:.1f}ms ({prefill_tokens_per_sec:.1f} t/s), "
849
- f"{len(response_tokens)} tokens{RESET_COLOR}")
 
850
 
 
851
  if auto_prompt is not None:
852
  break
853
 
854
  except KeyboardInterrupt:
855
- if not warmup:
856
- print("\nGeneration interrupted")
857
  token_printer.stop()
858
  continue
859
 
860
  except Exception as e:
861
- if not warmup:
862
- print(f"\nError in chat loop: {str(e)}")
863
- import traceback
864
- traceback.print_exc()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
865
 
866
  def main():
867
  args = parse_args()
@@ -926,7 +827,7 @@ def main():
926
  lmhead_model=lmhead_model,
927
  tokenizer=tokenizer,
928
  metadata=metadata,
929
- state=state, # Pass the state
930
  causal_mask=causal_mask, # Pass the causal mask
931
  warmup=True,
932
  auto_prompt="who are you?"
@@ -939,7 +840,7 @@ def main():
939
  lmhead_model=lmhead_model,
940
  tokenizer=tokenizer,
941
  metadata=metadata,
942
- state=state, # Pass the state
943
  causal_mask=causal_mask, # Pass the causal mask
944
  warmup=False,
945
  auto_prompt=args.prompt
 
26
  LIGHT_GREEN = "\033[92m"
27
  RESET_COLOR = "\033[0m"
28
 
29
+ # Add at top with other constants
30
  WARMUP_TOKEN_LIMIT = 10 # Maximum tokens to generate during warmup
 
 
31
 
32
  class TokenPrinter:
33
  """Handles background printing of generated tokens."""
 
40
  self.lock = threading.Lock()
41
  self.thinking = True # Track if we're still in thinking mode
42
  self.decoding_buffer = [] # Buffer for token IDs
43
+ # Add token counting and timing
44
  self.start_time = time.time()
45
  self.token_count = 0
 
 
 
46
  self.start()
47
 
48
  def start(self):
 
103
  self.thread.join(timeout=1.0)
104
  except Exception:
105
  pass
106
+ # Calculate and print tokens/s with shorter format in blue
107
+ elapsed = time.time() - self.start_time
108
+ if elapsed > 0 and self.token_count > 0:
109
+ tokens_per_sec = self.token_count / elapsed
110
+ print(f"\n{DARK_BLUE}{tokens_per_sec:.1f} t/s{RESET_COLOR}")
111
+ else:
112
+ print(RESET_COLOR) # Reset color at the end
113
  return self.buffer
114
 
 
 
 
 
 
 
115
  def parse_model_path(path):
116
  """Parse model path and return full path with .mlmodelc or .mlpackage extension."""
117
  path = Path(path)
 
188
  print("\nTry using the .mlpackage version instead, or recompile the model.")
189
  raise
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  def load_metadata(model,args):
192
  # Extract metadata and config parameters
193
  metadata = {}
 
386
  mask[:, :, col_indices <= (row_indices + start)] = 0
387
  return mask
388
 
389
+ def initialize_causal_mask(context_length):
390
+ """Initialize causal mask for transformer attention."""
391
+ causal_mask = make_causal_mask(context_length, 0)
392
+ causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
393
+ print(f"\nInitialized causal mask for context length {context_length}")
394
+ return causal_mask
395
+
396
+ def run_prefill(embed_model, ffn_models, input_ids, context_pos, context_length, batch_size=64, state=None, causal_mask=None):
397
  """Run prefill on the input sequence."""
398
+ # Use provided causal mask or create one if not provided
399
+ if causal_mask is None:
400
+ causal_mask = make_causal_mask(context_length, 0)
401
+ causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
402
 
403
  # Process in batches
404
  batch_pos = 0
405
+ while batch_pos < context_pos:
406
+ batch_end = min(batch_pos + batch_size, context_pos)
407
  current_batch_size = batch_end - batch_pos
408
 
 
 
409
  # Get current batch
410
  batch_input = input_ids[:, batch_pos:batch_end]
411
 
412
+ # Always pad to full batch size for prefill
413
  batch_input = F.pad(
414
  batch_input,
415
  (0, batch_size - current_batch_size),
416
  value=0
417
  )
418
 
419
+ # Generate position IDs for full batch size
420
+ position_ids = torch.arange(batch_size, dtype=torch.int32) # Changed: Always use full batch size
421
+ batch_causal_mask = causal_mask[:, :, :batch_size, :] # Changed: Use full batch size
 
 
422
 
423
  # Run embeddings
424
  hidden_states = torch.from_numpy(
425
  embed_model.predict({'input_ids': batch_input.numpy()})['hidden_states']
426
  )
427
 
428
+ # Run through FFN chunks with state
429
  for ffn_model in ffn_models:
430
  if isinstance(ffn_model, dict):
431
  inputs = {
432
+ 'hidden_states': hidden_states.numpy(), # [1, 64, hidden_size]
433
+ 'position_ids': position_ids.numpy(), # [64]
434
+ 'causal_mask': batch_causal_mask.numpy(), # [1, 1, 64, context_length]
435
+ 'current_pos': np.array([batch_pos], dtype=np.int32) # [1]
436
  }
437
  output = ffn_model['prefill'].predict(inputs, state)
438
  hidden_states = torch.from_numpy(output['output_hidden_states'])
439
 
440
  batch_pos = batch_end
441
 
442
+ return torch.tensor([context_pos], dtype=torch.int32)
443
 
444
+ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state=None, causal_mask=None, temperature=0.0):
445
  """Generate the next token."""
446
  # Get current token
447
+ current_token = input_ids[:, pos-1:pos] # [1, 1]
448
 
449
  # Run embeddings
450
  hidden_states = torch.from_numpy(
451
  embed_model.predict({'input_ids': current_token.numpy()})['hidden_states']
452
+ ) # [1, 1, hidden_size]
453
 
454
  # Create masks
455
  update_mask = torch.zeros((1, 1, context_length, 1), dtype=torch.float16)
456
  update_mask[0, 0, pos-1, 0] = 1.0
457
+ position_ids = torch.tensor([pos-1], dtype=torch.int32) # [1]
458
 
459
+ # Use provided causal mask or create one if not provided
460
+ if causal_mask is None:
461
+ causal_mask_data = make_causal_mask(context_length, 0)
462
+ single_causal_mask = torch.tensor(causal_mask_data[:, :, pos-1:pos, :], dtype=torch.float16) # [1, 1, 1, context_length]
463
+ else:
464
+ single_causal_mask = causal_mask[:, :, pos-1:pos, :]
465
 
466
+ # Run through FFN chunks with state
467
  for ffn_model in ffn_models:
468
  if isinstance(ffn_model, dict):
469
  inputs = {
 
476
  output = ffn_model['infer'].predict(inputs, state)
477
  hidden_states = torch.from_numpy(output['output_hidden_states'])
478
 
479
+ # Run LM head
480
  lm_output = lmhead_model.predict({'hidden_states': hidden_states.numpy()})
481
+ # Debug print
482
+ #print("\nLM Head output keys:", list(lm_output.keys()))
483
 
484
+ # Combine logits1-8 if they exist
485
  if 'logits1' in lm_output:
486
+ # Concatenate all logits parts
487
  logits_parts = []
488
  for i in range(1, 9):
489
  key = f'logits{i}'
490
  if key in lm_output:
491
  logits_parts.append(torch.from_numpy(lm_output[key]))
492
+ logits = torch.cat(logits_parts, dim=-1) # Concatenate along vocab dimension
493
  else:
494
+ # Try output_logits as fallback
495
  logits = torch.from_numpy(lm_output['output_logits'])
496
 
497
+ # Apply temperature and sample
498
  if temperature > 0:
499
  logits = logits / temperature
500
  probs = F.softmax(logits[0, -1, :], dim=-1)
 
516
  print("\nCreated unified transformer state")
517
  return state
518
 
519
+ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, causal_mask=None, auto_prompt=None, warmup=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
  """Interactive chat loop."""
 
521
  context_length = metadata.get('context_length')
522
  batch_size = metadata.get('batch_size', 64)
523
 
524
  if not warmup:
525
  print(f"\nUsing context length: {context_length}")
526
  print("\nStarting chat session. Press Ctrl+D to exit.")
527
+ print("Type your message and press Enter to chat.")
528
+
529
+ # Check if tokenizer has chat template and if it works
530
+ has_chat_template = False
531
+ try:
532
+ # Test if chat template works
533
+ test_messages = [{"role": "user", "content": "test"}]
534
+ tokenizer.apply_chat_template(test_messages, return_tensors="pt")
535
+ has_chat_template = True
536
+ if not warmup:
537
+ print("\nUsing chat template for prompts")
538
+ except:
539
+ if not warmup:
540
+ print("\nUsing manual formatting for prompts")
541
 
 
542
  conversation = []
543
 
544
  try:
545
  while True:
546
  try:
547
  if not warmup:
548
+ print(f"\n{LIGHT_GREEN}You:{RESET_COLOR}", end=' ', flush=True)
549
  if auto_prompt is not None:
550
  user_input = auto_prompt
551
  if not warmup:
 
556
  if not warmup:
557
  print("\nExiting chat...")
558
  break
559
+
560
  if not user_input:
561
  continue
 
 
 
 
 
 
 
 
 
562
 
563
+ # Format prompt based on tokenizer capabilities
564
+ if has_chat_template:
565
+ messages = [{"role": "user", "content": user_input}]
566
+ input_ids = tokenizer.apply_chat_template(
567
+ messages,
 
568
  return_tensors="pt",
569
  add_generation_prompt=True
570
  ).to(torch.int32)
571
  else:
572
+ # Manual formatting for Llama models without chat template
573
+ formatted_prompt = f"[INST] {user_input} [/INST]"
574
+ input_ids = tokenizer(
575
+ formatted_prompt,
576
  return_tensors="pt",
577
+ add_special_tokens=True
578
+ ).input_ids.to(torch.int32)
579
 
580
+ context_pos = input_ids.size(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581
 
582
  if not warmup:
583
  print(f"\n{LIGHT_BLUE}Assistant:{RESET_COLOR}", end=' ', flush=True)
584
 
585
+ # Initialize token printer
586
  token_printer = TokenPrinter(tokenizer)
587
+ tokens_generated = 0 # Track number of tokens
 
588
 
589
  try:
590
+ # Start prefill timing
591
+ prefill_start = time.time()
592
+
593
+ # Run prefill with state and causal mask
594
  current_pos = run_prefill(
595
  embed_model,
596
  ffn_models,
 
601
  state,
602
  causal_mask
603
  )
 
604
 
605
+ # Calculate prefill timing
606
+ prefill_time = time.time() - prefill_start
607
+ prefill_tokens = context_pos # Number of tokens in input
608
+ prefill_tokens_per_sec = prefill_tokens / prefill_time if prefill_time > 0 else 0
609
+
610
+ # Generation loop with state
611
+ input_ids = input_ids
612
  pos = context_pos
613
+ inference_start = time.time()
614
+ inference_tokens = 0
615
 
616
+ while pos < context_length - 1:
617
+ # Generate next token with causal mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
618
  next_token = generate_next_token(
619
  embed_model,
620
  ffn_models,
 
626
  causal_mask
627
  )
628
 
629
+ # Add token to sequence
630
+ if pos < input_ids.size(1):
631
+ input_ids[0, pos] = next_token
632
+ else:
633
+ input_ids = torch.cat([
634
+ input_ids,
635
+ torch.tensor([[next_token]], dtype=torch.int32)
636
+ ], dim=1)
637
+
638
+ # Add to printer only if not in warmup
639
  if not warmup:
640
  token_printer.add_token(next_token)
641
  token_printer.drain_buffer()
 
642
 
643
  pos += 1
644
  tokens_generated += 1
645
+ inference_tokens += 1
646
 
647
+ # Check limits
648
  if warmup and tokens_generated >= WARMUP_TOKEN_LIMIT:
649
  break
650
+
651
  if next_token == tokenizer.eos_token_id:
652
  break
653
 
654
+ # Calculate inference timing
655
+ inference_time = time.time() - inference_start
656
+ inference_tokens_per_sec = inference_tokens / inference_time if inference_time > 0 else 0
 
 
657
 
658
+ # Get final response and add to conversation
659
  if not warmup:
660
+ response = token_printer.stop()
661
+ # Print timing stats
662
+ prefill_ms = prefill_time * 1000 # Convert to milliseconds
663
+ print(f"\nPrefill: {prefill_ms:.1f}ms ({prefill_tokens_per_sec:.1f} t/s)")
664
+ print(f"Inference: {inference_tokens_per_sec:.1f} t/s")
665
+ print(f"Total: Generated {tokens_generated} tokens in {prefill_time + inference_time:.2f}s")
666
+ conversation.append({"role": "assistant", "content": response})
667
+ else:
668
+ token_printer.stop() # Clean up without printing stats
669
 
670
+ # Exit after one response in auto_prompt mode
671
  if auto_prompt is not None:
672
  break
673
 
674
  except KeyboardInterrupt:
675
+ print("\nGeneration interrupted")
 
676
  token_printer.stop()
677
  continue
678
 
679
  except Exception as e:
680
+ print(f"\nError in chat loop: {str(e)}")
681
+ import traceback
682
+ traceback.print_exc()
683
+
684
+ def parse_args():
685
+ parser = argparse.ArgumentParser(description='Chat with CoreML LLaMA, gil resolved (c) 2025 Anemll')
686
+
687
+ # Add meta.yaml option
688
+ parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
689
+
690
+ # Model paths
691
+ parser.add_argument('--d', '--dir', type=str, default='.',
692
+ help='Directory containing model files (default: current directory)')
693
+ parser.add_argument('--embed', type=str, required=False,
694
+ help='Path to embeddings model (relative to --dir)')
695
+ parser.add_argument('--ffn', type=str, required=False,
696
+ help='Path to FFN model (can be chunked, relative to --dir)')
697
+ parser.add_argument('--lmhead', type=str, required=False,
698
+ help='Path to LM head model (relative to --dir)')
699
+ parser.add_argument('--tokenizer', type=str, required=False,
700
+ help='Path to tokenizer')
701
+
702
+ # Add new argument for auto-generation
703
+ parser.add_argument('--prompt', type=str,
704
+ help='If specified, run once with this prompt and exit')
705
+
706
+ # Add no-warmup flag
707
+ parser.add_argument('--nw', action='store_true',
708
+ help='Skip warmup phase')
709
+
710
+ # Model configuration
711
+ parser.add_argument('--context-length', type=int,
712
+ help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
713
+ parser.add_argument('--batch-size', type=int,
714
+ help='Batch size for prefill (default: 64)')
715
+
716
+ args = parser.parse_args()
717
+
718
+ # If meta.yaml is provided, load parameters from it
719
+ if args.meta:
720
+ try:
721
+ with open(args.meta, 'r') as f:
722
+ meta = yaml.safe_load(f)
723
+ params = meta['model_info']['parameters']
724
+
725
+ # Set model directory to meta.yaml directory if not specified
726
+ if not args.d or args.d == '.':
727
+ args.d = str(Path(args.meta).parent)
728
+
729
+ # Build model paths based on parameters
730
+ prefix = params.get('model_prefix', 'llama') # Default to 'llama' if not specified
731
+ lut_ffn = f"_lut{params['lut_ffn']}" if params['lut_ffn'] != 'none' else ''
732
+ lut_lmhead = f"_lut{params['lut_lmhead']}" if params['lut_lmhead'] != 'none' else ''
733
+ num_chunks = int(params['num_chunks'])
734
+
735
+ # Set model paths if not specified
736
+ if not args.embed:
737
+ args.embed = f'{prefix}_embeddings'
738
+ if not args.lmhead:
739
+ args.lmhead = f'{prefix}_lm_head{lut_lmhead}'
740
+ if not args.ffn:
741
+ args.ffn = f'{prefix}_FFN_PF{lut_ffn}_chunk_01of{num_chunks:02d}'
742
+ if not args.tokenizer:
743
+ args.tokenizer = args.d
744
+
745
+ # Set other parameters if not overridden by command line
746
+ if args.context_length is None:
747
+ args.context_length = int(params['context_length'])
748
+ if args.batch_size is None:
749
+ args.batch_size = int(params['batch_size'])
750
+ args.num_chunks = num_chunks
751
+
752
+ print(f"\nLoaded parameters from {args.meta}:")
753
+ print(f" Context Length: {args.context_length}")
754
+ print(f" Batch Size: {args.batch_size}")
755
+ print(f" Num Chunks: {args.num_chunks}")
756
+ print(f" Models Directory: {args.d}")
757
+ print(f" Embeddings: {args.embed}")
758
+ print(f" LM Head: {args.lmhead}")
759
+ print(f" FFN: {args.ffn}")
760
+
761
+ except Exception as e:
762
+ print(f"\nError loading meta.yaml: {str(e)}")
763
+ sys.exit(1)
764
+
765
+ return args
766
 
767
  def main():
768
  args = parse_args()
 
827
  lmhead_model=lmhead_model,
828
  tokenizer=tokenizer,
829
  metadata=metadata,
830
+ state=state,
831
  causal_mask=causal_mask, # Pass the causal mask
832
  warmup=True,
833
  auto_prompt="who are you?"
 
840
  lmhead_model=lmhead_model,
841
  tokenizer=tokenizer,
842
  metadata=metadata,
843
+ state=state,
844
  causal_mask=causal_mask, # Pass the causal mask
845
  warmup=False,
846
  auto_prompt=args.prompt