Fixed GIL issue
Browse filesrace condition between CoreML and casual_mask update
chat.py
CHANGED
@@ -26,10 +26,8 @@ DARK_BLUE = "\033[34m"
|
|
26 |
LIGHT_GREEN = "\033[92m"
|
27 |
RESET_COLOR = "\033[0m"
|
28 |
|
29 |
-
# Add at
|
30 |
WARMUP_TOKEN_LIMIT = 10 # Maximum tokens to generate during warmup
|
31 |
-
THINKING_MODE = False
|
32 |
-
THINKING_PROMPT = """You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem."""
|
33 |
|
34 |
class TokenPrinter:
|
35 |
"""Handles background printing of generated tokens."""
|
@@ -42,12 +40,9 @@ class TokenPrinter:
|
|
42 |
self.lock = threading.Lock()
|
43 |
self.thinking = True # Track if we're still in thinking mode
|
44 |
self.decoding_buffer = [] # Buffer for token IDs
|
45 |
-
#
|
46 |
self.start_time = time.time()
|
47 |
self.token_count = 0
|
48 |
-
self.prefill_time = 0
|
49 |
-
self.inference_time = 0
|
50 |
-
self.context_pos = 0
|
51 |
self.start()
|
52 |
|
53 |
def start(self):
|
@@ -108,15 +103,15 @@ class TokenPrinter:
|
|
108 |
self.thread.join(timeout=1.0)
|
109 |
except Exception:
|
110 |
pass
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
return self.buffer
|
113 |
|
114 |
-
def set_timing(self, prefill_time, inference_time, context_pos):
|
115 |
-
"""Set timing information."""
|
116 |
-
self.prefill_time = prefill_time
|
117 |
-
self.inference_time = inference_time
|
118 |
-
self.context_pos = context_pos
|
119 |
-
|
120 |
def parse_model_path(path):
|
121 |
"""Parse model path and return full path with .mlmodelc or .mlpackage extension."""
|
122 |
path = Path(path)
|
@@ -193,89 +188,6 @@ def load_model(path, function_name=None):
|
|
193 |
print("\nTry using the .mlpackage version instead, or recompile the model.")
|
194 |
raise
|
195 |
|
196 |
-
def parse_args():
|
197 |
-
parser = argparse.ArgumentParser(description='Full Chat with CoreML LLaMA with context window shifting, gil resolved (c) 2025 Anemll')
|
198 |
-
|
199 |
-
# Add meta.yaml option
|
200 |
-
parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
|
201 |
-
|
202 |
-
# Add existing arguments
|
203 |
-
parser.add_argument('--d', '--dir', type=str, default='.',
|
204 |
-
help='Directory containing model files (default: current directory)')
|
205 |
-
parser.add_argument('--embed', type=str, required=False,
|
206 |
-
help='Path to embeddings model (relative to --dir)')
|
207 |
-
parser.add_argument('--ffn', type=str, required=False,
|
208 |
-
help='Path to FFN model (can be chunked, relative to --dir)')
|
209 |
-
parser.add_argument('--lmhead', type=str, required=False,
|
210 |
-
help='Path to LM head model (relative to --dir)')
|
211 |
-
parser.add_argument('--tokenizer', type=str, required=False,
|
212 |
-
help='Path to tokenizer')
|
213 |
-
|
214 |
-
# Add new argument for auto-generation
|
215 |
-
parser.add_argument('--prompt', type=str,
|
216 |
-
help='If specified, run once with this prompt and exit')
|
217 |
-
|
218 |
-
# Add no-warmup flag
|
219 |
-
parser.add_argument('--nw', action='store_true',
|
220 |
-
help='Skip warmup phase')
|
221 |
-
|
222 |
-
# Model configuration
|
223 |
-
parser.add_argument('--context-length', type=int,
|
224 |
-
help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
|
225 |
-
parser.add_argument('--batch-size', type=int,
|
226 |
-
help='Batch size for prefill (default: 64)')
|
227 |
-
|
228 |
-
args = parser.parse_args()
|
229 |
-
|
230 |
-
# If meta.yaml is provided, load parameters from it
|
231 |
-
if args.meta:
|
232 |
-
try:
|
233 |
-
with open(args.meta, 'r') as f:
|
234 |
-
meta = yaml.safe_load(f)
|
235 |
-
params = meta['model_info']['parameters']
|
236 |
-
|
237 |
-
# Set model directory to meta.yaml directory if not specified
|
238 |
-
if not args.d or args.d == '.':
|
239 |
-
args.d = str(Path(args.meta).parent)
|
240 |
-
|
241 |
-
# Build model paths based on parameters
|
242 |
-
prefix = params.get('model_prefix', 'llama') # Default to 'llama' if not specified
|
243 |
-
lut_ffn = f"_lut{params['lut_ffn']}" if params['lut_ffn'] != 'none' else ''
|
244 |
-
lut_lmhead = f"_lut{params['lut_lmhead']}" if params['lut_lmhead'] != 'none' else ''
|
245 |
-
num_chunks = int(params['num_chunks'])
|
246 |
-
|
247 |
-
# Set model paths if not specified
|
248 |
-
if not args.embed:
|
249 |
-
args.embed = f'{prefix}_embeddings'
|
250 |
-
if not args.lmhead:
|
251 |
-
args.lmhead = f'{prefix}_lm_head{lut_lmhead}'
|
252 |
-
if not args.ffn:
|
253 |
-
args.ffn = f'{prefix}_FFN_PF{lut_ffn}_chunk_01of{num_chunks:02d}'
|
254 |
-
if not args.tokenizer:
|
255 |
-
args.tokenizer = args.d
|
256 |
-
|
257 |
-
# Set other parameters if not overridden by command line
|
258 |
-
if args.context_length is None:
|
259 |
-
args.context_length = int(params['context_length'])
|
260 |
-
if args.batch_size is None:
|
261 |
-
args.batch_size = int(params['batch_size'])
|
262 |
-
args.num_chunks = num_chunks
|
263 |
-
|
264 |
-
print(f"\nLoaded parameters from {args.meta}:")
|
265 |
-
print(f" Context Length: {args.context_length}")
|
266 |
-
print(f" Batch Size: {args.batch_size}")
|
267 |
-
print(f" Num Chunks: {args.num_chunks}")
|
268 |
-
print(f" Models Directory: {args.d}")
|
269 |
-
print(f" Embeddings: {args.embed}")
|
270 |
-
print(f" LM Head: {args.lmhead}")
|
271 |
-
print(f" FFN: {args.ffn}")
|
272 |
-
|
273 |
-
except Exception as e:
|
274 |
-
print(f"\nError loading meta.yaml: {str(e)}")
|
275 |
-
sys.exit(1)
|
276 |
-
|
277 |
-
return args
|
278 |
-
|
279 |
def load_metadata(model,args):
|
280 |
# Extract metadata and config parameters
|
281 |
metadata = {}
|
@@ -474,74 +386,84 @@ def make_causal_mask(length, start):
|
|
474 |
mask[:, :, col_indices <= (row_indices + start)] = 0
|
475 |
return mask
|
476 |
|
477 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
478 |
"""Run prefill on the input sequence."""
|
479 |
-
#
|
|
|
|
|
|
|
480 |
|
481 |
# Process in batches
|
482 |
batch_pos = 0
|
483 |
-
while batch_pos <
|
484 |
-
batch_end = min(batch_pos + batch_size,
|
485 |
current_batch_size = batch_end - batch_pos
|
486 |
|
487 |
-
#print(f"[DEBUG] Prefill batch {batch_pos}-{batch_end} (size={current_batch_size})")
|
488 |
-
|
489 |
# Get current batch
|
490 |
batch_input = input_ids[:, batch_pos:batch_end]
|
491 |
|
492 |
-
#
|
493 |
batch_input = F.pad(
|
494 |
batch_input,
|
495 |
(0, batch_size - current_batch_size),
|
496 |
value=0
|
497 |
)
|
498 |
|
499 |
-
# Generate position IDs for
|
500 |
-
position_ids = torch.arange(
|
501 |
-
|
502 |
-
# Use the pre-initialized causal mask and extract the batch portion
|
503 |
-
batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
|
504 |
|
505 |
# Run embeddings
|
506 |
hidden_states = torch.from_numpy(
|
507 |
embed_model.predict({'input_ids': batch_input.numpy()})['hidden_states']
|
508 |
)
|
509 |
|
510 |
-
# Run through FFN chunks
|
511 |
for ffn_model in ffn_models:
|
512 |
if isinstance(ffn_model, dict):
|
513 |
inputs = {
|
514 |
-
'hidden_states': hidden_states.numpy(),
|
515 |
-
'position_ids': position_ids.numpy(),
|
516 |
-
'causal_mask': batch_causal_mask.numpy(),
|
517 |
-
'current_pos': np.array([batch_pos], dtype=np.int32)
|
518 |
}
|
519 |
output = ffn_model['prefill'].predict(inputs, state)
|
520 |
hidden_states = torch.from_numpy(output['output_hidden_states'])
|
521 |
|
522 |
batch_pos = batch_end
|
523 |
|
524 |
-
return torch.tensor([
|
525 |
|
526 |
-
def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state, causal_mask, temperature=0.0):
|
527 |
"""Generate the next token."""
|
528 |
# Get current token
|
529 |
-
current_token = input_ids[:, pos-1:pos]
|
530 |
|
531 |
# Run embeddings
|
532 |
hidden_states = torch.from_numpy(
|
533 |
embed_model.predict({'input_ids': current_token.numpy()})['hidden_states']
|
534 |
-
)
|
535 |
|
536 |
# Create masks
|
537 |
update_mask = torch.zeros((1, 1, context_length, 1), dtype=torch.float16)
|
538 |
update_mask[0, 0, pos-1, 0] = 1.0
|
539 |
-
position_ids = torch.tensor([pos-1], dtype=torch.int32)
|
540 |
|
541 |
-
# Use
|
542 |
-
|
|
|
|
|
|
|
|
|
543 |
|
544 |
-
# Run through FFN chunks
|
545 |
for ffn_model in ffn_models:
|
546 |
if isinstance(ffn_model, dict):
|
547 |
inputs = {
|
@@ -554,19 +476,25 @@ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, c
|
|
554 |
output = ffn_model['infer'].predict(inputs, state)
|
555 |
hidden_states = torch.from_numpy(output['output_hidden_states'])
|
556 |
|
557 |
-
# Run LM head
|
558 |
lm_output = lmhead_model.predict({'hidden_states': hidden_states.numpy()})
|
|
|
|
|
559 |
|
|
|
560 |
if 'logits1' in lm_output:
|
|
|
561 |
logits_parts = []
|
562 |
for i in range(1, 9):
|
563 |
key = f'logits{i}'
|
564 |
if key in lm_output:
|
565 |
logits_parts.append(torch.from_numpy(lm_output[key]))
|
566 |
-
logits = torch.cat(logits_parts, dim=-1)
|
567 |
else:
|
|
|
568 |
logits = torch.from_numpy(lm_output['output_logits'])
|
569 |
|
|
|
570 |
if temperature > 0:
|
571 |
logits = logits / temperature
|
572 |
probs = F.softmax(logits[0, -1, :], dim=-1)
|
@@ -588,93 +516,36 @@ def create_unified_state(ffn_models, context_length):
|
|
588 |
print("\nCreated unified transformer state")
|
589 |
return state
|
590 |
|
591 |
-
def
|
592 |
-
"""Initialize causal mask for transformer attention."""
|
593 |
-
causal_mask = make_causal_mask(context_length, 0)
|
594 |
-
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
595 |
-
print(f"\nInitialized causal mask for context length {context_length}")
|
596 |
-
return causal_mask
|
597 |
-
|
598 |
-
def get_user_input():
|
599 |
-
"""Get input from user, handling special key combinations."""
|
600 |
-
global THINKING_MODE
|
601 |
-
try:
|
602 |
-
import termios
|
603 |
-
import tty
|
604 |
-
import sys
|
605 |
-
|
606 |
-
def _getch():
|
607 |
-
fd = sys.stdin.fileno()
|
608 |
-
old_settings = termios.tcgetattr(fd)
|
609 |
-
try:
|
610 |
-
tty.setraw(sys.stdin.fileno())
|
611 |
-
ch = sys.stdin.read(1)
|
612 |
-
finally:
|
613 |
-
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
614 |
-
return ch
|
615 |
-
|
616 |
-
buffer = []
|
617 |
-
while True:
|
618 |
-
char = _getch()
|
619 |
-
|
620 |
-
# Debug: print the character code
|
621 |
-
print(f"\nKey pressed: {repr(char)} (hex: {hex(ord(char))})")
|
622 |
-
|
623 |
-
# Check for Enter key
|
624 |
-
if char == '\r' or char == '\n':
|
625 |
-
print() # Move to next line
|
626 |
-
input_text = ''.join(buffer)
|
627 |
-
# Check if the command is /t
|
628 |
-
if input_text == '/t':
|
629 |
-
THINKING_MODE = not THINKING_MODE
|
630 |
-
print(f"Thinking mode {'ON' if THINKING_MODE else 'OFF'}")
|
631 |
-
buffer = [] # Clear buffer
|
632 |
-
print(f"\n{LIGHT_GREEN}You{' (thinking)' if THINKING_MODE else ''}:{RESET_COLOR}", end=' ', flush=True)
|
633 |
-
continue
|
634 |
-
return input_text
|
635 |
-
|
636 |
-
# Handle backspace
|
637 |
-
if char == '\x7f': # backspace
|
638 |
-
if buffer:
|
639 |
-
buffer.pop()
|
640 |
-
sys.stdout.write('\b \b') # Erase character
|
641 |
-
sys.stdout.flush()
|
642 |
-
continue
|
643 |
-
|
644 |
-
# Handle Ctrl-C
|
645 |
-
if char == '\x03': # Ctrl-C
|
646 |
-
print("^C")
|
647 |
-
raise KeyboardInterrupt
|
648 |
-
|
649 |
-
# Print character and add to buffer
|
650 |
-
sys.stdout.write(char)
|
651 |
-
sys.stdout.flush()
|
652 |
-
buffer.append(char)
|
653 |
-
|
654 |
-
except ImportError:
|
655 |
-
# Fallback for systems without termios
|
656 |
-
return input("> ")
|
657 |
-
|
658 |
-
def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, causal_mask, auto_prompt=None, warmup=False):
|
659 |
"""Interactive chat loop."""
|
660 |
-
global THINKING_MODE
|
661 |
context_length = metadata.get('context_length')
|
662 |
batch_size = metadata.get('batch_size', 64)
|
663 |
|
664 |
if not warmup:
|
665 |
print(f"\nUsing context length: {context_length}")
|
666 |
print("\nStarting chat session. Press Ctrl+D to exit.")
|
667 |
-
print("Type your message and press Enter to chat.
|
668 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
669 |
|
670 |
-
# Keep track of conversation history
|
671 |
conversation = []
|
672 |
|
673 |
try:
|
674 |
while True:
|
675 |
try:
|
676 |
if not warmup:
|
677 |
-
print(f"\n{LIGHT_GREEN}You
|
678 |
if auto_prompt is not None:
|
679 |
user_input = auto_prompt
|
680 |
if not warmup:
|
@@ -685,69 +556,41 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
685 |
if not warmup:
|
686 |
print("\nExiting chat...")
|
687 |
break
|
688 |
-
|
689 |
if not user_input:
|
690 |
continue
|
691 |
-
|
692 |
-
# Handle /t command
|
693 |
-
if user_input == "/t":
|
694 |
-
THINKING_MODE = not THINKING_MODE
|
695 |
-
print(f"Thinking mode {'ON' if THINKING_MODE else 'OFF'}")
|
696 |
-
continue
|
697 |
-
|
698 |
-
# Add user message to conversation
|
699 |
-
conversation.append({"role": "user", "content": user_input})
|
700 |
|
701 |
-
# Format
|
702 |
-
if
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
conversation_with_thinking,
|
707 |
return_tensors="pt",
|
708 |
add_generation_prompt=True
|
709 |
).to(torch.int32)
|
710 |
else:
|
711 |
-
|
712 |
-
|
|
|
|
|
713 |
return_tensors="pt",
|
714 |
-
|
715 |
-
).to(torch.int32)
|
716 |
|
717 |
-
|
718 |
-
while base_input_ids.size(1) > context_length - 100: # Leave room for response
|
719 |
-
# Remove oldest message pair (user + assistant)
|
720 |
-
if len(conversation) > 2:
|
721 |
-
conversation = conversation[2:] # Remove oldest pair
|
722 |
-
base_input_ids = tokenizer.apply_chat_template(
|
723 |
-
conversation,
|
724 |
-
return_tensors="pt",
|
725 |
-
add_generation_prompt=True
|
726 |
-
).to(torch.int32)
|
727 |
-
else:
|
728 |
-
# If only current message remains and still too long, truncate
|
729 |
-
base_input_ids = base_input_ids[:, -context_length//2:]
|
730 |
-
break
|
731 |
-
|
732 |
-
context_pos = base_input_ids.size(1)
|
733 |
-
|
734 |
-
# Pad sequence to context_size
|
735 |
-
input_ids = F.pad(
|
736 |
-
base_input_ids,
|
737 |
-
(0, context_length - context_pos),
|
738 |
-
value=0
|
739 |
-
)
|
740 |
|
741 |
if not warmup:
|
742 |
print(f"\n{LIGHT_BLUE}Assistant:{RESET_COLOR}", end=' ', flush=True)
|
743 |
|
744 |
-
# Initialize token printer
|
745 |
token_printer = TokenPrinter(tokenizer)
|
746 |
-
|
747 |
-
generation_start_time = time.time()
|
748 |
|
749 |
try:
|
750 |
-
#
|
|
|
|
|
|
|
751 |
current_pos = run_prefill(
|
752 |
embed_model,
|
753 |
ffn_models,
|
@@ -758,51 +601,20 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
758 |
state,
|
759 |
causal_mask
|
760 |
)
|
761 |
-
#print(f"\n[DEBUG] After initial prefill - current_pos: {current_pos}")
|
762 |
|
763 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
764 |
pos = context_pos
|
765 |
-
|
766 |
-
|
767 |
|
768 |
-
while
|
769 |
-
#
|
770 |
-
if pos >= context_length - 2:
|
771 |
-
# Calculate shift to maintain full batches
|
772 |
-
batch_size = metadata.get('batch_size', 64)
|
773 |
-
# Calculate max batches that fit in context
|
774 |
-
max_batches = context_length // batch_size
|
775 |
-
desired_batches = max(1, max_batches - 2) # Leave room for new tokens
|
776 |
-
new_size = min(desired_batches * batch_size, context_length - batch_size)
|
777 |
-
|
778 |
-
# Create shifted input_ids
|
779 |
-
tmp = torch.zeros((1, context_length), dtype=torch.int32)
|
780 |
-
tmp[:,0:new_size] = input_ids[:,pos-new_size:pos]
|
781 |
-
input_ids = tmp
|
782 |
-
|
783 |
-
# Reset state and run prefill
|
784 |
-
# keep the same state
|
785 |
-
#state = create_unified_state(ffn_models, context_length)
|
786 |
-
current_pos = run_prefill(
|
787 |
-
embed_model,
|
788 |
-
ffn_models,
|
789 |
-
input_ids,
|
790 |
-
new_size, # Prefill the entire shifted content
|
791 |
-
context_length,
|
792 |
-
batch_size,
|
793 |
-
state,
|
794 |
-
causal_mask
|
795 |
-
)
|
796 |
-
|
797 |
-
# Start generating from the next position
|
798 |
-
pos = new_size # Don't back up, continue from where we left off
|
799 |
-
|
800 |
-
#print(f"\n[DEBUG] After shift - next token will be at pos {pos}")
|
801 |
-
#print(f"[DEBUG] Context before next token: {tokenizer.decode(input_ids[0, pos-40:pos])}")
|
802 |
-
|
803 |
-
window_shifted = True
|
804 |
-
|
805 |
-
# Generate next token
|
806 |
next_token = generate_next_token(
|
807 |
embed_model,
|
808 |
ffn_models,
|
@@ -814,54 +626,143 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
814 |
causal_mask
|
815 |
)
|
816 |
|
817 |
-
# Add token
|
818 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
819 |
if not warmup:
|
820 |
token_printer.add_token(next_token)
|
821 |
token_printer.drain_buffer()
|
822 |
-
response_tokens.append(next_token)
|
823 |
|
824 |
pos += 1
|
825 |
tokens_generated += 1
|
|
|
826 |
|
827 |
-
#
|
828 |
if warmup and tokens_generated >= WARMUP_TOKEN_LIMIT:
|
829 |
break
|
830 |
-
|
831 |
if next_token == tokenizer.eos_token_id:
|
832 |
break
|
833 |
|
834 |
-
|
835 |
-
|
836 |
-
|
837 |
-
response_text = token_printer.stop()
|
838 |
-
conversation.append({"role": "assistant", "content": response_text})
|
839 |
|
840 |
-
#
|
841 |
if not warmup:
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
prefill_ms
|
846 |
-
|
847 |
-
print(f"{
|
848 |
-
|
849 |
-
|
|
|
850 |
|
|
|
851 |
if auto_prompt is not None:
|
852 |
break
|
853 |
|
854 |
except KeyboardInterrupt:
|
855 |
-
|
856 |
-
print("\nGeneration interrupted")
|
857 |
token_printer.stop()
|
858 |
continue
|
859 |
|
860 |
except Exception as e:
|
861 |
-
|
862 |
-
|
863 |
-
|
864 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
865 |
|
866 |
def main():
|
867 |
args = parse_args()
|
@@ -926,7 +827,7 @@ def main():
|
|
926 |
lmhead_model=lmhead_model,
|
927 |
tokenizer=tokenizer,
|
928 |
metadata=metadata,
|
929 |
-
state=state,
|
930 |
causal_mask=causal_mask, # Pass the causal mask
|
931 |
warmup=True,
|
932 |
auto_prompt="who are you?"
|
@@ -939,7 +840,7 @@ def main():
|
|
939 |
lmhead_model=lmhead_model,
|
940 |
tokenizer=tokenizer,
|
941 |
metadata=metadata,
|
942 |
-
state=state,
|
943 |
causal_mask=causal_mask, # Pass the causal mask
|
944 |
warmup=False,
|
945 |
auto_prompt=args.prompt
|
|
|
26 |
LIGHT_GREEN = "\033[92m"
|
27 |
RESET_COLOR = "\033[0m"
|
28 |
|
29 |
+
# Add at top with other constants
|
30 |
WARMUP_TOKEN_LIMIT = 10 # Maximum tokens to generate during warmup
|
|
|
|
|
31 |
|
32 |
class TokenPrinter:
|
33 |
"""Handles background printing of generated tokens."""
|
|
|
40 |
self.lock = threading.Lock()
|
41 |
self.thinking = True # Track if we're still in thinking mode
|
42 |
self.decoding_buffer = [] # Buffer for token IDs
|
43 |
+
# Add token counting and timing
|
44 |
self.start_time = time.time()
|
45 |
self.token_count = 0
|
|
|
|
|
|
|
46 |
self.start()
|
47 |
|
48 |
def start(self):
|
|
|
103 |
self.thread.join(timeout=1.0)
|
104 |
except Exception:
|
105 |
pass
|
106 |
+
# Calculate and print tokens/s with shorter format in blue
|
107 |
+
elapsed = time.time() - self.start_time
|
108 |
+
if elapsed > 0 and self.token_count > 0:
|
109 |
+
tokens_per_sec = self.token_count / elapsed
|
110 |
+
print(f"\n{DARK_BLUE}{tokens_per_sec:.1f} t/s{RESET_COLOR}")
|
111 |
+
else:
|
112 |
+
print(RESET_COLOR) # Reset color at the end
|
113 |
return self.buffer
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
def parse_model_path(path):
|
116 |
"""Parse model path and return full path with .mlmodelc or .mlpackage extension."""
|
117 |
path = Path(path)
|
|
|
188 |
print("\nTry using the .mlpackage version instead, or recompile the model.")
|
189 |
raise
|
190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
def load_metadata(model,args):
|
192 |
# Extract metadata and config parameters
|
193 |
metadata = {}
|
|
|
386 |
mask[:, :, col_indices <= (row_indices + start)] = 0
|
387 |
return mask
|
388 |
|
389 |
+
def initialize_causal_mask(context_length):
|
390 |
+
"""Initialize causal mask for transformer attention."""
|
391 |
+
causal_mask = make_causal_mask(context_length, 0)
|
392 |
+
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
393 |
+
print(f"\nInitialized causal mask for context length {context_length}")
|
394 |
+
return causal_mask
|
395 |
+
|
396 |
+
def run_prefill(embed_model, ffn_models, input_ids, context_pos, context_length, batch_size=64, state=None, causal_mask=None):
|
397 |
"""Run prefill on the input sequence."""
|
398 |
+
# Use provided causal mask or create one if not provided
|
399 |
+
if causal_mask is None:
|
400 |
+
causal_mask = make_causal_mask(context_length, 0)
|
401 |
+
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
402 |
|
403 |
# Process in batches
|
404 |
batch_pos = 0
|
405 |
+
while batch_pos < context_pos:
|
406 |
+
batch_end = min(batch_pos + batch_size, context_pos)
|
407 |
current_batch_size = batch_end - batch_pos
|
408 |
|
|
|
|
|
409 |
# Get current batch
|
410 |
batch_input = input_ids[:, batch_pos:batch_end]
|
411 |
|
412 |
+
# Always pad to full batch size for prefill
|
413 |
batch_input = F.pad(
|
414 |
batch_input,
|
415 |
(0, batch_size - current_batch_size),
|
416 |
value=0
|
417 |
)
|
418 |
|
419 |
+
# Generate position IDs for full batch size
|
420 |
+
position_ids = torch.arange(batch_size, dtype=torch.int32) # Changed: Always use full batch size
|
421 |
+
batch_causal_mask = causal_mask[:, :, :batch_size, :] # Changed: Use full batch size
|
|
|
|
|
422 |
|
423 |
# Run embeddings
|
424 |
hidden_states = torch.from_numpy(
|
425 |
embed_model.predict({'input_ids': batch_input.numpy()})['hidden_states']
|
426 |
)
|
427 |
|
428 |
+
# Run through FFN chunks with state
|
429 |
for ffn_model in ffn_models:
|
430 |
if isinstance(ffn_model, dict):
|
431 |
inputs = {
|
432 |
+
'hidden_states': hidden_states.numpy(), # [1, 64, hidden_size]
|
433 |
+
'position_ids': position_ids.numpy(), # [64]
|
434 |
+
'causal_mask': batch_causal_mask.numpy(), # [1, 1, 64, context_length]
|
435 |
+
'current_pos': np.array([batch_pos], dtype=np.int32) # [1]
|
436 |
}
|
437 |
output = ffn_model['prefill'].predict(inputs, state)
|
438 |
hidden_states = torch.from_numpy(output['output_hidden_states'])
|
439 |
|
440 |
batch_pos = batch_end
|
441 |
|
442 |
+
return torch.tensor([context_pos], dtype=torch.int32)
|
443 |
|
444 |
+
def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state=None, causal_mask=None, temperature=0.0):
|
445 |
"""Generate the next token."""
|
446 |
# Get current token
|
447 |
+
current_token = input_ids[:, pos-1:pos] # [1, 1]
|
448 |
|
449 |
# Run embeddings
|
450 |
hidden_states = torch.from_numpy(
|
451 |
embed_model.predict({'input_ids': current_token.numpy()})['hidden_states']
|
452 |
+
) # [1, 1, hidden_size]
|
453 |
|
454 |
# Create masks
|
455 |
update_mask = torch.zeros((1, 1, context_length, 1), dtype=torch.float16)
|
456 |
update_mask[0, 0, pos-1, 0] = 1.0
|
457 |
+
position_ids = torch.tensor([pos-1], dtype=torch.int32) # [1]
|
458 |
|
459 |
+
# Use provided causal mask or create one if not provided
|
460 |
+
if causal_mask is None:
|
461 |
+
causal_mask_data = make_causal_mask(context_length, 0)
|
462 |
+
single_causal_mask = torch.tensor(causal_mask_data[:, :, pos-1:pos, :], dtype=torch.float16) # [1, 1, 1, context_length]
|
463 |
+
else:
|
464 |
+
single_causal_mask = causal_mask[:, :, pos-1:pos, :]
|
465 |
|
466 |
+
# Run through FFN chunks with state
|
467 |
for ffn_model in ffn_models:
|
468 |
if isinstance(ffn_model, dict):
|
469 |
inputs = {
|
|
|
476 |
output = ffn_model['infer'].predict(inputs, state)
|
477 |
hidden_states = torch.from_numpy(output['output_hidden_states'])
|
478 |
|
479 |
+
# Run LM head
|
480 |
lm_output = lmhead_model.predict({'hidden_states': hidden_states.numpy()})
|
481 |
+
# Debug print
|
482 |
+
#print("\nLM Head output keys:", list(lm_output.keys()))
|
483 |
|
484 |
+
# Combine logits1-8 if they exist
|
485 |
if 'logits1' in lm_output:
|
486 |
+
# Concatenate all logits parts
|
487 |
logits_parts = []
|
488 |
for i in range(1, 9):
|
489 |
key = f'logits{i}'
|
490 |
if key in lm_output:
|
491 |
logits_parts.append(torch.from_numpy(lm_output[key]))
|
492 |
+
logits = torch.cat(logits_parts, dim=-1) # Concatenate along vocab dimension
|
493 |
else:
|
494 |
+
# Try output_logits as fallback
|
495 |
logits = torch.from_numpy(lm_output['output_logits'])
|
496 |
|
497 |
+
# Apply temperature and sample
|
498 |
if temperature > 0:
|
499 |
logits = logits / temperature
|
500 |
probs = F.softmax(logits[0, -1, :], dim=-1)
|
|
|
516 |
print("\nCreated unified transformer state")
|
517 |
return state
|
518 |
|
519 |
+
def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, causal_mask=None, auto_prompt=None, warmup=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
520 |
"""Interactive chat loop."""
|
|
|
521 |
context_length = metadata.get('context_length')
|
522 |
batch_size = metadata.get('batch_size', 64)
|
523 |
|
524 |
if not warmup:
|
525 |
print(f"\nUsing context length: {context_length}")
|
526 |
print("\nStarting chat session. Press Ctrl+D to exit.")
|
527 |
+
print("Type your message and press Enter to chat.")
|
528 |
+
|
529 |
+
# Check if tokenizer has chat template and if it works
|
530 |
+
has_chat_template = False
|
531 |
+
try:
|
532 |
+
# Test if chat template works
|
533 |
+
test_messages = [{"role": "user", "content": "test"}]
|
534 |
+
tokenizer.apply_chat_template(test_messages, return_tensors="pt")
|
535 |
+
has_chat_template = True
|
536 |
+
if not warmup:
|
537 |
+
print("\nUsing chat template for prompts")
|
538 |
+
except:
|
539 |
+
if not warmup:
|
540 |
+
print("\nUsing manual formatting for prompts")
|
541 |
|
|
|
542 |
conversation = []
|
543 |
|
544 |
try:
|
545 |
while True:
|
546 |
try:
|
547 |
if not warmup:
|
548 |
+
print(f"\n{LIGHT_GREEN}You:{RESET_COLOR}", end=' ', flush=True)
|
549 |
if auto_prompt is not None:
|
550 |
user_input = auto_prompt
|
551 |
if not warmup:
|
|
|
556 |
if not warmup:
|
557 |
print("\nExiting chat...")
|
558 |
break
|
559 |
+
|
560 |
if not user_input:
|
561 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
562 |
|
563 |
+
# Format prompt based on tokenizer capabilities
|
564 |
+
if has_chat_template:
|
565 |
+
messages = [{"role": "user", "content": user_input}]
|
566 |
+
input_ids = tokenizer.apply_chat_template(
|
567 |
+
messages,
|
|
|
568 |
return_tensors="pt",
|
569 |
add_generation_prompt=True
|
570 |
).to(torch.int32)
|
571 |
else:
|
572 |
+
# Manual formatting for Llama models without chat template
|
573 |
+
formatted_prompt = f"[INST] {user_input} [/INST]"
|
574 |
+
input_ids = tokenizer(
|
575 |
+
formatted_prompt,
|
576 |
return_tensors="pt",
|
577 |
+
add_special_tokens=True
|
578 |
+
).input_ids.to(torch.int32)
|
579 |
|
580 |
+
context_pos = input_ids.size(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
581 |
|
582 |
if not warmup:
|
583 |
print(f"\n{LIGHT_BLUE}Assistant:{RESET_COLOR}", end=' ', flush=True)
|
584 |
|
585 |
+
# Initialize token printer
|
586 |
token_printer = TokenPrinter(tokenizer)
|
587 |
+
tokens_generated = 0 # Track number of tokens
|
|
|
588 |
|
589 |
try:
|
590 |
+
# Start prefill timing
|
591 |
+
prefill_start = time.time()
|
592 |
+
|
593 |
+
# Run prefill with state and causal mask
|
594 |
current_pos = run_prefill(
|
595 |
embed_model,
|
596 |
ffn_models,
|
|
|
601 |
state,
|
602 |
causal_mask
|
603 |
)
|
|
|
604 |
|
605 |
+
# Calculate prefill timing
|
606 |
+
prefill_time = time.time() - prefill_start
|
607 |
+
prefill_tokens = context_pos # Number of tokens in input
|
608 |
+
prefill_tokens_per_sec = prefill_tokens / prefill_time if prefill_time > 0 else 0
|
609 |
+
|
610 |
+
# Generation loop with state
|
611 |
+
input_ids = input_ids
|
612 |
pos = context_pos
|
613 |
+
inference_start = time.time()
|
614 |
+
inference_tokens = 0
|
615 |
|
616 |
+
while pos < context_length - 1:
|
617 |
+
# Generate next token with causal mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
618 |
next_token = generate_next_token(
|
619 |
embed_model,
|
620 |
ffn_models,
|
|
|
626 |
causal_mask
|
627 |
)
|
628 |
|
629 |
+
# Add token to sequence
|
630 |
+
if pos < input_ids.size(1):
|
631 |
+
input_ids[0, pos] = next_token
|
632 |
+
else:
|
633 |
+
input_ids = torch.cat([
|
634 |
+
input_ids,
|
635 |
+
torch.tensor([[next_token]], dtype=torch.int32)
|
636 |
+
], dim=1)
|
637 |
+
|
638 |
+
# Add to printer only if not in warmup
|
639 |
if not warmup:
|
640 |
token_printer.add_token(next_token)
|
641 |
token_printer.drain_buffer()
|
|
|
642 |
|
643 |
pos += 1
|
644 |
tokens_generated += 1
|
645 |
+
inference_tokens += 1
|
646 |
|
647 |
+
# Check limits
|
648 |
if warmup and tokens_generated >= WARMUP_TOKEN_LIMIT:
|
649 |
break
|
650 |
+
|
651 |
if next_token == tokenizer.eos_token_id:
|
652 |
break
|
653 |
|
654 |
+
# Calculate inference timing
|
655 |
+
inference_time = time.time() - inference_start
|
656 |
+
inference_tokens_per_sec = inference_tokens / inference_time if inference_time > 0 else 0
|
|
|
|
|
657 |
|
658 |
+
# Get final response and add to conversation
|
659 |
if not warmup:
|
660 |
+
response = token_printer.stop()
|
661 |
+
# Print timing stats
|
662 |
+
prefill_ms = prefill_time * 1000 # Convert to milliseconds
|
663 |
+
print(f"\nPrefill: {prefill_ms:.1f}ms ({prefill_tokens_per_sec:.1f} t/s)")
|
664 |
+
print(f"Inference: {inference_tokens_per_sec:.1f} t/s")
|
665 |
+
print(f"Total: Generated {tokens_generated} tokens in {prefill_time + inference_time:.2f}s")
|
666 |
+
conversation.append({"role": "assistant", "content": response})
|
667 |
+
else:
|
668 |
+
token_printer.stop() # Clean up without printing stats
|
669 |
|
670 |
+
# Exit after one response in auto_prompt mode
|
671 |
if auto_prompt is not None:
|
672 |
break
|
673 |
|
674 |
except KeyboardInterrupt:
|
675 |
+
print("\nGeneration interrupted")
|
|
|
676 |
token_printer.stop()
|
677 |
continue
|
678 |
|
679 |
except Exception as e:
|
680 |
+
print(f"\nError in chat loop: {str(e)}")
|
681 |
+
import traceback
|
682 |
+
traceback.print_exc()
|
683 |
+
|
684 |
+
def parse_args():
|
685 |
+
parser = argparse.ArgumentParser(description='Chat with CoreML LLaMA, gil resolved (c) 2025 Anemll')
|
686 |
+
|
687 |
+
# Add meta.yaml option
|
688 |
+
parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
|
689 |
+
|
690 |
+
# Model paths
|
691 |
+
parser.add_argument('--d', '--dir', type=str, default='.',
|
692 |
+
help='Directory containing model files (default: current directory)')
|
693 |
+
parser.add_argument('--embed', type=str, required=False,
|
694 |
+
help='Path to embeddings model (relative to --dir)')
|
695 |
+
parser.add_argument('--ffn', type=str, required=False,
|
696 |
+
help='Path to FFN model (can be chunked, relative to --dir)')
|
697 |
+
parser.add_argument('--lmhead', type=str, required=False,
|
698 |
+
help='Path to LM head model (relative to --dir)')
|
699 |
+
parser.add_argument('--tokenizer', type=str, required=False,
|
700 |
+
help='Path to tokenizer')
|
701 |
+
|
702 |
+
# Add new argument for auto-generation
|
703 |
+
parser.add_argument('--prompt', type=str,
|
704 |
+
help='If specified, run once with this prompt and exit')
|
705 |
+
|
706 |
+
# Add no-warmup flag
|
707 |
+
parser.add_argument('--nw', action='store_true',
|
708 |
+
help='Skip warmup phase')
|
709 |
+
|
710 |
+
# Model configuration
|
711 |
+
parser.add_argument('--context-length', type=int,
|
712 |
+
help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
|
713 |
+
parser.add_argument('--batch-size', type=int,
|
714 |
+
help='Batch size for prefill (default: 64)')
|
715 |
+
|
716 |
+
args = parser.parse_args()
|
717 |
+
|
718 |
+
# If meta.yaml is provided, load parameters from it
|
719 |
+
if args.meta:
|
720 |
+
try:
|
721 |
+
with open(args.meta, 'r') as f:
|
722 |
+
meta = yaml.safe_load(f)
|
723 |
+
params = meta['model_info']['parameters']
|
724 |
+
|
725 |
+
# Set model directory to meta.yaml directory if not specified
|
726 |
+
if not args.d or args.d == '.':
|
727 |
+
args.d = str(Path(args.meta).parent)
|
728 |
+
|
729 |
+
# Build model paths based on parameters
|
730 |
+
prefix = params.get('model_prefix', 'llama') # Default to 'llama' if not specified
|
731 |
+
lut_ffn = f"_lut{params['lut_ffn']}" if params['lut_ffn'] != 'none' else ''
|
732 |
+
lut_lmhead = f"_lut{params['lut_lmhead']}" if params['lut_lmhead'] != 'none' else ''
|
733 |
+
num_chunks = int(params['num_chunks'])
|
734 |
+
|
735 |
+
# Set model paths if not specified
|
736 |
+
if not args.embed:
|
737 |
+
args.embed = f'{prefix}_embeddings'
|
738 |
+
if not args.lmhead:
|
739 |
+
args.lmhead = f'{prefix}_lm_head{lut_lmhead}'
|
740 |
+
if not args.ffn:
|
741 |
+
args.ffn = f'{prefix}_FFN_PF{lut_ffn}_chunk_01of{num_chunks:02d}'
|
742 |
+
if not args.tokenizer:
|
743 |
+
args.tokenizer = args.d
|
744 |
+
|
745 |
+
# Set other parameters if not overridden by command line
|
746 |
+
if args.context_length is None:
|
747 |
+
args.context_length = int(params['context_length'])
|
748 |
+
if args.batch_size is None:
|
749 |
+
args.batch_size = int(params['batch_size'])
|
750 |
+
args.num_chunks = num_chunks
|
751 |
+
|
752 |
+
print(f"\nLoaded parameters from {args.meta}:")
|
753 |
+
print(f" Context Length: {args.context_length}")
|
754 |
+
print(f" Batch Size: {args.batch_size}")
|
755 |
+
print(f" Num Chunks: {args.num_chunks}")
|
756 |
+
print(f" Models Directory: {args.d}")
|
757 |
+
print(f" Embeddings: {args.embed}")
|
758 |
+
print(f" LM Head: {args.lmhead}")
|
759 |
+
print(f" FFN: {args.ffn}")
|
760 |
+
|
761 |
+
except Exception as e:
|
762 |
+
print(f"\nError loading meta.yaml: {str(e)}")
|
763 |
+
sys.exit(1)
|
764 |
+
|
765 |
+
return args
|
766 |
|
767 |
def main():
|
768 |
args = parse_args()
|
|
|
827 |
lmhead_model=lmhead_model,
|
828 |
tokenizer=tokenizer,
|
829 |
metadata=metadata,
|
830 |
+
state=state,
|
831 |
causal_mask=causal_mask, # Pass the causal mask
|
832 |
warmup=True,
|
833 |
auto_prompt="who are you?"
|
|
|
840 |
lmhead_model=lmhead_model,
|
841 |
tokenizer=tokenizer,
|
842 |
metadata=metadata,
|
843 |
+
state=state,
|
844 |
causal_mask=causal_mask, # Pass the causal mask
|
845 |
warmup=False,
|
846 |
auto_prompt=args.prompt
|