File size: 7,821 Bytes
09ffe17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import os
import json
import regex as re
from natsort import natsorted
from tqdm import tqdm

# Add the Marathi regex pattern at the top level
MARATHI_PATTERN = re.compile(r"""
    # Contractions and common affixes
    'चा|'ची|'चे|'ला|'ले|'नी|
    # Words with optional vowel signs and modifiers
    [\p{L}\p{M}]+|
    # Numbers
    \p{N}+|
    # Punctuation and special characters
    [^\s\p{L}\p{N}\p{M}]+|
    # Whitespace
    \s+
""", re.VERBOSE)

def text_to_bytes(text):
    """Convert text to byte tokens after applying Marathi regex"""
    words = MARATHI_PATTERN.findall(text)
    all_bytes = []
    for word in words:
        bytes_tokens = [b for c in word for b in c.encode('utf-8')]
        all_bytes.extend(bytes_tokens)
    return all_bytes

def read_text_files(folder_path='train', limit=10):
    # Check if the folder exists
    if not os.path.exists(folder_path):
        print(f"Error: The folder '{folder_path}' does not exist.")
        return
    
    # Get list of all files in the folder
    files = os.listdir(folder_path)
    
    # Filter for text files and sort them naturally
    text_files = natsorted([f for f in files if f.endswith(('.txt', '.text'))])
    
    if not text_files:
        print(f"No text files found in '{folder_path}' folder.")
        return
    
    # Take only the first 'limit' files
    text_files = text_files[:limit]
    
    # Initialize list to store all tokens
    all_tokens = []
    
    # Read and print contents of each file
    for file_name in text_files:
        file_path = os.path.join(folder_path, file_name)
        try:
            with open(file_path, 'r', encoding='utf-8') as file:
                content = file.read()
                # Convert text to bytes using Marathi-aware tokenization
                tokens = text_to_bytes(content)
                all_tokens.extend(tokens)
        except Exception as e:
            print(f"Error reading {file_name}: {str(e)}")
    
    print("\n=== Combined Statistics ===")
    print("Total number of tokens:", len(all_tokens))
    print("First 100 tokens:", all_tokens[:100])
    return all_tokens

def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]): # Pythonic way to iterate consecutive elements
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(ids, pair, idx):
  newids = []
  i = 0
  while i < len(ids):
    if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
      newids.append(idx)
      i += 2
    else:
      newids.append(ids[i])
      i += 1
  return newids

def encode(text, merges):
    """
    Encode text into tokens using the learned merges
    """
    # First convert text to bytes using Marathi-aware tokenization
    ids = text_to_bytes(text)
    
    # Apply the merges in order of their token indices
    # Sort by the token index to ensure consistent ordering
    sorted_merges = sorted(merges.items(), key=lambda x: x[1])
    for (p1, p2), idx in sorted_merges:
        ids = merge(ids, (p1, p2), idx)
    
    return ids

def decode(ids, merges):
    """
    Decode tokens back to text using the learned merges
    """
    # Create reverse mapping from token to pair
    reverse_merges = {idx: pair for pair, idx in merges.items()}
    
    # Expand all tokens recursively
    def expand_token(token):
        if token < 256:  # Base case: token is a byte
            return bytes([token])
        
        # Recursive case: expand the token into its constituent pair
        pair = reverse_merges[token]
        return expand_token(pair[0]) + expand_token(pair[1])
    
    # Expand all tokens and concatenate
    bytes_list = [expand_token(id) for id in ids]
    bytes_data = b''.join(bytes_list)
    
    # Convert bytes back to text
    try:
        return bytes_data.decode('utf-8')
    except UnicodeDecodeError:
        return "[DECODE_ERROR]"

class Tokenizer:
    def __init__(self, merges=None):
        self.merges = merges or {}
    
    def encode(self, text):
        return encode(text, self.merges)
    
    def decode(self, ids):
        return decode(ids, self.merges)
    
    def save(self, path):
        """Save the tokenizer to a JSON file"""
        # Convert tuple keys to strings for JSON serialization
        serializable_merges = {f"{p1},{p2}": idx for (p1, p2), idx in self.merges.items()}
        with open(path, 'w') as f:
            json.dump(serializable_merges, f)
    
    @classmethod
    def load(cls, path):
        """Load a tokenizer from a JSON file"""
        with open(path, 'r') as f:
            serialized_merges = json.load(f)
        # Convert string keys back to tuples
        merges = {tuple(map(int, k.split(','))): v for k, v in serialized_merges.items()}
        return cls(merges)

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--checkpoint', type=str, help='Path to tokenizer checkpoint')
    parser.add_argument('--train', action='store_true', help='Train a new tokenizer')
    parser.add_argument('--encode', type=str, help='Text to encode')
    parser.add_argument('--decode', type=str, help='Comma-separated integers to decode')
    args = parser.parse_args()

    if args.train:
        # Train new tokenizer
        all_tokens = read_text_files(limit=100) 
        initial_len = len(all_tokens)

        # ---
        vocab_size = 5000 # the desired final vocabulary size
        num_merges = vocab_size - 256
        ids = list(all_tokens) # copy so we don't destroy the original list

        merges = {} # (int, int) -> int
        pbar = tqdm(range(num_merges), desc="Merging tokens")
        for i in pbar:
            stats = get_stats(ids)
            pair = max(stats, key=stats.get)
            idx = 256 + i
            ids = merge(ids, pair, idx)
            merges[pair] = idx
            current_ratio = initial_len / len(ids)
            pbar.write(f"Iteration {i+1}: compression ratio: {current_ratio:.2f}X")

        print("\nFinal Statistics:")
        print("Initial tokens length:", initial_len)
        print("Final ids length:", len(ids))
        print(f"Final compression ratio: {initial_len / len(ids):.2f}X")

        tokenizer = Tokenizer(merges)
        
        if args.checkpoint:
            tokenizer.save(args.checkpoint)
            print(f"Saved tokenizer to {args.checkpoint}")

    elif args.encode or args.decode:
        if not args.checkpoint:
            print("Error: --checkpoint is required for encode/decode operations")
            exit(1)
        
        # Load tokenizer for encoding/decoding
        tokenizer = Tokenizer.load(args.checkpoint)
        print(f"Loaded tokenizer from {args.checkpoint}")

        if args.encode:
            # Encode the provided text
            encoded = tokenizer.encode(args.encode)
            print(f"\nEncoding: {args.encode}")
            print(f"Encoded tokens: {encoded}")

        if args.decode:
            # Decode the provided tokens
            try:
                tokens = [int(x.strip()) for x in args.decode.split(',')]
                decoded = tokenizer.decode(tokens)
                print(f"\nDecoding: {tokens}")
                print(f"Decoded text: {decoded}")
            except ValueError:
                print("Error: decode argument should be comma-separated integers")
                exit(1)
    
    else:
        parser.print_help()
        exit(1)
    # Test encode/decode
    test_text = "नमस्कार, जग! ही एक चाचणी आहे."
    encoded = tokenizer.encode(test_text)
    decoded = tokenizer.decode(encoded)
    print("\nEncoding/Decoding Test:")
    print(f"Original: {test_text}")
    print(f"Encoded: {encoded}")
    print(f"Decoded: {decoded}")
    print(f"Successful roundtrip: {test_text == decoded}")