Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import os | |
from tokenizers import Tokenizer | |
from model import LunarisCodex, LunarisCodexConfig | |
from collections import defaultdict | |
import tempfile | |
from huggingface_hub import hf_hub_download | |
import gc | |
# Cache global para modelo e tokenizer | |
_cached_model = None | |
_cached_tokenizer = None | |
_cached_checkpoint = None | |
def unwrap_model_keys(state_dict): | |
"""Remove prefixos desnecessários das chaves do state_dict""" | |
unwrapped = {} | |
prefixes_to_remove = ['_orig_mod.module.', 'module.', '_orig_mod.'] | |
for k, v in state_dict.items(): | |
new_k = k | |
for prefix in prefixes_to_remove: | |
if new_k.startswith(prefix): | |
new_k = new_k[len(prefix):] | |
break | |
unwrapped[new_k] = v | |
return unwrapped | |
def apply_repetition_penalty(logits, input_ids, penalty=1.0): | |
"""Aplica penalidade de repetição aos tokens que já apareceram na sequência.""" | |
if penalty == 1.0: | |
return logits | |
token_counts = defaultdict(int) | |
for token_id in input_ids.flatten(): | |
token_counts[token_id.item()] += 1 | |
for token_id, count in token_counts.items(): | |
if count > 0: | |
if logits[0, token_id] > 0: | |
logits[0, token_id] = logits[0, token_id] / penalty | |
else: | |
logits[0, token_id] = logits[0, token_id] * penalty | |
return logits | |
def apply_frequency_penalty(logits, input_ids, penalty=0.0): | |
"""Aplica penalidade de frequência linear baseada no número de ocorrências.""" | |
if penalty == 0.0: | |
return logits | |
token_counts = defaultdict(int) | |
for token_id in input_ids.flatten(): | |
token_counts[token_id.item()] += 1 | |
for token_id, count in token_counts.items(): | |
if count > 0: | |
logits[0, token_id] = logits[0, token_id] - penalty * count | |
return logits | |
def apply_presence_penalty(logits, input_ids, penalty=0.0): | |
"""Aplica penalidade de presença - penaliza tokens que já apareceram pelo menos uma vez.""" | |
if penalty == 0.0: | |
return logits | |
unique_tokens = set(input_ids.flatten().tolist()) | |
for token_id in unique_tokens: | |
logits[0, token_id] = logits[0, token_id] - penalty | |
return logits | |
def safe_softmax_sampling(logits, temperature=1.0): | |
"""Aplica softmax e amostragem de forma segura, evitando valores inválidos.""" | |
if temperature <= 1e-5: | |
temperature = 1e-5 | |
logits = logits / temperature | |
logits = torch.clamp(logits, min=-1e4, max=1e4) | |
probs = torch.softmax(logits, dim=-1) | |
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)) or torch.any(probs < 0): | |
probs = torch.ones_like(probs) / probs.size(-1) | |
probs = probs / probs.sum(dim=-1, keepdim=True) | |
return probs | |
def load_model_and_tokenizer(checkpoint_name): | |
"""Carrega modelo e tokenizer, usando cache quando possível""" | |
global _cached_model, _cached_tokenizer, _cached_checkpoint | |
# Se já temos este checkpoint carregado, reutiliza | |
if _cached_checkpoint == checkpoint_name and _cached_model is not None: | |
return _cached_model, _cached_tokenizer | |
# Limpa cache anterior para economizar memória | |
if _cached_model is not None: | |
del _cached_model | |
gc.collect() | |
torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
device = torch.device('cpu') # Force CPU para space gratuito | |
try: | |
# Download do checkpoint | |
checkpoint_path = hf_hub_download( | |
repo_id="meryyllebr543/Lunaris-0.6B-base", | |
filename=f"checkpoints/{checkpoint_name}", | |
cache_dir="./cache" | |
) | |
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) | |
# Carrega configuração | |
config = None | |
if 'config' in checkpoint: | |
config_data = checkpoint['config'] | |
if isinstance(config_data, dict) and 'model' in config_data: | |
model_config = config_data['model'] | |
if isinstance(model_config, dict): | |
config = LunarisCodexConfig(**model_config) | |
elif isinstance(model_config, LunarisCodexConfig): | |
config = model_config | |
elif isinstance(config_data, dict): | |
config = LunarisCodexConfig(**config_data) | |
elif isinstance(config_data, LunarisCodexConfig): | |
config = config_data | |
if config is None: | |
raise ValueError("Não foi possível carregar a configuração do checkpoint") | |
# Carrega modelo | |
model = LunarisCodex(config) | |
unwrapped_state_dict = unwrap_model_keys(checkpoint['model']) | |
model.load_state_dict(unwrapped_state_dict) | |
model.to(device) | |
model.eval() | |
# Carrega tokenizer (assume que está no mesmo repo) | |
try: | |
tokenizer_path = hf_hub_download( | |
repo_id="meryyllebr543/Lunaris-0.6B-base", | |
filename="lunaris-tokenizer.json", | |
cache_dir="./cache" | |
) | |
tokenizer = Tokenizer.from_file(tokenizer_path) | |
except: | |
# Fallback para tokenizer local se existir | |
if os.path.exists("lunaris-tokenizer.json"): | |
tokenizer = Tokenizer.from_file("lunaris-tokenizer.json") | |
else: | |
raise ValueError("Tokenizer não encontrado") | |
# Atualiza cache | |
_cached_model = model | |
_cached_tokenizer = tokenizer | |
_cached_checkpoint = checkpoint_name | |
return model, tokenizer | |
except Exception as e: | |
raise gr.Error(f"Erro ao carregar modelo: {str(e)}") | |
def generate_text( | |
checkpoint_name, | |
prompt, | |
max_new_tokens, | |
temperature, | |
top_k, | |
repetition_penalty, | |
frequency_penalty, | |
presence_penalty, | |
min_length | |
): | |
"""Função principal de geração de texto""" | |
if not prompt.strip(): | |
return "Por favor, insira um prompt." | |
try: | |
# Carrega modelo e tokenizer | |
model, tokenizer = load_model_and_tokenizer(checkpoint_name) | |
device = next(model.parameters()).device | |
# Prepara entrada | |
start_ids = tokenizer.encode(prompt).ids | |
x = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...] | |
# Tokens especiais de fim | |
eos_tokens = set([0, 1, 2]) | |
original_length = x.size(1) | |
past_key_values = None | |
with torch.no_grad(): | |
for step in range(max_new_tokens): | |
# Verifica limite de sequência | |
current_len = past_key_values[0][0].shape[-2] if past_key_values else x.shape[1] | |
if current_len >= model.config.max_seq_len: | |
break | |
# Usa cache KV se disponível | |
idx_cond = x if past_key_values is None else x[:, -1:] | |
# Forward pass | |
logits, _, past_key_values = model(idx_cond, targets=None, past_key_values=past_key_values) | |
logits = logits[:, -1, :] | |
# Aplica penalidades | |
logits = apply_repetition_penalty(logits, x, repetition_penalty) | |
logits = apply_frequency_penalty(logits, x, frequency_penalty) | |
logits = apply_presence_penalty(logits, x, presence_penalty) | |
# Evita tokens de fim se não atingiu comprimento mínimo | |
current_length = x.size(1) | |
if current_length - original_length < min_length: | |
for eos_token in eos_tokens: | |
if eos_token < logits.size(-1): | |
logits[0, eos_token] = -float('inf') | |
# Aplica top-k | |
if top_k is not None and top_k > 0: | |
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
logits[logits < v[:, [-1]]] = -float('inf') | |
# Amostragem | |
probs = safe_softmax_sampling(logits, temperature) | |
idx_next = torch.multinomial(probs, num_samples=1) | |
# Adiciona token | |
x = torch.cat((x, idx_next), dim=1) | |
# Verifica parada | |
if idx_next.item() in eos_tokens and current_length - original_length >= min_length: | |
break | |
# Decodifica resultado | |
generated_text = tokenizer.decode(x[0].tolist()) | |
return generated_text | |
except Exception as e: | |
return f"Erro durante a geração: {str(e)}" | |
# ---- UI: versão repaginada, mantendo a mesma lógica/funções ---- | |
def _clear_fields(): | |
return "", "" | |
def create_interface(): | |
# Lista de checkpoints disponíveis (mantido) | |
checkpoints = [f"ckpt_{i}.pt" for i in range(1000, 21000, 1000)] | |
custom_css = """ | |
:root { | |
--brand: #6d28d9; /* roxo elegante */ | |
--brand-600: #7c3aed; | |
--bg: #0b0b10; | |
--card: rgba(255,255,255,0.04); | |
--border: rgba(255,255,255,0.12); | |
} | |
.gradio-container {max-width: 1100px !important; margin: 0 auto;} | |
/* Cabeçalho */ | |
.lc-header {text-align:center; padding: 24px 0 8px;} | |
.lc-title {font-size: 26px; font-weight: 800; letter-spacing: .2px; margin: 0;} | |
.lc-sub {opacity: .85; margin-top: 6px;} | |
/* Cartões / painéis */ | |
.lc-card {background: var(--card); border: 1px solid var(--border); border-radius: 16px; padding: 16px;} | |
.lc-footer {text-align:center; opacity:.8;} | |
.gr-button.primary {background: var(--brand) !important; border-color: var(--brand) !important;} | |
.gr-button.primary:hover {filter: brightness(1.1);} | |
.gr-accordion .label-wrap {font-weight:600} | |
/* Layout responsivo */ | |
@media (max-width: 900px){ | |
.gradio-container {padding: 0 10px;} | |
} | |
""" | |
theme = gr.themes.Soft(primary_hue=gr.themes.colors.purple, neutral_hue=gr.themes.colors.gray) | |
with gr.Blocks(title="🌙 Lunaris-0.6B Text Generation", theme=theme, css=custom_css) as demo: | |
# Header | |
gr.HTML( | |
""" | |
<div class="lc-header"> | |
<h1 class="lc-title">🌙 Lunaris 0.6B — Text Generation</h1> | |
<div class="lc-sub">Uma interface enxuta e profissional para gerar texto com seus checkpoints</div> | |
</div> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=5): | |
with gr.Group(elem_classes=["lc-card"]): | |
prompt_input = gr.Textbox( | |
lines=6, | |
placeholder="Digite seu prompt aqui…", | |
label="Prompt", | |
value="The first step to build a rocket is", | |
autofocus=True, | |
show_label=True, | |
) | |
with gr.Row(): | |
generate_btn = gr.Button("Gerar texto", variant="primary", size="lg") | |
clear_btn = gr.Button("Limpar") | |
output_text = gr.Textbox( | |
lines=16, | |
label="Texto gerado", | |
interactive=False, | |
show_copy_button=True, | |
) | |
with gr.Column(scale=4): | |
with gr.Group(elem_classes=["lc-card"]): | |
checkpoint_dropdown = gr.Dropdown( | |
choices=checkpoints, | |
value="ckpt_20000.pt", | |
label="Checkpoint do modelo", | |
info="Selecione qual checkpoint usar" | |
) | |
with gr.Row(): | |
max_tokens = gr.Slider(minimum=10, maximum=500, value=100, step=10, label="Máx. tokens") | |
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature", info="Mais alto = mais aleatório") | |
with gr.Row(): | |
top_k = gr.Slider(minimum=1, maximum=500, value=50, step=10, label="Top‑k", info="Filtra para as k melhores opções") | |
min_length = gr.Slider(minimum=0, maximum=100, value=0, step=5, label="Comprimento mínimo") | |
with gr.Accordion("Configurações avançadas", open=False): | |
repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.1, label="Penalidade de repetição", info=">1 reduz loops") | |
frequency_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="Penalidade de frequência") | |
presence_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="Penalidade de presença") | |
with gr.Group(elem_classes=["lc-card"]): | |
gr.Markdown("**Dicas:** Use *temperature* alto para ser criativo e *top‑k* menor para respostas mais focadas.") | |
gr.Examples( | |
examples=[ | |
["ckpt_20000.pt", "The first step to build a rocket is", 100, 0.7, 50, 1.1, 0.0, 0.0, 0], | |
["ckpt_15000.pt", "Once upon a time in a magical forest", 150, 0.8, 40, 1.2, 0.1, 0.0, 10], | |
["ckpt_10000.pt", "To create artificial intelligence, we need", 120, 0.6, 60, 1.0, 0.0, 0.1, 5], | |
["ckpt_20000.pt", "The future of technology will be", 80, 0.9, 30, 1.3, 0.2, 0.1, 0], | |
], | |
inputs=[ | |
checkpoint_dropdown, prompt_input, max_tokens, temperature, | |
top_k, repetition_penalty, frequency_penalty, presence_penalty, min_length | |
], | |
label="Exemplos" | |
) | |
gr.HTML( | |
""" | |
<div class="lc-footer" style="margin-top:12px"> | |
⚡ Execução em CPU • 🤗 <a href="https://huggingface.co/meryyllebr543/Lunaris-0.6B-base" target="_blank">Lunaris‑0.6B‑base</a> | |
</div> | |
""" | |
) | |
# Ações: mantém mesma função e ordem de parâmetros | |
input_components = [ | |
checkpoint_dropdown, prompt_input, max_tokens, temperature, | |
top_k, repetition_penalty, frequency_penalty, presence_penalty, min_length | |
] | |
generate_btn.click( | |
fn=generate_text, | |
inputs=input_components, | |
outputs=output_text, | |
show_progress=True, | |
) | |
# Enviar com Enter | |
prompt_input.submit( | |
fn=generate_text, | |
inputs=input_components, | |
outputs=output_text, | |
show_progress=True, | |
) | |
# Limpar prompt e saída | |
clear_btn.click(fn=_clear_fields, inputs=None, outputs=[prompt_input, output_text]) | |
return demo | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
show_error=True | |
) | |