LegalGPT / app.py
yasserrmd's picture
Create app.py
19c5106 verified
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
import tiktoken
import math
# Paste your full GPT code here (copy your GPTConfig, LayerNorm, CausalSelfAttention, MLP, Block, GPT classes)
# For brevity, assuming GPTConfig and GPT are defined here exactly as your code.
@dataclass
class GPTConfig:
block_size: int
vocab_size: int
n_layer: int
n_head: int
n_embd: int
dropout: float = 0.1
bias: bool = True
class LayerNorm(nn.Module):
def __init__(self, ndim, bias):
super().__init__()
self.weight = nn.Parameter(torch.ones(ndim))
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
def forward(self, x):
return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.n_head = config.n_head
self.n_embd = config.n_embd
self.flash = hasattr(F, 'scaled_dot_product_attention')
if not self.flash:
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
def forward(self, x):
B, T, C = x.size()
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
if self.flash:
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True)
else:
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.resid_dropout(self.c_proj(y))
return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
self.gelu = nn.GELU()
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln1 = LayerNorm(config.n_embd, config.bias)
self.attn = CausalSelfAttention(config)
self.ln2 = LayerNorm(config.n_embd, config.bias)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict(dict(
wte=nn.Embedding(config.vocab_size, config.n_embd),
wpe=nn.Embedding(config.block_size, config.n_embd),
drop=nn.Dropout(config.dropout),
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f=LayerNorm(config.n_embd, config.bias),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer.wte.weight = self.lm_head.weight # weight tying
self.apply(self._init_weights)
for pn, p in self.named_parameters():
if pn.endswith('c_proj.weight'):
nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
device = idx.device
b, t = idx.size()
assert t <= self.config.block_size
pos = torch.arange(0, t, dtype=torch.long, device=device)
tok_emb = self.transformer.wte(idx)
pos_emb = self.transformer.wpe(pos)
x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
if targets is not None:
logits = self.lm_head(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
return logits, loss
else:
logits = self.lm_head(x[:, [-1], :])
return logits, None
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, top_p=None):
for _ in range(max_new_tokens):
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
if top_p is not None:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[:, indices_to_remove] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
# --- Load checkpoint and tokenizer ---
checkpoint_path = "best_model_params.pt" # update path if needed
config = GPTConfig(
vocab_size=50257,
block_size=128,
n_layer=6,
n_head=6,
n_embd=384,
dropout=0.1,
bias=True,
)
model = GPT(config)
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
model.eval()
enc = tiktoken.get_encoding("gpt2")
# --- Gradio interface ---
samples = [
"The Fourth Amendment protects citizens against unreasonable searches and seizures.",
"Under the doctrine of stare decisis, courts follow precedent to ensure legal consistency.",
"The Commerce Clause grants Congress the power to regulate interstate commerce.",
"Due process requires that the government respect all legal rights owed to a person.",
"The principle of double jeopardy prevents a defendant from being tried twice for the same offense."
]
def generate_text(prompt, max_new_tokens=150, temperature=0.7, top_k=50, top_p=0.9):
input_ids = torch.tensor(enc.encode_ordinary(prompt)).unsqueeze(0)
with torch.no_grad():
output_ids = model.generate(
input_ids, max_new_tokens=max_new_tokens,
temperature=temperature, top_k=top_k, top_p=top_p
)
generated = enc.decode(output_ids.squeeze().tolist())
return generated
import gradio as gr
with gr.Blocks() as demo:
gr.Markdown("# Legal GPT Text Generation Demo")
sample_dropdown = gr.Dropdown(label="Sample prompts", choices=samples, value=samples[0])
prompt_input = gr.Textbox(label="Input Prompt", lines=3, value=samples[0])
def update_prompt(selected):
return selected
sample_dropdown.change(update_prompt, inputs=sample_dropdown, outputs=prompt_input)
generate_button = gr.Button("Generate Text")
output_text = gr.Textbox(label="Generated Output", lines=15)
generate_button.click(generate_text, inputs=prompt_input, outputs=output_text)
demo.launch()