Shilpaj commited on
Commit
baa4d1d
·
verified ·
1 Parent(s): ea06c18

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -53
app.py CHANGED
@@ -11,12 +11,16 @@ from transformers import GPT2Tokenizer
11
  import spaces
12
  import os
13
  from pathlib import Path
 
14
 
15
  # Local imports
16
  from smollmv2 import SmollmV2
17
  from config import SmollmConfig, DataConfig
18
  from smollv2_lightning import LitSmollmv2
19
 
 
 
 
20
 
21
  def combine_model_parts(model_dir="split_models", output_file="checkpoints/last.ckpt"):
22
  """
@@ -56,7 +60,7 @@ def load_model():
56
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
57
 
58
  # Load model directly from checkpoint
59
- checkpoint_path = "last.ckpt" # Assuming the checkpoint is in the root directory
60
 
61
  if not os.path.exists(checkpoint_path):
62
  raise FileNotFoundError(
@@ -64,21 +68,25 @@ def load_model():
64
  "Please ensure the model checkpoint file 'last.ckpt' is present in the root directory."
65
  )
66
 
67
- # Load the model from checkpoint using Lightning module
68
- model = LitSmollmv2.load_from_checkpoint(
69
- checkpoint_path,
70
- model_config=SmollmConfig,
71
- strict=False
72
- )
73
-
74
- model.to(device)
75
- model.eval()
76
-
77
- # Initialize tokenizer
78
- tokenizer = GPT2Tokenizer.from_pretrained(DataConfig.tokenizer_path)
79
- tokenizer.pad_token = tokenizer.eos_token
 
 
 
80
 
81
- return model, tokenizer, device
 
82
 
83
 
84
  @spaces.GPU(enable_queue=True)
@@ -86,50 +94,59 @@ def generate_text(prompt, num_tokens, temperature=0.8, top_p=0.9):
86
  """
87
  Generate text using the SmollmV2 model.
88
  """
89
- # Ensure num_tokens doesn't exceed model's block size
90
- num_tokens = min(num_tokens, SmollmConfig.block_size)
91
-
92
- # Tokenize input prompt
93
- input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
94
-
95
- # Generate tokens one at a time
96
- for _ in range(num_tokens):
97
- # Get the model's predictions
98
- with torch.no_grad():
99
- with torch.autocast(device_type=device, dtype=torch.bfloat16):
100
- logits, _ = model.model(input_ids)
101
-
102
- # Get the next token probabilities
103
- logits = logits[:, -1, :] / temperature
104
- probs = F.softmax(logits, dim=-1)
105
-
106
- # Apply top-p sampling
107
- if top_p > 0:
108
- sorted_probs, sorted_indices = torch.sort(probs, descending=True)
109
- cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
110
- sorted_indices_to_keep = cumsum_probs <= top_p
111
- sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone()
112
- sorted_indices_to_keep[..., 0] = 1
113
- indices_to_keep = torch.zeros_like(probs, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_indices_to_keep)
114
- probs = torch.where(indices_to_keep, probs, torch.zeros_like(probs))
115
- probs = probs / probs.sum(dim=-1, keepdim=True)
116
 
117
- # Sample next token
118
- next_token = torch.multinomial(probs, num_samples=1)
119
 
120
- # Append to input_ids
121
- input_ids = torch.cat([input_ids, next_token], dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
- # Stop if we generate an EOS token
124
- if next_token.item() == tokenizer.eos_token_id:
125
- break
126
 
127
- # Decode and return the generated text
128
- generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
129
- return generated_text
130
 
131
  # Load the model globally
132
- model, tokenizer, device = load_model()
 
 
 
 
133
 
134
  # Create the Gradio interface
135
  demo = gr.Interface(
 
11
  import spaces
12
  import os
13
  from pathlib import Path
14
+ import warnings
15
 
16
  # Local imports
17
  from smollmv2 import SmollmV2
18
  from config import SmollmConfig, DataConfig
19
  from smollv2_lightning import LitSmollmv2
20
 
21
+ # Configure PyTorch to handle the device properties issue
22
+ torch._dynamo.config.suppress_errors = True
23
+ warnings.filterwarnings('ignore', category=UserWarning)
24
 
25
  def combine_model_parts(model_dir="split_models", output_file="checkpoints/last.ckpt"):
26
  """
 
60
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
61
 
62
  # Load model directly from checkpoint
63
+ checkpoint_path = "last.ckpt"
64
 
65
  if not os.path.exists(checkpoint_path):
66
  raise FileNotFoundError(
 
68
  "Please ensure the model checkpoint file 'last.ckpt' is present in the root directory."
69
  )
70
 
71
+ try:
72
+ # Load the model from checkpoint using Lightning module
73
+ model = LitSmollmv2.load_from_checkpoint(
74
+ checkpoint_path,
75
+ model_config=SmollmConfig,
76
+ strict=False
77
+ )
78
+
79
+ model.to(device)
80
+ model.eval()
81
+
82
+ # Initialize tokenizer
83
+ tokenizer = GPT2Tokenizer.from_pretrained(DataConfig.tokenizer_path)
84
+ tokenizer.pad_token = tokenizer.eos_token
85
+
86
+ return model, tokenizer, device
87
 
88
+ except Exception as e:
89
+ raise RuntimeError(f"Error loading model: {str(e)}")
90
 
91
 
92
  @spaces.GPU(enable_queue=True)
 
94
  """
95
  Generate text using the SmollmV2 model.
96
  """
97
+ try:
98
+ # Ensure num_tokens doesn't exceed model's block size
99
+ num_tokens = min(num_tokens, SmollmConfig.block_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ # Tokenize input prompt
102
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
103
 
104
+ # Generate tokens one at a time
105
+ with torch.inference_mode(): # Use inference_mode instead of no_grad
106
+ for _ in range(num_tokens):
107
+ # Get the model's predictions
108
+ with torch.autocast(device_type=device, dtype=torch.float16): # Changed to float16
109
+ outputs = model(input_ids)
110
+ logits = outputs[0] if isinstance(outputs, tuple) else outputs
111
+
112
+ # Get the next token probabilities
113
+ logits = logits[:, -1, :] / temperature
114
+ probs = F.softmax(logits, dim=-1)
115
+
116
+ # Apply top-p sampling
117
+ if top_p > 0:
118
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True)
119
+ cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
120
+ sorted_indices_to_keep = cumsum_probs <= top_p
121
+ sorted_indices_to_keep[..., 1:] = sorted_indices_to_keep[..., :-1].clone()
122
+ sorted_indices_to_keep[..., 0] = 1
123
+ indices_to_keep = torch.zeros_like(probs, dtype=torch.bool).scatter_(-1, sorted_indices, sorted_indices_to_keep)
124
+ probs = torch.where(indices_to_keep, probs, torch.zeros_like(probs))
125
+ probs = probs / probs.sum(dim=-1, keepdim=True)
126
+
127
+ # Sample next token
128
+ next_token = torch.multinomial(probs, num_samples=1)
129
+
130
+ # Append to input_ids
131
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
132
+
133
+ # Stop if we generate an EOS token
134
+ if next_token.item() == tokenizer.eos_token_id:
135
+ break
136
 
137
+ # Decode and return the generated text
138
+ generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
139
+ return generated_text
140
 
141
+ except Exception as e:
142
+ return f"Error during text generation: {str(e)}"
 
143
 
144
  # Load the model globally
145
+ try:
146
+ model, tokenizer, device = load_model()
147
+ except Exception as e:
148
+ print(f"Error initializing model: {str(e)}")
149
+ raise
150
 
151
  # Create the Gradio interface
152
  demo = gr.Interface(