Spaces:
Running
on
Zero
Running
on
Zero
import re | |
from collections.abc import Callable, MutableMapping | |
from dataclasses import dataclass, field | |
from functools import wraps | |
from itertools import starmap | |
from typing import Any, ParamSpec, TypeVar, cast | |
from ether0.model_prompts import ( | |
ANSWER_END, | |
ANSWER_START, | |
THINK_END, | |
THINK_START, | |
ProblemPrompt, | |
SysPrompt, | |
extract_answer_loose, | |
) | |
from ether0.rewards import accuracy_reward, format_reward | |
P = ParamSpec("P") | |
R = TypeVar("R") | |
def wrap_reward_func(func: Callable[P, R], **wrap_kwargs: Any) -> Callable[P, R]: | |
# needed by GRPOTrainer for logging | |
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R: | |
return func(*args, **wrap_kwargs, **kwargs) | |
return wrapped | |
class ChatArguments: | |
"""Arguments for making a chat conversation for SFT or RL training.""" | |
sys_prompt: SysPrompt | None = field( | |
default=None, | |
metadata={ | |
"help": ( | |
"If provided, use this system prompt. If not provided, the chat" | |
" template may inject one." | |
) | |
}, | |
) | |
problem_prompt: ProblemPrompt = field( | |
default=ProblemPrompt.NONE, | |
metadata={ | |
"help": ( | |
"Prompt to put before the problem in the first user message, relevant" | |
" for both RL or SFT. Make sure this matches between SFT and RL, so if" | |
" the SFT'd model wasn't passed this during SFT, don't pass this to RL." | |
) | |
}, | |
) | |
reasoning: bool = field( | |
default=True, | |
metadata={ | |
"help": ( | |
"If True (default), it is assumed that the model's response contains" | |
f" reasoning enclosed in `{THINK_START}` and `{THINK_END}`." | |
) | |
}, | |
) | |
def make_rl_conversation( | |
self, row: MutableMapping[str, str | list[str]] | |
) -> dict[str, list[dict] | list[list[dict]]]: | |
"""Format a dataset row into a chat-like conversation structure. | |
This will add a `messages` key to the dataset. Unlike make_sft_convo, | |
the answer will not be included. | |
""" | |
if not self.sys_prompt: | |
msgs: list[dict] = [] | |
else: | |
msgs = [{ | |
"role": "system", | |
"content": SysPrompt(self.sys_prompt).get_sys_prompt(), | |
}] | |
problem_prompt = ProblemPrompt(self.problem_prompt).get_prompt() | |
if problem_prompt: | |
problem_prompt += "\n\n" | |
def add_user(problem: str) -> list[dict]: | |
return [*msgs, {"role": "user", "content": problem_prompt + problem}] | |
if isinstance(row["problem"], str): # Single | |
all_msgs: list[dict] | list[list[dict]] = add_user(row["problem"]) | |
else: # Batched | |
all_msgs = [add_user(p) for p in row["problem"]] | |
return {"prompt": all_msgs} | |
def make_sft_conversation( | |
self, row: MutableMapping[str, str | list[str]] | |
) -> dict[str, list[dict] | list[list[dict]]]: | |
"""Format a dataset row into a chat-like conversation structure. | |
This will add a `messages` key to the dataset. | |
""" | |
if ( | |
self.reasoning | |
and ProblemPrompt(self.problem_prompt) == ProblemPrompt.ANSWER | |
): | |
raise ValueError( | |
"It does not make sense to include reasoning in the SFT traces," | |
" but then only prompt about answer XML (without thoughts)." | |
) | |
def add_assistant( | |
raw_answer: str, thought: str, prior_msgs: list[dict] | |
) -> list[dict]: | |
if re.search(r"<\/answer>", raw_answer): | |
# Remove prelude and postlude plus XML tags, | |
# because an OpenRouter-hosted DeepSeek R1 can give answer | |
# with a prelude and XML tags, but our training expects just an answer | |
# > The reaction involves sodium borohydride ([BH4-].[Na+]), <redacted>. | |
# > Under these conditions, <redacted>. | |
# > <answer>N1(CCOCC1)C1=CC=C(C(O))C=C1</answer> | |
answer = extract_answer_loose(raw_answer) | |
if not answer: | |
raise ValueError( | |
"Failed to extract just the answer from the answer" | |
f" {raw_answer!r}." | |
) | |
else: | |
answer = raw_answer | |
return [ | |
*prior_msgs, | |
{ | |
"role": "assistant", | |
"content": ( | |
(f"{THINK_START}{thought}{THINK_END}" if self.reasoning else "") | |
+ f"{ANSWER_START}{answer}{ANSWER_END}" | |
), | |
}, | |
] | |
# The first part will be the same as the RL conversation | |
msgs = self.make_rl_conversation(row)["prompt"] | |
# Now add the answer, with optional thinking | |
if isinstance(row["problem"], str): # Single | |
all_msgs: list[dict] | list[list[dict]] = add_assistant( | |
cast(str, row["answer"]), | |
cast(str, row["thought"]), | |
cast(list[dict], msgs), | |
) | |
else: # Batched | |
all_msgs = list( | |
starmap( | |
add_assistant, zip(row["answer"], row["thought"], msgs, strict=True) | |
) | |
) | |
return {"messages": all_msgs} | |
def get_reward_funcs( | |
self, | |
format_reward_value: float = 1.0, | |
soft: bool = False, | |
test: bool = False, | |
good_molecule_bonus: float = 0.0, | |
) -> list[Callable]: | |
return [ | |
wrap_reward_func( | |
format_reward, | |
reasoning=self.reasoning, | |
reward=format_reward_value, | |
), | |
wrap_reward_func( | |
accuracy_reward, | |
reasoning=self.reasoning, | |
soft=soft, | |
test=test, | |
good_molecule_bonus=good_molecule_bonus, | |
), | |
] | |