|
|
|
|
|
|
|
|
|
|
|
DEFAULT_PROMPT = ["Provide 3 reasons why cats make good pets?", "Why should I consider using an LLM?"] |
|
MAX_GENERATION_LENGTH = 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import sys |
|
import json |
|
import time |
|
import struct |
|
import math |
|
from typing import List, Tuple, Dict, Union, Optional |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
def load_json(file_path: str) -> Dict: |
|
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
return json.load(f) |
|
|
|
def timed_step(start: float, step_name: str) -> float: |
|
|
|
end = time.time() |
|
print(f"Time taken for {step_name}: {end - start:.4f} seconds") |
|
return end |
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
|
def __init__(self, dim: int, eps: float = 1e-5): |
|
super().__init__() |
|
self.eps = eps |
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
norm_x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
return self.weight * norm_x |
|
|
|
def silu(x: torch.Tensor) -> torch.Tensor: |
|
|
|
return x * torch.sigmoid(x) |
|
|
|
class RotaryEmbedding(nn.Module): |
|
|
|
def __init__(self, dim: int, base: int = 10000): |
|
super().__init__() |
|
self.dim = dim |
|
self.base = base |
|
self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) |
|
|
|
def forward(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
t = torch.arange(seq_len, device=device).type_as(self.inv_freq) |
|
freqs = torch.outer(t, self.inv_freq) |
|
return torch.cat((freqs, freqs), dim=-1) |
|
|
|
def apply_rotary_emb(pos: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
|
|
|
return (t * torch.cos(pos)) + (rotate_half(t) * torch.sin(pos)) |
|
|
|
def rotate_half(x: torch.Tensor) -> torch.Tensor: |
|
|
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
class LlamaAttention(nn.Module): |
|
|
|
def __init__(self, config: Dict): |
|
super().__init__() |
|
self.config = config |
|
self.hidden_size = config['hidden_size'] |
|
self.num_heads = config['num_attention_heads'] |
|
self.head_dim = self.hidden_size // self.num_heads |
|
self.num_key_value_heads = config["num_key_value_heads"] |
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
|
self.rope_theta = config['rope_theta'] |
|
|
|
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) |
|
self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) |
|
self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) |
|
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) |
|
|
|
self.rotary_emb = RotaryEmbedding(self.head_dim, base=self.rope_theta) |
|
self.attn_dropout = nn.Dropout(config['attention_dropout']) |
|
|
|
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
|
|
|
|
|
batch_size, seq_length, _ = hidden_states.size() |
|
query_states = self.q_proj(hidden_states).view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) |
|
key_states = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
value_states = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
|
if position_ids is not None: |
|
cos, sin = self.rotary_emb(position_ids.size(-1), device=position_ids.device) |
|
position_ids = position_ids.unsqueeze(1).unsqueeze(2) |
|
cos = cos[position_ids.squeeze(1).squeeze(1)].unsqueeze(1) |
|
sin = sin[position_ids.squeeze(1).squeeze(1)].unsqueeze(1) |
|
query_states = apply_rotary_emb(cos, query_states) |
|
key_states = apply_rotary_emb(cos, key_states) |
|
|
|
if past_key_value is not None: |
|
key_states = torch.cat([past_key_value[0], key_states], dim=2) |
|
value_states = torch.cat([past_key_value[1], value_states], dim=2) |
|
|
|
if use_cache: |
|
present_key_value = (key_states, value_states) |
|
else: |
|
present_key_value = None |
|
|
|
seq_length_k = key_states.shape[-2] |
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups) |
|
value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
|
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) |
|
|
|
if attn_weights.size() != (batch_size, self.num_heads, seq_length, seq_length_k): |
|
raise ValueError( |
|
f"Attention weights should be of size {(batch_size, self.num_heads, seq_length, seq_length_k)}, but is" |
|
f" {attn_weights.size()}" |
|
) |
|
|
|
if attention_mask is not None: |
|
attn_weights = attn_weights + attention_mask |
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
|
attn_weights = self.attn_dropout(attn_weights) |
|
|
|
attn_output = torch.matmul(attn_weights, value_states) |
|
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, self.hidden_size) |
|
attn_output = self.o_proj(attn_output) |
|
return attn_output, present_key_value |
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
|
|
|
|
batch, num_key_value_heads, seq_len, head_dim = hidden_states.shape |
|
if n_rep == 1: |
|
return hidden_states |
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, seq_len, head_dim) |
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, seq_len, head_dim) |
|
|
|
class LlamaMLP(nn.Module): |
|
|
|
def __init__(self, config: Dict): |
|
super().__init__() |
|
hidden_size = config['hidden_size'] |
|
intermediate_size = config['intermediate_size'] |
|
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) |
|
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) |
|
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) |
|
self.act_fn = silu if config['hidden_act'] == 'silu' else getattr(F, config['hidden_act']) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
class LlamaBlock(nn.Module): |
|
|
|
def __init__(self, config: Dict): |
|
super().__init__() |
|
self.hidden_size = config['hidden_size'] |
|
self.self_attn = LlamaAttention(config) |
|
self.mlp = LlamaMLP(config) |
|
self.input_layernorm = RMSNorm(self.hidden_size, eps=config['rms_norm_eps']) |
|
self.post_attention_layernorm = RMSNorm(self.hidden_size, eps=config['rms_norm_eps']) |
|
|
|
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
|
|
|
residual = hidden_states |
|
hidden_states = self.input_layernorm(hidden_states) |
|
hidden_states, present_key_value = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache) |
|
hidden_states = residual + hidden_states |
|
residual = hidden_states |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
hidden_states = self.mlp(hidden_states) |
|
hidden_states = residual + hidden_states |
|
return hidden_states, present_key_value |
|
|
|
class SmolLM2_360M(nn.Module): |
|
|
|
def __init__(self, config_path: str): |
|
super().__init__() |
|
self.config = load_json(config_path) |
|
self.hidden_size = self.config['hidden_size'] |
|
self.vocab_size = self.config['vocab_size'] |
|
self.num_hidden_layers = self.config['num_hidden_layers'] |
|
self.max_position_embeddings = self.config['max_position_embeddings'] |
|
self.torch_dtype = self.config.get('torch_dtype', 'bfloat16') |
|
self.use_cache = self.config.get('use_cache', True) |
|
if self.torch_dtype == "bfloat16": |
|
if not torch.cuda.is_available(): |
|
print ("Warning: System does not have a CUDA device, using torch.float32 dtype instead of bfloat16.") |
|
self.torch_dtype = torch.float32 |
|
else: |
|
self.torch_dtype = torch.bfloat16 |
|
elif self.torch_dtype == "float16": |
|
if not torch.cuda.is_available(): |
|
print ("Warning: System does not have a CUDA device, using torch.float32 dtype instead of float16.") |
|
self.torch_dtype = torch.float32 |
|
else: |
|
self.torch_dtype = torch.float16 |
|
else: |
|
self.torch_dtype = torch.float32 |
|
self.embed_tokens = nn.Embedding(self.vocab_size, self.hidden_size) |
|
self.layers = nn.ModuleList([LlamaBlock(self.config) for _ in range(self.num_hidden_layers)]) |
|
self.norm = RMSNorm(self.hidden_size, eps=self.config['rms_norm_eps']) |
|
self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False) |
|
self.past_keys_values = None |
|
|
|
def load_weights(self, weights_path: str): |
|
|
|
start = time.time() |
|
try: |
|
from safetensors import safe_open |
|
with safe_open(weights_path, framework="pt", device='cpu') as f: |
|
weights = f.get_tensor("model.embed_tokens.weight") |
|
self.embed_tokens.weight = nn.Parameter(weights) |
|
self.lm_head.weight = nn.Parameter(f.get_tensor("lm_head.weight")) |
|
for i in range(self.num_hidden_layers): |
|
self.layers[i].input_layernorm.weight = nn.Parameter(f.get_tensor(f"model.layers.{i}.input_layernorm.weight")) |
|
self.layers[i].post_attention_layernorm.weight = nn.Parameter(f.get_tensor(f"model.layers.{i}.post_attention_layernorm.weight")) |
|
self.layers[i].self_attn.q_proj.weight = nn.Parameter(f.get_tensor(f"model.layers.{i}.self_attn.q_proj.weight")) |
|
self.layers[i].self_attn.k_proj.weight = nn.Parameter(f.get_tensor(f"model.layers.{i}.self_attn.k_proj.weight")) |
|
self.layers[i].self_attn.v_proj.weight = nn.Parameter(f.get_tensor(f"model.layers.{i}.self_attn.v_proj.weight")) |
|
self.layers[i].self_attn.o_proj.weight = nn.Parameter(f.get_tensor(f"model.layers.{i}.self_attn.o_proj.weight")) |
|
self.layers[i].mlp.gate_proj.weight = nn.Parameter(f.get_tensor(f"model.layers.{i}.mlp.gate_proj.weight")) |
|
self.layers[i].mlp.up_proj.weight = nn.Parameter(f.get_tensor(f"model.layers.{i}.mlp.up_proj.weight")) |
|
self.layers[i].mlp.down_proj.weight = nn.Parameter(f.get_tensor(f"model.layers.{i}.mlp.down_proj.weight")) |
|
except ImportError: |
|
print("Error: Safetensors library not found. Please install it with 'pip install safetensors'.") |
|
sys.exit(1) |
|
except Exception as e: |
|
print(f"An error occurred while loading weights: {e}") |
|
sys.exit(1) |
|
end = timed_step(start, "Weight Loading") |
|
|
|
def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, use_cache: Optional[bool] = None) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]: |
|
|
|
use_cache = use_cache if use_cache is not None else self.use_cache |
|
batch_size, seq_length = input_ids.shape |
|
if position_ids is None: |
|
|
|
position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device).unsqueeze(0) |
|
if past_key_values is not None: |
|
position_ids = position_ids + past_key_values[0][0].shape[-2] |
|
if position_ids.shape[-1] > self.max_position_embeddings: |
|
position_ids = position_ids[:, -self.max_position_embeddings:] |
|
inputs_embeds = self.embed_tokens(input_ids) |
|
hidden_states = inputs_embeds |
|
|
|
if past_key_values is None: |
|
past_key_values = [None] * len(self.layers) |
|
|
|
present_key_values = [] if use_cache else None |
|
|
|
for i in range(self.num_hidden_layers): |
|
hidden_states, present_key_value = self.layers[i]( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_values[i], |
|
use_cache=use_cache, |
|
) |
|
if use_cache: |
|
present_key_values.append(present_key_value) |
|
|
|
hidden_states = self.norm(hidden_states) |
|
logits = self.lm_head(hidden_states) |
|
|
|
return logits, present_key_values |
|
|
|
|
|
|
|
class SmolLM2Tokenizer: |
|
|
|
def __init__(self, tokenizer_path: str = ".", special_tokens_map_path: str = ".", config_path: str = "."): |
|
self.tokenizer_path = tokenizer_path |
|
self.special_tokens_map_path = special_tokens_map_path |
|
self.config = load_json(config_path) if config_path else None |
|
self.vocab_size = self.config['vocab_size'] if self.config else None |
|
self.use_sentencepiece = True |
|
self.special_tokens_map = load_json(special_tokens_map_path) if special_tokens_map_path else {} |
|
|
|
|
|
self.additional_special_tokens = self.special_tokens_map.get("additional_special_tokens",[]) |
|
self.inv_special_tokens_map = {v['content']: k for k, v in self.special_tokens_map.items() if isinstance(v,dict)} |
|
self.additional_special_tokens_inv_map = {token: f"additional_special_tokens_{i}" for i, token in enumerate(self.additional_special_tokens)} |
|
|
|
try: |
|
import sentencepiece as spm |
|
self.sp_model = spm.SentencePieceProcessor(model_file=os.path.join(tokenizer_path, 'tokenizer.model')) |
|
|
|
self.bos_token_id = self.sp_model.bos_id() |
|
self.eos_token_id = self.sp_model.eos_id() |
|
self.pad_token_id = self.sp_model.pad_id() if self.sp_model.pad_id() >=0 else self.eos_token_id |
|
self.unk_token_id = self.sp_model.unk_id() |
|
self.additional_special_tokens_ids = [self.sp_model.piece_to_id(token) for token in self.additional_special_tokens] |
|
|
|
self.update_special_tokens_from_sp() |
|
except ImportError: |
|
print("Warning: SentencePiece not found, using rudimentary BPE tokenizer. Install SentencePiece for better performance.") |
|
self.use_sentencepiece = False |
|
self.vocab = load_json(os.path.join(tokenizer_path, 'vocab.json')) |
|
self.merges = open(os.path.join(tokenizer_path, 'merges.txt'), 'r', encoding='utf-8').read().split('\n')[:-1] |
|
self.merges = [tuple(merge.split()) for merge in self.merges] |
|
self.token_to_id = {token: id for id, token in enumerate(self.vocab)} |
|
self.id_to_token = {id: token for token, id in self.token_to_id.items()} |
|
self.bos_token = self.special_tokens_map.get('bos_token', {}).get('content') |
|
self.eos_token = self.special_tokens_map.get('eos_token', {}).get('content') |
|
self.unk_token = self.special_tokens_map.get('unk_token', {}).get('content') |
|
self.pad_token = '<PAD>' |
|
self.bos_token_id = self.token_to_id.get(self.bos_token, -1) |
|
self.eos_token_id = self.token_to_id.get(self.eos_token, -1) |
|
self.unk_token_id = self.token_to_id.get(self.unk_token, -1) |
|
self.pad_token_id = self.token_to_id.get(self.pad_token, -1) |
|
self.additional_special_tokens_ids = [self.token_to_id.get(token, -1) for token in self.additional_special_tokens] |
|
|
|
def update_special_tokens_from_sp(self): |
|
|
|
for token_name, token_data in self.special_tokens_map.items(): |
|
sp_id = self.sp_model.piece_to_id(token_data['content']) |
|
if sp_id != self.sp_model.unk_id(): |
|
if token_name == 'bos_token': |
|
self.bos_token_id = sp_id |
|
elif token_name == 'eos_token': |
|
self.eos_token_id = sp_id |
|
elif token_name == 'unk_token': |
|
self.unk_token_id = sp_id |
|
|
|
|
|
def get_special_tokens_dict(self) -> Dict[str, Union[str, int]]: |
|
|
|
|
|
result_dict = { |
|
'bos_token': self.inv_special_tokens_map.get(self.sp_model.id_to_piece(self.bos_token_id), None) if self.use_sentencepiece else self.bos_token, |
|
'eos_token': self.inv_special_tokens_map.get(self.sp_model.id_to_piece(self.eos_token_id), None) if self.use_sentencepiece else self.eos_token, |
|
'unk_token': self.inv_special_tokens_map.get(self.sp_model.id_to_piece(self.unk_token_id), None) if self.use_sentencepiece else self.unk_token, |
|
'pad_token': self.inv_special_tokens_map.get(self.sp_model.id_to_piece(self.pad_token_id), None) if self.use_sentencepiece and hasattr(self, 'pad_token_id') else self.pad_token if hasattr(self, 'pad_token') else None, |
|
'bos_token_id': self.bos_token_id, |
|
'eos_token_id': self.eos_token_id, |
|
'unk_token_id': self.unk_token_id, |
|
'pad_token_id': self.pad_token_id if hasattr(self, 'pad_token_id') else None, |
|
'additional_special_tokens': self.additional_special_tokens, |
|
'additional_special_tokens_ids': self.additional_special_tokens_ids, |
|
} |
|
result_dict.update(self.additional_special_tokens_inv_map) |
|
return result_dict |
|
|
|
|
|
def bpe(self, token: str) -> List[str]: |
|
|
|
if not self.use_sentencepiece: |
|
word = list(token) |
|
while len(word) > 1: |
|
pairs = [(word[i], word[i+1]) for i in range(len(word) - 1)] |
|
bigram = min(pairs, key=lambda pair: self.merges.index(pair) if pair in self.merges else float('inf')) |
|
if bigram not in self.merges: |
|
break |
|
first, second = bigram |
|
new_word = [] |
|
i = 0 |
|
while i < len(word): |
|
if i < len(word) - 1 and word[i] == first and word[i+1] == second: |
|
new_word.append(first + second) |
|
i += 2 |
|
else: |
|
new_word.append(word[i]) |
|
|
|
i += 1 |
|
word = new_word |
|
return word |
|
else: |
|
return [] |
|
|
|
def encode(self, text: str, add_special_tokens: bool = True) -> List[int]: |
|
|
|
if self.use_sentencepiece: |
|
if add_special_tokens: |
|
return self.sp_model.encode(text, out_type=int) |
|
else: |
|
return self.sp_model.encode_as_ids(text) |
|
else: |
|
tokens = [] |
|
for word in text.split(): |
|
tokens.extend(self.bpe(word)) |
|
token_ids = [self.token_to_id.get(token, self.unk_token_id) for token in tokens] |
|
if add_special_tokens and self.bos_token_id != -1 and self.eos_token_id != -1: |
|
token_ids = [self.bos_token_id] + token_ids + [self.eos_token_id] |
|
return token_ids |
|
|
|
def decode(self, token_ids: List[int]) -> str: |
|
|
|
if self.use_sentencepiece: |
|
return self.sp_model.decode(token_ids) |
|
else: |
|
tokens = [self.id_to_token.get(token_id, self.unk_token) for token_id in token_ids] |
|
return " ".join(tokens) |
|
|
|
|
|
|
|
|
|
def generate_text(model: SmolLM2_360M, tokenizer: SmolLM2Tokenizer, prompt: str, MAX_GENERATION_LENGTH: int = 100, device: torch.device = 'cpu') -> str: |
|
|
|
input_ids = tokenizer.encode(prompt, add_special_tokens=True) |
|
input_ids = torch.tensor([input_ids], dtype=torch.long, device=device) |
|
|
|
past_key_values = None |
|
for _ in range(MAX_GENERATION_LENGTH): |
|
logits, past_key_values = model(input_ids=input_ids, past_key_values=past_key_values) |
|
next_token_logits = logits[:, -1, :] |
|
next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(1) |
|
input_ids = torch.cat([input_ids, next_token_id], dim=1) |
|
if next_token_id.item() == tokenizer.eos_token_id: |
|
break |
|
generated_ids = input_ids[0].tolist() |
|
generated_text = tokenizer.decode(generated_ids) |
|
return generated_text |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
start = time.time() |
|
config_path = "config.json" |
|
weights_path = "model.safetensors" |
|
tokenizer_path = "." |
|
special_tokens_map_path = "special_tokens_map.json" |
|
|
|
config = load_json(config_path) |
|
tokenizer = SmolLM2Tokenizer(tokenizer_path, special_tokens_map_path, config_path) |
|
|
|
model = SmolLM2_360M(config_path) |
|
model.load_weights(weights_path) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
special_tokens = tokenizer.get_special_tokens_dict() |
|
print("Special Tokens:") |
|
for k, v in special_tokens.items(): |
|
print(f"\t{k}: {v}") |
|
|
|
model.to(device, dtype=model.torch_dtype).eval() |
|
|
|
end = timed_step(start, "Model initialization") |
|
|
|
start = time.time() |
|
|
|
for prompt in DEFAULT_PROMPT: |
|
print(f"\nDefault Prompt: {prompt}") |
|
generated_text = generate_text(model, tokenizer, prompt, MAX_GENERATION_LENGTH=MAX_GENERATION_LENGTH, device=device) |
|
print(f"Generated Text: {generated_text}") |
|
end = timed_step(start, "Default Prompt Generation") |
|
|
|
|
|
while True: |
|
user_input = input("\nEnter prompt (or 'exit' to quit, 'hyper' for hyperparameters): ") |
|
if user_input.lower() == "exit": |
|
break |
|
elif "hyper" in user_input.lower(): |
|
print("\nHyperparameters:") |
|
for key, value in config.items(): |
|
print(f"\t{key}: {value}") |
|
else: |
|
start = time.time() |
|
generated_text = generate_text(model, tokenizer, user_input, MAX_GENERATION_LENGTH=MAX_GENERATION_LENGTH, device=device) |
|
print(f"Generated Text: {generated_text}") |
|
end = timed_step(start, "Prompt Generation") |
|
|
|
|