import math import pickle import random import numpy as np import matplotlib.pyplot as plt import string from abc import ABC, abstractmethod from functools import partial class TextGenerator(ABC): def __init__(self, max_lenght: int = None): self.max_length = max_lenght @abstractmethod def generate(self): pass class AugmentedGenerator(TextGenerator): def __init__(self, strength: float, alphabet: list, max_lenght: int = None): super().__init__(max_lenght) self.strength = strength self.alphabet = list(alphabet) if "%" in alphabet: self.alphabet.remove("%") @abstractmethod def generate(self): pass def set_strength(self, strength: float): self.strength = strength def get_strength(self): return self.strength class ProportionalAugmentedGenerator(AugmentedGenerator): def __init__(self, max_length: int, generator: TextGenerator, alphabet: list, strength: float = 0.5): super().__init__(strength, alphabet, max_length) self.generator = generator self.char_stats = {} self.sampling_probs = {} self.init_statistics() def init_statistics(self): char_occurrences = {k: 0 for k in self.alphabet} character_count = 0 for _ in range(10000): word = self.generator.generate() for char in word: char_occurrences[char] += 1 character_count += 1 self.char_stats = {k: v / character_count for k, v in char_occurrences.items()} scale = max([v for v in self.char_stats.values()]) self.char_stats = {k: v / scale for k, v in self.char_stats.items()} self.sampling_probs = {k: 1.0 - v for k, v in self.char_stats.items()} def random_char(self): return random.choices(list(self.sampling_probs.keys()), weights=list(self.sampling_probs.values()), k=1)[0] def generate(self): word = self.generator.generate() word = self.augment(word) return word def augment(self, word): probs = np.random.rand(len(word)) target_probs = [self.strength * self.char_stats[c] for c in word] replace = probs < target_probs for index in range(len(word)): if replace[index]: char = self.random_char() word = set_char(word, char, index) return word class FileTextGenerator(TextGenerator): def __init__(self, max_length: int, file_path: str, alphabet: list): super().__init__(max_length) with open(file_path, 'r') as f: self.words = self.words = [l for l in self.words if len(l) < self.max_length and set(l) <= set(alphabet)] def generate(self): return random.choice(self.words) class CVLFileTextIterator(TextGenerator): def __init__(self, max_length: int, file_path: str, alphabet: list): super().__init__(max_length) self.words = [] with open(file_path, 'r') as f: next(f) for line in f: _, *annotation = line.rstrip().split(",") annotation = ",".join(annotation) self.words.append(annotation) self.words = [l for l in self.words if len(l) < self.max_length and set(l) <= set(alphabet)] self.index = 0 def generate(self): word = self.words[self.index % len(self.words)] self.index += 1 return word def set_char(s, character, location): return s[:location] + character + s[location + 1:] class GibberishGenerator(TextGenerator): def __init__(self, max_length: int = None): super().__init__(max_length) self.lower_case = list(string.ascii_lowercase) self.upper_case = list(string.ascii_uppercase) self.special = list(' .-\',"&();#:!?+*/') self.numbers = [str(i) for i in range(10)] def get_word_length(self) -> int: length = int(math.ceil(np.random.chisquare(8))) while self.max_length is not None and length > self.max_length: length = int(math.ceil(np.random.chisquare(8))) return length def generate(self): return self.generate_random() def generate_random(self): alphabet = self.upper_case + self.lower_case + self.special + self.numbers string = ''.join(random.choices(alphabet, k=self.get_word_length())) return string class IAMTextGenerator(TextGenerator): def generate(self): return random.choice(self.words) def __init__(self, max_length: int, path: str, subset: str = 'train'): super().__init__(max_length) with open(path, 'rb') as f: data = pickle.load(f) data = data[subset] self.words = [] for author_id in data.keys(): for image_dict in data[author_id]: if len(image_dict['label']) <= self.max_length: self.words.append(image_dict['label']) def get_generator(args): if args.corpus == "standard": if args.english_words_path.endswith(".csv"): generator = CVLFileTextIterator(20, args.english_words_path, args.alphabet) else: generator = FileTextGenerator(20, args.english_words_path, args.alphabet) else: generator = IAMTextGenerator(20, "files/IAM-32.pickle", 'train') if args.text_augment_strength > 0: if args.text_aug_type == 'proportional': return ProportionalAugmentedGenerator(20, generator, args.alphabet, args.text_augment_strength) elif args.text_aug_type == 'gibberish': return GibberishGenerator(20) else: return ProportionalAugmentedGenerator(20, generator, args.alphabet, args.text_augment_strength) return generator if __name__ == "__main__": alphabet = list('Only thewigsofrcvdampbkuq.A-210xT5\'MDL,RYHJ"ISPWENj&BC93VGFKz();#:!7U64Q8?+*ZX/%') original_generator = FileTextGenerator(max_length=20, file_path="../files/english_words.txt", alphabet=alphabet) gib = ProportionalAugmentedGenerator(20, original_generator, alphabet=alphabet, strength=0.5) generated_words = [] for _ in range(1000): word = gib.generate() generated_words.append(len(word)) if len(set(word)) < len(word): print(word) plt.hist(generated_words)