Update PrateritumGPT.py
Browse files- PrateritumGPT.py +4 -3
PrateritumGPT.py
CHANGED
|
@@ -4,6 +4,7 @@ import torch.nn as nn
|
|
| 4 |
from torch.utils.data import Dataset, DataLoader
|
| 5 |
from torch.nn.utils.rnn import pad_sequence
|
| 6 |
import math
|
|
|
|
| 7 |
|
| 8 |
tokens = list("azertyuiopqsdfghjklmwxcvbnäüöß—– ")
|
| 9 |
tokensdict = {}
|
|
@@ -12,7 +13,7 @@ for i in range(len(tokens)):
|
|
| 12 |
tokensdict.update({tokens[i]: [0] * i + [0] * (len(tokens) - (i + 1))})
|
| 13 |
|
| 14 |
# Ouvrir le fichier CSV
|
| 15 |
-
with open("
|
| 16 |
# Créer un objet lecteur CSV
|
| 17 |
reader = [i for i in csv.reader(file)][1:]
|
| 18 |
|
|
@@ -37,12 +38,12 @@ for i in reader:
|
|
| 37 |
for j in i[2]:
|
| 38 |
k += [tokens.index(j)]
|
| 39 |
k += [len(tokens) + 1] * (25 - len(k))
|
| 40 |
-
features += [torch.Tensor(k)]
|
| 41 |
k = []
|
| 42 |
for j in i[8]:
|
| 43 |
k += [tokens.index(j)]
|
| 44 |
k += [len(tokens) + 1] * (25 - len(k))
|
| 45 |
-
labels += [torch.Tensor(k)]
|
| 46 |
|
| 47 |
MyDataset = CSVDataset(features=features, labels=labels)
|
| 48 |
|
|
|
|
| 4 |
from torch.utils.data import Dataset, DataLoader
|
| 5 |
from torch.nn.utils.rnn import pad_sequence
|
| 6 |
import math
|
| 7 |
+
import os
|
| 8 |
|
| 9 |
tokens = list("azertyuiopqsdfghjklmwxcvbnäüöß—– ")
|
| 10 |
tokensdict = {}
|
|
|
|
| 13 |
tokensdict.update({tokens[i]: [0] * i + [0] * (len(tokens) - (i + 1))})
|
| 14 |
|
| 15 |
# Ouvrir le fichier CSV
|
| 16 |
+
with open(os.path.dirname(os.path.abspath(__file__))+"\\top-german-verbs.csv", 'r', encoding="utf-8") as file:
|
| 17 |
# Créer un objet lecteur CSV
|
| 18 |
reader = [i for i in csv.reader(file)][1:]
|
| 19 |
|
|
|
|
| 38 |
for j in i[2]:
|
| 39 |
k += [tokens.index(j)]
|
| 40 |
k += [len(tokens) + 1] * (25 - len(k))
|
| 41 |
+
features += [torch.Tensor([k])]
|
| 42 |
k = []
|
| 43 |
for j in i[8]:
|
| 44 |
k += [tokens.index(j)]
|
| 45 |
k += [len(tokens) + 1] * (25 - len(k))
|
| 46 |
+
labels += [torch.Tensor([k])]
|
| 47 |
|
| 48 |
MyDataset = CSVDataset(features=features, labels=labels)
|
| 49 |
|