File size: 5,144 Bytes
4c346eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""Prompts and utilities used for training the ether0 model."""

import re
from enum import Enum, StrEnum
from typing import assert_never

# Tokens to surround reasoning and answer in XML format
THINK_START = "<|think_start|>"
THINK_END = "<|think_end|>"
ANSWER_START = "<|answer_start|>"
ANSWER_END = "<|answer_end|>"


# Keys: True (reasoning + answer), False (answer only)
# Use strict regex for ether0 models, as we can SFT or RL the models into compliance
STRICT_XML_ANSWER_SPLIT_PATTERNS: dict[bool, re.Pattern] = {
    True: re.compile(
        rf"^\s?{re.escape(THINK_START)}\s*([\s\S]*?)\s*{re.escape(THINK_END)}([\s\S]*?){re.escape(ANSWER_START)}\s*([\s\S]*?)\s*{re.escape(ANSWER_END)}$"
    ),
    False: re.compile(
        rf"^\s?{re.escape(ANSWER_START)}\s*(\S[\s\S]*?)\s*{re.escape(ANSWER_END)}$"
    ),
}
# Use loose regex for other models because:
# 1. <think> may be out-of-distribution from the model's training data,
#    so requiring thoughts may degrade performance.
# 2. We allow baseline models to add extra whitespace and/or preceding or trailing text
#    around answer XML, again to maximize performance.
# 3. Similarly, we allow models to ramble for a bit mentioning <answer>,
#    and then we just keep the last <answer> XML.
# 4. We want to avoid prompt engineering tricks to get around the previous items.
LOOSE_XML_ANSWER_LOOSE_PATTERN = r"<answer>\s*(\S[\s\S]*?)\s*<\/answer>"


class XMLAnswerPrompts(StrEnum):
    """Enum of prompts to use ."""

    REASONING_ANSWER = (
        "A conversation between User and Assistant."
        " The user asks a question, and the Assistant solves it."
        " The assistant first thinks about the reasoning process"
        " in the mind and then provides the user with the answer."
        " The reasoning process and answer are enclosed within"
        f" {THINK_START} {THINK_END} and {ANSWER_START} {ANSWER_END} tags,"
        " respectively, i.e.,"
        f" {THINK_START} reasoning process here {THINK_END}"
        f"{ANSWER_START} answer here {ANSWER_END}"
    )
    ANSWER_ONLY = (
        "A conversation between User and Assistant."
        " The user asks a question, and the Assistant solves it."
        " The assistant encloses its answer within"
        f" {ANSWER_START} {ANSWER_END} tags, i.e.,"
        f" {ANSWER_START} answer here {ANSWER_END}"
    )

    @property
    def pattern(self) -> re.Pattern:
        return STRICT_XML_ANSWER_SPLIT_PATTERNS[
            self == XMLAnswerPrompts.REASONING_ANSWER
        ]


class SysPrompt(Enum):  # Use Enum over StrEnum for trl.TrlParser compatibility
    """Possible system prompts for making a conversation to train upon."""

    SCIENTIFIC_AI = "scientific_ai"

    def get_sys_prompt(self) -> str:
        match self:
            case SysPrompt.SCIENTIFIC_AI:
                return "You are a scientific reasoning AI assistant."
            case _:
                assert_never(self)


class ProblemPrompt(Enum):  # Use Enum over StrEnum for trl.TrlParser compatibility
    """Possible user prompts for making a conversation to train upon."""

    NONE = "none"
    THINK_ANSWER = "think_answer"
    ANSWER = "answer"

    def get_prompt(self) -> str:
        match self:
            case ProblemPrompt.NONE:
                return ""
            case ProblemPrompt.THINK_ANSWER:
                return XMLAnswerPrompts.REASONING_ANSWER.value
            case ProblemPrompt.ANSWER:
                return XMLAnswerPrompts.ANSWER_ONLY.value
            case _:
                assert_never(self)


def extract_thought_answer_strict(
    text: str, reasoning: bool
) -> tuple[str | None, str | None]:
    """Extract thought and answer from text using a strict XML pattern."""
    # Use `maxsplit=1` to enforce just one match
    matches = STRICT_XML_ANSWER_SPLIT_PATTERNS[reasoning].split(text, maxsplit=1)
    try:
        _, *inner, suffix = matches
    except (IndexError, ValueError):
        return None, None  # Consider no answer or 2+ answers as a failure
    if reasoning:
        thought, inter, answer = inner
    else:
        thought, inter = None, None
        (answer,) = inner
    if (
        THINK_START not in (thought or "")
        and THINK_START not in (inter or "")
        and ANSWER_START not in answer
        and not suffix
    ):
        return thought, answer or None
    return None, None  # Consider nested answer as a failure


LOOSE_XML_ANSWER_USER_PROMPT = (
    "When answering,"
    " be sure to place the final answer as"
    " SMILES notation into XML tags <answer></answer>."
    " An example is <answer>CCO</answer>."
)


def extract_answer_loose(text: str | None) -> str:
    """
    Extract thought and answer from text using a loose XML pattern.

    SEE: LOOSE_XML_ANSWER_LOOSE_PATTERN for when to use this.
    """
    matches = re.findall(LOOSE_XML_ANSWER_LOOSE_PATTERN, text or "")
    try:
        last_answer = matches[-1]  # Last answer in the response
    except IndexError:
        return ""  # Consider no answer as a failure
    if "<answer>" not in last_answer:
        return last_answer
    return ""  # Consider nested answer as a failure