ZubairAhmed777 commited on
Commit
c3e07b2
·
verified ·
1 Parent(s): be4c742

Create vocab.py

Browse files
Files changed (1) hide show
  1. vocab.py +45 -0
vocab.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ from collections import defaultdict
5
+ import glob
6
+ import numpy as np
7
+ import time
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torchvision.models as models
12
+ import torch.nn.functional as F
13
+ from torch import optim
14
+ from torch.utils.data import Dataset
15
+ from torchvision import transforms
16
+ from torch.utils.data import DataLoader
17
+
18
+ from PIL import Image
19
+
20
+ class Vocabulary:
21
+ def __init__(self, vocabulary_file_path):
22
+ #Initialize the Vocabulary object.
23
+ # Load vocabulary from the provided file path
24
+ self.vocabulary = self._load_vocabulary(vocabulary_file_path)
25
+ # Create a mapping from words to indices
26
+ self.vocabulary2idx = {word: idx for idx, word in enumerate(self.vocabulary)}
27
+ # Store the total size of the vocabulary
28
+ self.vocabulary_size = len(self.vocabulary)
29
+
30
+ def _load_vocabulary(self, vocabulary_file_path):
31
+ #Load vocabulary from a file.
32
+ with open(vocabulary_file_path, 'r') as file:
33
+ # Read each line, strip extra whitespace, and return as a list
34
+ vocabulary = [line.strip() for line in file]
35
+ return vocabulary
36
+
37
+ def word2idx(self, word):
38
+ #Convert a word to its corresponding index.
39
+ # Return the index of the word or the index of '<unk>' if the word is not in the vocabulary
40
+ return self.vocabulary2idx.get(word, self.vocabulary2idx.get('<unk>'))
41
+
42
+ def idx2word(self, idx):
43
+ #Convert an index back to its corresponding word.
44
+ return self.vocabulary[idx]
45
+