ZivK commited on
Commit
781bf2a
·
1 Parent(s): 4b4d15d

Added the full interface

Browse files
Files changed (2) hide show
  1. app.py +44 -0
  2. model.py +95 -0
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from model import SmolLM
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ hf_token = os.environ.get("HF_TOKEN")
8
+ repo_id = "ZivK/smollm2-end-of-sentence"
9
+ model_options = {
10
+ "Word-level Model": "word_model.ckpt",
11
+ "Token-level Model": "token_model.ckpt"
12
+ }
13
+ models = {}
14
+ for model_name, filename in model_options.items():
15
+ print(f"Loading {model_name} ...")
16
+ checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename, token=hf_token)
17
+ models[model_name] = SmolLM.load_from_checkpoint(checkpoint_path)
18
+
19
+
20
+ def classify_sentence(sentence, model_choice):
21
+ model = models[model_choice]
22
+ inputs = model.tokenizer(sentence, return_tensors="pt", padding=True, truncation=True)
23
+ logits = model(inputs)
24
+ confidence = torch.sigmoid(logits).item() * 100
25
+ confidence_to_display = confidence if confidence > 50.0 else 100 - confidence
26
+ label = "Complete" if confidence > 50.0 else "Incomplete"
27
+
28
+ return f"{label} Sentence\nConfidence: {confidence_to_display:.2f}"
29
+
30
+
31
+ # Create the Gradio interface
32
+ interface = gr.Interface(
33
+ fn=classify_sentence,
34
+ inputs=[
35
+ gr.Textbox(lines=1, placeholder="Enter your sentence here..."),
36
+ gr.Dropdown(choices=list(model_options.keys()), label="Select Model")
37
+ ],
38
+ outputs="text",
39
+ title="Complete Sentence Classifier",
40
+ description="## Enter a sentence to determine if it's complete or if it might be cut off"
41
+ )
42
+
43
+ # Launch the demo
44
+ interface.launch()
model.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ from peft import LoraConfig, get_peft_model
4
+ from torch import nn as nn
5
+ from torchmetrics import Accuracy
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+
8
+
9
+ base_checkpoint = "HuggingFaceTB/SmolLM2-360M"
10
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
11
+ criterion = nn.BCEWithLogitsLoss()
12
+
13
+
14
+ class SmolLM(pl.LightningModule):
15
+ def __init__(self, learning_rate=3e-4):
16
+ super().__init__()
17
+ self.learning_rate = learning_rate
18
+ self.criterion = criterion
19
+ self.tokenizer = AutoTokenizer.from_pretrained(base_checkpoint)
20
+ self.tokenizer.pad_token = self.tokenizer.eos_token
21
+ self.base_model = AutoModelForCausalLM.from_pretrained(base_checkpoint).to(device)
22
+ self.base_model.lm_head = nn.Identity()
23
+ self.classifier = nn.Sequential(
24
+ # nn.Linear(self.base_model.lm_head.out_features, 1024),
25
+ nn.Linear(960, 128),
26
+ nn.ReLU(),
27
+ nn.Linear(128, 1),
28
+ )
29
+ # Freeze smollm2 parameters
30
+ for param in self.base_model.parameters():
31
+ param.requires_grad = False
32
+ # LoRA fine-tuning
33
+ lora_config = LoraConfig(
34
+ r=8,
35
+ lora_alpha=32,
36
+ target_modules=["q_proj", "v_proj", 'k_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
37
+ # Target modules for LoRA
38
+ lora_dropout=0.0,
39
+ bias="none",
40
+ use_dora=True
41
+ )
42
+ self.base_model = get_peft_model(self.base_model, lora_config)
43
+ self.base_model.print_trainable_parameters()
44
+ self.save_hyperparameters()
45
+ self.val_accuracy = Accuracy(task="binary")
46
+
47
+ def forward(self, x):
48
+ input_ids = x["input_ids"]
49
+ attention_mask = x["attention_mask"]
50
+
51
+ # Forward pass through the base model using the attention mask
52
+ out = self.base_model(input_ids, attention_mask=attention_mask)
53
+ logits = out.logits # shape: (batch_size, seq_len, hidden_dim)
54
+
55
+ # Calculate the index of the last non-padding token for each sequence
56
+ last_token_indices = attention_mask.sum(dim=1) - 1 # shape: (batch_size)
57
+ real_batch_size = logits.size(0)
58
+ batch_indices = torch.arange(real_batch_size, device=device)
59
+
60
+ # Select logits corresponding to the last non-padding token
61
+ last_logits = logits[batch_indices, last_token_indices, :] # shape: (batch_size, hidden_dim)
62
+
63
+ # Pass the selected logits through the classifier
64
+ output_logits = self.classifier(last_logits)
65
+ return output_logits.squeeze(-1)
66
+
67
+ def training_step(self, batch, batch_idx):
68
+ sentences = batch["sentence"]
69
+ labels = batch["eos_label"].to(device)
70
+ inputs = self.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True).to(device)
71
+ logits = self(inputs)
72
+ loss = self.criterion(logits, labels)
73
+ self.log('Train Step Loss', loss, prog_bar=True)
74
+ return loss
75
+
76
+ def validation_step(self, batch, batch_idx):
77
+ sentences = batch["sentence"]
78
+ labels = batch["eos_label"].to(device)
79
+ inputs = self.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True).to(device)
80
+ logits = self(inputs)
81
+ loss = self.criterion(logits, labels)
82
+ preds = (torch.sigmoid(logits) > 0.5).long()
83
+ self.val_accuracy.update(preds, labels.long())
84
+ self.log('Validation Step Loss', loss, prog_bar=True)
85
+ return loss
86
+
87
+ def on_validation_epoch_end(self):
88
+ # Compute and log the overall validation accuracy
89
+ acc = self.val_accuracy.compute()
90
+ self.log('Validation Accuracy', acc, prog_bar=True)
91
+ self.val_accuracy.reset()
92
+
93
+ def configure_optimizers(self):
94
+ optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate)
95
+ return optimizer