import gradio as gr import torch import os from tokenizers import Tokenizer from model import LunarisCodex, LunarisCodexConfig from collections import defaultdict import tempfile from huggingface_hub import hf_hub_download import gc # Cache global para modelo e tokenizer _cached_model = None _cached_tokenizer = None _cached_checkpoint = None def unwrap_model_keys(state_dict): """Remove prefixos desnecessários das chaves do state_dict""" unwrapped = {} prefixes_to_remove = ['_orig_mod.module.', 'module.', '_orig_mod.'] for k, v in state_dict.items(): new_k = k for prefix in prefixes_to_remove: if new_k.startswith(prefix): new_k = new_k[len(prefix):] break unwrapped[new_k] = v return unwrapped def apply_repetition_penalty(logits, input_ids, penalty=1.0): """Aplica penalidade de repetição aos tokens que já apareceram na sequência.""" if penalty == 1.0: return logits token_counts = defaultdict(int) for token_id in input_ids.flatten(): token_counts[token_id.item()] += 1 for token_id, count in token_counts.items(): if count > 0: if logits[0, token_id] > 0: logits[0, token_id] = logits[0, token_id] / penalty else: logits[0, token_id] = logits[0, token_id] * penalty return logits def apply_frequency_penalty(logits, input_ids, penalty=0.0): """Aplica penalidade de frequência linear baseada no número de ocorrências.""" if penalty == 0.0: return logits token_counts = defaultdict(int) for token_id in input_ids.flatten(): token_counts[token_id.item()] += 1 for token_id, count in token_counts.items(): if count > 0: logits[0, token_id] = logits[0, token_id] - penalty * count return logits def apply_presence_penalty(logits, input_ids, penalty=0.0): """Aplica penalidade de presença - penaliza tokens que já apareceram pelo menos uma vez.""" if penalty == 0.0: return logits unique_tokens = set(input_ids.flatten().tolist()) for token_id in unique_tokens: logits[0, token_id] = logits[0, token_id] - penalty return logits def safe_softmax_sampling(logits, temperature=1.0): """Aplica softmax e amostragem de forma segura, evitando valores inválidos.""" if temperature <= 1e-5: temperature = 1e-5 logits = logits / temperature logits = torch.clamp(logits, min=-1e4, max=1e4) probs = torch.softmax(logits, dim=-1) if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)) or torch.any(probs < 0): probs = torch.ones_like(probs) / probs.size(-1) probs = probs / probs.sum(dim=-1, keepdim=True) return probs def load_model_and_tokenizer(checkpoint_name): """Carrega modelo e tokenizer, usando cache quando possível""" global _cached_model, _cached_tokenizer, _cached_checkpoint # Se já temos este checkpoint carregado, reutiliza if _cached_checkpoint == checkpoint_name and _cached_model is not None: return _cached_model, _cached_tokenizer # Limpa cache anterior para economizar memória if _cached_model is not None: del _cached_model gc.collect() torch.cuda.empty_cache() if torch.cuda.is_available() else None device = torch.device('cpu') # Force CPU para space gratuito try: # Download do checkpoint checkpoint_path = hf_hub_download( repo_id="meryyllebr543/Lunaris-0.6B-base", filename=f"checkpoints/{checkpoint_name}", cache_dir="./cache" ) checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) # Carrega configuração config = None if 'config' in checkpoint: config_data = checkpoint['config'] if isinstance(config_data, dict) and 'model' in config_data: model_config = config_data['model'] if isinstance(model_config, dict): config = LunarisCodexConfig(**model_config) elif isinstance(model_config, LunarisCodexConfig): config = model_config elif isinstance(config_data, dict): config = LunarisCodexConfig(**config_data) elif isinstance(config_data, LunarisCodexConfig): config = config_data if config is None: raise ValueError("Não foi possível carregar a configuração do checkpoint") # Carrega modelo model = LunarisCodex(config) unwrapped_state_dict = unwrap_model_keys(checkpoint['model']) model.load_state_dict(unwrapped_state_dict) model.to(device) model.eval() # Carrega tokenizer (assume que está no mesmo repo) try: tokenizer_path = hf_hub_download( repo_id="meryyllebr543/Lunaris-0.6B-base", filename="lunaris-tokenizer.json", cache_dir="./cache" ) tokenizer = Tokenizer.from_file(tokenizer_path) except: # Fallback para tokenizer local se existir if os.path.exists("lunaris-tokenizer.json"): tokenizer = Tokenizer.from_file("lunaris-tokenizer.json") else: raise ValueError("Tokenizer não encontrado") # Atualiza cache _cached_model = model _cached_tokenizer = tokenizer _cached_checkpoint = checkpoint_name return model, tokenizer except Exception as e: raise gr.Error(f"Erro ao carregar modelo: {str(e)}") def generate_text( checkpoint_name, prompt, max_new_tokens, temperature, top_k, repetition_penalty, frequency_penalty, presence_penalty, min_length ): """Função principal de geração de texto""" if not prompt.strip(): return "Por favor, insira um prompt." try: # Carrega modelo e tokenizer model, tokenizer = load_model_and_tokenizer(checkpoint_name) device = next(model.parameters()).device # Prepara entrada start_ids = tokenizer.encode(prompt).ids x = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...] # Tokens especiais de fim eos_tokens = set([0, 1, 2]) original_length = x.size(1) past_key_values = None with torch.no_grad(): for step in range(max_new_tokens): # Verifica limite de sequência current_len = past_key_values[0][0].shape[-2] if past_key_values else x.shape[1] if current_len >= model.config.max_seq_len: break # Usa cache KV se disponível idx_cond = x if past_key_values is None else x[:, -1:] # Forward pass logits, _, past_key_values = model(idx_cond, targets=None, past_key_values=past_key_values) logits = logits[:, -1, :] # Aplica penalidades logits = apply_repetition_penalty(logits, x, repetition_penalty) logits = apply_frequency_penalty(logits, x, frequency_penalty) logits = apply_presence_penalty(logits, x, presence_penalty) # Evita tokens de fim se não atingiu comprimento mínimo current_length = x.size(1) if current_length - original_length < min_length: for eos_token in eos_tokens: if eos_token < logits.size(-1): logits[0, eos_token] = -float('inf') # Aplica top-k if top_k is not None and top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('inf') # Amostragem probs = safe_softmax_sampling(logits, temperature) idx_next = torch.multinomial(probs, num_samples=1) # Adiciona token x = torch.cat((x, idx_next), dim=1) # Verifica parada if idx_next.item() in eos_tokens and current_length - original_length >= min_length: break # Decodifica resultado generated_text = tokenizer.decode(x[0].tolist()) return generated_text except Exception as e: return f"Erro durante a geração: {str(e)}" # ---- UI: versão repaginada, mantendo a mesma lógica/funções ---- def _clear_fields(): return "", "" def create_interface(): # Lista de checkpoints disponíveis (mantido) checkpoints = [f"ckpt_{i}.pt" for i in range(1000, 21000, 1000)] custom_css = """ :root { --brand: #6d28d9; /* roxo elegante */ --brand-600: #7c3aed; --bg: #0b0b10; --card: rgba(255,255,255,0.04); --border: rgba(255,255,255,0.12); } .gradio-container {max-width: 1100px !important; margin: 0 auto;} /* Cabeçalho */ .lc-header {text-align:center; padding: 24px 0 8px;} .lc-title {font-size: 26px; font-weight: 800; letter-spacing: .2px; margin: 0;} .lc-sub {opacity: .85; margin-top: 6px;} /* Cartões / painéis */ .lc-card {background: var(--card); border: 1px solid var(--border); border-radius: 16px; padding: 16px;} .lc-footer {text-align:center; opacity:.8;} .gr-button.primary {background: var(--brand) !important; border-color: var(--brand) !important;} .gr-button.primary:hover {filter: brightness(1.1);} .gr-accordion .label-wrap {font-weight:600} /* Layout responsivo */ @media (max-width: 900px){ .gradio-container {padding: 0 10px;} } """ theme = gr.themes.Soft(primary_hue=gr.themes.colors.purple, neutral_hue=gr.themes.colors.gray) with gr.Blocks(title="🌙 Lunaris-0.6B Text Generation", theme=theme, css=custom_css) as demo: # Header gr.HTML( """