andrewdalpino commited on
Commit
fe4abd1
·
verified ·
1 Parent(s): c6186b4

Upload model

Browse files
Files changed (3) hide show
  1. config.json +5 -0
  2. model.py +671 -0
  3. pytorch_model.bin +1 -1
config.json CHANGED
@@ -2,9 +2,14 @@
2
  "architectures": [
3
  "LightGPTHuggingFaceModel"
4
  ],
 
 
 
 
5
  "dropout": 0.1,
6
  "embedding_dimensions": 1024,
7
  "feed_forward_ratio": 4,
 
8
  "num_heads": 16,
9
  "num_layers": 24,
10
  "padding_index": -100,
 
2
  "architectures": [
3
  "LightGPTHuggingFaceModel"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "model.LightGPTHuggingFaceConfig",
7
+ "AutoModel": "model.LightGPTHuggingFaceModel"
8
+ },
9
  "dropout": 0.1,
10
  "embedding_dimensions": 1024,
11
  "feed_forward_ratio": 4,
12
+ "model_type": "lightgpt",
13
  "num_heads": 16,
14
  "num_layers": 24,
15
  "padding_index": -100,
model.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import sqrt
2
+ from dataclasses import dataclass
3
+ from functools import partial, cached_property
4
+ from typing import Iterator, Self
5
+
6
+ import torch
7
+
8
+ from torch import Tensor
9
+ from torch.nn import (
10
+ Module,
11
+ ModuleList,
12
+ Sequential,
13
+ Embedding,
14
+ MultiheadAttention,
15
+ Linear,
16
+ SiLU,
17
+ RMSNorm,
18
+ Dropout1d,
19
+ CrossEntropyLoss,
20
+ Parameter,
21
+ )
22
+
23
+ from torch.nn.functional import softmax, log_softmax
24
+ from torch.nn.utils.parametrize import register_parametrization, remove_parametrizations
25
+ from torch.utils.checkpoint import checkpoint as torch_checkpoint
26
+
27
+ from transformers import PretrainedConfig, PreTrainedModel
28
+
29
+
30
+ class LightGPT(Module):
31
+ """A generative pretrained transformer with no positional embeddings."""
32
+
33
+ def __init__(
34
+ self,
35
+ vocabulary_size: int,
36
+ embedding_dimensions: int,
37
+ num_heads: int,
38
+ num_layers: int,
39
+ feed_forward_ratio: int,
40
+ dropout: float,
41
+ padding_index: int,
42
+ ):
43
+ super().__init__()
44
+
45
+ if vocabulary_size <= 0:
46
+ raise ValueError(
47
+ f"Vocabulary size must be greater than 0, {vocabulary_size} given."
48
+ )
49
+
50
+ if num_layers <= 0:
51
+ raise ValueError(f"Num layers must be greater than 0, {num_layers} given.")
52
+
53
+ if feed_forward_ratio not in {1, 2, 4}:
54
+ raise ValueError("Feed-forward ratio must be either 1, 2, or 4.")
55
+
56
+ token_embeddings = Embedding(
57
+ vocabulary_size, embedding_dimensions, padding_idx=padding_index
58
+ )
59
+
60
+ output_layer = Linear(embedding_dimensions, vocabulary_size, bias=False)
61
+
62
+ output_layer.weight = token_embeddings.weight # Tie weights
63
+
64
+ self.token_embeddings = token_embeddings
65
+
66
+ self.body = ModuleList(
67
+ [
68
+ CausalSelfAttentionBlock(
69
+ embedding_dimensions,
70
+ num_heads,
71
+ feed_forward_ratio,
72
+ dropout,
73
+ )
74
+ for _ in range(num_layers)
75
+ ]
76
+ )
77
+
78
+ self.checkpoint = lambda layer, x, attention_mask: layer(x, attention_mask)
79
+
80
+ self.output_norm = RMSNorm(embedding_dimensions)
81
+ self.output_layer = output_layer
82
+
83
+ self.loss_function = CrossEntropyLoss(ignore_index=padding_index)
84
+
85
+ self.vocabulary_size = vocabulary_size
86
+
87
+ @cached_property
88
+ def num_trainable_params(self) -> int:
89
+ return sum(param.numel() for param in self.parameters() if param.requires_grad)
90
+
91
+ def enable_activation_checkpointing(self) -> None:
92
+ """Instead of memorizing the activations of the forward pass, recompute them at various checkpoints."""
93
+ self.checkpoint = partial(torch_checkpoint, use_reentrant=False)
94
+
95
+ def resize_token_embeddings(self, num_tokens: int) -> None:
96
+ """Resize the token embeddings to accommodate a new vocabulary size."""
97
+
98
+ new_embeddings = Embedding(num_tokens, self.token_embeddings.embedding_dim).to(
99
+ self.token_embeddings.weight.device
100
+ )
101
+
102
+ num_tokens_to_copy = min(num_tokens, self.token_embeddings.num_embeddings)
103
+
104
+ new_embeddings.weight[:num_tokens_to_copy, :] = self.token_embeddings.weight[
105
+ :num_tokens_to_copy, :
106
+ ]
107
+
108
+ self.token_embeddings = new_embeddings
109
+
110
+ self.output_layer.weight = self.token_embeddings.weight
111
+
112
+ self.vocabulary_size = num_tokens
113
+
114
+ def forward(
115
+ self, x: Tensor, y: Tensor | None = None
116
+ ) -> tuple[Tensor, Tensor | None]:
117
+ """A forward pass optimized for batch training."""
118
+
119
+ z = self.token_embeddings(x)
120
+
121
+ b, t, d = z.size()
122
+
123
+ causal_mask = torch.full((t, t), float("-inf"), dtype=z.dtype, device=z.device)
124
+ causal_mask = torch.triu(causal_mask, diagonal=1)
125
+
126
+ for layer in self.body:
127
+ z = self.checkpoint(layer, z, causal_mask)
128
+
129
+ z = self.output_norm(z)
130
+ z = self.output_layer(z)
131
+
132
+ if y is not None:
133
+ y_pred = z.view(-1, z.size(-1))
134
+ labels = y.view(-1) # Flatten the batch dimension.
135
+
136
+ loss = self.loss_function(y_pred, labels)
137
+ else:
138
+ loss = None
139
+
140
+ return z, loss
141
+
142
+ @torch.no_grad()
143
+ def predict(self, x: Tensor) -> Tensor:
144
+ """A forward pass optimized for batch next-token prediction."""
145
+
146
+ z = self.token_embeddings(x)
147
+
148
+ b, t, d = z.size()
149
+
150
+ causal_mask = torch.full((t, t), float("-inf"), dtype=z.dtype, device=z.device)
151
+ causal_mask = torch.triu(causal_mask, diagonal=1)
152
+
153
+ for layer in self.body:
154
+ z = layer(z, causal_mask)
155
+
156
+ z = self.output_norm(z)
157
+
158
+ z = z[:, -1, :] # Pluck only the last token embedding from each batch.
159
+
160
+ z = self.output_layer(z)
161
+
162
+ return z
163
+
164
+ @torch.no_grad()
165
+ def generate(
166
+ self,
167
+ prompt: Tensor,
168
+ max_tokens: int = 1000,
169
+ context_length: int = 1024,
170
+ temperature: float = 1.0,
171
+ top_k: int = 500,
172
+ top_p: float = 0.9,
173
+ eos_indices: set = set(),
174
+ ) -> Iterator:
175
+ """
176
+ Given a prompt, sample the next {max_tokens} tokens from the model weighted
177
+ by their predicted probabilities and filtered by the {top_k} and {top_p}.
178
+ """
179
+
180
+ if max_tokens <= 0:
181
+ raise ValueError(f"Max tokens must be greater than 0, {max_tokens} given.")
182
+
183
+ if temperature <= 0:
184
+ raise ValueError(
185
+ f"Temperature must be greater than 0, {temperature} given."
186
+ )
187
+
188
+ if top_k <= 0 or top_k > self.vocabulary_size:
189
+ raise ValueError(
190
+ f"Top k must be between 1 and {self.vocabulary_size}, {top_k} given."
191
+ )
192
+
193
+ if top_p <= 0.0 or top_p > 1.0:
194
+ raise ValueError(f"Top p must be between 0 and 1, {top_p} given.")
195
+
196
+ context_window = prompt
197
+
198
+ for _ in range(max_tokens):
199
+ context_window = context_window[-context_length:]
200
+
201
+ logits = self.predict(context_window.unsqueeze(0)).squeeze()
202
+
203
+ logits, indices = torch.topk(logits, top_k, sorted=True)
204
+
205
+ probabilities = softmax(logits, dim=0)
206
+
207
+ cumulative_probability_mass = torch.cumsum(probabilities, dim=0)
208
+
209
+ min_probability_mass = cumulative_probability_mass[0]
210
+
211
+ threshold_p = max(top_p, min_probability_mass.item())
212
+
213
+ selected_indices = cumulative_probability_mass <= threshold_p
214
+
215
+ logits = logits[selected_indices]
216
+ indices = indices[selected_indices]
217
+
218
+ logits /= temperature
219
+
220
+ probabilities = softmax(logits, dim=0)
221
+
222
+ offset = torch.multinomial(probabilities, num_samples=1).squeeze()
223
+
224
+ next_token = indices[offset]
225
+
226
+ if next_token.item() in eos_indices:
227
+ break
228
+
229
+ yield next_token
230
+
231
+ context_window = torch.cat((context_window, next_token.unsqueeze(0)))
232
+
233
+ @torch.no_grad()
234
+ def beam_search(
235
+ self,
236
+ prompt: Tensor,
237
+ max_tokens: int = 100,
238
+ context_length: int = 1024,
239
+ num_candidates: int = 3,
240
+ beam_width: int = 16,
241
+ length_penalty: float = 1.0,
242
+ eos_indices: set = set(),
243
+ ) -> list:
244
+ """
245
+ Given a prompt, return the {num_candidates} highest probability sequences. Note that
246
+ this method is often best for generating shorter sequences and is typically less
247
+ natural sounding than sequences that are more random in nature.
248
+ """
249
+
250
+ if max_tokens <= 0:
251
+ raise ValueError(f"Max tokens must be greater than 0, {max_tokens} given.")
252
+
253
+ if num_candidates <= 0:
254
+ raise ValueError(
255
+ f"Num candidates must be greater than 0, {num_candidates} given."
256
+ )
257
+
258
+ if beam_width <= 0:
259
+ raise ValueError(f"Beam width must be greater than 0, {beam_width} given.")
260
+
261
+ if length_penalty <= 0:
262
+ raise ValueError(
263
+ f"Length penalty must be greater than 0, {length_penalty} given."
264
+ )
265
+
266
+ @dataclass
267
+ class Candidate:
268
+ cumulative_log_probability: float
269
+ tokens: Tensor
270
+
271
+ def priority(self) -> float:
272
+ return (
273
+ self.cumulative_log_probability / len(self.tokens) ** length_penalty
274
+ )
275
+
276
+ sort_candidates = partial(
277
+ sorted,
278
+ key=lambda candidate: candidate.priority(),
279
+ reverse=True,
280
+ )
281
+
282
+ candidates: list[Candidate] = []
283
+ completed: list[Candidate] = []
284
+
285
+ tokens = torch.tensor([], dtype=prompt.dtype).to(prompt.device)
286
+
287
+ candidates.append(Candidate(0.0, tokens))
288
+
289
+ while len(candidates) > 0:
290
+ candidate = candidates.pop()
291
+
292
+ if len(completed) >= num_candidates:
293
+ completed = sort_candidates(completed)
294
+
295
+ completed = completed[:num_candidates]
296
+
297
+ worst_candidate = completed[-1]
298
+
299
+ if (
300
+ candidate.cumulative_log_probability
301
+ < worst_candidate.cumulative_log_probability
302
+ ):
303
+ break
304
+
305
+ if len(candidate.tokens) > 0:
306
+ last_token = candidate.tokens[-1]
307
+
308
+ if last_token.item() in eos_indices:
309
+ candidate.tokens = candidate.tokens[:-1]
310
+
311
+ completed.append(candidate)
312
+
313
+ continue
314
+
315
+ if len(candidate.tokens) >= max_tokens:
316
+ completed.append(candidate)
317
+
318
+ continue
319
+
320
+ context_window = torch.cat((prompt, candidate.tokens))
321
+
322
+ context_window = context_window[-context_length:]
323
+
324
+ logits = self.predict(context_window.unsqueeze(0)).squeeze()
325
+
326
+ logits, indices = torch.topk(logits, beam_width, sorted=False)
327
+
328
+ log_probabilities = log_softmax(logits, dim=0)
329
+
330
+ for log_probability, index in zip(log_probabilities, indices):
331
+ cumulative_log_probability = (
332
+ candidate.cumulative_log_probability + log_probability
333
+ )
334
+
335
+ tokens = torch.cat((candidate.tokens, index.unsqueeze(0)))
336
+
337
+ candidates.append(Candidate(cumulative_log_probability, tokens))
338
+
339
+ candidates = sort_candidates(candidates)
340
+
341
+ candidates = candidates[:beam_width]
342
+
343
+ return completed
344
+
345
+
346
+ class LightGPTInstruct(Module):
347
+ """
348
+ A wrapper for pretrained GPT models that applies a LoRA reparameterization
349
+ to the intermediate layers of the network.
350
+ """
351
+
352
+ def __init__(
353
+ self,
354
+ model: LightGPT,
355
+ vocabulary_size: int,
356
+ rank: int,
357
+ alpha: float,
358
+ dropout: float,
359
+ ):
360
+ super().__init__()
361
+
362
+ if vocabulary_size <= 0:
363
+ raise ValueError(
364
+ f"Vocabulary size must be greater than 0, {vocabulary_size} given."
365
+ )
366
+
367
+ if rank <= 0:
368
+ raise ValueError(f"Rank must be greater than 0, {rank} given.")
369
+
370
+ if alpha <= 0.0:
371
+ raise ValueError(f"Alpha must be greater than 0, {alpha} given.")
372
+
373
+ if vocabulary_size != model.vocabulary_size:
374
+ model.resize_token_embeddings(vocabulary_size)
375
+
376
+ for param in model.parameters():
377
+ param.requires_grad = False
378
+
379
+ for i in range(vocabulary_size, model.vocabulary_size, -1):
380
+ model.output_layer.weight[i - 1].requires_grad = True
381
+
382
+ for module in model.body:
383
+ out_features, in_features = module.attention.in_proj_weight.shape
384
+
385
+ register_parametrization(
386
+ module.attention,
387
+ "in_proj_weight",
388
+ LoRA(in_features, out_features, rank, alpha, dropout),
389
+ )
390
+
391
+ out_features, in_features = module.attention.out_proj.weight.shape
392
+
393
+ register_parametrization(
394
+ module.attention.out_proj,
395
+ "weight",
396
+ LoRA(in_features, out_features, rank, alpha, dropout),
397
+ )
398
+
399
+ for layer in module.mlp.layers:
400
+ if isinstance(layer, Linear):
401
+ register_parametrization(
402
+ layer,
403
+ "weight",
404
+ LoRA.from_linear(layer, rank, alpha, dropout),
405
+ )
406
+
407
+ register_parametrization(
408
+ model.output_layer,
409
+ "weight",
410
+ LoRA.from_linear(model.output_layer, rank, alpha, dropout),
411
+ )
412
+
413
+ self.model = model
414
+
415
+ @property
416
+ def num_trainable_params(self) -> int:
417
+ return self.model.num_trainable_params
418
+
419
+ def state_dict(self):
420
+ return {
421
+ name: module
422
+ for name, module in super().state_dict().items()
423
+ if "lora" in name
424
+ }
425
+
426
+ def merge_lora_parameters(self):
427
+ """Merge the LoRA parameters with the original parameters."""
428
+
429
+ for module in self.model.modules():
430
+ if hasattr(module, "parametrizations"):
431
+ lora_params = [name for name in module.parametrizations.keys()]
432
+
433
+ for name in lora_params:
434
+ remove_parametrizations(module, name, leave_parametrized=True)
435
+
436
+ def forward(
437
+ self, x: Tensor, y: Tensor | None = None
438
+ ) -> tuple[Tensor, Tensor | None]:
439
+ return self.model.forward(x, y)
440
+
441
+ def predict(self, x: Tensor) -> Tensor:
442
+ return self.model.predict(x)
443
+
444
+ def generate(
445
+ self,
446
+ prompt: Tensor,
447
+ max_tokens: int = 1000,
448
+ context_length: int = 1024,
449
+ temperature: float = 1.0,
450
+ top_k: int = 500,
451
+ top_p: float = 0.9,
452
+ eos_indices: set = set(),
453
+ ) -> Iterator:
454
+ return self.model.generate(
455
+ prompt, max_tokens, context_length, temperature, top_k, top_p, eos_indices
456
+ )
457
+
458
+ def beam_search(
459
+ self,
460
+ prompt: Tensor,
461
+ max_tokens: int = 100,
462
+ context_length: int = 1024,
463
+ num_candidates: int = 3,
464
+ beam_width: int = 16,
465
+ length_penalty: float = 1.0,
466
+ eos_indices: set = set(),
467
+ ) -> list:
468
+ return self.model.beam_search(
469
+ prompt,
470
+ max_tokens,
471
+ context_length,
472
+ num_candidates,
473
+ beam_width,
474
+ length_penalty,
475
+ eos_indices,
476
+ )
477
+
478
+
479
+ class LightGPTHuggingFaceConfig(PretrainedConfig):
480
+ """Provide a monolithic configuration object to compensate for HuggingFace Transformers' API."""
481
+
482
+ model_type = "lightgpt"
483
+
484
+ def __init__(
485
+ self,
486
+ vocabulary_size: int = 50257,
487
+ embedding_dimensions: int = 1024,
488
+ num_heads: int = 16,
489
+ num_layers: int = 24,
490
+ feed_forward_ratio: int = 4,
491
+ dropout: float = 0.1,
492
+ padding_index: int = -100,
493
+ **kwargs,
494
+ ):
495
+ self.vocabulary_size = vocabulary_size
496
+ self.embedding_dimensions = embedding_dimensions
497
+ self.num_heads = num_heads
498
+ self.num_layers = num_layers
499
+ self.feed_forward_ratio = feed_forward_ratio
500
+ self.dropout = dropout
501
+ self.padding_index = padding_index
502
+
503
+ super().__init__(**kwargs)
504
+
505
+
506
+ class LightGPTHuggingFaceModel(PreTrainedModel):
507
+ """Compensate for HuggingFace Transformers' API using a model wrapper."""
508
+
509
+ config_class = LightGPTHuggingFaceConfig
510
+
511
+ def __init__(self, config: LightGPTHuggingFaceConfig):
512
+ super().__init__(config)
513
+
514
+ self.model = LightGPT(
515
+ config.vocabulary_size,
516
+ config.embedding_dimensions,
517
+ config.num_heads,
518
+ config.num_layers,
519
+ config.feed_forward_ratio,
520
+ config.dropout,
521
+ config.padding_index,
522
+ )
523
+
524
+ def forward(
525
+ self, x: Tensor, y: Tensor | None = None
526
+ ) -> tuple[Tensor, Tensor | None]:
527
+ logits, loss = self.model.forward(x, y)
528
+
529
+ return {
530
+ "logits": logits,
531
+ "loss": loss,
532
+ }
533
+
534
+
535
+ class ONNXModel(Module):
536
+ """This wrapper provides a clean inferencing API for ONNX production models."""
537
+
538
+ def __init__(self, model: LightGPT | LightGPTInstruct):
539
+ super().__init__()
540
+
541
+ self.model = model
542
+
543
+ def forward(self, x: Tensor) -> Tensor:
544
+ return self.model.predict(x)
545
+
546
+
547
+ class CausalSelfAttentionBlock(Module):
548
+ """Causal self-attention block with residual connections."""
549
+
550
+ def __init__(
551
+ self,
552
+ embedding_dimensions: int,
553
+ num_heads: int,
554
+ feed_forward_ratio: int,
555
+ dropout: float,
556
+ ):
557
+ super().__init__()
558
+
559
+ if embedding_dimensions <= 0:
560
+ raise ValueError(
561
+ f"Embedding dimensions must be greater than 0, {embedding_dimensions} given."
562
+ )
563
+
564
+ if num_heads <= 0:
565
+ raise ValueError(f"Num heads must be greater than 0, {num_heads} given.")
566
+
567
+ if dropout < 0 or dropout > 1:
568
+ raise ValueError(f"Dropout must be between 0 and 1, {dropout} given")
569
+
570
+ self.norm1 = RMSNorm(embedding_dimensions)
571
+ self.attention = MultiheadAttention(
572
+ embedding_dimensions,
573
+ num_heads,
574
+ batch_first=True,
575
+ dropout=dropout,
576
+ bias=False,
577
+ )
578
+
579
+ hidden_dimensions = feed_forward_ratio * embedding_dimensions
580
+
581
+ self.norm2 = RMSNorm(embedding_dimensions)
582
+ self.mlp = MLP(embedding_dimensions, hidden_dimensions, dropout)
583
+
584
+ def forward(self, x: Tensor, attention_mask: Tensor) -> Tensor:
585
+ z = self.norm1(x)
586
+ z, _ = self.attention(z, z, z, attn_mask=attention_mask, is_causal=True)
587
+
588
+ z = x + z # Residual connection
589
+
590
+ x = z
591
+
592
+ z = self.norm2(x)
593
+ z = self.mlp(z)
594
+
595
+ z = x + z # Residual connection
596
+
597
+ return z
598
+
599
+
600
+ class MLP(Module):
601
+ """A two-layer fully-connected network with dropout."""
602
+
603
+ def __init__(
604
+ self, embedding_dimensions: int, hidden_dimensions: int, dropout: float
605
+ ):
606
+ super().__init__()
607
+
608
+ if embedding_dimensions <= 0:
609
+ raise ValueError(
610
+ f"Embedding dimensions must be greater than 0, {embedding_dimensions} given."
611
+ )
612
+
613
+ if hidden_dimensions <= 0:
614
+ raise ValueError(
615
+ f"Hidden dimensions must be greater than 0, {hidden_dimensions} given."
616
+ )
617
+
618
+ self.layers = Sequential(
619
+ Linear(embedding_dimensions, hidden_dimensions, bias=False),
620
+ SiLU(),
621
+ Linear(hidden_dimensions, embedding_dimensions, bias=False),
622
+ )
623
+
624
+ self.dropout = Dropout1d(p=dropout)
625
+
626
+ def forward(self, x: Tensor) -> Tensor:
627
+ return self.dropout(self.layers(x))
628
+
629
+
630
+ class LoRA(Module):
631
+ """Rank decomposition transformation."""
632
+
633
+ @classmethod
634
+ def from_linear(
635
+ cls, linear: Linear, rank: int, alpha: float, dropout: float
636
+ ) -> Self:
637
+ out_features, in_features = linear.weight.shape
638
+
639
+ return cls(in_features, out_features, rank, alpha, dropout)
640
+
641
+ def __init__(
642
+ self,
643
+ in_features: int,
644
+ out_features: int,
645
+ rank: int,
646
+ alpha: float,
647
+ dropout: float,
648
+ ):
649
+ super().__init__()
650
+
651
+ if rank <= 0:
652
+ raise ValueError(f"Rank must be greater than 0, {rank} given.")
653
+
654
+ if alpha <= 0.0:
655
+ raise ValueError(f"Alpha must be greater than 0, {alpha} given.")
656
+
657
+ std_dev = 1.0 / sqrt(rank)
658
+
659
+ self.lora_a = Parameter(torch.randn(rank, in_features) * std_dev)
660
+ self.lora_b = Parameter(torch.zeros(out_features, rank))
661
+
662
+ self.dropout = Dropout1d(p=dropout)
663
+
664
+ self.alpha = alpha
665
+
666
+ def forward(self, x: Tensor) -> Tensor:
667
+ z = self.lora_b @ self.dropout(self.lora_a)
668
+
669
+ z *= self.alpha
670
+
671
+ return x + z
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c7f31cd073f8fab5e04d300e4619afe2dd82bfa4f621d73db477dca3c94855df
3
  size 1414060818
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2716db68bb0143039012e287e16b005dc5b071d545f109fc40236d2ba2ab333
3
  size 1414060818