Gregniuki commited on
Commit
cc97bdc
·
verified ·
1 Parent(s): a05745b

Update model/utils.py

Browse files
Files changed (1) hide show
  1. model/utils.py +30 -7
model/utils.py CHANGED
@@ -76,21 +76,44 @@ def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d
76
 
77
  return num / den.clamp(min=1.0)
78
 
79
-
80
- # simple utf-8 tokenizer, since paper went character based
81
  def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
82
- list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
 
 
 
 
83
  text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
84
  return text
85
 
86
-
87
- # char tokenizer, based on custom dataset's extracted .txt file
88
  def list_str_to_idx(
89
  text: list[str] | list[list[str]],
90
- vocab_char_map: dict[str, int], # {char: idx}
91
  padding_value=-1,
92
  ) -> int["b nt"]: # noqa: F722
93
- list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
95
  return text
96
 
 
76
 
77
  return num / den.clamp(min=1.0)
78
 
 
 
79
  def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
80
+ # Split each string into words
81
+ list_words = [t.split() for t in text]
82
+
83
+ # Convert words to tensors (assuming words are already in byte format)
84
+ list_tensors = [torch.tensor([*bytes(" ".join(words), "UTF-8")]) for words in list_words] # ByT5 style
85
  text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
86
  return text
87
 
 
 
88
  def list_str_to_idx(
89
  text: list[str] | list[list[str]],
90
+ vocab_map: dict[str, int], # {word: idx}
91
  padding_value=-1,
92
  ) -> int["b nt"]: # noqa: F722
93
+ # Split each string into words if not already split
94
+ if isinstance(text[0], str):
95
+ list_words = []
96
+ for t in text:
97
+ # Split the text by triple spaces
98
+ parts = t.split(" ")
99
+ words = []
100
+ for i, part in enumerate(parts):
101
+ # Split each part into words (by single spaces)
102
+ words.extend(part.split())
103
+ # Add a space token if there are more parts (i.e., triple spaces were present)
104
+ if i < len(parts) - 1:
105
+ words.append(" ") # Add a space token
106
+ list_words.append(words)
107
+ else:
108
+ list_words = text
109
+
110
+ # Convert words to their corresponding indices using vocab_map
111
+ list_idx_tensors = [
112
+ torch.tensor([vocab_map.get(word, 0) for word in words]) # Use 0 for unknown words
113
+ for words in list_words
114
+ ]
115
+
116
+ # Pad the sequences
117
  text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
118
  return text
119