Spaces:
Running
on
Zero
Running
on
Zero
"""Prompts and utilities used for training the ether0 model.""" | |
import re | |
from enum import Enum, StrEnum | |
from typing import assert_never | |
# Tokens to surround reasoning and answer in XML format | |
THINK_START = "<|think_start|>" | |
THINK_END = "<|think_end|>" | |
ANSWER_START = "<|answer_start|>" | |
ANSWER_END = "<|answer_end|>" | |
# Keys: True (reasoning + answer), False (answer only) | |
# Use strict regex for ether0 models, as we can SFT or RL the models into compliance | |
STRICT_XML_ANSWER_SPLIT_PATTERNS: dict[bool, re.Pattern] = { | |
True: re.compile( | |
rf"^\s?{re.escape(THINK_START)}\s*([\s\S]*?)\s*{re.escape(THINK_END)}([\s\S]*?){re.escape(ANSWER_START)}\s*([\s\S]*?)\s*{re.escape(ANSWER_END)}$" | |
), | |
False: re.compile( | |
rf"^\s?{re.escape(ANSWER_START)}\s*(\S[\s\S]*?)\s*{re.escape(ANSWER_END)}$" | |
), | |
} | |
# Use loose regex for other models because: | |
# 1. <think> may be out-of-distribution from the model's training data, | |
# so requiring thoughts may degrade performance. | |
# 2. We allow baseline models to add extra whitespace and/or preceding or trailing text | |
# around answer XML, again to maximize performance. | |
# 3. Similarly, we allow models to ramble for a bit mentioning <answer>, | |
# and then we just keep the last <answer> XML. | |
# 4. We want to avoid prompt engineering tricks to get around the previous items. | |
LOOSE_XML_ANSWER_LOOSE_PATTERN = r"<answer>\s*(\S[\s\S]*?)\s*<\/answer>" | |
class XMLAnswerPrompts(StrEnum): | |
"""Enum of prompts to use .""" | |
REASONING_ANSWER = ( | |
"A conversation between User and Assistant." | |
" The user asks a question, and the Assistant solves it." | |
" The assistant first thinks about the reasoning process" | |
" in the mind and then provides the user with the answer." | |
" The reasoning process and answer are enclosed within" | |
f" {THINK_START} {THINK_END} and {ANSWER_START} {ANSWER_END} tags," | |
" respectively, i.e.," | |
f" {THINK_START} reasoning process here {THINK_END}" | |
f"{ANSWER_START} answer here {ANSWER_END}" | |
) | |
ANSWER_ONLY = ( | |
"A conversation between User and Assistant." | |
" The user asks a question, and the Assistant solves it." | |
" The assistant encloses its answer within" | |
f" {ANSWER_START} {ANSWER_END} tags, i.e.," | |
f" {ANSWER_START} answer here {ANSWER_END}" | |
) | |
def pattern(self) -> re.Pattern: | |
return STRICT_XML_ANSWER_SPLIT_PATTERNS[ | |
self == XMLAnswerPrompts.REASONING_ANSWER | |
] | |
class SysPrompt(Enum): # Use Enum over StrEnum for trl.TrlParser compatibility | |
"""Possible system prompts for making a conversation to train upon.""" | |
SCIENTIFIC_AI = "scientific_ai" | |
def get_sys_prompt(self) -> str: | |
match self: | |
case SysPrompt.SCIENTIFIC_AI: | |
return "You are a scientific reasoning AI assistant." | |
case _: | |
assert_never(self) | |
class ProblemPrompt(Enum): # Use Enum over StrEnum for trl.TrlParser compatibility | |
"""Possible user prompts for making a conversation to train upon.""" | |
NONE = "none" | |
THINK_ANSWER = "think_answer" | |
ANSWER = "answer" | |
def get_prompt(self) -> str: | |
match self: | |
case ProblemPrompt.NONE: | |
return "" | |
case ProblemPrompt.THINK_ANSWER: | |
return XMLAnswerPrompts.REASONING_ANSWER.value | |
case ProblemPrompt.ANSWER: | |
return XMLAnswerPrompts.ANSWER_ONLY.value | |
case _: | |
assert_never(self) | |
def extract_thought_answer_strict( | |
text: str, reasoning: bool | |
) -> tuple[str | None, str | None]: | |
"""Extract thought and answer from text using a strict XML pattern.""" | |
# Use `maxsplit=1` to enforce just one match | |
matches = STRICT_XML_ANSWER_SPLIT_PATTERNS[reasoning].split(text, maxsplit=1) | |
try: | |
_, *inner, suffix = matches | |
except (IndexError, ValueError): | |
return None, None # Consider no answer or 2+ answers as a failure | |
if reasoning: | |
thought, inter, answer = inner | |
else: | |
thought, inter = None, None | |
(answer,) = inner | |
if ( | |
THINK_START not in (thought or "") | |
and THINK_START not in (inter or "") | |
and ANSWER_START not in answer | |
and not suffix | |
): | |
return thought, answer or None | |
return None, None # Consider nested answer as a failure | |
LOOSE_XML_ANSWER_USER_PROMPT = ( | |
"When answering," | |
" be sure to place the final answer as" | |
" SMILES notation into XML tags <answer></answer>." | |
" An example is <answer>CCO</answer>." | |
) | |
def extract_answer_loose(text: str | None) -> str: | |
""" | |
Extract thought and answer from text using a loose XML pattern. | |
SEE: LOOSE_XML_ANSWER_LOOSE_PATTERN for when to use this. | |
""" | |
matches = re.findall(LOOSE_XML_ANSWER_LOOSE_PATTERN, text or "") | |
try: | |
last_answer = matches[-1] # Last answer in the response | |
except IndexError: | |
return "" # Consider no answer as a failure | |
if "<answer>" not in last_answer: | |
return last_answer | |
return "" # Consider nested answer as a failure | |