Atharva Mete commited on
Commit
dd4b88e
·
1 Parent(s): 57b4d23

fixed initialization

Browse files
Files changed (1) hide show
  1. modeling_molmo.py +14 -3
modeling_molmo.py CHANGED
@@ -1759,9 +1759,14 @@ class Molmo(nn.Module):
1759
  self.__num_fwd_flops: Optional[int] = None
1760
 
1761
  self.total_vocab_size = config.vocab_size + config.additional_vocab_size + config.skill_vocab_size
1762
- torch.nn.init.xavier_uniform_(self.transformer.skill_ff_out.weight)
1763
- if self.transformer.skill_ff_out.bias is not None:
1764
- torch.nn.init.zeros_(self.transformer.skill_ff_out.bias)
 
 
 
 
 
1765
 
1766
  def reset_parameters(self):
1767
  if self.vision_backbone is not None:
@@ -1773,6 +1778,9 @@ class Molmo(nn.Module):
1773
  if hasattr(self.transformer.wte, "new_embedding"):
1774
  nn.init.normal_(self.transformer.wte.new_embedding, std=self.config.new_embedding_init_range)
1775
 
 
 
 
1776
  if hasattr(self.transformer, "wpe"):
1777
  nn.init.normal_(self.transformer.wpe, mean=0.0, std=1.0)
1778
 
@@ -1780,6 +1788,9 @@ class Molmo(nn.Module):
1780
 
1781
  if hasattr(self.transformer, "ff_out"):
1782
  nn.init.normal_(self.transformer.ff_out, mean=0.0, std=0.02)
 
 
 
1783
 
1784
  if self.config.block_group_size == 1:
1785
  for block in self.transformer.blocks:
 
1759
  self.__num_fwd_flops: Optional[int] = None
1760
 
1761
  self.total_vocab_size = config.vocab_size + config.additional_vocab_size + config.skill_vocab_size
1762
+
1763
+ def init_weights(self):
1764
+ if hasattr(self.transformer, "skill_ff_out"):
1765
+ nn.init.xavier_uniform_(self.transformer.skill_ff_out.weight)
1766
+ if self.transformer.skill_ff_out.bias is not None:
1767
+ nn.init.zeros_(self.transformer.skill_ff_out.bias)
1768
+ if hasattr(self.transformer.wte, "skill_embedding"):
1769
+ nn.init.xavier_uniform_(self.transformer.wte.skill_embedding)
1770
 
1771
  def reset_parameters(self):
1772
  if self.vision_backbone is not None:
 
1778
  if hasattr(self.transformer.wte, "new_embedding"):
1779
  nn.init.normal_(self.transformer.wte.new_embedding, std=self.config.new_embedding_init_range)
1780
 
1781
+ if hasattr(self.transformer.wte, "skill_embedding"):
1782
+ nn.init.xavier_uniform_(self.transformer.wte.skill_embedding)
1783
+
1784
  if hasattr(self.transformer, "wpe"):
1785
  nn.init.normal_(self.transformer.wpe, mean=0.0, std=1.0)
1786
 
 
1788
 
1789
  if hasattr(self.transformer, "ff_out"):
1790
  nn.init.normal_(self.transformer.ff_out, mean=0.0, std=0.02)
1791
+
1792
+ if hasattr(self.transformer, "skill_ff_out"):
1793
+ nn.init.normal_(self.transformer.skill_ff_out, mean=0.0, std=0.02)
1794
 
1795
  if self.config.block_group_size == 1:
1796
  for block in self.transformer.blocks: