add model
Browse files- modeling_word2vec.py +5 -3
modeling_word2vec.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from transformers import PreTrainedModel
|
2 |
from torch import nn
|
3 |
import torch
|
4 |
from .configuration_word2vec import PretrainedWord2VecHFConfig
|
@@ -14,5 +14,7 @@ class PretrainedWord2VecHFModel(PreTrainedModel):
|
|
14 |
self.embeddings = nn.Embedding.from_pretrained(torch.tensor(embeddings))
|
15 |
|
16 |
def forward(self, input_ids, **kwargs):
|
17 |
-
|
18 |
-
|
|
|
|
|
|
1 |
+
from transformers import PreTrainedModel, modeling_outputs
|
2 |
from torch import nn
|
3 |
import torch
|
4 |
from .configuration_word2vec import PretrainedWord2VecHFConfig
|
|
|
14 |
self.embeddings = nn.Embedding.from_pretrained(torch.tensor(embeddings))
|
15 |
|
16 |
def forward(self, input_ids, **kwargs):
|
17 |
+
if type(input_ids) != torch.tensor: # e.g., list or np.array
|
18 |
+
input_ids = torch.tensor(input_ids)
|
19 |
+
x = self.embeddings(input_ids)
|
20 |
+
return modeling_outputs.BaseModelOutput(last_hidden_state=x)
|