Atharva Mete
commited on
Commit
·
dd4b88e
1
Parent(s):
57b4d23
fixed initialization
Browse files- modeling_molmo.py +14 -3
modeling_molmo.py
CHANGED
@@ -1759,9 +1759,14 @@ class Molmo(nn.Module):
|
|
1759 |
self.__num_fwd_flops: Optional[int] = None
|
1760 |
|
1761 |
self.total_vocab_size = config.vocab_size + config.additional_vocab_size + config.skill_vocab_size
|
1762 |
-
|
1763 |
-
|
1764 |
-
|
|
|
|
|
|
|
|
|
|
|
1765 |
|
1766 |
def reset_parameters(self):
|
1767 |
if self.vision_backbone is not None:
|
@@ -1773,6 +1778,9 @@ class Molmo(nn.Module):
|
|
1773 |
if hasattr(self.transformer.wte, "new_embedding"):
|
1774 |
nn.init.normal_(self.transformer.wte.new_embedding, std=self.config.new_embedding_init_range)
|
1775 |
|
|
|
|
|
|
|
1776 |
if hasattr(self.transformer, "wpe"):
|
1777 |
nn.init.normal_(self.transformer.wpe, mean=0.0, std=1.0)
|
1778 |
|
@@ -1780,6 +1788,9 @@ class Molmo(nn.Module):
|
|
1780 |
|
1781 |
if hasattr(self.transformer, "ff_out"):
|
1782 |
nn.init.normal_(self.transformer.ff_out, mean=0.0, std=0.02)
|
|
|
|
|
|
|
1783 |
|
1784 |
if self.config.block_group_size == 1:
|
1785 |
for block in self.transformer.blocks:
|
|
|
1759 |
self.__num_fwd_flops: Optional[int] = None
|
1760 |
|
1761 |
self.total_vocab_size = config.vocab_size + config.additional_vocab_size + config.skill_vocab_size
|
1762 |
+
|
1763 |
+
def init_weights(self):
|
1764 |
+
if hasattr(self.transformer, "skill_ff_out"):
|
1765 |
+
nn.init.xavier_uniform_(self.transformer.skill_ff_out.weight)
|
1766 |
+
if self.transformer.skill_ff_out.bias is not None:
|
1767 |
+
nn.init.zeros_(self.transformer.skill_ff_out.bias)
|
1768 |
+
if hasattr(self.transformer.wte, "skill_embedding"):
|
1769 |
+
nn.init.xavier_uniform_(self.transformer.wte.skill_embedding)
|
1770 |
|
1771 |
def reset_parameters(self):
|
1772 |
if self.vision_backbone is not None:
|
|
|
1778 |
if hasattr(self.transformer.wte, "new_embedding"):
|
1779 |
nn.init.normal_(self.transformer.wte.new_embedding, std=self.config.new_embedding_init_range)
|
1780 |
|
1781 |
+
if hasattr(self.transformer.wte, "skill_embedding"):
|
1782 |
+
nn.init.xavier_uniform_(self.transformer.wte.skill_embedding)
|
1783 |
+
|
1784 |
if hasattr(self.transformer, "wpe"):
|
1785 |
nn.init.normal_(self.transformer.wpe, mean=0.0, std=1.0)
|
1786 |
|
|
|
1788 |
|
1789 |
if hasattr(self.transformer, "ff_out"):
|
1790 |
nn.init.normal_(self.transformer.ff_out, mean=0.0, std=0.02)
|
1791 |
+
|
1792 |
+
if hasattr(self.transformer, "skill_ff_out"):
|
1793 |
+
nn.init.normal_(self.transformer.skill_ff_out, mean=0.0, std=0.02)
|
1794 |
|
1795 |
if self.config.block_group_size == 1:
|
1796 |
for block in self.transformer.blocks:
|