feat: add Metharme prompt strategy (#446)
Browse files* Add Metharme tokenizing strategy
This strategy accounts for how the Metharme JSONLs are formatted as well as adds duplicated EOS tokens which can help trim model output length.
I haven't gotten the chance to test this yet, and probably won't have the chance for quite a bit, so I'm committing this now.
* Redo Metharme tokenizing strategy
lol
* fix: oops
* Rearrange a conditional
* chore: reformat code in accordance with linter
* chore: Make lint not freak out
* chore: fix lint
---------
Co-authored-by: NanoCode012 <[email protected]>
- README.md +4 -0
- src/axolotl/prompt_strategies/metharme.py +76 -0
    	
        README.md
    CHANGED
    
    | @@ -257,6 +257,10 @@ Have dataset(s) in one of the following format (JSONL recommended): | |
| 257 | 
             
              ```json
         | 
| 258 | 
             
              {"conversations": [{"role": "...", "value": "..."}]}
         | 
| 259 | 
             
              ```
         | 
|  | |
|  | |
|  | |
|  | |
| 260 | 
             
            - `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
         | 
| 261 | 
             
              ```json
         | 
| 262 | 
             
              {"conversations": [{"role": "...", "value": "..."}]}
         | 
|  | |
| 257 | 
             
              ```json
         | 
| 258 | 
             
              {"conversations": [{"role": "...", "value": "..."}]}
         | 
| 259 | 
             
              ```
         | 
| 260 | 
            +
            - `metharme`: instruction, adds additional eos tokens
         | 
| 261 | 
            +
              ```json
         | 
| 262 | 
            +
              {"prompt": "...", "generation": "..."}
         | 
| 263 | 
            +
              ```
         | 
| 264 | 
             
            - `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
         | 
| 265 | 
             
              ```json
         | 
| 266 | 
             
              {"conversations": [{"role": "...", "value": "..."}]}
         | 
    	
        src/axolotl/prompt_strategies/metharme.py
    ADDED
    
    | @@ -0,0 +1,76 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class"""
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import logging
         | 
| 4 | 
            +
            from typing import Tuple
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
         | 
| 7 | 
            +
            from axolotl.prompters import AlpacaPrompter
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            LOG = logging.getLogger("axolotl")
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            IGNORE_TOKEN_ID = -100
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            # pylint: disable=duplicate-code
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            class MetharmePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
         | 
| 17 | 
            +
                """
         | 
| 18 | 
            +
                Tokenizing strategy for the Metharme models
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
         | 
| 22 | 
            +
                    return (prompt["prompt"], "", prompt["generation"])
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def _tokenize(
         | 
| 25 | 
            +
                    self,
         | 
| 26 | 
            +
                    prompt: str,
         | 
| 27 | 
            +
                    add_eos_token: bool = True,
         | 
| 28 | 
            +
                    strip_bos_token: bool = False,
         | 
| 29 | 
            +
                    num_eos_tokens: int = 3,
         | 
| 30 | 
            +
                ):
         | 
| 31 | 
            +
                    result = self.tokenizer(
         | 
| 32 | 
            +
                        prompt,
         | 
| 33 | 
            +
                        truncation=True,
         | 
| 34 | 
            +
                        max_length=self.sequence_len,
         | 
| 35 | 
            +
                        padding=False,
         | 
| 36 | 
            +
                        return_tensors=None,
         | 
| 37 | 
            +
                    )
         | 
| 38 | 
            +
                    if len(result["input_ids"]) == 0:
         | 
| 39 | 
            +
                        LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
         | 
| 40 | 
            +
                    # If there's already an EOS token there, subtract from the number added
         | 
| 41 | 
            +
                    if result["input_ids"][-1] == self.tokenizer.eos_token_id:
         | 
| 42 | 
            +
                        num_eos_tokens -= 1
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    if num_eos_tokens > 0 and add_eos_token and len(result["input_ids"]) > 0:
         | 
| 45 | 
            +
                        for _ in range(num_eos_tokens):
         | 
| 46 | 
            +
                            if len(result["input_ids"]) < self.sequence_len:
         | 
| 47 | 
            +
                                result["input_ids"].append(self.tokenizer.eos_token_id)
         | 
| 48 | 
            +
                                result["attention_mask"].append(1)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
         | 
| 51 | 
            +
                        result["input_ids"] = result["input_ids"][1:]
         | 
| 52 | 
            +
                        result["attention_mask"] = result["attention_mask"][1:]
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    result["labels"] = result["input_ids"].copy()
         | 
| 55 | 
            +
                    return result
         | 
| 56 | 
            +
             | 
| 57 | 
            +
             | 
| 58 | 
            +
            class MetharmePrompter(AlpacaPrompter):
         | 
| 59 | 
            +
                """
         | 
| 60 | 
            +
                Prompter for the Metharme models.
         | 
| 61 | 
            +
                """
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                system_prompt = ""
         | 
| 64 | 
            +
                system_no_input_prompt = ""
         | 
| 65 | 
            +
                system_format = ""
         | 
| 66 | 
            +
                turn_format = "{instruction}"
         | 
| 67 | 
            +
                turn_no_input_format = "{instruction}"
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                def __init__(self, *args, **kwargs):  # pylint: disable=super-init-not-called
         | 
| 70 | 
            +
                    pass
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            def load(tokenizer, cfg):
         | 
| 74 | 
            +
                return MetharmePromptTokenizingStrategy(
         | 
| 75 | 
            +
                    MetharmePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
         | 
| 76 | 
            +
                )
         |