meryyllebr543's picture
Update app.py
fa1dc68 verified
import gradio as gr
import torch
import os
from tokenizers import Tokenizer
from model import LunarisCodex, LunarisCodexConfig
from collections import defaultdict
import tempfile
from huggingface_hub import hf_hub_download
import gc
# Cache global para modelo e tokenizer
_cached_model = None
_cached_tokenizer = None
_cached_checkpoint = None
def unwrap_model_keys(state_dict):
"""Remove prefixos desnecessários das chaves do state_dict"""
unwrapped = {}
prefixes_to_remove = ['_orig_mod.module.', 'module.', '_orig_mod.']
for k, v in state_dict.items():
new_k = k
for prefix in prefixes_to_remove:
if new_k.startswith(prefix):
new_k = new_k[len(prefix):]
break
unwrapped[new_k] = v
return unwrapped
def apply_repetition_penalty(logits, input_ids, penalty=1.0):
"""Aplica penalidade de repetição aos tokens que já apareceram na sequência."""
if penalty == 1.0:
return logits
token_counts = defaultdict(int)
for token_id in input_ids.flatten():
token_counts[token_id.item()] += 1
for token_id, count in token_counts.items():
if count > 0:
if logits[0, token_id] > 0:
logits[0, token_id] = logits[0, token_id] / penalty
else:
logits[0, token_id] = logits[0, token_id] * penalty
return logits
def apply_frequency_penalty(logits, input_ids, penalty=0.0):
"""Aplica penalidade de frequência linear baseada no número de ocorrências."""
if penalty == 0.0:
return logits
token_counts = defaultdict(int)
for token_id in input_ids.flatten():
token_counts[token_id.item()] += 1
for token_id, count in token_counts.items():
if count > 0:
logits[0, token_id] = logits[0, token_id] - penalty * count
return logits
def apply_presence_penalty(logits, input_ids, penalty=0.0):
"""Aplica penalidade de presença - penaliza tokens que já apareceram pelo menos uma vez."""
if penalty == 0.0:
return logits
unique_tokens = set(input_ids.flatten().tolist())
for token_id in unique_tokens:
logits[0, token_id] = logits[0, token_id] - penalty
return logits
def safe_softmax_sampling(logits, temperature=1.0):
"""Aplica softmax e amostragem de forma segura, evitando valores inválidos."""
if temperature <= 1e-5:
temperature = 1e-5
logits = logits / temperature
logits = torch.clamp(logits, min=-1e4, max=1e4)
probs = torch.softmax(logits, dim=-1)
if torch.any(torch.isnan(probs)) or torch.any(torch.isinf(probs)) or torch.any(probs < 0):
probs = torch.ones_like(probs) / probs.size(-1)
probs = probs / probs.sum(dim=-1, keepdim=True)
return probs
def load_model_and_tokenizer(checkpoint_name):
"""Carrega modelo e tokenizer, usando cache quando possível"""
global _cached_model, _cached_tokenizer, _cached_checkpoint
# Se já temos este checkpoint carregado, reutiliza
if _cached_checkpoint == checkpoint_name and _cached_model is not None:
return _cached_model, _cached_tokenizer
# Limpa cache anterior para economizar memória
if _cached_model is not None:
del _cached_model
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
device = torch.device('cpu') # Force CPU para space gratuito
try:
# Download do checkpoint
checkpoint_path = hf_hub_download(
repo_id="meryyllebr543/Lunaris-0.6B-base",
filename=f"checkpoints/{checkpoint_name}",
cache_dir="./cache"
)
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
# Carrega configuração
config = None
if 'config' in checkpoint:
config_data = checkpoint['config']
if isinstance(config_data, dict) and 'model' in config_data:
model_config = config_data['model']
if isinstance(model_config, dict):
config = LunarisCodexConfig(**model_config)
elif isinstance(model_config, LunarisCodexConfig):
config = model_config
elif isinstance(config_data, dict):
config = LunarisCodexConfig(**config_data)
elif isinstance(config_data, LunarisCodexConfig):
config = config_data
if config is None:
raise ValueError("Não foi possível carregar a configuração do checkpoint")
# Carrega modelo
model = LunarisCodex(config)
unwrapped_state_dict = unwrap_model_keys(checkpoint['model'])
model.load_state_dict(unwrapped_state_dict)
model.to(device)
model.eval()
# Carrega tokenizer (assume que está no mesmo repo)
try:
tokenizer_path = hf_hub_download(
repo_id="meryyllebr543/Lunaris-0.6B-base",
filename="lunaris-tokenizer.json",
cache_dir="./cache"
)
tokenizer = Tokenizer.from_file(tokenizer_path)
except:
# Fallback para tokenizer local se existir
if os.path.exists("lunaris-tokenizer.json"):
tokenizer = Tokenizer.from_file("lunaris-tokenizer.json")
else:
raise ValueError("Tokenizer não encontrado")
# Atualiza cache
_cached_model = model
_cached_tokenizer = tokenizer
_cached_checkpoint = checkpoint_name
return model, tokenizer
except Exception as e:
raise gr.Error(f"Erro ao carregar modelo: {str(e)}")
def generate_text(
checkpoint_name,
prompt,
max_new_tokens,
temperature,
top_k,
repetition_penalty,
frequency_penalty,
presence_penalty,
min_length
):
"""Função principal de geração de texto"""
if not prompt.strip():
return "Por favor, insira um prompt."
try:
# Carrega modelo e tokenizer
model, tokenizer = load_model_and_tokenizer(checkpoint_name)
device = next(model.parameters()).device
# Prepara entrada
start_ids = tokenizer.encode(prompt).ids
x = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]
# Tokens especiais de fim
eos_tokens = set([0, 1, 2])
original_length = x.size(1)
past_key_values = None
with torch.no_grad():
for step in range(max_new_tokens):
# Verifica limite de sequência
current_len = past_key_values[0][0].shape[-2] if past_key_values else x.shape[1]
if current_len >= model.config.max_seq_len:
break
# Usa cache KV se disponível
idx_cond = x if past_key_values is None else x[:, -1:]
# Forward pass
logits, _, past_key_values = model(idx_cond, targets=None, past_key_values=past_key_values)
logits = logits[:, -1, :]
# Aplica penalidades
logits = apply_repetition_penalty(logits, x, repetition_penalty)
logits = apply_frequency_penalty(logits, x, frequency_penalty)
logits = apply_presence_penalty(logits, x, presence_penalty)
# Evita tokens de fim se não atingiu comprimento mínimo
current_length = x.size(1)
if current_length - original_length < min_length:
for eos_token in eos_tokens:
if eos_token < logits.size(-1):
logits[0, eos_token] = -float('inf')
# Aplica top-k
if top_k is not None and top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('inf')
# Amostragem
probs = safe_softmax_sampling(logits, temperature)
idx_next = torch.multinomial(probs, num_samples=1)
# Adiciona token
x = torch.cat((x, idx_next), dim=1)
# Verifica parada
if idx_next.item() in eos_tokens and current_length - original_length >= min_length:
break
# Decodifica resultado
generated_text = tokenizer.decode(x[0].tolist())
return generated_text
except Exception as e:
return f"Erro durante a geração: {str(e)}"
# ---- UI: versão repaginada, mantendo a mesma lógica/funções ----
def _clear_fields():
return "", ""
def create_interface():
# Lista de checkpoints disponíveis (mantido)
checkpoints = [f"ckpt_{i}.pt" for i in range(1000, 21000, 1000)]
custom_css = """
:root {
--brand: #6d28d9; /* roxo elegante */
--brand-600: #7c3aed;
--bg: #0b0b10;
--card: rgba(255,255,255,0.04);
--border: rgba(255,255,255,0.12);
}
.gradio-container {max-width: 1100px !important; margin: 0 auto;}
/* Cabeçalho */
.lc-header {text-align:center; padding: 24px 0 8px;}
.lc-title {font-size: 26px; font-weight: 800; letter-spacing: .2px; margin: 0;}
.lc-sub {opacity: .85; margin-top: 6px;}
/* Cartões / painéis */
.lc-card {background: var(--card); border: 1px solid var(--border); border-radius: 16px; padding: 16px;}
.lc-footer {text-align:center; opacity:.8;}
.gr-button.primary {background: var(--brand) !important; border-color: var(--brand) !important;}
.gr-button.primary:hover {filter: brightness(1.1);}
.gr-accordion .label-wrap {font-weight:600}
/* Layout responsivo */
@media (max-width: 900px){
.gradio-container {padding: 0 10px;}
}
"""
theme = gr.themes.Soft(primary_hue=gr.themes.colors.purple, neutral_hue=gr.themes.colors.gray)
with gr.Blocks(title="🌙 Lunaris-0.6B Text Generation", theme=theme, css=custom_css) as demo:
# Header
gr.HTML(
"""
<div class="lc-header">
<h1 class="lc-title">🌙 Lunaris 0.6B — Text Generation</h1>
<div class="lc-sub">Uma interface enxuta e profissional para gerar texto com seus checkpoints</div>
</div>
"""
)
with gr.Row():
with gr.Column(scale=5):
with gr.Group(elem_classes=["lc-card"]):
prompt_input = gr.Textbox(
lines=6,
placeholder="Digite seu prompt aqui…",
label="Prompt",
value="The first step to build a rocket is",
autofocus=True,
show_label=True,
)
with gr.Row():
generate_btn = gr.Button("Gerar texto", variant="primary", size="lg")
clear_btn = gr.Button("Limpar")
output_text = gr.Textbox(
lines=16,
label="Texto gerado",
interactive=False,
show_copy_button=True,
)
with gr.Column(scale=4):
with gr.Group(elem_classes=["lc-card"]):
checkpoint_dropdown = gr.Dropdown(
choices=checkpoints,
value="ckpt_20000.pt",
label="Checkpoint do modelo",
info="Selecione qual checkpoint usar"
)
with gr.Row():
max_tokens = gr.Slider(minimum=10, maximum=500, value=100, step=10, label="Máx. tokens")
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature", info="Mais alto = mais aleatório")
with gr.Row():
top_k = gr.Slider(minimum=1, maximum=500, value=50, step=10, label="Top‑k", info="Filtra para as k melhores opções")
min_length = gr.Slider(minimum=0, maximum=100, value=0, step=5, label="Comprimento mínimo")
with gr.Accordion("Configurações avançadas", open=False):
repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.1, label="Penalidade de repetição", info=">1 reduz loops")
frequency_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="Penalidade de frequência")
presence_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="Penalidade de presença")
with gr.Group(elem_classes=["lc-card"]):
gr.Markdown("**Dicas:** Use *temperature* alto para ser criativo e *top‑k* menor para respostas mais focadas.")
gr.Examples(
examples=[
["ckpt_20000.pt", "The first step to build a rocket is", 100, 0.7, 50, 1.1, 0.0, 0.0, 0],
["ckpt_15000.pt", "Once upon a time in a magical forest", 150, 0.8, 40, 1.2, 0.1, 0.0, 10],
["ckpt_10000.pt", "To create artificial intelligence, we need", 120, 0.6, 60, 1.0, 0.0, 0.1, 5],
["ckpt_20000.pt", "The future of technology will be", 80, 0.9, 30, 1.3, 0.2, 0.1, 0],
],
inputs=[
checkpoint_dropdown, prompt_input, max_tokens, temperature,
top_k, repetition_penalty, frequency_penalty, presence_penalty, min_length
],
label="Exemplos"
)
gr.HTML(
"""
<div class="lc-footer" style="margin-top:12px">
⚡ Execução em CPU • 🤗 <a href="https://huggingface.co/meryyllebr543/Lunaris-0.6B-base" target="_blank">Lunaris‑0.6B‑base</a>
</div>
"""
)
# Ações: mantém mesma função e ordem de parâmetros
input_components = [
checkpoint_dropdown, prompt_input, max_tokens, temperature,
top_k, repetition_penalty, frequency_penalty, presence_penalty, min_length
]
generate_btn.click(
fn=generate_text,
inputs=input_components,
outputs=output_text,
show_progress=True,
)
# Enviar com Enter
prompt_input.submit(
fn=generate_text,
inputs=input_components,
outputs=output_text,
show_progress=True,
)
# Limpar prompt e saída
clear_btn.click(fn=_clear_fields, inputs=None, outputs=[prompt_input, output_text])
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)