winglian commited on
Commit
5079753
·
unverified ·
2 Parent(s): 8eb5811 0136f51

Merge pull request #131 from OpenAccess-AI-Collective/fix-packing-mask

Browse files
src/axolotl/datasets.py CHANGED
@@ -127,6 +127,11 @@ class ConstantLengthDataset(IterableDataset):
127
  input_ids = example["input_ids"]
128
  attention_mask = example["attention_mask"]
129
  labels = example["labels"]
 
 
 
 
 
130
 
131
  if add_concat_token:
132
  input_ids.append(self.concat_token_id)
 
127
  input_ids = example["input_ids"]
128
  attention_mask = example["attention_mask"]
129
  labels = example["labels"]
130
+ if (
131
+ buffer["input_ids"]
132
+ and input_ids[0] == self.tokenizer.bos_token_id
133
+ ):
134
+ attention_mask[0] = 0
135
 
136
  if add_concat_token:
137
  input_ids.append(self.concat_token_id)
tests/fixtures/alpaca/alpaca.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "instruction": "You will be given a series of words. Output these words in reverse order, with each word on its own line.",
4
+ "input": "Words: ['Hello', 'world'].",
5
+ "output": "['world', 'Hello']"
6
+ },
7
+ {
8
+ "instruction": "In this task, you're given a short description of an event. Your job is to order the steps involved in the event from first to last. Note that there may be multiple correct answers for each event.",
9
+ "input": "Description: A man walks into a bar and orders a drink. He pays for his drink and leaves the bar.",
10
+ "output": "1. The man walks into the bar.\n2. He orders a drink.\n3. He pays for his drink.\n4. He leaves the bar."
11
+ }
12
+ ]
tests/test_packed_dataset.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module for testing dataset sequence packing"""
2
+
3
+ import unittest
4
+ from pathlib import Path
5
+
6
+ from datasets import Dataset, load_dataset
7
+ from transformers import AutoTokenizer
8
+
9
+ from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
10
+ from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
11
+ from axolotl.prompters import AlpacaPrompter
12
+
13
+
14
+ class TestPacking(unittest.TestCase):
15
+ """
16
+ Test class for packing dataset sequences
17
+ """
18
+
19
+ def setUp(self) -> None:
20
+ # pylint: disable=duplicate-code
21
+ self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
22
+ self.tokenizer.add_special_tokens(
23
+ {
24
+ "bos_token": "<s>",
25
+ "eos_token": "</s>",
26
+ "unk_token": "<unk>",
27
+ }
28
+ )
29
+
30
+ def test_resets_attention(self):
31
+ prompter = AlpacaPrompter("chat")
32
+ strat = AlpacaPromptTokenizingStrategy(
33
+ prompter,
34
+ self.tokenizer,
35
+ False,
36
+ 2048,
37
+ )
38
+ dateset = load_dataset(
39
+ "json",
40
+ data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"),
41
+ )["train"]
42
+ dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset)))
43
+
44
+ constant_len_dataset = ConstantLengthDataset(
45
+ self.tokenizer,
46
+ [dataset],
47
+ seq_length=2048,
48
+ )
49
+ packed_dataset = Dataset.from_list(list(constant_len_dataset))
50
+ example = packed_dataset[0]
51
+ next_bos_index = (
52
+ example["input_ids"][1:].index(self.tokenizer.bos_token_id) + 1
53
+ ) # add one since we sliced
54
+
55
+ # first example doesn't have mask reset
56
+ assert example["input_ids"][0] == self.tokenizer.bos_token_id
57
+ assert example["attention_mask"][0] == 1
58
+
59
+ # but subsequent one does
60
+ assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
61
+ assert example["attention_mask"][next_bos_index] == 0
62
+
63
+
64
+ if __name__ == "__main__":
65
+ unittest.main()
tests/test_prompt_tokenizers.py CHANGED
@@ -18,6 +18,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
18
  """
19
 
20
  def setUp(self) -> None:
 
21
  self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
22
  self.tokenizer.add_special_tokens(
23
  {
 
18
  """
19
 
20
  def setUp(self) -> None:
21
+ # pylint: disable=duplicate-code
22
  self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
23
  self.tokenizer.add_special_tokens(
24
  {