# generate.py (VERSÃO CORRIGIDA PARA COMPATIBILIDADE COM model.py) import torch import argparse from tokenizers import Tokenizer from model import LunarisCodex, LunarisCodexConfig from collections import defaultdict import math def unwrap_model_keys(state_dict): unwrapped = {} prefixes_to_remove = ['_orig_mod.module.', 'module.', '_orig_mod.'] for k, v in state_dict.items(): new_k = k for prefix in prefixes_to_remove: if new_k.startswith(prefix): new_k = new_k[len(prefix):] break unwrapped[new_k] = v return unwrapped def apply_repetition_penalty(logits, input_ids, penalty=1.0): """ Aplica penalidade de repetição aos tokens que já apareceram na sequência. """ if penalty == 1.0: return logits # Conta a frequência de cada token token_counts = defaultdict(int) for token_id in input_ids.flatten(): token_counts[token_id.item()] += 1 # Aplica penalidade baseada na frequência for token_id, count in token_counts.items(): if count > 0: # Se logit é positivo, divide pela penalidade # Se logit é negativo, multiplica pela penalidade if logits[0, token_id] > 0: logits[0, token_id] = logits[0, token_id] / penalty else: logits[0, token_id] = logits[0, token_id] * penalty return logits def apply_frequency_penalty(logits, input_ids, penalty=0.0): """ Aplica penalidade de frequência linear baseada no número de ocorrências. """ if penalty == 0.0: return logits # Conta a frequência de cada token token_counts = defaultdict(int) for token_id in input_ids.flatten(): token_counts[token_id.item()] += 1 # Aplica penalidade linear baseada na frequência for token_id, count in token_counts.items(): if count > 0: logits[0, token_id] = logits[0, token_id] - penalty * count return logits def apply_presence_penalty(logits, input_ids, penalty=0.0): """ Aplica penalidade de presença - penaliza tokens que já apareceram pelo menos uma vez. """ if penalty == 0.0: return logits # Identifica tokens únicos que já apareceram unique_tokens = set(input_ids.flatten().tolist()) # Aplica penalidade fixa para tokens que já apareceram for token_id in unique_tokens: logits[0, token_id] = logits[0, token_id] - penalty return logits def apply_typical_sampling(logits, typical_p=1.0): """ Aplica Typical Sampling - mantém tokens com probabilidade "típica". """ if typical_p >= 1.0: return logits # Calcula probabilidades probs = torch.softmax(logits, dim=-1) # Calcula a entropia entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1) # Calcula surpresa de cada token surprisal = -torch.log(probs + 1e-10) # Mantém apenas tokens com surpresa próxima da entropia deviation = torch.abs(surprisal - entropy.unsqueeze(-1)) # Ordena por desvio e mantém apenas os típicos sorted_indices = torch.argsort(deviation, dim=-1) sorted_probs = probs.gather(-1, sorted_indices) # Acumula probabilidades até atingir typical_p cumulative_probs = torch.cumsum(sorted_probs, dim=-1) # Encontra o cutoff cutoff = torch.searchsorted(cumulative_probs, typical_p) cutoff = torch.clamp(cutoff, min=1) # Mantém pelo menos 1 token # Cria máscara para tokens típicos mask = torch.zeros_like(logits, dtype=torch.bool) for i in range(logits.size(0)): typical_indices = sorted_indices[i, :cutoff[i]] mask[i].scatter_(0, typical_indices, True) # Aplica máscara logits = logits.masked_fill(~mask, -float('inf')) return logits def safe_softmax_sampling(logits, temperature=1.0): """ Aplica softmax e amostragem de forma segura, evitando valores inválidos. """ # Evita temperature zero ou muito próximo de zero if temperature <= 1e-5: temperature = 1e-5 # Aplica temperatura logits = logits / temperature # Remove valores infinitos negativos extremos para evitar underflow logits = torch.clamp(logits, min=-1e4, max=1e4) # Aplica softmax probs = torch.softmax(logits, dim=-1) # Verifica se há valores inválidos if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)) or torch.any(probs < 0): # Fallback: usa distribuição uniforme print("AVISO: Detectados valores inválidos nas probabilidades. Usando distribuição uniforme.") probs = torch.ones_like(probs) / probs.size(-1) # Garante que a soma é 1 probs = probs / probs.sum(dim=-1, keepdim=True) return probs def generate_with_penalties(model, idx, max_new_tokens, temperature=1.0, top_k=None, repetition_penalty=1.0, frequency_penalty=0.0, presence_penalty=0.0, typical_p=1.0, min_length=0): """ Geração de texto com sistema completo de penalidades e verificações de segurança. CORRIGIDO PARA COMPATIBILIDADE COM model.py """ # Tokens especiais que podem indicar fim de sequência eos_tokens = set([0, 1, 2]) # Ajuste conforme necessário original_length = idx.size(1) past_key_values = None # Inicializa o cache KV for step in range(max_new_tokens): # Verifica se excedeu o tamanho máximo da sequência current_len = past_key_values[0][0].shape[-2] if past_key_values else idx.shape[1] if current_len >= model.config.max_seq_len: break # Se há cache, usa apenas o último token como input idx_cond = idx if past_key_values is None else idx[:, -1:] # CORREÇÃO: Usa a assinatura correta do forward do model.py # forward(idx, targets=None, past_key_values=None) -> (logits, loss, new_past_key_values) logits, _, past_key_values = model(idx_cond, targets=None, past_key_values=past_key_values) logits = logits[:, -1, :] # Pega apenas o último token # Aplica penalidades logits = apply_repetition_penalty(logits, idx, repetition_penalty) logits = apply_frequency_penalty(logits, idx, frequency_penalty) logits = apply_presence_penalty(logits, idx, presence_penalty) # Evita tokens de fim se ainda não atingiu o comprimento mínimo current_length = idx.size(1) if current_length - original_length < min_length: for eos_token in eos_tokens: if eos_token < logits.size(-1): logits[0, eos_token] = -float('inf') # Aplica typical sampling logits = apply_typical_sampling(logits, typical_p) # Aplica top-k if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('inf') # Amostragem segura probs = safe_softmax_sampling(logits, temperature) # Amostra o próximo token idx_next = torch.multinomial(probs, num_samples=1) # Adiciona o token à sequência idx = torch.cat((idx, idx_next), dim=1) # Verifica se deve parar (opcional) if idx_next.item() in eos_tokens and current_length - original_length >= min_length: break return idx def generate_optimized(model, idx, max_new_tokens, temperature=1.0, top_k=None): """ Usa o método generate otimizado do modelo (com KV caching automático). Esta é uma alternativa mais eficiente para geração simples sem penalidades. """ with torch.no_grad(): return model.generate(idx, max_new_tokens, temperature, top_k) def main(args): print("--- Iniciando Geração de Texto com Sistema de Penalidade ---") torch.manual_seed(1337) device = torch.device(args.device) print(f"Usando dispositivo: {device}") try: print(f"Carregando checkpoint de: {args.checkpoint_path}") checkpoint = torch.load(args.checkpoint_path, map_location=device, weights_only=False) except FileNotFoundError: print(f"ERRO: Arquivo de checkpoint não encontrado em '{args.checkpoint_path}'") return except Exception as e: print(f"ERRO: Falha ao carregar o checkpoint: {e}") return # CORREÇÃO: Carrega configuração corretamente config = None # Primeiro, tenta diferentes estruturas de configuração if 'config' in checkpoint: config_data = checkpoint['config'] # Caso 1: Configuração aninhada como em train.py if isinstance(config_data, dict) and 'model' in config_data: model_config = config_data['model'] if isinstance(model_config, dict): config = LunarisCodexConfig(**model_config) elif isinstance(model_config, LunarisCodexConfig): config = model_config # Caso 2: Configuração direta como dict elif isinstance(config_data, dict): config = LunarisCodexConfig(**config_data) # Caso 3: Configuração já é um objeto LunarisCodexConfig elif isinstance(config_data, LunarisCodexConfig): config = config_data if config is None: print("ERRO: Não foi possível carregar a configuração do checkpoint") print("Estrutura do checkpoint:", list(checkpoint.keys())) if 'config' in checkpoint: print("Estrutura da config:", type(checkpoint['config'])) if isinstance(checkpoint['config'], dict): print("Chaves da config:", list(checkpoint['config'].keys())) return print(f"Configuração carregada: {config}") model = LunarisCodex(config) unwrapped_state_dict = unwrap_model_keys(checkpoint['model']) model.load_state_dict(unwrapped_state_dict) model.to(device) model.eval() print(f"Carregando tokenizador de: {args.tokenizer_path}") try: tokenizer = Tokenizer.from_file(args.tokenizer_path) except FileNotFoundError: print(f"ERRO: Arquivo de tokenizador não encontrado em '{args.tokenizer_path}'") return # Prepara entrada start_ids = tokenizer.encode(args.prompt).ids x = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...] print("\n" + "="*50) print(f"Prompt: '{args.prompt}'") print("Parâmetros de geração:") print(f" Temperature: {args.temperature}") print(f" Top-k: {args.top_k}") print(f" Repetition penalty: {args.repetition_penalty}") print(f" Frequency penalty: {args.frequency_penalty}") print(f" Presence penalty: {args.presence_penalty}") print(f" Typical-p: {args.typical_p}") print(f" Min length: {args.min_length}") print(f" Use optimized: {args.use_optimized}") print("Gerando texto...") print("="*50) with torch.no_grad(): # Escolhe método de geração if args.use_optimized and (args.repetition_penalty == 1.0 and args.frequency_penalty == 0.0 and args.presence_penalty == 0.0 and args.typical_p == 1.0): # Usa método otimizado do modelo se não há penalidades print("Usando método otimizado (sem penalidades)") y = generate_optimized(model, x, args.max_new_tokens, args.temperature, args.top_k) else: # Usa método com penalidades print("Usando método com penalidades") y = generate_with_penalties( model, x, args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, repetition_penalty=args.repetition_penalty, frequency_penalty=args.frequency_penalty, presence_penalty=args.presence_penalty, typical_p=args.typical_p, min_length=args.min_length ) generated_text = tokenizer.decode(y[0].tolist()) print(generated_text) print("\n--- Geração Concluída ---") if __name__ == '__main__': parser = argparse.ArgumentParser(description="Gerar texto com sistema de penalidades do LunarisCodex.") # Parâmetros básicos parser.add_argument('--checkpoint_path', type=str, required=True, help='Caminho para o arquivo .pt do checkpoint.') parser.add_argument('--tokenizer_path', type=str, default='./lunaris-tokenizer.json', help='Caminho para o arquivo do tokenizador.') parser.add_argument('--prompt', type=str, default='The first step to build a rocket is', help='O texto inicial para o modelo completar.') parser.add_argument('--max_new_tokens', type=int, default=50, help='Número de novos tokens a serem gerados.') parser.add_argument('--device', type=str, default='cpu', help='Dispositivo para rodar a geração (ex: "cpu" ou "cuda").') # Parâmetros de controle de geração parser.add_argument('--temperature', type=float, default=0.6, help='Controla a aleatoriedade. Valores mais altos = mais criativo.') parser.add_argument('--top_k', type=int, default=200, help='Considera apenas os k tokens mais prováveis para amostragem.') # Parâmetros de penalidade parser.add_argument('--repetition_penalty', type=float, default=1.1, help='Penalidade de repetição. 1.0 = sem penalidade, >1.0 = penaliza repetições.') parser.add_argument('--frequency_penalty', type=float, default=0.0, help='Penalidade de frequência. 0.0 = sem penalidade, >0.0 = penaliza tokens frequentes.') parser.add_argument('--presence_penalty', type=float, default=0.0, help='Penalidade de presença. 0.0 = sem penalidade, >0.0 = penaliza tokens já usados.') parser.add_argument('--typical_p', type=float, default=1.0, help='Typical sampling. 1.0 = desabilitado, <1.0 = mantém tokens típicos.') parser.add_argument('--min_length', type=int, default=0, help='Comprimento mínimo antes de permitir tokens de fim.') # Novo parâmetro para usar método otimizado parser.add_argument('--use_optimized', action='store_true', help='Usar método generate otimizado quando não há penalidades.') args = parser.parse_args() main(args)