Fix sharegpt prompt
Browse files
src/axolotl/prompt_tokenizers.py
CHANGED
|
@@ -371,15 +371,16 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
| 371 |
]
|
| 372 |
# not masked out from labels
|
| 373 |
labels = copy.deepcopy(res["input_ids"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
else:
|
| 375 |
logging.warning(f"unhandled role: {part[0]}")
|
| 376 |
-
else:
|
| 377 |
-
# this is only ever the first part, should include the bos token and the user query
|
| 378 |
-
res = self._tokenize(
|
| 379 |
-
part.strip(), add_eos_token=False, strip_bos_token=False
|
| 380 |
-
)
|
| 381 |
-
# everything from this is masked out from the labels
|
| 382 |
-
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
| 383 |
|
| 384 |
# pylint: disable=duplicate-code
|
| 385 |
result, current_len = parse_tokenized_to_result(
|
|
|
|
| 371 |
]
|
| 372 |
# not masked out from labels
|
| 373 |
labels = copy.deepcopy(res["input_ids"])
|
| 374 |
+
elif part[0] == "SYSTEM:":
|
| 375 |
+
part = part[1] # Ignore the system role from preamble
|
| 376 |
+
# this is only ever the first part, should include the bos token and the user query
|
| 377 |
+
res = self._tokenize(
|
| 378 |
+
part.strip(), add_eos_token=False, strip_bos_token=False
|
| 379 |
+
)
|
| 380 |
+
# everything from this is masked out from the labels
|
| 381 |
+
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
| 382 |
else:
|
| 383 |
logging.warning(f"unhandled role: {part[0]}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
|
| 385 |
# pylint: disable=duplicate-code
|
| 386 |
result, current_len = parse_tokenized_to_result(
|
src/axolotl/prompters.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
import dataclasses
|
| 4 |
import logging
|
| 5 |
from enum import Enum, auto
|
| 6 |
-
from typing import Generator, List, Optional, Union
|
| 7 |
|
| 8 |
IGNORE_TOKEN_ID = -100
|
| 9 |
|
|
@@ -235,16 +235,16 @@ class Conversation:
|
|
| 235 |
sep: str = "###"
|
| 236 |
sep2: Optional[str] = None
|
| 237 |
|
| 238 |
-
def get_prompt(self) -> Generator[str, None, None]:
|
| 239 |
# seps = [self.sep, self.sep2]
|
| 240 |
preamble = self.system + self.sep
|
| 241 |
-
yield preamble
|
| 242 |
for _, (role, message) in enumerate(self.messages):
|
| 243 |
if message:
|
| 244 |
-
yield role + ":"
|
| 245 |
else:
|
| 246 |
logging.warning(f"role with empty message: {role}")
|
| 247 |
-
yield role + ":"
|
| 248 |
|
| 249 |
def copy(self):
|
| 250 |
return Conversation(
|
|
|
|
| 3 |
import dataclasses
|
| 4 |
import logging
|
| 5 |
from enum import Enum, auto
|
| 6 |
+
from typing import Generator, List, Optional, Tuple, Union
|
| 7 |
|
| 8 |
IGNORE_TOKEN_ID = -100
|
| 9 |
|
|
|
|
| 235 |
sep: str = "###"
|
| 236 |
sep2: Optional[str] = None
|
| 237 |
|
| 238 |
+
def get_prompt(self) -> Generator[Tuple[str, str], None, None]:
|
| 239 |
# seps = [self.sep, self.sep2]
|
| 240 |
preamble = self.system + self.sep
|
| 241 |
+
yield ("SYSTEM:", preamble)
|
| 242 |
for _, (role, message) in enumerate(self.messages):
|
| 243 |
if message:
|
| 244 |
+
yield (role + ":", " " + message)
|
| 245 |
else:
|
| 246 |
logging.warning(f"role with empty message: {role}")
|
| 247 |
+
yield (role + ":", "")
|
| 248 |
|
| 249 |
def copy(self):
|
| 250 |
return Conversation(
|