jonahkall commited on
Commit
4c346eb
·
verified ·
1 Parent(s): d02cab9

Upload 51 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. CITATION.cff +66 -0
  3. LICENSE +201 -0
  4. README.md +207 -16
  5. app.py +78 -0
  6. docs/adding_tokens.ipynb +129 -0
  7. docs/assets/ether0_logo.svg +13 -0
  8. docs/assets/training_info.png +3 -0
  9. docs/assets/valid_molecule.svg +211 -0
  10. docs/updated_mistral_chat_template.jinja +22 -0
  11. gradio_test.py +12 -0
  12. packages/remotes/.DS_Store +0 -0
  13. packages/remotes/.gitignore +2 -0
  14. packages/remotes/LICENSE +201 -0
  15. packages/remotes/README.md +38 -0
  16. packages/remotes/pyproject.toml +59 -0
  17. packages/remotes/src/ether0.remotes.egg-info/PKG-INFO +274 -0
  18. packages/remotes/src/ether0.remotes.egg-info/SOURCES.txt +14 -0
  19. packages/remotes/src/ether0.remotes.egg-info/dependency_links.txt +1 -0
  20. packages/remotes/src/ether0.remotes.egg-info/entry_points.txt +2 -0
  21. packages/remotes/src/ether0.remotes.egg-info/requires.txt +16 -0
  22. packages/remotes/src/ether0.remotes.egg-info/top_level.txt +1 -0
  23. packages/remotes/src/ether0/server.py +316 -0
  24. packages/remotes/tests/conftest.py +29 -0
  25. packages/remotes/tests/test_clients.py +91 -0
  26. packages/remotes/tests/test_rewards.py +108 -0
  27. pyproject.toml +546 -0
  28. src/.DS_Store +0 -0
  29. src/ether0.egg-info/PKG-INFO +467 -0
  30. src/ether0.egg-info/SOURCES.txt +46 -0
  31. src/ether0.egg-info/dependency_links.txt +1 -0
  32. src/ether0.egg-info/requires.txt +40 -0
  33. src/ether0.egg-info/top_level.txt +1 -0
  34. src/ether0/chat.py +177 -0
  35. src/ether0/clients.py +163 -0
  36. src/ether0/data.py +225 -0
  37. src/ether0/fingerprints.bloom +3 -0
  38. src/ether0/model_prompts.py +142 -0
  39. src/ether0/models.py +173 -0
  40. src/ether0/problem_prompts.py +196 -0
  41. src/ether0/py.typed +0 -0
  42. src/ether0/rewards.py +753 -0
  43. src/ether0/rings.bloom +3 -0
  44. src/ether0/utils.py +108 -0
  45. tests/conftest.py +12 -0
  46. tests/test_chat.py +61 -0
  47. tests/test_data.py +158 -0
  48. tests/test_model_prompts.py +424 -0
  49. tests/test_models.py +81 -0
  50. tests/test_rewards.py +489 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ docs/assets/training_info.png filter=lfs diff=lfs merge=lfs -text
37
+ src/ether0/fingerprints.bloom filter=lfs diff=lfs merge=lfs -text
38
+ src/ether0/rings.bloom filter=lfs diff=lfs merge=lfs -text
CITATION.cff ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ cff-version: 1.2.0
3
+ title: "Training a Scientific Reasoning Model for Chemistry"
4
+ message: >-
5
+ If you use this software, please cite it using the
6
+ metadata from this file.
7
+ authors:
8
+ - given-names: Siddharth M.
9
+ family-names: Narayanan
10
+ - given-names: James D.
11
+ family-names: Braza
12
+ - given-names: Ryan-Rhys
13
+ family-names: Griffiths
14
+ - given-names: Albert
15
+ family-names: Bou
16
+ - given-names: Geemi P.
17
+ family-names: Wellawatte
18
+ - given-names: Mayk
19
+ family-names: Caldas Ramos
20
+ - given-names: Ludovico
21
+ family-names: Mitchener
22
+ - given-names: Samuel G.
23
+ family-names: Rodriques
24
+ - given-names: Andrew D.
25
+ family-names: White
26
+ identifiers:
27
+ - type: doi
28
+ value: 10.48550/arXiv.2506.17238
29
+ description: ArXiv DOI
30
+ - type: url
31
+ value: https://arxiv.org/abs/2506.17238
32
+ description: ArXiv abstract
33
+ repository-code: https://github.com/Future-House/ether0
34
+ keywords:
35
+ - Artificial Intelligence
36
+ - Chemistry
37
+ - Computation and Language
38
+ - Machine Learning
39
+ - Reasoning Model
40
+ license: Apache-2.0
41
+ preferred-citation:
42
+ authors:
43
+ - given-names: Siddharth M.
44
+ family-names: Narayanan
45
+ - given-names: James D.
46
+ family-names: Braza
47
+ - given-names: Ryan-Rhys
48
+ family-names: Griffiths
49
+ - given-names: Albert
50
+ family-names: Bou
51
+ - given-names: Geemi P.
52
+ family-names: Wellawatte
53
+ - given-names: Mayk
54
+ family-names: Caldas Ramos
55
+ - given-names: Ludovico
56
+ family-names: Mitchener
57
+ - given-names: Samuel G.
58
+ family-names: Rodriques
59
+ - given-names: Andrew D.
60
+ family-names: White
61
+ date-published: 2025-06-04
62
+ doi: 10.48550/arXiv.2506.17238
63
+ journal: preprint
64
+ title: "Training a Scientific Reasoning Model for Chemistry"
65
+ type: article
66
+ url: https://arxiv.org/abs/2506.17238
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2025 FutureHouse
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,16 +1,207 @@
1
- ---
2
- title: Ether0 Inference
3
- emoji: 💬
4
- colorFrom: yellow
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.42.0
8
- app_file: app.py
9
- pinned: false
10
- hf_oauth: true
11
- hf_oauth_scopes:
12
- - inference-api
13
- license: mit
14
- ---
15
-
16
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ether0 Reward Model
2
+
3
+ [![GitHub](https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white)](https://github.com/Future-House/ether0)
4
+ [![arXiv](https://img.shields.io/badge/arXiv-2506.17238-b31b1b.svg)](https://arxiv.org/abs/2506.17238)
5
+ [![Project Status: Active](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active)
6
+ ![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)
7
+
8
+ [![Tests](https://github.com/Future-House/ether0/actions/workflows/lint-test.yaml/badge.svg)](https://github.com/Future-House/ether0/actions)
9
+ [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
10
+ [![python](https://img.shields.io/badge/python-3.11+-blue?style=flat&logo=python&logoColor=white)](https://www.python.org)
11
+ [![Model on HF](https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-md-dark.svg)](https://huggingface.co/futurehouse/ether0)
12
+ [![Dataset on HF](https://huggingface.co/datasets/huggingface/badges/resolve/main/dataset-on-hf-md-dark.svg)](https://huggingface.co/datasets/futurehouse/ether0-benchmark)
13
+
14
+ ![ether0 logo](docs/assets/ether0_logo.svg)
15
+
16
+ _ether0: a scientific reasoning model, dataset, and reward functions for chemistry._
17
+
18
+ This repo contains the reward model for evaluating ether0 and similar models,
19
+ along with utilities for working with the verifiable rewards in
20
+ [our benchmark](https://huggingface.co/datasets/futurehouse/ether0-benchmark).
21
+
22
+ ## Overview
23
+
24
+ ether0 is a reasoning language model post-trained through a loop of:
25
+
26
+ 1. Supervised fine-tuning (SFT) on long chain-of-thought reasoning traces,
27
+ to elicit reasoning from a base model.
28
+ 2. Reinforcement learning with verifiable rewards (RLVR)
29
+ to improve reasoning on focused task groups, at their own pace.
30
+ These multitask learned models are referred to as 'specialists'.
31
+ 3. Rejection sampling to filter specialists' reasoning
32
+ for correctness and quality.
33
+ 4. SFT on the base model again to make a 'generalist' reasoning model.
34
+ 5. RLVR to recover any lost performance and push further in an all-task setting.
35
+
36
+ ![ether0 training info](docs/assets/training_info.png)
37
+
38
+ ### Repo Structure
39
+
40
+ This repo contains several packages:
41
+
42
+ - `ether0`: reward functions, `rdkit` data utilities,
43
+ dataset generation prompts, dataset data models,
44
+ language model training prompts, and data models.
45
+ - `ether0.remotes`: server code for ether0 reward functions involving
46
+ exotic packages and/or third party models.
47
+
48
+ > [!NOTE]
49
+ > This repo does not contain training code,
50
+ > although you can find open source repositories like [NeMo-RL](https://github.com/NVIDIA/NeMo-RL)
51
+ > or [Hugging Face TRL](https://github.com/huggingface/trl)
52
+ > that can do the SFT and RL phases of training.
53
+
54
+ ### Open Weights
55
+
56
+ Please see our open-source weights on Hugging Face:
57
+ <https://huggingface.co/futurehouse/ether0>
58
+
59
+ ```python
60
+ from transformers import AutoModelForCausalLM, AutoTokenizer
61
+
62
+ model = AutoModelForCausalLM.from_pretrained("futurehouse/ether0")
63
+ tokenizer = AutoTokenizer.from_pretrained("futurehouse/ether0")
64
+ ```
65
+
66
+ ### Open Test Set
67
+
68
+ Please see our open-source benchmark (test set) on Hugging Face:
69
+ <https://huggingface.co/datasets/futurehouse/ether0-benchmark>
70
+
71
+ ```python
72
+ from datasets import load_dataset
73
+
74
+ test_ds = load_dataset("futurehouse/ether0-benchmark", split="test")
75
+ ```
76
+
77
+ ## Usage
78
+
79
+ ### Installation
80
+
81
+ The easiest way to get started is a `pip install` from GitHub:
82
+
83
+ ```bash
84
+ pip install git+https://github.com/Future-House/ether0.git
85
+ ```
86
+
87
+ Or if you want the full set up, clone the repo and use `uv`:
88
+
89
+ ```bash
90
+ git clone https://github.com/Future-House/ether0.git
91
+ cd ether0
92
+ uv sync
93
+ ```
94
+
95
+ ### Reward Functions
96
+
97
+ Here is a basic example of how to use the reward functions:
98
+
99
+ ```python
100
+ from ether0.rewards import valid_mol_eval
101
+
102
+ # Task: provide a valid completion of this molecule
103
+ partial_smiles = "O=C(OC1C(OC(=O)C=2C=CC=CC2)C3(O)C(C)(C)CCCC3(C)C4CC=5OC=CC5C(C)C14"
104
+
105
+ # Here's two model-proposed SMILES completions
106
+ invalid_completion_smiles = "CCC"
107
+ valid_completion_smiles = ")C=6C=CC=CC6"
108
+
109
+ # Evaluate the completions
110
+ assert not valid_mol_eval(invalid_completion_smiles, partial_smiles)
111
+ assert valid_mol_eval(valid_completion_smiles, partial_smiles)
112
+ ```
113
+
114
+ ### Visualization
115
+
116
+ If it helps, you can visualize the molecules:
117
+
118
+ ```python
119
+ from ether0.data import draw_molecule
120
+
121
+ # See above reward functions demo for where these came from
122
+ partial_smiles = "O=C(OC1C(OC(=O)C=2C=CC=CC2)C3(O)C(C)(C)CCCC3(C)C4CC=5OC=CC5C(C)C14"
123
+ invalid_completion_smiles = "CCC"
124
+ valid_completion_smiles = ")C=6C=CC=CC6"
125
+
126
+ valid_mol_text = draw_molecule(partial_smiles + valid_completion_smiles)
127
+ with open("valid_molecule.svg", "w") as f:
128
+ f.write(valid_mol_text)
129
+ ```
130
+
131
+ The output of `draw_molecule` can also be easily visualized using `IPython.display`,
132
+ or in your terminal via `chafa valid_molecule.svg`
133
+ ([chafa docs](https://hpjansson.org/chafa/)).
134
+
135
+ ![valid molecule](docs/assets/valid_molecule.svg)
136
+
137
+ ### Benchmark
138
+
139
+ Here is a sample baseline of
140
+ [`ether0-benchmark`](https://huggingface.co/datasets/futurehouse/ether0-benchmark)
141
+ on `gpt-4o` using [`lmi`](https://github.com/Future-House/ldp/tree/main/packages/lmi).
142
+ To install `lmi`, please install `ether0` with the `baselines` extra
143
+ (for example `uv sync --extra baselines`).
144
+
145
+ We also need to run our remote rewards server via `ether0-serve`
146
+ (for more information, see [`ether0.remotes` docs](packages/remotes/README.md)):
147
+
148
+ ```bash
149
+ ETHER0_REMOTES_API_TOKEN=abc123 ether0-serve
150
+ ```
151
+
152
+ Next, start `ipython` with the relevant environment variables set:
153
+
154
+ ```bash
155
+ ETHER0_REMOTES_API_BASE_URL="http://127.0.0.1:8000" ETHER0_REMOTES_API_TOKEN=abc123 \
156
+ ipython
157
+ ```
158
+
159
+ And run the following Python code:
160
+
161
+ ```python
162
+ import itertools
163
+ import statistics
164
+ from collections import defaultdict
165
+
166
+ from aviary.core import Message
167
+ from datasets import load_dataset
168
+ from lmi import LiteLLMModel
169
+ from tqdm.asyncio import tqdm_asyncio as asyncio
170
+
171
+ from ether0.data import get_problem_category
172
+ from ether0.model_prompts import LOOSE_XML_ANSWER_USER_PROMPT, extract_answer_loose
173
+ from ether0.models import RewardFunctionInfo
174
+ from ether0.rewards import EVAL_FUNCTIONS
175
+
176
+ # Add LLM prompt of your making to the dataset
177
+ test_ds = load_dataset("futurehouse/ether0-benchmark", split="test").map(
178
+ lambda x: {"prompt": "\n\n".join((LOOSE_XML_ANSWER_USER_PROMPT, x["problem"]))}
179
+ )
180
+
181
+ # Prompt to LLM
182
+ model = LiteLLMModel(name="gpt-4o")
183
+ results = await asyncio.gather(
184
+ *(model.acompletion([Message(content=row["prompt"])]) for row in test_ds),
185
+ desc="Running evaluation",
186
+ )
187
+
188
+ # Compute rewards
189
+ per_category_rewards = defaultdict(list)
190
+ for row, result in zip(test_ds, results, strict=True):
191
+ # NOTE: you can also use `ether0.rewards.accuracy_reward`,
192
+ # but we decided to go a bit "lower level" for this demo
193
+ reward_info = RewardFunctionInfo.model_validate(row["solution"])
194
+ yhat = extract_answer_loose(result[0].text)
195
+ reward = EVAL_FUNCTIONS[reward_info.fxn_name](
196
+ yhat=yhat, y=reward_info.answer_info, test=True
197
+ )
198
+ per_category_rewards[get_problem_category(reward_info.problem_type)].append(reward)
199
+
200
+ for category, rewards in sorted(per_category_rewards.items()):
201
+ print(
202
+ f"In category {category!r} of {len(rewards)} questions,"
203
+ f" average reward was {statistics.mean(rewards):.3f}."
204
+ )
205
+ accuracy = statistics.mean(itertools.chain.from_iterable(per_category_rewards.values()))
206
+ print(f"Cumulative average reward across {len(test_ds)} questions was {accuracy:.3f}.")
207
+ ```
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import spaces
4
+ import torch
5
+ import logging
6
+ import time
7
+
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ from transformers.utils import logging as hf_logging
10
+
11
+ logging.basicConfig(
12
+ filename="/tmp/app.log",
13
+ level=logging.DEBUG,
14
+ format="%(asctime)s %(levelname)s: %(message)s"
15
+ )
16
+
17
+ logging.info("Starting app.py logging")
18
+ hf_logging.set_verbosity_debug()
19
+ hf_logging.set_verbosity_info()
20
+ hf_logging.enable_default_handler()
21
+ hf_logging.enable_explicit_format()
22
+ hf_logging.add_handler(logging.FileHandler("/tmp/transformers.log"))
23
+
24
+
25
+ model_id = "futurehouse/ether0"
26
+
27
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ model_id,
30
+ device_map="auto",
31
+ torch_dtype=torch.float16
32
+ )
33
+
34
+ @spaces.GPU
35
+ def chat_fn(prompt, max_tokens=512):
36
+ t0 = time.time()
37
+ max_tokens = min(int(max_tokens), 32_000)
38
+
39
+ try:
40
+ messages = [{"role": "user", "content": prompt}]
41
+ chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
42
+
43
+ inputs = tokenizer(chat_prompt, return_tensors="pt").to(model.device)
44
+ t1 = time.time()
45
+ logging.info(f"🧠 Tokenization complete in {t1 - t0:.2f}s")
46
+
47
+ outputs = model.generate(
48
+ **inputs,
49
+ max_new_tokens=max_tokens,
50
+ do_sample=True,
51
+ temperature=0.1,
52
+ pad_token_id=tokenizer.eos_token_id
53
+ )
54
+ t2 = time.time()
55
+ logging.info(f"⚡️ Generation complete in {t2 - t1:.2f}s (max_tokens={max_tokens})")
56
+
57
+ generated_text = tokenizer.decode(
58
+ outputs[0][inputs['input_ids'].shape[1]:],
59
+ skip_special_tokens=True
60
+ )
61
+ t3 = time.time()
62
+ logging.info(f"🔓 Decoding complete in {t3 - t2:.2f}s (output length: {len(generated_text)})")
63
+
64
+ return generated_text
65
+
66
+ except Exception:
67
+ logging.exception("❌ Exception during generation")
68
+ return "⚠️ Generation failed"
69
+
70
+ gr.Interface(
71
+ fn=chat_fn,
72
+ inputs=[
73
+ gr.Textbox(label="prompt"),
74
+ gr.Number(label="max_tokens", value=512, precision=0)
75
+ ],
76
+ outputs="text",
77
+ title="Ether0"
78
+ ).launch(ssr_mode=False)
docs/adding_tokens.ipynb ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "072120f9",
6
+ "metadata": {},
7
+ "source": [
8
+ "If you would like to modify a base model to add our custom reasoning tokens,\n",
9
+ "here's how to do it.\n",
10
+ "\n",
11
+ "Firstly, please install the `add-tokens` extra via\n",
12
+ "`pip install ether0[add-tokens]` for the `transformers` package.\n",
13
+ "\n",
14
+ "Then, configure the following inputs."
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "id": "a2fb6296",
21
+ "metadata": {},
22
+ "outputs": [],
23
+ "source": [
24
+ "# Model name/revisions for Hugging Face Hub\n",
25
+ "input_model_name = \"mistralai/Mistral-Small-24B-Instruct-2501\"\n",
26
+ "input_model_revision: str | None = None\n",
27
+ "output_model_name = \"FILL ME IN\"\n",
28
+ "output_model_revision: str | None = None\n",
29
+ "output_model_is_private = True\n",
30
+ "tokenizer_only = False # Set True to only update the tokenizer\n",
31
+ "push_to_hf = False # Set True to push to Hugging Face Hub\n",
32
+ "\n",
33
+ "# Chat template file that uses the new tokens\n",
34
+ "chat_template_path = \"updated_mistral_chat_template.jinja\""
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "id": "99927d80",
41
+ "metadata": {},
42
+ "outputs": [
43
+ {
44
+ "data": {
45
+ "application/vnd.jupyter.widget-view+json": {
46
+ "model_id": "8e15d3fb5e864e1286cf94fc588e504d",
47
+ "version_major": 2,
48
+ "version_minor": 0
49
+ },
50
+ "text/plain": [
51
+ "Loading checkpoint shards: 0%| | 0/10 [00:00<?, ?it/s]"
52
+ ]
53
+ },
54
+ "metadata": {},
55
+ "output_type": "display_data"
56
+ },
57
+ {
58
+ "name": "stderr",
59
+ "output_type": "stream",
60
+ "text": [
61
+ "The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`\n",
62
+ "The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`\n"
63
+ ]
64
+ }
65
+ ],
66
+ "source": [
67
+ "from pathlib import Path\n",
68
+ "\n",
69
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
70
+ "\n",
71
+ "from ether0.model_prompts import ANSWER_END, ANSWER_START, THINK_END, THINK_START\n",
72
+ "\n",
73
+ "REASONING_TOKENS_TO_ADD = [\n",
74
+ " THINK_START,\n",
75
+ " THINK_END,\n",
76
+ " ANSWER_START,\n",
77
+ " ANSWER_END,\n",
78
+ "]\n",
79
+ "\n",
80
+ "tokenizer = AutoTokenizer.from_pretrained(\n",
81
+ " input_model_name, revision=input_model_revision\n",
82
+ ")\n",
83
+ "# NOTE: reasoning tokens are normal (not special) tokens so they aren't\n",
84
+ "# removed when passing skip_special_tokens=True to a tokenizer\n",
85
+ "tokenizer.add_tokens(REASONING_TOKENS_TO_ADD)\n",
86
+ "tokenizer.chat_template = Path(chat_template_path).read_text(encoding=\"utf-8\")\n",
87
+ "if push_to_hf:\n",
88
+ " tokenizer.push_to_hub(\n",
89
+ " output_model_name,\n",
90
+ " revision=output_model_revision,\n",
91
+ " private=output_model_is_private,\n",
92
+ " )\n",
93
+ "\n",
94
+ "if not tokenizer_only:\n",
95
+ " model = AutoModelForCausalLM.from_pretrained(\n",
96
+ " input_model_name, revision=input_model_revision\n",
97
+ " )\n",
98
+ " # SEE: https://www.thonking.ai/p/what-shapes-do-matrix-multiplications\n",
99
+ " model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)\n",
100
+ " if push_to_hf:\n",
101
+ " model.push_to_hub(\n",
102
+ " output_model_name,\n",
103
+ " revision=output_model_revision,\n",
104
+ " private=output_model_is_private,\n",
105
+ " )"
106
+ ]
107
+ }
108
+ ],
109
+ "metadata": {
110
+ "kernelspec": {
111
+ "display_name": ".venv",
112
+ "language": "python",
113
+ "name": "python3"
114
+ },
115
+ "language_info": {
116
+ "codemirror_mode": {
117
+ "name": "ipython",
118
+ "version": 3
119
+ },
120
+ "file_extension": ".py",
121
+ "mimetype": "text/x-python",
122
+ "name": "python",
123
+ "nbconvert_exporter": "python",
124
+ "pygments_lexer": "ipython3"
125
+ }
126
+ },
127
+ "nbformat": 4,
128
+ "nbformat_minor": 5
129
+ }
docs/assets/ether0_logo.svg ADDED
docs/assets/training_info.png ADDED

Git LFS Details

  • SHA256: 394653640102293eb6a10e4aac2c14fa58aea61077f3203211cf0226b38d84fe
  • Pointer size: 131 Bytes
  • Size of remote file: 589 kB
docs/assets/valid_molecule.svg ADDED
docs/updated_mistral_chat_template.jinja ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- set default_system_message = "You are a scientific reasoning AI assistant." %}
2
+ {{- bos_token }}
3
+ {%- if messages[0]['role'] == 'system' %}
4
+ {%- set system_message = messages[0]['content'] %}
5
+ {%- set loop_messages = messages[1:] %}
6
+ {%- else %}
7
+ {%- set system_message = default_system_message %}
8
+ {%- set loop_messages = messages %}
9
+ {%- endif %}
10
+ {{- '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }}
11
+
12
+ {%- for message in loop_messages %}
13
+ {%- if message['role'] == 'user' %}
14
+ {{- '[INST]' + message['content'] + '[/INST]' }}
15
+ {%- elif message['role'] == 'system' %}
16
+ {{- '[SYSTEM_PROMPT]' + message['content'] + '[/SYSTEM_PROMPT]' }}
17
+ {%- elif message['role'] == 'assistant' %}
18
+ {{- message['content'] + eos_token }}
19
+ {%- else %}
20
+ {{- raise_exception("Only user, system and assistant roles are supported!") }}
21
+ {%- endif %}
22
+ {%- endfor %}
gradio_test.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def greet(name, intensity):
4
+ return "Hello, " + name + "!" * int(intensity)
5
+
6
+ demo = gr.Interface(
7
+ fn=greet,
8
+ inputs=["text", "slider"],
9
+ outputs=["text"],
10
+ )
11
+
12
+ demo.launch()
packages/remotes/.DS_Store ADDED
Binary file (6.15 kB). View file
 
packages/remotes/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Ignore any downloaded PyTorch model (e.g. a MolTrans model)
2
+ src/ether0/*.pt
packages/remotes/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2025 FutureHouse
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
packages/remotes/README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ether0.remotes
2
+
3
+ Server code for ether0 reward functions.
4
+
5
+ ## Molecular Transformer (MolTrans) Model
6
+
7
+ To run the `/translate` endpoint, you need a pre-trained MolTrans PyTorch model.
8
+ This can be acquired from Future House's Google Drive via the following command:
9
+
10
+ ```bash
11
+ curl --location --output src/ether0/USPTO480k_model_step_400000.pt \
12
+ "https://drive.usercontent.google.com/download?id=1Rjd3wXg2oLeCpNUofFRvVvQoOcgWd6vf&export=download&confirm=t"
13
+ ```
14
+
15
+ Or more manually:
16
+
17
+ 1. Go to [this notebook][1]
18
+ 2. Download the `USPTO480k_model_step_400000.pt`
19
+ linked in the `trained_model_url` variable's linked Google Drive file:
20
+ <https://drive.google.com/uc?id=1ywJCJHunoPTB5wr6KdZ8aLv7tMFMBHNy>
21
+ 3. Set the environment variable `ETHER0_REMOTES_MOLTRANS_MODEL_PATH`
22
+ to the downloaded PyTorch model's location,
23
+ or place the model in the default checked `ether0` source code folder (`src/ether0`).
24
+
25
+ [1]: https://github.com/schwallergroup/ai4chem_course/blob/main/notebooks/07%20-%20Reaction%20Prediction/template_free.ipynb
26
+
27
+ ## Serving
28
+
29
+ To run the server:
30
+
31
+ 1. `pip install` with the `serve` extra: `pip install ether0.remotes[serve]`
32
+ 2. Then run the following command:
33
+
34
+ ```bash
35
+ ETHER0_REMOTES_API_TOKEN="abc123" \
36
+ ETHER0_REMOTES_MOLTRANS_MODEL_PATH="/path/to/USPTO480k_model_step_400000.pt" \
37
+ ether0-serve
38
+ ```
packages/remotes/pyproject.toml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ build-backend = "setuptools.build_meta"
3
+ requires = ["setuptools>=64", "setuptools_scm>=8"]
4
+
5
+ [dependency-groups]
6
+ dev = ["ether0.remotes[dev]"]
7
+
8
+ [project]
9
+ authors = [
10
+ {email = "[email protected]", name = "FutureHouse technical staff"},
11
+ ]
12
+ classifiers = [
13
+ "Intended Audience :: Developers",
14
+ "License :: OSI Approved :: Apache Software License",
15
+ "Operating System :: POSIX",
16
+ "Programming Language :: Python :: 3 :: Only",
17
+ "Programming Language :: Python :: 3.11",
18
+ "Programming Language :: Python :: 3.12",
19
+ "Programming Language :: Python :: 3.13",
20
+ "Programming Language :: Python",
21
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
22
+ "Topic :: Scientific/Engineering :: Chemistry",
23
+ ]
24
+ dependencies = [
25
+ "OpenNMT-py==2.3.0", # Match MolecularTransformer version
26
+ "fastapi",
27
+ "molbloom>=2.3.4", # For nested zinc20 catalog",
28
+ "molsol>=0.0.3",
29
+ "numpy>=1.20", # Pin for numpy typing
30
+ "pydantic>=2", # Pin to keep recent
31
+ "rdkit",
32
+ "torch<2.6", # Downpin for weights_only introduction breaking model_builder's load_test_model
33
+ ]
34
+ description = "Server code for remotes of ether0"
35
+ dynamic = ["version"]
36
+ license = {file = "LICENSE"}
37
+ name = "ether0.remotes"
38
+ readme = "README.md"
39
+ requires-python = ">=3.11"
40
+
41
+ [project.optional-dependencies]
42
+ dev = [
43
+ "ether0",
44
+ "ether0.remotes[serve]",
45
+ "tensorboard>=2.19", # Indirect dependency we pin to keep recent
46
+ ]
47
+ serve = ["uvicorn"]
48
+
49
+ [project.scripts]
50
+ ether0-serve = "ether0.server:main"
51
+
52
+ [tool.ruff]
53
+ extend = "../../pyproject.toml"
54
+
55
+ [tool.setuptools.packages.find]
56
+ where = ["src"]
57
+
58
+ [tool.setuptools_scm]
59
+ root = "../.."
packages/remotes/src/ether0.remotes.egg-info/PKG-INFO ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: ether0.remotes
3
+ Version: 0.1.dev19+g97db23388
4
+ Summary: Server code for remotes of ether0
5
+ Author-email: FutureHouse technical staff <[email protected]>
6
+ License: Apache License
7
+ Version 2.0, January 2004
8
+ http://www.apache.org/licenses/
9
+
10
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
11
+
12
+ 1. Definitions.
13
+
14
+ "License" shall mean the terms and conditions for use, reproduction,
15
+ and distribution as defined by Sections 1 through 9 of this document.
16
+
17
+ "Licensor" shall mean the copyright owner or entity authorized by
18
+ the copyright owner that is granting the License.
19
+
20
+ "Legal Entity" shall mean the union of the acting entity and all
21
+ other entities that control, are controlled by, or are under common
22
+ control with that entity. For the purposes of this definition,
23
+ "control" means (i) the power, direct or indirect, to cause the
24
+ direction or management of such entity, whether by contract or
25
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
26
+ outstanding shares, or (iii) beneficial ownership of such entity.
27
+
28
+ "You" (or "Your") shall mean an individual or Legal Entity
29
+ exercising permissions granted by this License.
30
+
31
+ "Source" form shall mean the preferred form for making modifications,
32
+ including but not limited to software source code, documentation
33
+ source, and configuration files.
34
+
35
+ "Object" form shall mean any form resulting from mechanical
36
+ transformation or translation of a Source form, including but
37
+ not limited to compiled object code, generated documentation,
38
+ and conversions to other media types.
39
+
40
+ "Work" shall mean the work of authorship, whether in Source or
41
+ Object form, made available under the License, as indicated by a
42
+ copyright notice that is included in or attached to the work
43
+ (an example is provided in the Appendix below).
44
+
45
+ "Derivative Works" shall mean any work, whether in Source or Object
46
+ form, that is based on (or derived from) the Work and for which the
47
+ editorial revisions, annotations, elaborations, or other modifications
48
+ represent, as a whole, an original work of authorship. For the purposes
49
+ of this License, Derivative Works shall not include works that remain
50
+ separable from, or merely link (or bind by name) to the interfaces of,
51
+ the Work and Derivative Works thereof.
52
+
53
+ "Contribution" shall mean any work of authorship, including
54
+ the original version of the Work and any modifications or additions
55
+ to that Work or Derivative Works thereof, that is intentionally
56
+ submitted to Licensor for inclusion in the Work by the copyright owner
57
+ or by an individual or Legal Entity authorized to submit on behalf of
58
+ the copyright owner. For the purposes of this definition, "submitted"
59
+ means any form of electronic, verbal, or written communication sent
60
+ to the Licensor or its representatives, including but not limited to
61
+ communication on electronic mailing lists, source code control systems,
62
+ and issue tracking systems that are managed by, or on behalf of, the
63
+ Licensor for the purpose of discussing and improving the Work, but
64
+ excluding communication that is conspicuously marked or otherwise
65
+ designated in writing by the copyright owner as "Not a Contribution."
66
+
67
+ "Contributor" shall mean Licensor and any individual or Legal Entity
68
+ on behalf of whom a Contribution has been received by Licensor and
69
+ subsequently incorporated within the Work.
70
+
71
+ 2. Grant of Copyright License. Subject to the terms and conditions of
72
+ this License, each Contributor hereby grants to You a perpetual,
73
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
74
+ copyright license to reproduce, prepare Derivative Works of,
75
+ publicly display, publicly perform, sublicense, and distribute the
76
+ Work and such Derivative Works in Source or Object form.
77
+
78
+ 3. Grant of Patent License. Subject to the terms and conditions of
79
+ this License, each Contributor hereby grants to You a perpetual,
80
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
81
+ (except as stated in this section) patent license to make, have made,
82
+ use, offer to sell, sell, import, and otherwise transfer the Work,
83
+ where such license applies only to those patent claims licensable
84
+ by such Contributor that are necessarily infringed by their
85
+ Contribution(s) alone or by combination of their Contribution(s)
86
+ with the Work to which such Contribution(s) was submitted. If You
87
+ institute patent litigation against any entity (including a
88
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
89
+ or a Contribution incorporated within the Work constitutes direct
90
+ or contributory patent infringement, then any patent licenses
91
+ granted to You under this License for that Work shall terminate
92
+ as of the date such litigation is filed.
93
+
94
+ 4. Redistribution. You may reproduce and distribute copies of the
95
+ Work or Derivative Works thereof in any medium, with or without
96
+ modifications, and in Source or Object form, provided that You
97
+ meet the following conditions:
98
+
99
+ (a) You must give any other recipients of the Work or
100
+ Derivative Works a copy of this License; and
101
+
102
+ (b) You must cause any modified files to carry prominent notices
103
+ stating that You changed the files; and
104
+
105
+ (c) You must retain, in the Source form of any Derivative Works
106
+ that You distribute, all copyright, patent, trademark, and
107
+ attribution notices from the Source form of the Work,
108
+ excluding those notices that do not pertain to any part of
109
+ the Derivative Works; and
110
+
111
+ (d) If the Work includes a "NOTICE" text file as part of its
112
+ distribution, then any Derivative Works that You distribute must
113
+ include a readable copy of the attribution notices contained
114
+ within such NOTICE file, excluding those notices that do not
115
+ pertain to any part of the Derivative Works, in at least one
116
+ of the following places: within a NOTICE text file distributed
117
+ as part of the Derivative Works; within the Source form or
118
+ documentation, if provided along with the Derivative Works; or,
119
+ within a display generated by the Derivative Works, if and
120
+ wherever such third-party notices normally appear. The contents
121
+ of the NOTICE file are for informational purposes only and
122
+ do not modify the License. You may add Your own attribution
123
+ notices within Derivative Works that You distribute, alongside
124
+ or as an addendum to the NOTICE text from the Work, provided
125
+ that such additional attribution notices cannot be construed
126
+ as modifying the License.
127
+
128
+ You may add Your own copyright statement to Your modifications and
129
+ may provide additional or different license terms and conditions
130
+ for use, reproduction, or distribution of Your modifications, or
131
+ for any such Derivative Works as a whole, provided Your use,
132
+ reproduction, and distribution of the Work otherwise complies with
133
+ the conditions stated in this License.
134
+
135
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
136
+ any Contribution intentionally submitted for inclusion in the Work
137
+ by You to the Licensor shall be under the terms and conditions of
138
+ this License, without any additional terms or conditions.
139
+ Notwithstanding the above, nothing herein shall supersede or modify
140
+ the terms of any separate license agreement you may have executed
141
+ with Licensor regarding such Contributions.
142
+
143
+ 6. Trademarks. This License does not grant permission to use the trade
144
+ names, trademarks, service marks, or product names of the Licensor,
145
+ except as required for reasonable and customary use in describing the
146
+ origin of the Work and reproducing the content of the NOTICE file.
147
+
148
+ 7. Disclaimer of Warranty. Unless required by applicable law or
149
+ agreed to in writing, Licensor provides the Work (and each
150
+ Contributor provides its Contributions) on an "AS IS" BASIS,
151
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
152
+ implied, including, without limitation, any warranties or conditions
153
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
154
+ PARTICULAR PURPOSE. You are solely responsible for determining the
155
+ appropriateness of using or redistributing the Work and assume any
156
+ risks associated with Your exercise of permissions under this License.
157
+
158
+ 8. Limitation of Liability. In no event and under no legal theory,
159
+ whether in tort (including negligence), contract, or otherwise,
160
+ unless required by applicable law (such as deliberate and grossly
161
+ negligent acts) or agreed to in writing, shall any Contributor be
162
+ liable to You for damages, including any direct, indirect, special,
163
+ incidental, or consequential damages of any character arising as a
164
+ result of this License or out of the use or inability to use the
165
+ Work (including but not limited to damages for loss of goodwill,
166
+ work stoppage, computer failure or malfunction, or any and all
167
+ other commercial damages or losses), even if such Contributor
168
+ has been advised of the possibility of such damages.
169
+
170
+ 9. Accepting Warranty or Additional Liability. While redistributing
171
+ the Work or Derivative Works thereof, You may choose to offer,
172
+ and charge a fee for, acceptance of support, warranty, indemnity,
173
+ or other liability obligations and/or rights consistent with this
174
+ License. However, in accepting such obligations, You may act only
175
+ on Your own behalf and on Your sole responsibility, not on behalf
176
+ of any other Contributor, and only if You agree to indemnify,
177
+ defend, and hold each Contributor harmless for any liability
178
+ incurred by, or claims asserted against, such Contributor by reason
179
+ of your accepting any such warranty or additional liability.
180
+
181
+ END OF TERMS AND CONDITIONS
182
+
183
+ APPENDIX: How to apply the Apache License to your work.
184
+
185
+ To apply the Apache License to your work, attach the following
186
+ boilerplate notice, with the fields enclosed by brackets "[]"
187
+ replaced with your own identifying information. (Don't include
188
+ the brackets!) The text should be enclosed in the appropriate
189
+ comment syntax for the file format. We also recommend that a
190
+ file or class name and description of purpose be included on the
191
+ same "printed page" as the copyright notice for easier
192
+ identification within third-party archives.
193
+
194
+ Copyright 2025 FutureHouse
195
+
196
+ Licensed under the Apache License, Version 2.0 (the "License");
197
+ you may not use this file except in compliance with the License.
198
+ You may obtain a copy of the License at
199
+
200
+ http://www.apache.org/licenses/LICENSE-2.0
201
+
202
+ Unless required by applicable law or agreed to in writing, software
203
+ distributed under the License is distributed on an "AS IS" BASIS,
204
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
205
+ See the License for the specific language governing permissions and
206
+ limitations under the License.
207
+
208
+ Classifier: Intended Audience :: Developers
209
+ Classifier: License :: OSI Approved :: Apache Software License
210
+ Classifier: Operating System :: POSIX
211
+ Classifier: Programming Language :: Python :: 3 :: Only
212
+ Classifier: Programming Language :: Python :: 3.11
213
+ Classifier: Programming Language :: Python :: 3.12
214
+ Classifier: Programming Language :: Python :: 3.13
215
+ Classifier: Programming Language :: Python
216
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
217
+ Classifier: Topic :: Scientific/Engineering :: Chemistry
218
+ Requires-Python: >=3.11
219
+ Description-Content-Type: text/markdown
220
+ License-File: LICENSE
221
+ Requires-Dist: OpenNMT-py==2.3.0
222
+ Requires-Dist: fastapi
223
+ Requires-Dist: molbloom>=2.3.4
224
+ Requires-Dist: molsol>=0.0.3
225
+ Requires-Dist: numpy>=1.20
226
+ Requires-Dist: pydantic>=2
227
+ Requires-Dist: rdkit
228
+ Requires-Dist: torch<2.6
229
+ Provides-Extra: dev
230
+ Requires-Dist: ether0; extra == "dev"
231
+ Requires-Dist: ether0.remotes[serve]; extra == "dev"
232
+ Requires-Dist: tensorboard>=2.19; extra == "dev"
233
+ Provides-Extra: serve
234
+ Requires-Dist: uvicorn; extra == "serve"
235
+ Dynamic: license-file
236
+
237
+ # ether0.remotes
238
+
239
+ Server code for ether0 reward functions.
240
+
241
+ ## Molecular Transformer (MolTrans) Model
242
+
243
+ To run the `/translate` endpoint, you need a pre-trained MolTrans PyTorch model.
244
+ This can be acquired from Future House's Google Drive via the following command:
245
+
246
+ ```bash
247
+ curl --location --output src/ether0/USPTO480k_model_step_400000.pt \
248
+ "https://drive.usercontent.google.com/download?id=1Rjd3wXg2oLeCpNUofFRvVvQoOcgWd6vf&export=download&confirm=t"
249
+ ```
250
+
251
+ Or more manually:
252
+
253
+ 1. Go to [this notebook][1]
254
+ 2. Download the `USPTO480k_model_step_400000.pt`
255
+ linked in the `trained_model_url` variable's linked Google Drive file:
256
+ <https://drive.google.com/uc?id=1ywJCJHunoPTB5wr6KdZ8aLv7tMFMBHNy>
257
+ 3. Set the environment variable `ETHER0_REMOTES_MOLTRANS_MODEL_PATH`
258
+ to the downloaded PyTorch model's location,
259
+ or place the model in the default checked `ether0` source code folder (`src/ether0`).
260
+
261
+ [1]: https://github.com/schwallergroup/ai4chem_course/blob/main/notebooks/07%20-%20Reaction%20Prediction/template_free.ipynb
262
+
263
+ ## Serving
264
+
265
+ To run the server:
266
+
267
+ 1. `pip install` with the `serve` extra: `pip install ether0.remotes[serve]`
268
+ 2. Then run the following command:
269
+
270
+ ```bash
271
+ ETHER0_REMOTES_API_TOKEN="abc123" \
272
+ ETHER0_REMOTES_MOLTRANS_MODEL_PATH="/path/to/USPTO480k_model_step_400000.pt" \
273
+ ether0-serve
274
+ ```
packages/remotes/src/ether0.remotes.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .gitignore
2
+ LICENSE
3
+ README.md
4
+ pyproject.toml
5
+ src/ether0/server.py
6
+ src/ether0.remotes.egg-info/PKG-INFO
7
+ src/ether0.remotes.egg-info/SOURCES.txt
8
+ src/ether0.remotes.egg-info/dependency_links.txt
9
+ src/ether0.remotes.egg-info/entry_points.txt
10
+ src/ether0.remotes.egg-info/requires.txt
11
+ src/ether0.remotes.egg-info/top_level.txt
12
+ tests/conftest.py
13
+ tests/test_clients.py
14
+ tests/test_rewards.py
packages/remotes/src/ether0.remotes.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
packages/remotes/src/ether0.remotes.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ ether0-serve = ether0.server:main
packages/remotes/src/ether0.remotes.egg-info/requires.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ OpenNMT-py==2.3.0
2
+ fastapi
3
+ molbloom>=2.3.4
4
+ molsol>=0.0.3
5
+ numpy>=1.20
6
+ pydantic>=2
7
+ rdkit
8
+ torch<2.6
9
+
10
+ [dev]
11
+ ether0
12
+ ether0.remotes[serve]
13
+ tensorboard>=2.19
14
+
15
+ [serve]
16
+ uvicorn
packages/remotes/src/ether0.remotes.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ether0
packages/remotes/src/ether0/server.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import re
4
+ import secrets
5
+ import tempfile
6
+ import uuid
7
+ from collections import defaultdict
8
+ from pathlib import Path
9
+ from typing import ClassVar, Literal
10
+
11
+ import numpy as np
12
+ import numpy.typing as npt
13
+ from fastapi import Depends, FastAPI, HTTPException, status
14
+ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
15
+ from molbloom import buy
16
+ from molsol import KDESol
17
+ from onmt import opts
18
+ from onmt.translate.translator import build_translator
19
+ from onmt.utils.logging import init_logger
20
+ from onmt.utils.misc import split_corpus
21
+ from onmt.utils.parse import ArgumentParser
22
+ from pydantic import BaseModel
23
+ from rdkit import Chem
24
+
25
+ ETHER0_DIR = Path(__file__).parent
26
+
27
+ auth_scheme = HTTPBearer()
28
+
29
+
30
+ def validate_token(
31
+ credentials: HTTPAuthorizationCredentials = Depends(auth_scheme), # noqa: B008
32
+ ) -> str:
33
+ # NOTE: don't use os.environ.get() to avoid possible empty string matches, and
34
+ # to have clearer server failures if the AUTH_TOKEN env var isn't present
35
+ if not secrets.compare_digest(
36
+ credentials.credentials, os.environ["ETHER0_REMOTES_API_TOKEN"]
37
+ ):
38
+ raise HTTPException(
39
+ status_code=status.HTTP_401_UNAUTHORIZED,
40
+ detail="Incorrect bearer token",
41
+ headers={"WWW-Authenticate": "Bearer"},
42
+ )
43
+ return credentials.credentials
44
+
45
+
46
+ app = FastAPI(title="ether0 remotes server", dependencies=[Depends(validate_token)])
47
+
48
+
49
+ class MolecularTransformer:
50
+ """Uses code from https://doi.org/10.1021/acscentsci.9b00576."""
51
+
52
+ DEFAULT_MOLTRANS_MODEL_PATH: ClassVar[Path] = (
53
+ ETHER0_DIR / "USPTO480k_model_step_400000.pt"
54
+ )
55
+
56
+ def __init__(self):
57
+ # Use `or None` to deny setting empty string to the environment variable
58
+ os_environ_model_path = (
59
+ os.environ.get("ETHER0_REMOTES_MOLTRANS_MODEL_PATH") or None
60
+ )
61
+ self.model_path = os_environ_model_path or str(self.DEFAULT_MOLTRANS_MODEL_PATH)
62
+ if not Path(self.model_path).exists():
63
+ raise FileNotFoundError(
64
+ f"MolTrans model not found"
65
+ f"{f', did you misconfigure the path {os_environ_model_path}?' if os_environ_model_path else '.'}" # noqa: E501
66
+ " Please properly configure the environment variable"
67
+ " 'ETHER0_REMOTES_MOLTRANS_MODEL_PATH',"
68
+ f" or the default path checked is {self.DEFAULT_MOLTRANS_MODEL_PATH}."
69
+ )
70
+
71
+ @staticmethod
72
+ def translate(opt: argparse.Namespace) -> None:
73
+ ArgumentParser.validate_translate_opts(opt)
74
+ logger = init_logger(opt.log_file)
75
+
76
+ translator = build_translator(opt, logger=logger, report_score=True)
77
+ src_shards = split_corpus(opt.src, opt.shard_size)
78
+ tgt_shards = split_corpus(opt.tgt, opt.shard_size)
79
+ features_shards = []
80
+ features_names = []
81
+ for feat_name, feat_path in opt.src_feats.items():
82
+ features_shards.append(split_corpus(feat_path, opt.shard_size))
83
+ features_names.append(feat_name)
84
+ shard_pairs = zip(src_shards, tgt_shards, *features_shards) # noqa: B905
85
+
86
+ for (src_shard, tgt_shard, *features_shard) in shard_pairs:
87
+ features_shard_ = defaultdict(list)
88
+ for j, x in enumerate(features_shard):
89
+ features_shard_[features_names[j]] = x
90
+ translator.translate(
91
+ src=src_shard,
92
+ src_feats=features_shard_,
93
+ tgt=tgt_shard,
94
+ batch_size=opt.batch_size,
95
+ batch_type=opt.batch_type,
96
+ attn_debug=opt.attn_debug,
97
+ align_debug=opt.align_debug,
98
+ )
99
+
100
+ @staticmethod
101
+ def smiles_tokenizer(smiles: str) -> str:
102
+ smiles_regex = re.compile(
103
+ r"(\%\([0-9]{3}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\||\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
104
+ )
105
+ tokens = list(smiles_regex.findall(smiles))
106
+ return " ".join(tokens)
107
+
108
+ @staticmethod
109
+ def canonicalize_smiles(smiles: str) -> str:
110
+ # Try to use canonical smiles because original uspto is distributed in canonical form.
111
+ # If fails, we trust the augmentation and use the original smiles.
112
+ try:
113
+ return Chem.MolToSmiles(
114
+ Chem.MolFromSmiles(smiles), isomericSmiles=True, canonical=True
115
+ )
116
+ except Exception as err:
117
+ # If rdkit failed, it means some molecule is invalid.
118
+ # Here we catch which ones are invalid so we inform what's wrong
119
+ # on the error message.
120
+ invalid_smiles = []
121
+ for mol in smiles.split("."):
122
+ try:
123
+ Chem.MolToSmiles(
124
+ Chem.MolFromSmiles(mol), isomericSmiles=True, canonical=True
125
+ )
126
+ except: # noqa: E722
127
+ invalid_smiles.append(mol)
128
+ raise HTTPException(
129
+ status_code=status.HTTP_400_BAD_REQUEST,
130
+ detail=(
131
+ "The reaction could not be parsed by RDKit. The following"
132
+ f" SMILES were invalid: {', '.join(invalid_smiles)}"
133
+ ),
134
+ ) from err
135
+
136
+ def run(self, reaction: str) -> tuple[str, uuid.UUID]:
137
+ """Translates SMILES reaction strings using MolTrans model.
138
+
139
+ Args:
140
+ reaction: SMILES representation of a chemical reaction
141
+
142
+ Returns:
143
+ SMILES representation of the predicted product and a job ID
144
+ """
145
+ # Create a unique ID for the request
146
+ job_id = uuid.uuid4()
147
+
148
+ # Create temporary files for use in mol moltransformer
149
+ with (
150
+ tempfile.NamedTemporaryFile(
151
+ mode="w+", delete=False, encoding="utf-8"
152
+ ) as precursor_file,
153
+ tempfile.NamedTemporaryFile(
154
+ mode="w+", delete=False, encoding="utf-8"
155
+ ) as output_file,
156
+ ):
157
+
158
+ # Write tokenized reaction to the precursor file
159
+ precursor_file.write(MolecularTransformer.smiles_tokenizer(reaction))
160
+ precursor_file.flush()
161
+
162
+ # OpenNMT expects to receive a list of arguments to translate
163
+ parser = ArgumentParser()
164
+ opts.config_opts(parser)
165
+ opts.translate_opts(parser)
166
+
167
+ args_dict = {
168
+ "model": self.model_path,
169
+ "src": precursor_file.name,
170
+ "output": output_file.name,
171
+ "batch_size": "64",
172
+ "beam_size": "50",
173
+ "max_length": "300",
174
+ }
175
+ args_list = [f"--{k}={v}" for k, v in args_dict.items()]
176
+ opt = parser.parse_args(args_list)
177
+
178
+ MolecularTransformer.translate(opt)
179
+
180
+ output_file.close()
181
+ prediction = Path(output_file.name).read_text(encoding="utf-8")
182
+
183
+ # Clean up temporary files
184
+ # we don't care if a failure leaves them dangling,
185
+ # since they are in a temp dir
186
+ os.unlink(precursor_file.name)
187
+ os.unlink(output_file.name)
188
+
189
+ return prediction.replace(" ", "").strip(), job_id
190
+
191
+
192
+ class MolBloom:
193
+ """Uses code from https://doi.org/10.1186/s13321-023-00765-1."""
194
+
195
+ def __init__(self) -> None:
196
+ # trigger eager loading of the bloom filter
197
+ buy("C1=CC=CC=C1", catalog="zinc20")
198
+ self.bloom = buy
199
+
200
+ def run(self, smiles: str) -> bool:
201
+ """Checks if a molecule is purchasable using MolBloom.
202
+
203
+ Args:
204
+ smiles: SMILES representation of a molecule
205
+
206
+ Returns:
207
+ True if the molecule is purchasable, False otherwise
208
+ """
209
+ return self.bloom(smiles, canonicalize=True, catalog="zinc20")
210
+
211
+
212
+ class Solubility:
213
+ """Uses code from https://doi.org/10.1039/D3DD00217A."""
214
+
215
+ def __init__(self) -> None:
216
+ self.sol = KDESol()
217
+
218
+ def run(self, smiles: str) -> npt.NDArray[np.float32] | Literal[False]:
219
+ """Computes solubility prediction for a molecule using KDESol.
220
+
221
+ Args:
222
+ smiles: SMILES representation of a molecule.
223
+
224
+ Returns:
225
+ Numpy array containing the mean predicted solubility,
226
+ aleatoric uncertainty (au), and epistemic uncertainty (eu).
227
+ """
228
+ m = Chem.MolFromSmiles(smiles)
229
+ if m is None:
230
+ return False # type: ignore[unreachable]
231
+ prediction = self.sol(Chem.MolToSmiles(m, canonical=True, isomericSmiles=False))
232
+ if prediction is None:
233
+ # Try without canonicalization.
234
+ # The model is an LSTM that uses tokens generated from SELFIES tokens.
235
+ # Depending on the SMILES notation, the model might not have the necessary tokens
236
+ # in its vocabulary to describe the molecule.
237
+ prediction = self.sol(smiles)
238
+ return prediction if prediction is not None else False
239
+
240
+
241
+ class MolTransRequest(BaseModel):
242
+ reaction: str
243
+
244
+
245
+ @app.post("/translate")
246
+ def translate_endpoint(request: MolTransRequest) -> dict[str, str | uuid.UUID]:
247
+ reaction = request.reaction.replace(" ", "")
248
+ if not reaction.count(">") == 2: # noqa: PLR2004
249
+ raise HTTPException(
250
+ status_code=status.HTTP_400_BAD_REQUEST,
251
+ detail=(
252
+ f"Syntax error in the reaction SMILES: {reaction}\n"
253
+ "The reaction should have two '>' characters, and no spaces."
254
+ ),
255
+ )
256
+ rxn = reaction.split(">")[:-1]
257
+ query_reaction = MolecularTransformer.canonicalize_smiles(
258
+ ".".join([r for r in rxn if r])
259
+ )
260
+
261
+ product, job_id = MolecularTransformer().run(query_reaction)
262
+ return {
263
+ "product": product,
264
+ "id": job_id,
265
+ "reaction": query_reaction + ">>" + product,
266
+ }
267
+
268
+
269
+ class MolBloomRequest(BaseModel):
270
+ smiles: list[str] | str
271
+
272
+
273
+ @app.post("/is_purchasable")
274
+ def is_purchasable_endpoint(request: MolBloomRequest) -> dict[str, bool]:
275
+ is_purchasable = MolBloom().run
276
+ smiles = request.smiles
277
+ if isinstance(smiles, str):
278
+ smiles = [smiles]
279
+ return {s: is_purchasable(s) for s in smiles}
280
+
281
+
282
+ class SmilesRequest(BaseModel):
283
+ smiles: str
284
+
285
+
286
+ @app.post("/compute_solubility")
287
+ def compute_solubility_endpoint(
288
+ request: SmilesRequest,
289
+ ) -> dict[str, float] | dict[str, str]:
290
+ if "." in request.smiles:
291
+ raise HTTPException(
292
+ status_code=status.HTTP_400_BAD_REQUEST,
293
+ detail="Only single molecules are supported",
294
+ )
295
+ prediction = Solubility().run(smiles=request.smiles)
296
+ if prediction is False:
297
+ return {"error": "Solubility prediction failed."}
298
+ mean, au, eu = prediction.tolist()
299
+ return {"mean": mean, "au": au, "eu": eu}
300
+
301
+
302
+ def main() -> None:
303
+ """Run uvicorn to serve the FastAPI app."""
304
+ try:
305
+ import uvicorn # noqa: PLC0415
306
+ except ImportError as exc:
307
+ raise ImportError(
308
+ "Serving requires the 'serve' extra for the `uvicorn` package. Please:"
309
+ " `pip install ether0.remotes[serve]`."
310
+ ) from exc
311
+
312
+ uvicorn.run("ether0.server:app")
313
+
314
+
315
+ if __name__ == "__main__":
316
+ main()
packages/remotes/tests/conftest.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+ from collections.abc import Iterator
4
+ from unittest.mock import patch
5
+
6
+ import pytest
7
+ from fastapi.testclient import TestClient
8
+
9
+
10
+ @pytest.fixture(name="test_client", scope="session")
11
+ def fixture_test_client() -> Iterator[TestClient]:
12
+ # Lazily import from aviary so typeguard doesn't throw:
13
+ # > /path/to/.venv/lib/python3.11/site-packages/typeguard/_pytest_plugin.py:93:
14
+ # > InstrumentationWarning: typeguard cannot check these packages
15
+ # > because they are already imported: ether0
16
+ import ether0.clients # noqa: PLC0415
17
+
18
+ from ether0.server import app # noqa: PLC0415
19
+
20
+ client = TestClient(app)
21
+ with patch.dict(
22
+ os.environ,
23
+ {
24
+ "ETHER0_REMOTES_API_BASE_URL": str(client.base_url),
25
+ "ETHER0_REMOTES_API_TOKEN": "test_stub",
26
+ },
27
+ ):
28
+ importlib.reload(ether0.clients) # Pull in updated environment variables
29
+ yield client
packages/remotes/tests/test_clients.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import TYPE_CHECKING
3
+ from unittest.mock import patch
4
+
5
+ import httpx
6
+ import pytest
7
+ from ether0.clients import fetch_forward_rxn, fetch_purchasable, fetch_solubility
8
+
9
+ if TYPE_CHECKING:
10
+ from fastapi.testclient import TestClient
11
+
12
+ IN_GITHUB_ACTIONS: bool = os.getenv("GITHUB_ACTIONS") == "true"
13
+
14
+
15
+ @pytest.mark.parametrize(
16
+ ("smiles", "purchasable"),
17
+ [
18
+ ("CC(=O)OC1=CC=CC=C1C(=O)O", True),
19
+ ("CCO", True),
20
+ ("C1=CC=C(C=C1)C(=O)O", True),
21
+ ("OCN1C=CC=C1C(=O)O", False),
22
+ ],
23
+ )
24
+ def test_fetch_purchasable(
25
+ test_client: "TestClient", smiles: str, purchasable: bool
26
+ ) -> None:
27
+ with patch.object(httpx, "post", test_client.post):
28
+ assert fetch_purchasable(smiles)[smiles] == purchasable
29
+
30
+
31
+ @pytest.mark.parametrize(
32
+ ("smiles", "solubility"),
33
+ [
34
+ ("CC(=O)OC1=CC=CC=C1C(=O)O", -2.5),
35
+ ("O=C(NC1CCCC1)C(C1CC1)S1C(=N)C(C2=CC=NC3=CC=CC=C23)N=C1", -5.9),
36
+ ],
37
+ )
38
+ def test_fetch_solubility(
39
+ test_client: "TestClient", smiles: str, solubility: float
40
+ ) -> None:
41
+ with patch.object(httpx, "post", test_client.post):
42
+ result = fetch_solubility(smiles)
43
+ assert "solubility" in result
44
+ assert pytest.approx(result["solubility"], abs=0.1) == solubility
45
+
46
+
47
+ @pytest.mark.skipif(
48
+ IN_GITHUB_ACTIONS, reason="Requires loading MolTrans model, too large for CI"
49
+ )
50
+ @pytest.mark.parametrize(
51
+ ("precursor", "product", "correct"),
52
+ [
53
+ pytest.param("CC=O.O=C1CCC1Cl>[Mg].CCOCC>", "CC(O)C1(O)CCC1Cl", True),
54
+ pytest.param(
55
+ "CCC=O.CC1(C)CC(N)C(=O)N1>[B-](OC(=O)C)(OC(=O)C)OC(=O)C.[Na+].C=O>",
56
+ "CCCN(C)C1CC(C)(C)NC1=O",
57
+ True,
58
+ ),
59
+ pytest.param("CCCC=O.O=C1CC=C(Br)S1>[Mg].CCOCC>", "CCCC(O)C1=CCC(=O)S1", True),
60
+ pytest.param("CCCC=O.COC(=O)C1CC1Br>[Mg].CCOCC>", "CCCC(O)C1CC1C(=O)OC", True),
61
+ pytest.param(
62
+ "CCCC=O.NC1CCCNC1=O>[B-](OC(=O)C)(OC(=O)C)OC(=O)C.[Na+].C=O>",
63
+ "O=C1NCCCC1N1CNCCCC1=O",
64
+ True,
65
+ ),
66
+ pytest.param("CC=O.O=C1CCC1Cl.[Mg].CCOCC", None, False, id="missing_arrow"),
67
+ pytest.param(
68
+ "CC=O.O=C1CCC1Cl > [Mg].CCOCC", None, False, id="space_in_reaction"
69
+ ),
70
+ pytest.param("not a > reaction", None, False, id="invalid_reaction"),
71
+ pytest.param(
72
+ "CCCC=O.COC(=O)C1CC1Br>[Mg].CCOCC", None, False, id="trailing_arrow"
73
+ ),
74
+ ],
75
+ )
76
+ def test_fetch_forward_rxn(
77
+ test_client: "TestClient", precursor: str, product: str | None, correct: bool
78
+ ) -> None:
79
+ with patch.object(httpx, "post", test_client.post):
80
+ result = fetch_forward_rxn(precursor)
81
+ if correct:
82
+ assert (
83
+ result.get("product") == product
84
+ ), f"Failed to get expected {product=} in {result=}."
85
+ assert "error" not in result
86
+ else:
87
+ assert result.get(
88
+ "error"
89
+ ), f"Expected an error given {precursor=} and {correct=}"
90
+ assert "syntax error" in result["error"].lower()
91
+ assert "product" not in result
packages/remotes/tests/test_rewards.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING
2
+ from unittest.mock import patch
3
+
4
+ import httpx
5
+ import pytest
6
+ from ether0.rewards import oracle_solubility_eval
7
+ from pydantic import JsonValue
8
+
9
+ if TYPE_CHECKING:
10
+ from fastapi.testclient import TestClient
11
+
12
+
13
+ @pytest.mark.parametrize(
14
+ ("yhat", "y", "expected"),
15
+ [
16
+ pytest.param(
17
+ "c1c(O)nc2ccc(CN)cc2c1OC1CCCC1",
18
+ '("scaffold", "c1ccc2c(OC3CCCC3)ccnc2c1", -3.844724178314209, "increase")',
19
+ 1.0,
20
+ id="match-scaffold",
21
+ ),
22
+ pytest.param(
23
+ "Oc1c(O)nc2ccc(C[NH3])cc2c1OC1CCCC1O",
24
+ '("scaffold", "c1ccc2c(OC3CCCC3)ccnc2c1", -3.844724178314209, "decrease")',
25
+ 0.0,
26
+ id="match-scaffold-bad-solubility",
27
+ ),
28
+ pytest.param(
29
+ "CCCCCC=CCCCN(C)CCC",
30
+ '("groups", ["cis double bond", "hetero N basic H"], -4.693881511688232, "decrease")', # noqa: E501
31
+ 1.0,
32
+ id="match-groups",
33
+ ),
34
+ pytest.param(
35
+ "CCCCCCCCCCN(C)N[NH]CNCC",
36
+ '("groups", ["cis double bond", "hetero N basic H"], -1.9085578918457031, "decrease")', # noqa: E501
37
+ 0.0,
38
+ id="match-groups-bad-groups",
39
+ ),
40
+ pytest.param(
41
+ "CCCCN(CCCC)C(=O)C1c2ccccc2Oc2ccccc21",
42
+ '("tanimoto", "CCCN(CCC)C(=O)C1c2ccccc2Oc2ccccc21", -5.273194313049316, "decrease")',
43
+ 1.0,
44
+ id="match-tanimoto",
45
+ ),
46
+ pytest.param(
47
+ "CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCN(CCCC)C(=O)C1c2ccccc2Oc2ccccc21",
48
+ '("tanimoto", "CCCN(CCC)C(=O)C1c2ccccc2Oc2ccccc21", -5.273194313049316, "decrease")',
49
+ 0.0,
50
+ id="match-tanimoto-too-far",
51
+ ),
52
+ pytest.param(
53
+ "CCCCCCCCCCCCCCCCCCCCCCN(CCC)C(=O)C1c2ccccc2Oc2ccccc21",
54
+ '("tanimoto", "CCCN(CCC)C(=O)C1c2ccccc2Oc2ccccc21", -7.45, "decrease")',
55
+ 0.0,
56
+ id="match-tanimoto-hacked-dist",
57
+ ),
58
+ pytest.param(
59
+ "CN(C)C(=O)C1c2ccccc2Oc2ccccc21",
60
+ '("tanimoto", "CCCN(CCC)C(=O)C1c2ccccc2Oc2ccccc21", -4.273194313049316, "decrease")',
61
+ 0.0,
62
+ id="match-tanimoto-bad-solubility",
63
+ ),
64
+ pytest.param(
65
+ "CN1CCN(CCCCNc2ncc3cc(-c4c(Cl)cccc4Cl)c(=O)n(C)c3n2)CC1.CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC",
66
+ '("tanimoto", "CN1CCN(CCCCNc2ncc3cc(-c4c(Cl)cccc4Cl)c(=O)n(C)c3n2)CC1", -4.273194313049316, "decrease")', # noqa: E501
67
+ 0.0,
68
+ id="match-tanimoto-bad-structure",
69
+ ),
70
+ pytest.param(
71
+ "C[C@@H]1CC[C@@]2(CC[C@@]3(C(=CC[C@H]4[C@]3(CC[C@@H]5[C@@]4(C[C@H]([C@@H]([C@@]5(C)CO)O)O)C)C)[C@@H]2[C@H]1C)C)C(=O)O[C@H]6[C@@H]([C@H]([C@@H]([C@H](O6)CO[C@H]7[C@@H]([C@H]([C@@H]([C@H](O7)CO)O[C@H]8[C@@H]([C@@H]([C@H]([C@@H](O8)C)O)O)O)O)O)O)O)O",
72
+ '("groups", ["secondary alcohol", "primary alcohol", "hydroxylated heteroatom substituted glycosidic ring"], -5.921097755432129, "increase")', # noqa: E501
73
+ 1.0,
74
+ id="problematic-groups",
75
+ ),
76
+ pytest.param(
77
+ "COc1cc2ncnc(Nc3ccc(F)c(Cl)c3)c2cc1-c1ccc(C#CCCCC(=O)NO)o1",
78
+ '("tanimoto", "CCCC", -6.25, "increase")',
79
+ 0.0,
80
+ id="identical-increase",
81
+ ),
82
+ pytest.param(
83
+ "COc1cc2ncnc(Nc3ccc(F)c(Cl)c3)c2cc1-c1ccc(C#CCCCC(=O)NO)o1",
84
+ '("tanimoto", "CCCC", -7.25, "decrease")',
85
+ 0.0,
86
+ id="identical-decrease",
87
+ ),
88
+ pytest.param(
89
+ "OOCCCN(CCC)C(=O)C1c2ccccc2Oc2ccccc21",
90
+ '("tanimoto", "OCCCN(CCC)C(=O)C1c2ccccc2Oc2ccccc21", -5.273194313049316, "decrease")', # noqa: E501
91
+ 0.0,
92
+ id="unreasonable-molecule-failure",
93
+ ),
94
+ pytest.param(
95
+ "CC(C)(C)Cc1nc(Br)c(S(C)(=O)=O)n1Cc1ccc(-c2ccccc2-c2nn[nH]n2)cc1",
96
+ "('scaffold', 'c1ccc(-c2nn[nH]n2)c(-c2ccc(Cn3ccnc3)cc2)c1', '-7.790801048278809', 'decrease')", # noqa: E501
97
+ 0.0,
98
+ id="eval-has-str-value",
99
+ ),
100
+ ],
101
+ )
102
+ def test_oracle_solubility_eval(
103
+ test_client: "TestClient", yhat: str, y: str, expected: float
104
+ ) -> None:
105
+ expl: dict[str, JsonValue] = {}
106
+ with patch.object(httpx, "post", test_client.post):
107
+ result = oracle_solubility_eval(yhat, y, metadata=expl)
108
+ assert result == expected, f"Expected {expected}, got {result}. Explanation: {expl}"
pyproject.toml ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ build-backend = "setuptools.build_meta"
3
+ # SEE: https://github.com/pypa/setuptools_scm#pyprojecttoml-usage
4
+ requires = ["setuptools>=64", "setuptools_scm>=8"]
5
+
6
+ [dependency-groups]
7
+ dev = ["ether0.remotes[dev]", "ether0[dev]"]
8
+
9
+ [project]
10
+ authors = [
11
+ {email = "[email protected]", name = "FutureHouse technical staff"},
12
+ ]
13
+ # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
14
+ classifiers = [
15
+ "Intended Audience :: Developers",
16
+ "License :: OSI Approved :: Apache Software License",
17
+ "Operating System :: OS Independent",
18
+ "Programming Language :: Python :: 3 :: Only",
19
+ "Programming Language :: Python :: 3.11",
20
+ "Programming Language :: Python :: 3.12",
21
+ "Programming Language :: Python :: 3.13",
22
+ "Programming Language :: Python",
23
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
24
+ "Topic :: Scientific/Engineering :: Chemistry",
25
+ ]
26
+ dependencies = [
27
+ "accelerate>=1.10.1",
28
+ "datasets",
29
+ "exmol>=3.3.0", # to get fixed functional group names and py.typed
30
+ "gradio>=5.44.0",
31
+ "httpx",
32
+ "huggingface-hub",
33
+ "molbloom==2.3.4", # exact pin to be compatible with rings.bloom
34
+ "pydantic>=2", # Pin to keep recent
35
+ "rdkit",
36
+ "regex", # For unicode property \p
37
+ "spaces>=0.40.1",
38
+ "tenacity",
39
+ ]
40
+ description = "Data models, rewards, and utility functions for ether0."
41
+ dynamic = [
42
+ "version", # Required for setuptools_scm
43
+ ]
44
+ license = {file = "LICENSE"}
45
+ name = "ether0"
46
+ readme = "README.md"
47
+ requires-python = ">=3.11"
48
+
49
+ [project.optional-dependencies]
50
+ add-tokens = [
51
+ "ipykernel", # For Jupyter notebook support
52
+ "ipywidgets>=8", # For Jupyter notebook support, and pin to keep recent
53
+ "transformers>=4.49", # Pin to keep recent
54
+ ]
55
+ baselines = [
56
+ "fhaviary>=0.19", # Pin for Python 3.13 compatibility
57
+ "fhlmi>=0.26", # Pin for Python 3.13 compatibility
58
+ "ipython",
59
+ ]
60
+ dev = [
61
+ "ether0[add-tokens,typing]",
62
+ "huggingface-hub[cli]", # For login inside of CI
63
+ "ipython>=8", # Pin to keep recent
64
+ "mypy>=1.8", # For addition of mutable-override
65
+ "pre-commit>=3.4", # Pin to keep recent
66
+ "pylint>=3", # Pin to keep recent
67
+ "pytest",
68
+ "pytest-subtests",
69
+ "pytest-sugar",
70
+ "pytest-timer[colorama]",
71
+ "pytest-xdist",
72
+ "refurb>=2", # Pin to keep recent
73
+ "typeguard",
74
+ ]
75
+ typing = [
76
+ "types-regex",
77
+ ]
78
+
79
+ [project.urls]
80
+ issues = "https://github.com/Future-House/ether0/issues"
81
+ repository = "https://github.com/Future-House/ether0"
82
+
83
+ [tool.black]
84
+ enable-unstable-feature = [
85
+ "hug_parens_with_braces_and_square_brackets", # TODO: remove after https://github.com/psf/black/issues/4036 resolution
86
+ ]
87
+ preview = true
88
+
89
+ [tool.codespell]
90
+ check-filenames = true
91
+ check-hidden = true
92
+ ignore-words-list = "amination,astroid,ser"
93
+
94
+ [tool.coverage]
95
+
96
+ [tool.coverage.report]
97
+ exclude_also = [
98
+ "@overload", # SEE: https://github.com/nedbat/coveragepy/issues/970
99
+ "if TYPE_CHECKING:",
100
+ ]
101
+ # Number of digits after the decimal point to display for reported coverage percentages
102
+ precision = 2
103
+
104
+ [tool.coverage.run]
105
+ # Measure branch coverage
106
+ branch = true
107
+ # This will be used if you run `coverage run` with no further arguments
108
+ # This is designed to be invoked from within the test directory
109
+ command_line = "-m pytest"
110
+
111
+ [tool.markdownlint]
112
+ no-inline-html = false
113
+
114
+ [tool.markdownlint.line-length]
115
+ code_block_line_length = 88 # Match ruff line-length
116
+ line_length = 120 # Match ruff max-doc-length
117
+ stern = true
118
+ tables = false
119
+
120
+ [tool.mypy]
121
+ # Type-checks the interior of functions without type annotations.
122
+ check_untyped_defs = true
123
+ # Allows enabling one or multiple error codes globally. Note: This option will
124
+ # override disabled error codes from the disable_error_code option.
125
+ enable_error_code = [
126
+ "ignore-without-code",
127
+ "mutable-override",
128
+ "redundant-cast",
129
+ "redundant-expr",
130
+ "redundant-self",
131
+ "truthy-bool",
132
+ "truthy-iterable",
133
+ "unimported-reveal",
134
+ "unreachable",
135
+ "unused-awaitable",
136
+ "unused-ignore",
137
+ ]
138
+ # Shows a short summary line after error messages.
139
+ error_summary = false
140
+ # A regular expression that matches file names, directory names and paths which mypy
141
+ # should ignore while recursively discovering files to check. Use forward slashes (/) as
142
+ # directory separators on all platforms.
143
+ exclude = [
144
+ "^\\.?venv", # SEE: https://regex101.com/r/0rp5Br/1
145
+ ]
146
+ # This flag tells mypy that top-level packages will be based in either the current
147
+ # directory, or a member of the MYPYPATH environment variable or mypy_path config
148
+ # option. This option is only useful in the absence of __init__.py. See Mapping file
149
+ # paths to modules for details.
150
+ explicit_package_bases = true
151
+ # Specifies the paths to use, after trying the paths from MYPYPATH environment variable.
152
+ # Useful if you'd like to keep stubs in your repo, along with the config file.
153
+ # Multiple paths are always separated with a : or , regardless of the platform.
154
+ # User home directory and environment variables will be expanded.
155
+ mypy_path = "$MYPY_CONFIG_FILE_DIR/src,$MYPY_CONFIG_FILE_DIR/packages/remotes/src"
156
+ # Comma-separated list of mypy plugins.
157
+ plugins = ["pydantic.mypy"]
158
+ # Use visually nicer output in error messages: use soft word wrap, show source
159
+ # code snippets, and show error location markers.
160
+ pretty = true
161
+ # Shows column numbers in error messages.
162
+ show_column_numbers = true
163
+ # Shows error codes in error messages.
164
+ # SEE: https://mypy.readthedocs.io/en/stable/error_codes.html#error-codes
165
+ show_error_codes = true
166
+ # Prefixes each error with the relevant context.
167
+ show_error_context = true
168
+ # Warns about casting an expression to its inferred type.
169
+ warn_redundant_casts = true
170
+ # Shows a warning when encountering any code inferred to be unreachable or
171
+ # redundant after performing type analysis.
172
+ warn_unreachable = true
173
+ # Warns about per-module sections in the config file that do not match any
174
+ # files processed when invoking mypy.
175
+ warn_unused_configs = true
176
+ # Warns about unneeded `# type: ignore` comments.
177
+ warn_unused_ignores = true
178
+
179
+ [[tool.mypy.overrides]]
180
+ # Suppresses error messages about imports that cannot be resolved.
181
+ ignore_missing_imports = true
182
+ # Per-module configuration options
183
+ module = [
184
+ "datasets.*", # SEE: https://github.com/huggingface/datasets/issues/3841
185
+ "huggingface_hub.*", # SEE: https://github.com/huggingface/huggingface_hub/issues/1662
186
+ "molbloom", # SEE: https://github.com/whitead/molbloom/issues/29
187
+ "molsol", # SEE: https://github.com/maykcaldas/molsol/issues/6
188
+ "onmt.*",
189
+ "setuptools_scm", # SEE: https://github.com/pypa/setuptools_scm/issues/501
190
+ "transformers.*", # SEE: https://github.com/huggingface/transformers/pull/18485
191
+ ]
192
+
193
+ [tool.pylint]
194
+
195
+ [tool.pylint.design]
196
+ # Maximum number of attributes for a class (see R0902).
197
+ max-attributes = 12
198
+
199
+ [tool.pylint.format]
200
+ # Maximum number of characters on a single line.
201
+ max-line-length = 97 # Match ruff line-length
202
+
203
+ [tool.pylint.main]
204
+ # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
205
+ # number of processors available to use, and will cap the count on Windows to
206
+ # avoid hangs.
207
+ jobs = 0
208
+
209
+ [tool.pylint.messages_control]
210
+ # Disable the message, report, category or checker with the given id(s).
211
+ disable = [
212
+ "bare-except", # Rely on ruff E722 for this
213
+ "broad-exception-caught", # Rely on ruff BLE001 for this
214
+ "broad-exception-raised", # Rely on ruff TRY002 for this
215
+ "dangerous-default-value", # Rely on ruff W0102 for this
216
+ "empty-docstring", # Let pep257 take care of docstrings
217
+ "expression-not-assigned", # Rely on mypy func-returns-value for this
218
+ "fixme", # codetags are useful
219
+ "function-redefined", # Rely on mypy no-redef for this
220
+ "global-statement", # Rely on ruff PLW0603 for this
221
+ "global-variable-not-assigned", # Rely on ruff PLW0602 for this
222
+ "import-outside-toplevel", # Rely on ruff PLC0415 for this
223
+ "keyword-arg-before-vararg", # Rely on ruff B026 for this
224
+ "line-too-long", # Rely on ruff E501 for this
225
+ "logging-fstring-interpolation", # f-strings are convenient
226
+ "logging-too-many-args", # Rely on ruff PLE1205 for this
227
+ "missing-docstring", # Let docformatter and ruff take care of docstrings
228
+ "missing-final-newline", # Rely on ruff W292 for this
229
+ "no-else-return", # Rely on ruff RET506 for this
230
+ "no-member", # Buggy, SEE: https://github.com/pylint-dev/pylint/issues/8138
231
+ "no-value-for-parameter", # Rely on mypy call-arg for this
232
+ "protected-access", # Don't care to enforce this in testing
233
+ "raise-missing-from", # Rely on ruff B904 for this
234
+ "redefined-builtin", # Rely on ruff A002 for this
235
+ "too-few-public-methods", # Don't care to enforce this
236
+ "too-many-arguments", # Don't care to enforce this
237
+ "too-many-boolean-expressions", # Rely on ruff PLR0916 for this
238
+ "too-many-branches", # Rely on ruff PLR0912 for this
239
+ "too-many-locals", # Rely on ruff PLR0914 for this
240
+ "too-many-positional-arguments", # Rely on ruff PLR0917 for this
241
+ "too-many-public-methods", # Rely on ruff PLR0904 for this
242
+ "too-many-return-statements", # Rely on ruff PLR0911 for this
243
+ "too-many-statements", # Rely on ruff PLR0915 for this
244
+ "ungrouped-imports", # Rely on ruff I001 for this
245
+ "unidiomatic-typecheck", # Rely on ruff E721 for this
246
+ "unnecessary-dict-index-lookup", # Rely on ruff PLR1733 for this
247
+ "unreachable", # Rely on mypy unreachable for this
248
+ "unspecified-encoding", # Rely on ruff PLW1514 for this
249
+ "unsubscriptable-object", # Buggy, SEE: https://github.com/pylint-dev/pylint/issues/3637
250
+ "unsupported-membership-test", # Buggy, SEE: https://github.com/pylint-dev/pylint/issues/3045
251
+ "unused-argument", # Rely on ruff ARG002 for this
252
+ "unused-import", # Rely on ruff F401 for this
253
+ "unused-variable", # Rely on ruff F841 for this
254
+ "unused-wildcard-import", # Wildcard imports are convenient
255
+ "use-sequence-for-iteration", # Rely on ruff C0208 for this
256
+ "wildcard-import", # Wildcard imports are convenient
257
+ "wrong-import-order", # Rely on ruff I001 for this
258
+ "wrong-import-position", # Rely on ruff E402 for this
259
+ ]
260
+ # Enable the message, report, category or checker with the given id(s).
261
+ enable = [
262
+ "useless-suppression", # Print unused `pylint: disable` comments
263
+ ]
264
+
265
+ [tool.pylint.reports]
266
+ # Set true to activate the evaluation score.
267
+ score = false
268
+
269
+ [tool.pylint.similarities]
270
+ # Minimum lines number of a similarity.
271
+ min-similarity-lines = 10
272
+
273
+ [tool.pytest.ini_options]
274
+ # Add the specified `OPTS` to the set of command line arguments as if they had
275
+ # been specified by the user.
276
+ addopts = "--typeguard-packages=ether0 --doctest-modules"
277
+ # List of directories that should be searched for tests when no specific directories,
278
+ # files or test ids are given in the command line when executing pytest from the rootdir
279
+ # directory. File system paths may use shell-style wildcards, including the recursive **
280
+ # pattern.
281
+ testpaths = ["packages/remotes", "src", "tests"]
282
+
283
+ [tool.refurb]
284
+ enable_all = true
285
+ ignore = [
286
+ "FURB101", # Rely on ruff FURB101 for this
287
+ "FURB103", # Rely on ruff FURB103 for this
288
+ "FURB141", # Rely on ruff PTH110 for this
289
+ "FURB144", # Rely on ruff PTH107 for this
290
+ "FURB146", # Rely on ruff PTH113 for this
291
+ "FURB147", # Rely on ruff PTH118 for this
292
+ "FURB150", # Rely on ruff PTH102 for this
293
+ "FURB155", # Rely on ruff PTH202 for this
294
+ ]
295
+
296
+ [tool.ruff]
297
+ # Line length to use when enforcing long-lines violations (like `E501`).
298
+ line-length = 97 # ceil(1.1 * 88) makes `E501` equivalent to `B950`
299
+ # Enable application of unsafe fixes.
300
+ unsafe-fixes = true
301
+
302
+ [tool.ruff.format]
303
+ # Enable reformatting of code snippets in docstrings.
304
+ docstring-code-format = true
305
+ # Enable preview style formatting.
306
+ preview = true
307
+
308
+ [tool.ruff.lint]
309
+ explicit-preview-rules = true
310
+ extend-select = [
311
+ "AIR002",
312
+ "AIR301",
313
+ "AIR302",
314
+ "AIR311",
315
+ "AIR312",
316
+ "ASYNC116",
317
+ "B901",
318
+ "B903",
319
+ "B909",
320
+ "CPY001",
321
+ "DOC201",
322
+ "DOC202",
323
+ "DOC402",
324
+ "DOC403",
325
+ "DOC501",
326
+ "DOC502",
327
+ "E111",
328
+ "E112",
329
+ "E113",
330
+ "E114",
331
+ "E115",
332
+ "E116",
333
+ "E117",
334
+ "E201",
335
+ "E202",
336
+ "E203",
337
+ "E204",
338
+ "E211",
339
+ "E221",
340
+ "E222",
341
+ "E223",
342
+ "E224",
343
+ "E225",
344
+ "E226",
345
+ "E227",
346
+ "E228",
347
+ "E231",
348
+ "E241",
349
+ "E242",
350
+ "E251",
351
+ "E252",
352
+ "E261",
353
+ "E262",
354
+ "E265",
355
+ "E266",
356
+ "E271",
357
+ "E272",
358
+ "E273",
359
+ "E274",
360
+ "E275",
361
+ "E301",
362
+ "E302",
363
+ "E303",
364
+ "E304",
365
+ "E305",
366
+ "E306",
367
+ "E502",
368
+ "FURB101",
369
+ "FURB103",
370
+ "FURB110",
371
+ "FURB113",
372
+ "FURB116",
373
+ "FURB118",
374
+ "FURB122",
375
+ "FURB131",
376
+ "FURB132",
377
+ "FURB140",
378
+ "FURB142",
379
+ "FURB145",
380
+ "FURB148",
381
+ "FURB152",
382
+ "FURB154",
383
+ "FURB156",
384
+ "FURB157",
385
+ "FURB162",
386
+ "FURB164",
387
+ "FURB166",
388
+ "FURB171",
389
+ "FURB180",
390
+ "FURB189",
391
+ "FURB192",
392
+ "LOG004",
393
+ "LOG014",
394
+ "PLC0415",
395
+ "PLC1901",
396
+ "PLC2701",
397
+ "PLC2801",
398
+ "PLE0304",
399
+ "PLE1141",
400
+ "PLE4703",
401
+ "PLR0202",
402
+ "PLR0203",
403
+ "PLR0904",
404
+ "PLR0914",
405
+ "PLR0916",
406
+ "PLR0917",
407
+ "PLR1702",
408
+ "PLR1733",
409
+ "PLR6104",
410
+ "PLR6201",
411
+ "PLR6301",
412
+ "PLW0108",
413
+ "PLW0177",
414
+ "PLW0244",
415
+ "PLW1514",
416
+ "PLW1641",
417
+ "PLW3201",
418
+ "PT028",
419
+ "PT029",
420
+ "PT030",
421
+ "PT031",
422
+ "PYI059",
423
+ "PYI061",
424
+ "RUF027",
425
+ "RUF028",
426
+ "RUF029",
427
+ "RUF031",
428
+ "RUF036",
429
+ "RUF037",
430
+ "RUF038",
431
+ "RUF039",
432
+ "RUF043",
433
+ "RUF045",
434
+ "RUF047",
435
+ "RUF049",
436
+ "RUF052",
437
+ "RUF053",
438
+ "RUF054",
439
+ "RUF055",
440
+ "RUF056",
441
+ "RUF057",
442
+ "RUF058",
443
+ "RUF059",
444
+ "RUF102",
445
+ "TC008",
446
+ "UP042",
447
+ "UP045",
448
+ "UP046",
449
+ "UP047",
450
+ "UP049",
451
+ "W391",
452
+ ]
453
+ external = [
454
+ "FURB", # refurb
455
+ ]
456
+ ignore = [
457
+ "ANN", # Don't care to enforce typing
458
+ "BLE001", # Don't care to enforce blind exception catching
459
+ "C901", # Don't care to limit complexity
460
+ "COM812", # Trailing comma with black leads to wasting lines
461
+ "CPY001", # Don't care to require copyright notices in every file
462
+ "D100", # D100, D101, D102, D103, D104, D105, D106, D107: don't always need docstrings
463
+ "D101",
464
+ "D102",
465
+ "D103",
466
+ "D104",
467
+ "D105",
468
+ "D106",
469
+ "D107",
470
+ "D203", # Keep docstring next to the class definition (covered by D211)
471
+ "D212", # Summary should be on second line (opposite of D213)
472
+ "D402", # It's nice to reuse the method name
473
+ "D406", # Google style requires ":" at end
474
+ "D407", # We aren't using numpy style
475
+ "D413", # Blank line after last section. -> No blank line
476
+ "DOC201", # Don't care to require Returns in docstrings
477
+ "DOC402", # Don't care to require Yields in docstrings
478
+ "DOC501", # Don't care to require Raises in docstrings
479
+ "DTZ", # Don't care to have timezone safety
480
+ "EM", # Overly pedantic
481
+ "FBT001", # FBT001, FBT002: overly pedantic
482
+ "FBT002",
483
+ "FIX", # Don't care to prevent TODO, FIXME, etc.
484
+ "G004", # f-strings are convenient
485
+ "INP001", # Can use namespace packages
486
+ "ISC001", # For ruff format compatibility
487
+ "PLR0911", # Don't care to limit complexity
488
+ "PLR0912", # Don't care to limit complexity
489
+ "PLR0913", # Don't care to limit complexity
490
+ "PLR0917", # Don't care to limit complexity
491
+ "PTH", # Overly pedantic
492
+ "SLF001", # Overly pedantic
493
+ "T201", # Overly pedantic
494
+ "TC001", # TC001, TC002, TC003: don't care to enforce type checking blocks
495
+ "TC002",
496
+ "TC003",
497
+ "TC006", # Strings in cast don't work with PyCharm CE 2024.3.4's jump-to-definition
498
+ "TD002", # Don't care for TODO author
499
+ "TD003", # Don't care for TODO links
500
+ "TRY003", # Overly pedantic
501
+ ]
502
+ preview = true
503
+ select = ["ALL"]
504
+ unfixable = [
505
+ "B007", # While debugging, unused loop variables can be useful
506
+ "B905", # Default fix is zip(strict=False), but that can hide bugs
507
+ "ERA001", # While debugging, temporarily commenting code can be useful
508
+ "F401", # While debugging, unused imports can be useful
509
+ "F841", # While debugging, unused locals can be useful
510
+ "TC004", # While debugging, it can be nice to keep TYPE_CHECKING in-tact
511
+ ]
512
+
513
+ [tool.ruff.lint.flake8-annotations]
514
+ mypy-init-return = true
515
+
516
+ [tool.ruff.lint.per-file-ignores]
517
+ "**/tests/*.py" = [
518
+ "N802", # Tests function names can match class names
519
+ "PLR2004", # Tests can have magic values
520
+ "PLR6301", # Test classes can ignore self
521
+ "S101", # Tests can have assertions
522
+ ]
523
+
524
+ [tool.ruff.lint.pycodestyle]
525
+ # The maximum line length to allow for line-length violations within
526
+ # documentation (W505), including standalone comments.
527
+ max-doc-length = 97 # Match line-length
528
+
529
+ [tool.ruff.lint.pydocstyle]
530
+ # Whether to use Google-style or NumPy-style conventions or the PEP257
531
+ # defaults when analyzing docstring sections.
532
+ convention = "google"
533
+
534
+ [tool.tomlsort]
535
+ all = true
536
+ in_place = true
537
+ spaces_before_inline_comment = 2 # Match Python PEP 8
538
+ spaces_indent_inline_array = 4 # Match Python PEP 8
539
+ trailing_comma_inline_array = true
540
+
541
+ [tool.uv.sources]
542
+ ether0 = {workspace = true}
543
+ "ether0.remotes" = {workspace = true}
544
+
545
+ [tool.uv.workspace]
546
+ members = ["packages/*"]
src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/ether0.egg-info/PKG-INFO ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: ether0
3
+ Version: 0.0.0
4
+ Summary: Data models, rewards, and utility functions for ether0.
5
+ Author-email: FutureHouse technical staff <[email protected]>
6
+ License: Apache License
7
+ Version 2.0, January 2004
8
+ http://www.apache.org/licenses/
9
+
10
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
11
+
12
+ 1. Definitions.
13
+
14
+ "License" shall mean the terms and conditions for use, reproduction,
15
+ and distribution as defined by Sections 1 through 9 of this document.
16
+
17
+ "Licensor" shall mean the copyright owner or entity authorized by
18
+ the copyright owner that is granting the License.
19
+
20
+ "Legal Entity" shall mean the union of the acting entity and all
21
+ other entities that control, are controlled by, or are under common
22
+ control with that entity. For the purposes of this definition,
23
+ "control" means (i) the power, direct or indirect, to cause the
24
+ direction or management of such entity, whether by contract or
25
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
26
+ outstanding shares, or (iii) beneficial ownership of such entity.
27
+
28
+ "You" (or "Your") shall mean an individual or Legal Entity
29
+ exercising permissions granted by this License.
30
+
31
+ "Source" form shall mean the preferred form for making modifications,
32
+ including but not limited to software source code, documentation
33
+ source, and configuration files.
34
+
35
+ "Object" form shall mean any form resulting from mechanical
36
+ transformation or translation of a Source form, including but
37
+ not limited to compiled object code, generated documentation,
38
+ and conversions to other media types.
39
+
40
+ "Work" shall mean the work of authorship, whether in Source or
41
+ Object form, made available under the License, as indicated by a
42
+ copyright notice that is included in or attached to the work
43
+ (an example is provided in the Appendix below).
44
+
45
+ "Derivative Works" shall mean any work, whether in Source or Object
46
+ form, that is based on (or derived from) the Work and for which the
47
+ editorial revisions, annotations, elaborations, or other modifications
48
+ represent, as a whole, an original work of authorship. For the purposes
49
+ of this License, Derivative Works shall not include works that remain
50
+ separable from, or merely link (or bind by name) to the interfaces of,
51
+ the Work and Derivative Works thereof.
52
+
53
+ "Contribution" shall mean any work of authorship, including
54
+ the original version of the Work and any modifications or additions
55
+ to that Work or Derivative Works thereof, that is intentionally
56
+ submitted to Licensor for inclusion in the Work by the copyright owner
57
+ or by an individual or Legal Entity authorized to submit on behalf of
58
+ the copyright owner. For the purposes of this definition, "submitted"
59
+ means any form of electronic, verbal, or written communication sent
60
+ to the Licensor or its representatives, including but not limited to
61
+ communication on electronic mailing lists, source code control systems,
62
+ and issue tracking systems that are managed by, or on behalf of, the
63
+ Licensor for the purpose of discussing and improving the Work, but
64
+ excluding communication that is conspicuously marked or otherwise
65
+ designated in writing by the copyright owner as "Not a Contribution."
66
+
67
+ "Contributor" shall mean Licensor and any individual or Legal Entity
68
+ on behalf of whom a Contribution has been received by Licensor and
69
+ subsequently incorporated within the Work.
70
+
71
+ 2. Grant of Copyright License. Subject to the terms and conditions of
72
+ this License, each Contributor hereby grants to You a perpetual,
73
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
74
+ copyright license to reproduce, prepare Derivative Works of,
75
+ publicly display, publicly perform, sublicense, and distribute the
76
+ Work and such Derivative Works in Source or Object form.
77
+
78
+ 3. Grant of Patent License. Subject to the terms and conditions of
79
+ this License, each Contributor hereby grants to You a perpetual,
80
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
81
+ (except as stated in this section) patent license to make, have made,
82
+ use, offer to sell, sell, import, and otherwise transfer the Work,
83
+ where such license applies only to those patent claims licensable
84
+ by such Contributor that are necessarily infringed by their
85
+ Contribution(s) alone or by combination of their Contribution(s)
86
+ with the Work to which such Contribution(s) was submitted. If You
87
+ institute patent litigation against any entity (including a
88
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
89
+ or a Contribution incorporated within the Work constitutes direct
90
+ or contributory patent infringement, then any patent licenses
91
+ granted to You under this License for that Work shall terminate
92
+ as of the date such litigation is filed.
93
+
94
+ 4. Redistribution. You may reproduce and distribute copies of the
95
+ Work or Derivative Works thereof in any medium, with or without
96
+ modifications, and in Source or Object form, provided that You
97
+ meet the following conditions:
98
+
99
+ (a) You must give any other recipients of the Work or
100
+ Derivative Works a copy of this License; and
101
+
102
+ (b) You must cause any modified files to carry prominent notices
103
+ stating that You changed the files; and
104
+
105
+ (c) You must retain, in the Source form of any Derivative Works
106
+ that You distribute, all copyright, patent, trademark, and
107
+ attribution notices from the Source form of the Work,
108
+ excluding those notices that do not pertain to any part of
109
+ the Derivative Works; and
110
+
111
+ (d) If the Work includes a "NOTICE" text file as part of its
112
+ distribution, then any Derivative Works that You distribute must
113
+ include a readable copy of the attribution notices contained
114
+ within such NOTICE file, excluding those notices that do not
115
+ pertain to any part of the Derivative Works, in at least one
116
+ of the following places: within a NOTICE text file distributed
117
+ as part of the Derivative Works; within the Source form or
118
+ documentation, if provided along with the Derivative Works; or,
119
+ within a display generated by the Derivative Works, if and
120
+ wherever such third-party notices normally appear. The contents
121
+ of the NOTICE file are for informational purposes only and
122
+ do not modify the License. You may add Your own attribution
123
+ notices within Derivative Works that You distribute, alongside
124
+ or as an addendum to the NOTICE text from the Work, provided
125
+ that such additional attribution notices cannot be construed
126
+ as modifying the License.
127
+
128
+ You may add Your own copyright statement to Your modifications and
129
+ may provide additional or different license terms and conditions
130
+ for use, reproduction, or distribution of Your modifications, or
131
+ for any such Derivative Works as a whole, provided Your use,
132
+ reproduction, and distribution of the Work otherwise complies with
133
+ the conditions stated in this License.
134
+
135
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
136
+ any Contribution intentionally submitted for inclusion in the Work
137
+ by You to the Licensor shall be under the terms and conditions of
138
+ this License, without any additional terms or conditions.
139
+ Notwithstanding the above, nothing herein shall supersede or modify
140
+ the terms of any separate license agreement you may have executed
141
+ with Licensor regarding such Contributions.
142
+
143
+ 6. Trademarks. This License does not grant permission to use the trade
144
+ names, trademarks, service marks, or product names of the Licensor,
145
+ except as required for reasonable and customary use in describing the
146
+ origin of the Work and reproducing the content of the NOTICE file.
147
+
148
+ 7. Disclaimer of Warranty. Unless required by applicable law or
149
+ agreed to in writing, Licensor provides the Work (and each
150
+ Contributor provides its Contributions) on an "AS IS" BASIS,
151
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
152
+ implied, including, without limitation, any warranties or conditions
153
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
154
+ PARTICULAR PURPOSE. You are solely responsible for determining the
155
+ appropriateness of using or redistributing the Work and assume any
156
+ risks associated with Your exercise of permissions under this License.
157
+
158
+ 8. Limitation of Liability. In no event and under no legal theory,
159
+ whether in tort (including negligence), contract, or otherwise,
160
+ unless required by applicable law (such as deliberate and grossly
161
+ negligent acts) or agreed to in writing, shall any Contributor be
162
+ liable to You for damages, including any direct, indirect, special,
163
+ incidental, or consequential damages of any character arising as a
164
+ result of this License or out of the use or inability to use the
165
+ Work (including but not limited to damages for loss of goodwill,
166
+ work stoppage, computer failure or malfunction, or any and all
167
+ other commercial damages or losses), even if such Contributor
168
+ has been advised of the possibility of such damages.
169
+
170
+ 9. Accepting Warranty or Additional Liability. While redistributing
171
+ the Work or Derivative Works thereof, You may choose to offer,
172
+ and charge a fee for, acceptance of support, warranty, indemnity,
173
+ or other liability obligations and/or rights consistent with this
174
+ License. However, in accepting such obligations, You may act only
175
+ on Your own behalf and on Your sole responsibility, not on behalf
176
+ of any other Contributor, and only if You agree to indemnify,
177
+ defend, and hold each Contributor harmless for any liability
178
+ incurred by, or claims asserted against, such Contributor by reason
179
+ of your accepting any such warranty or additional liability.
180
+
181
+ END OF TERMS AND CONDITIONS
182
+
183
+ APPENDIX: How to apply the Apache License to your work.
184
+
185
+ To apply the Apache License to your work, attach the following
186
+ boilerplate notice, with the fields enclosed by brackets "[]"
187
+ replaced with your own identifying information. (Don't include
188
+ the brackets!) The text should be enclosed in the appropriate
189
+ comment syntax for the file format. We also recommend that a
190
+ file or class name and description of purpose be included on the
191
+ same "printed page" as the copyright notice for easier
192
+ identification within third-party archives.
193
+
194
+ Copyright 2025 FutureHouse
195
+
196
+ Licensed under the Apache License, Version 2.0 (the "License");
197
+ you may not use this file except in compliance with the License.
198
+ You may obtain a copy of the License at
199
+
200
+ http://www.apache.org/licenses/LICENSE-2.0
201
+
202
+ Unless required by applicable law or agreed to in writing, software
203
+ distributed under the License is distributed on an "AS IS" BASIS,
204
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
205
+ See the License for the specific language governing permissions and
206
+ limitations under the License.
207
+
208
+ Project-URL: issues, https://github.com/Future-House/ether0/issues
209
+ Project-URL: repository, https://github.com/Future-House/ether0
210
+ Classifier: Intended Audience :: Developers
211
+ Classifier: License :: OSI Approved :: Apache Software License
212
+ Classifier: Operating System :: OS Independent
213
+ Classifier: Programming Language :: Python :: 3 :: Only
214
+ Classifier: Programming Language :: Python :: 3.11
215
+ Classifier: Programming Language :: Python :: 3.12
216
+ Classifier: Programming Language :: Python :: 3.13
217
+ Classifier: Programming Language :: Python
218
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
219
+ Classifier: Topic :: Scientific/Engineering :: Chemistry
220
+ Requires-Python: >=3.11
221
+ Description-Content-Type: text/markdown
222
+ License-File: LICENSE
223
+ Requires-Dist: accelerate>=1.10.1
224
+ Requires-Dist: datasets
225
+ Requires-Dist: exmol>=3.3.0
226
+ Requires-Dist: gradio>=5.44.0
227
+ Requires-Dist: httpx
228
+ Requires-Dist: huggingface-hub
229
+ Requires-Dist: molbloom==2.3.4
230
+ Requires-Dist: pydantic>=2
231
+ Requires-Dist: rdkit
232
+ Requires-Dist: regex
233
+ Requires-Dist: spaces>=0.40.1
234
+ Requires-Dist: tenacity
235
+ Provides-Extra: add-tokens
236
+ Requires-Dist: ipykernel; extra == "add-tokens"
237
+ Requires-Dist: ipywidgets>=8; extra == "add-tokens"
238
+ Requires-Dist: transformers>=4.49; extra == "add-tokens"
239
+ Provides-Extra: baselines
240
+ Requires-Dist: fhaviary>=0.19; extra == "baselines"
241
+ Requires-Dist: fhlmi>=0.26; extra == "baselines"
242
+ Requires-Dist: ipython; extra == "baselines"
243
+ Provides-Extra: dev
244
+ Requires-Dist: ether0[add-tokens,typing]; extra == "dev"
245
+ Requires-Dist: huggingface-hub[cli]; extra == "dev"
246
+ Requires-Dist: ipython>=8; extra == "dev"
247
+ Requires-Dist: mypy>=1.8; extra == "dev"
248
+ Requires-Dist: pre-commit>=3.4; extra == "dev"
249
+ Requires-Dist: pylint>=3; extra == "dev"
250
+ Requires-Dist: pytest; extra == "dev"
251
+ Requires-Dist: pytest-subtests; extra == "dev"
252
+ Requires-Dist: pytest-sugar; extra == "dev"
253
+ Requires-Dist: pytest-timer[colorama]; extra == "dev"
254
+ Requires-Dist: pytest-xdist; extra == "dev"
255
+ Requires-Dist: refurb>=2; extra == "dev"
256
+ Requires-Dist: typeguard; extra == "dev"
257
+ Provides-Extra: typing
258
+ Requires-Dist: types-regex; extra == "typing"
259
+ Dynamic: license-file
260
+
261
+ # ether0 Reward Model
262
+
263
+ [![GitHub](https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white)](https://github.com/Future-House/ether0)
264
+ [![arXiv](https://img.shields.io/badge/arXiv-2506.17238-b31b1b.svg)](https://arxiv.org/abs/2506.17238)
265
+ [![Project Status: Active](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active)
266
+ ![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)
267
+
268
+ [![Tests](https://github.com/Future-House/ether0/actions/workflows/lint-test.yaml/badge.svg)](https://github.com/Future-House/ether0/actions)
269
+ [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
270
+ [![python](https://img.shields.io/badge/python-3.11+-blue?style=flat&logo=python&logoColor=white)](https://www.python.org)
271
+ [![Model on HF](https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-md-dark.svg)](https://huggingface.co/futurehouse/ether0)
272
+ [![Dataset on HF](https://huggingface.co/datasets/huggingface/badges/resolve/main/dataset-on-hf-md-dark.svg)](https://huggingface.co/datasets/futurehouse/ether0-benchmark)
273
+
274
+ ![ether0 logo](docs/assets/ether0_logo.svg)
275
+
276
+ _ether0: a scientific reasoning model, dataset, and reward functions for chemistry._
277
+
278
+ This repo contains the reward model for evaluating ether0 and similar models,
279
+ along with utilities for working with the verifiable rewards in
280
+ [our benchmark](https://huggingface.co/datasets/futurehouse/ether0-benchmark).
281
+
282
+ ## Overview
283
+
284
+ ether0 is a reasoning language model post-trained through a loop of:
285
+
286
+ 1. Supervised fine-tuning (SFT) on long chain-of-thought reasoning traces,
287
+ to elicit reasoning from a base model.
288
+ 2. Reinforcement learning with verifiable rewards (RLVR)
289
+ to improve reasoning on focused task groups, at their own pace.
290
+ These multitask learned models are referred to as 'specialists'.
291
+ 3. Rejection sampling to filter specialists' reasoning
292
+ for correctness and quality.
293
+ 4. SFT on the base model again to make a 'generalist' reasoning model.
294
+ 5. RLVR to recover any lost performance and push further in an all-task setting.
295
+
296
+ ![ether0 training info](docs/assets/training_info.png)
297
+
298
+ ### Repo Structure
299
+
300
+ This repo contains several packages:
301
+
302
+ - `ether0`: reward functions, `rdkit` data utilities,
303
+ dataset generation prompts, dataset data models,
304
+ language model training prompts, and data models.
305
+ - `ether0.remotes`: server code for ether0 reward functions involving
306
+ exotic packages and/or third party models.
307
+
308
+ > [!NOTE]
309
+ > This repo does not contain training code,
310
+ > although you can find open source repositories like [NeMo-RL](https://github.com/NVIDIA/NeMo-RL)
311
+ > or [Hugging Face TRL](https://github.com/huggingface/trl)
312
+ > that can do the SFT and RL phases of training.
313
+
314
+ ### Open Weights
315
+
316
+ Please see our open-source weights on Hugging Face:
317
+ <https://huggingface.co/futurehouse/ether0>
318
+
319
+ ```python
320
+ from transformers import AutoModelForCausalLM, AutoTokenizer
321
+
322
+ model = AutoModelForCausalLM.from_pretrained("futurehouse/ether0")
323
+ tokenizer = AutoTokenizer.from_pretrained("futurehouse/ether0")
324
+ ```
325
+
326
+ ### Open Test Set
327
+
328
+ Please see our open-source benchmark (test set) on Hugging Face:
329
+ <https://huggingface.co/datasets/futurehouse/ether0-benchmark>
330
+
331
+ ```python
332
+ from datasets import load_dataset
333
+
334
+ test_ds = load_dataset("futurehouse/ether0-benchmark", split="test")
335
+ ```
336
+
337
+ ## Usage
338
+
339
+ ### Installation
340
+
341
+ The easiest way to get started is a `pip install` from GitHub:
342
+
343
+ ```bash
344
+ pip install git+https://github.com/Future-House/ether0.git
345
+ ```
346
+
347
+ Or if you want the full set up, clone the repo and use `uv`:
348
+
349
+ ```bash
350
+ git clone https://github.com/Future-House/ether0.git
351
+ cd ether0
352
+ uv sync
353
+ ```
354
+
355
+ ### Reward Functions
356
+
357
+ Here is a basic example of how to use the reward functions:
358
+
359
+ ```python
360
+ from ether0.rewards import valid_mol_eval
361
+
362
+ # Task: provide a valid completion of this molecule
363
+ partial_smiles = "O=C(OC1C(OC(=O)C=2C=CC=CC2)C3(O)C(C)(C)CCCC3(C)C4CC=5OC=CC5C(C)C14"
364
+
365
+ # Here's two model-proposed SMILES completions
366
+ invalid_completion_smiles = "CCC"
367
+ valid_completion_smiles = ")C=6C=CC=CC6"
368
+
369
+ # Evaluate the completions
370
+ assert not valid_mol_eval(invalid_completion_smiles, partial_smiles)
371
+ assert valid_mol_eval(valid_completion_smiles, partial_smiles)
372
+ ```
373
+
374
+ ### Visualization
375
+
376
+ If it helps, you can visualize the molecules:
377
+
378
+ ```python
379
+ from ether0.data import draw_molecule
380
+
381
+ # See above reward functions demo for where these came from
382
+ partial_smiles = "O=C(OC1C(OC(=O)C=2C=CC=CC2)C3(O)C(C)(C)CCCC3(C)C4CC=5OC=CC5C(C)C14"
383
+ invalid_completion_smiles = "CCC"
384
+ valid_completion_smiles = ")C=6C=CC=CC6"
385
+
386
+ valid_mol_text = draw_molecule(partial_smiles + valid_completion_smiles)
387
+ with open("valid_molecule.svg", "w") as f:
388
+ f.write(valid_mol_text)
389
+ ```
390
+
391
+ The output of `draw_molecule` can also be easily visualized using `IPython.display`,
392
+ or in your terminal via `chafa valid_molecule.svg`
393
+ ([chafa docs](https://hpjansson.org/chafa/)).
394
+
395
+ ![valid molecule](docs/assets/valid_molecule.svg)
396
+
397
+ ### Benchmark
398
+
399
+ Here is a sample baseline of
400
+ [`ether0-benchmark`](https://huggingface.co/datasets/futurehouse/ether0-benchmark)
401
+ on `gpt-4o` using [`lmi`](https://github.com/Future-House/ldp/tree/main/packages/lmi).
402
+ To install `lmi`, please install `ether0` with the `baselines` extra
403
+ (for example `uv sync --extra baselines`).
404
+
405
+ We also need to run our remote rewards server via `ether0-serve`
406
+ (for more information, see [`ether0.remotes` docs](packages/remotes/README.md)):
407
+
408
+ ```bash
409
+ ETHER0_REMOTES_API_TOKEN=abc123 ether0-serve
410
+ ```
411
+
412
+ Next, start `ipython` with the relevant environment variables set:
413
+
414
+ ```bash
415
+ ETHER0_REMOTES_API_BASE_URL="http://127.0.0.1:8000" ETHER0_REMOTES_API_TOKEN=abc123 \
416
+ ipython
417
+ ```
418
+
419
+ And run the following Python code:
420
+
421
+ ```python
422
+ import itertools
423
+ import statistics
424
+ from collections import defaultdict
425
+
426
+ from aviary.core import Message
427
+ from datasets import load_dataset
428
+ from lmi import LiteLLMModel
429
+ from tqdm.asyncio import tqdm_asyncio as asyncio
430
+
431
+ from ether0.data import get_problem_category
432
+ from ether0.model_prompts import LOOSE_XML_ANSWER_USER_PROMPT, extract_answer_loose
433
+ from ether0.models import RewardFunctionInfo
434
+ from ether0.rewards import EVAL_FUNCTIONS
435
+
436
+ # Add LLM prompt of your making to the dataset
437
+ test_ds = load_dataset("futurehouse/ether0-benchmark", split="test").map(
438
+ lambda x: {"prompt": "\n\n".join((LOOSE_XML_ANSWER_USER_PROMPT, x["problem"]))}
439
+ )
440
+
441
+ # Prompt to LLM
442
+ model = LiteLLMModel(name="gpt-4o")
443
+ results = await asyncio.gather(
444
+ *(model.acompletion([Message(content=row["prompt"])]) for row in test_ds),
445
+ desc="Running evaluation",
446
+ )
447
+
448
+ # Compute rewards
449
+ per_category_rewards = defaultdict(list)
450
+ for row, result in zip(test_ds, results, strict=True):
451
+ # NOTE: you can also use `ether0.rewards.accuracy_reward`,
452
+ # but we decided to go a bit "lower level" for this demo
453
+ reward_info = RewardFunctionInfo.model_validate(row["solution"])
454
+ yhat = extract_answer_loose(result[0].text)
455
+ reward = EVAL_FUNCTIONS[reward_info.fxn_name](
456
+ yhat=yhat, y=reward_info.answer_info, test=True
457
+ )
458
+ per_category_rewards[get_problem_category(reward_info.problem_type)].append(reward)
459
+
460
+ for category, rewards in sorted(per_category_rewards.items()):
461
+ print(
462
+ f"In category {category!r} of {len(rewards)} questions,"
463
+ f" average reward was {statistics.mean(rewards):.3f}."
464
+ )
465
+ accuracy = statistics.mean(itertools.chain.from_iterable(per_category_rewards.values()))
466
+ print(f"Cumulative average reward across {len(test_ds)} questions was {accuracy:.3f}.")
467
+ ```
src/ether0.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .gitignore
2
+ .pre-commit-config.yaml
3
+ .python-version
4
+ CITATION.cff
5
+ LICENSE
6
+ README.md
7
+ pyproject.toml
8
+ uv.lock
9
+ .github/renovate.json5
10
+ .github/workflows/lint-test.yaml
11
+ docs/adding_tokens.ipynb
12
+ docs/updated_mistral_chat_template.jinja
13
+ docs/assets/ether0_logo.svg
14
+ docs/assets/training_info.png
15
+ docs/assets/valid_molecule.svg
16
+ packages/remotes/.gitignore
17
+ packages/remotes/LICENSE
18
+ packages/remotes/README.md
19
+ packages/remotes/pyproject.toml
20
+ packages/remotes/src/ether0/server.py
21
+ packages/remotes/tests/conftest.py
22
+ packages/remotes/tests/test_clients.py
23
+ packages/remotes/tests/test_rewards.py
24
+ src/ether0/chat.py
25
+ src/ether0/clients.py
26
+ src/ether0/data.py
27
+ src/ether0/fingerprints.bloom
28
+ src/ether0/model_prompts.py
29
+ src/ether0/models.py
30
+ src/ether0/problem_prompts.py
31
+ src/ether0/py.typed
32
+ src/ether0/rewards.py
33
+ src/ether0/rings.bloom
34
+ src/ether0/utils.py
35
+ src/ether0.egg-info/PKG-INFO
36
+ src/ether0.egg-info/SOURCES.txt
37
+ src/ether0.egg-info/dependency_links.txt
38
+ src/ether0.egg-info/requires.txt
39
+ src/ether0.egg-info/top_level.txt
40
+ tests/conftest.py
41
+ tests/test_chat.py
42
+ tests/test_data.py
43
+ tests/test_model_prompts.py
44
+ tests/test_models.py
45
+ tests/test_rewards.py
46
+ tests/test_utils.py
src/ether0.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
src/ether0.egg-info/requires.txt ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=1.10.1
2
+ datasets
3
+ exmol>=3.3.0
4
+ gradio>=5.44.0
5
+ httpx
6
+ huggingface-hub
7
+ molbloom==2.3.4
8
+ pydantic>=2
9
+ rdkit
10
+ regex
11
+ spaces>=0.40.1
12
+ tenacity
13
+
14
+ [add-tokens]
15
+ ipykernel
16
+ ipywidgets>=8
17
+ transformers>=4.49
18
+
19
+ [baselines]
20
+ fhaviary>=0.19
21
+ fhlmi>=0.26
22
+ ipython
23
+
24
+ [dev]
25
+ ether0[add-tokens,typing]
26
+ huggingface-hub[cli]
27
+ ipython>=8
28
+ mypy>=1.8
29
+ pre-commit>=3.4
30
+ pylint>=3
31
+ pytest
32
+ pytest-subtests
33
+ pytest-sugar
34
+ pytest-timer[colorama]
35
+ pytest-xdist
36
+ refurb>=2
37
+ typeguard
38
+
39
+ [typing]
40
+ types-regex
src/ether0.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ether0
src/ether0/chat.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from collections.abc import Callable, MutableMapping
3
+ from dataclasses import dataclass, field
4
+ from functools import wraps
5
+ from itertools import starmap
6
+ from typing import Any, ParamSpec, TypeVar, cast
7
+
8
+ from ether0.model_prompts import (
9
+ ANSWER_END,
10
+ ANSWER_START,
11
+ THINK_END,
12
+ THINK_START,
13
+ ProblemPrompt,
14
+ SysPrompt,
15
+ extract_answer_loose,
16
+ )
17
+ from ether0.rewards import accuracy_reward, format_reward
18
+
19
+ P = ParamSpec("P")
20
+ R = TypeVar("R")
21
+
22
+
23
+ def wrap_reward_func(func: Callable[P, R], **wrap_kwargs: Any) -> Callable[P, R]:
24
+ @wraps(func) # needed by GRPOTrainer for logging
25
+ def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
26
+ return func(*args, **wrap_kwargs, **kwargs)
27
+
28
+ return wrapped
29
+
30
+
31
+ @dataclass
32
+ class ChatArguments:
33
+ """Arguments for making a chat conversation for SFT or RL training."""
34
+
35
+ sys_prompt: SysPrompt | None = field(
36
+ default=None,
37
+ metadata={
38
+ "help": (
39
+ "If provided, use this system prompt. If not provided, the chat"
40
+ " template may inject one."
41
+ )
42
+ },
43
+ )
44
+
45
+ problem_prompt: ProblemPrompt = field(
46
+ default=ProblemPrompt.NONE,
47
+ metadata={
48
+ "help": (
49
+ "Prompt to put before the problem in the first user message, relevant"
50
+ " for both RL or SFT. Make sure this matches between SFT and RL, so if"
51
+ " the SFT'd model wasn't passed this during SFT, don't pass this to RL."
52
+ )
53
+ },
54
+ )
55
+
56
+ reasoning: bool = field(
57
+ default=True,
58
+ metadata={
59
+ "help": (
60
+ "If True (default), it is assumed that the model's response contains"
61
+ f" reasoning enclosed in `{THINK_START}` and `{THINK_END}`."
62
+ )
63
+ },
64
+ )
65
+
66
+ def make_rl_conversation(
67
+ self, row: MutableMapping[str, str | list[str]]
68
+ ) -> dict[str, list[dict] | list[list[dict]]]:
69
+ """Format a dataset row into a chat-like conversation structure.
70
+
71
+ This will add a `messages` key to the dataset. Unlike make_sft_convo,
72
+ the answer will not be included.
73
+ """
74
+ if not self.sys_prompt:
75
+ msgs: list[dict] = []
76
+ else:
77
+ msgs = [{
78
+ "role": "system",
79
+ "content": SysPrompt(self.sys_prompt).get_sys_prompt(),
80
+ }]
81
+ problem_prompt = ProblemPrompt(self.problem_prompt).get_prompt()
82
+ if problem_prompt:
83
+ problem_prompt += "\n\n"
84
+
85
+ def add_user(problem: str) -> list[dict]:
86
+ return [*msgs, {"role": "user", "content": problem_prompt + problem}]
87
+
88
+ if isinstance(row["problem"], str): # Single
89
+ all_msgs: list[dict] | list[list[dict]] = add_user(row["problem"])
90
+ else: # Batched
91
+ all_msgs = [add_user(p) for p in row["problem"]]
92
+ return {"prompt": all_msgs}
93
+
94
+ def make_sft_conversation(
95
+ self, row: MutableMapping[str, str | list[str]]
96
+ ) -> dict[str, list[dict] | list[list[dict]]]:
97
+ """Format a dataset row into a chat-like conversation structure.
98
+
99
+ This will add a `messages` key to the dataset.
100
+ """
101
+ if (
102
+ self.reasoning
103
+ and ProblemPrompt(self.problem_prompt) == ProblemPrompt.ANSWER
104
+ ):
105
+ raise ValueError(
106
+ "It does not make sense to include reasoning in the SFT traces,"
107
+ " but then only prompt about answer XML (without thoughts)."
108
+ )
109
+
110
+ def add_assistant(
111
+ raw_answer: str, thought: str, prior_msgs: list[dict]
112
+ ) -> list[dict]:
113
+ if re.search(r"<\/answer>", raw_answer):
114
+ # Remove prelude and postlude plus XML tags,
115
+ # because an OpenRouter-hosted DeepSeek R1 can give answer
116
+ # with a prelude and XML tags, but our training expects just an answer
117
+ # > The reaction involves sodium borohydride ([BH4-].[Na+]), <redacted>.
118
+ # > Under these conditions, <redacted>.
119
+ # > <answer>N1(CCOCC1)C1=CC=C(C(O))C=C1</answer>
120
+ answer = extract_answer_loose(raw_answer)
121
+ if not answer:
122
+ raise ValueError(
123
+ "Failed to extract just the answer from the answer"
124
+ f" {raw_answer!r}."
125
+ )
126
+ else:
127
+ answer = raw_answer
128
+
129
+ return [
130
+ *prior_msgs,
131
+ {
132
+ "role": "assistant",
133
+ "content": (
134
+ (f"{THINK_START}{thought}{THINK_END}" if self.reasoning else "")
135
+ + f"{ANSWER_START}{answer}{ANSWER_END}"
136
+ ),
137
+ },
138
+ ]
139
+
140
+ # The first part will be the same as the RL conversation
141
+ msgs = self.make_rl_conversation(row)["prompt"]
142
+ # Now add the answer, with optional thinking
143
+ if isinstance(row["problem"], str): # Single
144
+ all_msgs: list[dict] | list[list[dict]] = add_assistant(
145
+ cast(str, row["answer"]),
146
+ cast(str, row["thought"]),
147
+ cast(list[dict], msgs),
148
+ )
149
+ else: # Batched
150
+ all_msgs = list(
151
+ starmap(
152
+ add_assistant, zip(row["answer"], row["thought"], msgs, strict=True)
153
+ )
154
+ )
155
+ return {"messages": all_msgs}
156
+
157
+ def get_reward_funcs(
158
+ self,
159
+ format_reward_value: float = 1.0,
160
+ soft: bool = False,
161
+ test: bool = False,
162
+ good_molecule_bonus: float = 0.0,
163
+ ) -> list[Callable]:
164
+ return [
165
+ wrap_reward_func(
166
+ format_reward,
167
+ reasoning=self.reasoning,
168
+ reward=format_reward_value,
169
+ ),
170
+ wrap_reward_func(
171
+ accuracy_reward,
172
+ reasoning=self.reasoning,
173
+ soft=soft,
174
+ test=test,
175
+ good_molecule_bonus=good_molecule_bonus,
176
+ ),
177
+ ]
src/ether0/clients.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from collections import Counter
4
+ from collections.abc import Mapping
5
+ from typing import Any
6
+
7
+ import httpx
8
+ from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ BASE_URL = os.environ.get("ETHER0_REMOTES_API_BASE_URL")
13
+ HEADERS = {
14
+ "Authorization": f"Bearer {os.environ.get('ETHER0_REMOTES_API_TOKEN')}",
15
+ "Content-Type": "application/json",
16
+ }
17
+ SERVER_ERRORS_COUNTER = Counter({
18
+ "fetch_solubility": 0,
19
+ "fetch_purchasable": 0,
20
+ "fetch_forward_rxn": 0,
21
+ "fetch_rxn_info": 0,
22
+ })
23
+ THROW_500_ERROR_THRESHOLD = int(
24
+ os.environ.get("ETHER0_REMOTES_THROW_500_ERROR_THRESHOLD", "100")
25
+ )
26
+ # If our server throws a 501, we don't retry
27
+ OUR_SERVER_DONT_RETRY_CODE = httpx.codes.NOT_IMPLEMENTED.value
28
+ REMOTE_WORKER_COLD_START_TIME = 180 # sec
29
+
30
+
31
+ class RetryableServerError(Exception):
32
+ """Retryable server error."""
33
+
34
+ @classmethod
35
+ def check_raise(
36
+ cls, response: httpx.Response, kwargs: Mapping[str, Any] | None = None
37
+ ) -> None:
38
+ if (
39
+ response.is_server_error
40
+ and response.status_code != OUR_SERVER_DONT_RETRY_CODE
41
+ ):
42
+ raise cls(
43
+ f"Retryable server error with status code {response.status_code}"
44
+ f" and inputs {kwargs or {}} and response {response=}."
45
+ )
46
+
47
+
48
+ @retry(
49
+ stop=stop_after_attempt(3),
50
+ wait=wait_fixed(1),
51
+ retry=retry_if_exception_type((
52
+ httpx.ReadTimeout,
53
+ httpx.ConnectError,
54
+ RetryableServerError,
55
+ )),
56
+ )
57
+ def fetch_solubility(query_smiles: str) -> dict:
58
+ response = httpx.post(
59
+ f"{BASE_URL}/compute_solubility",
60
+ json={"smiles": query_smiles},
61
+ headers=HEADERS,
62
+ timeout=REMOTE_WORKER_COLD_START_TIME,
63
+ )
64
+
65
+ error_message = ""
66
+ if response.is_success:
67
+ result = response.json()
68
+ if "error" in result:
69
+ error_message = result["error"]
70
+ else:
71
+ solubility = result["mean"]
72
+ return {"smiles": query_smiles, "solubility": solubility}
73
+ if response.is_redirect or response.is_server_error:
74
+ # We should not have redirect responses or server errors, so let's retry these
75
+ error_message = response.text
76
+ SERVER_ERRORS_COUNTER["fetch_solubility"] += 1
77
+ if SERVER_ERRORS_COUNTER["fetch_solubility"] >= THROW_500_ERROR_THRESHOLD:
78
+ response.raise_for_status()
79
+ RetryableServerError.check_raise(
80
+ response, kwargs={"query_smiles": query_smiles}
81
+ )
82
+ if error_message:
83
+ logger.warning(
84
+ f"fetch_solubility did not succeed on {query_smiles=} with"
85
+ f" {response=} and {error_message=}."
86
+ )
87
+ return {
88
+ "smiles": query_smiles,
89
+ "error": f"API error: {response} - {error_message}",
90
+ }
91
+
92
+
93
+ @retry(
94
+ stop=stop_after_attempt(3),
95
+ wait=wait_fixed(1),
96
+ retry=retry_if_exception_type((
97
+ httpx.ReadTimeout,
98
+ httpx.ConnectError,
99
+ RetryableServerError,
100
+ )),
101
+ )
102
+ def fetch_purchasable(query_smiles_list: list[str] | str) -> dict[str, bool]:
103
+ response = httpx.post(
104
+ f"{BASE_URL}/is_purchasable",
105
+ json={"smiles": query_smiles_list},
106
+ headers=HEADERS,
107
+ timeout=REMOTE_WORKER_COLD_START_TIME,
108
+ )
109
+
110
+ if response.is_success:
111
+ return response.json()
112
+ logger.warning(
113
+ f"fetch_purchasable did not succeed on {query_smiles_list=} with"
114
+ f" {response=} and {response.text=}."
115
+ )
116
+ if response.is_redirect or response.is_server_error:
117
+ # We should not have redirect responses or server errors, so let's retry these
118
+ SERVER_ERRORS_COUNTER["fetch_purchasable"] += 1
119
+ if SERVER_ERRORS_COUNTER["fetch_purchasable"] >= THROW_500_ERROR_THRESHOLD:
120
+ response.raise_for_status()
121
+ RetryableServerError.check_raise(
122
+ response, kwargs={"query_smiles_list": query_smiles_list}
123
+ )
124
+ return {}
125
+
126
+
127
+ @retry(
128
+ stop=stop_after_attempt(3),
129
+ wait=wait_fixed(1),
130
+ retry=retry_if_exception_type((
131
+ httpx.ReadTimeout,
132
+ httpx.ConnectError,
133
+ RetryableServerError,
134
+ )),
135
+ )
136
+ def fetch_forward_rxn(query_rxn_smiles: str) -> dict[str, str]:
137
+ response = httpx.post(
138
+ f"{BASE_URL}/translate",
139
+ json={"reaction": query_rxn_smiles},
140
+ headers=HEADERS,
141
+ timeout=REMOTE_WORKER_COLD_START_TIME,
142
+ )
143
+
144
+ if response.is_success:
145
+ result = response.json()
146
+ product = result["product"]
147
+ return {"smiles": query_rxn_smiles, "product": product}
148
+ logger.warning(
149
+ f"fetch_forward_rxn did not succeed on {query_rxn_smiles=} with"
150
+ f" {response=} and {response.text=}."
151
+ )
152
+ if response.is_redirect or response.is_server_error:
153
+ # We should not have redirect responses or server errors, so let's retry these
154
+ SERVER_ERRORS_COUNTER["fetch_forward_rxn"] += 1
155
+ if SERVER_ERRORS_COUNTER["fetch_forward_rxn"] >= THROW_500_ERROR_THRESHOLD:
156
+ response.raise_for_status()
157
+ RetryableServerError.check_raise(
158
+ response, kwargs={"query_rxn_smiles": query_rxn_smiles}
159
+ )
160
+ return {
161
+ "smiles": query_rxn_smiles,
162
+ "error": f"API error: {response} - {response.text}",
163
+ }
src/ether0/data.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+ from collections.abc import Collection, Mapping
4
+ from pathlib import Path
5
+
6
+ from datasets import Dataset
7
+ from molbloom import BloomFilter, canon
8
+ from rdkit import Chem
9
+ from rdkit.Chem.Draw import MolDraw2D, MolDraw2DSVG # pylint: disable=no-name-in-module
10
+ from rdkit.Chem.Draw.rdMolDraw2D import MolDraw2DCairo
11
+ from rdkit.Chem.rdChemReactions import ( # pylint: disable=no-name-in-module
12
+ ReactionFromSmarts,
13
+ )
14
+ from rdkit.Chem.rdDepictor import ( # pylint: disable=no-name-in-module
15
+ Compute2DCoords,
16
+ StraightenDepiction,
17
+ )
18
+ from rdkit.Chem.rdMolDescriptors import ( # pylint: disable=no-name-in-module
19
+ GetMorganFingerprint,
20
+ )
21
+ from rdkit.Chem.rdmolfiles import MolFromSmiles # pylint: disable=no-name-in-module
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ PROBLEM_CATEGORY_TO_NICKNAME: Mapping[str, str] = {
27
+ "functional-group": "functional group",
28
+ "molecule-caption": "molecule caption",
29
+ "molecule-completion": "SMILES completion",
30
+ "molecule-formula": "elucidation",
31
+ "molecule-name": "IUPAC name",
32
+ "oracle-solubility": "solubility edit",
33
+ "property": "multiple choice",
34
+ "property-cat-brain": "BBB permeability",
35
+ "property-cat-eve": "Human receptor binding",
36
+ "property-cat-safety": "safety",
37
+ "property-cat-smell": "scent",
38
+ "property-regression-pka": "pKa",
39
+ "property-regression-ld50": "LD50",
40
+ "property-regression-adme": "ADME",
41
+ "reaction-prediction": "reaction prediction",
42
+ "retro-synthesis": "retrosynthesis",
43
+ "simple-formula": "molecular formula",
44
+ "property-regression-adme/log_hlm_clint": "log of HLM CL$_{\\text{int}}$",
45
+ "property-regression-adme/log_mdr1-mdck_er": "log of MDR1-MDCK ER",
46
+ "property-regression-adme/log_rlm_clint": "log of RLM CL$_{\\text{int}}$",
47
+ "property-regression-adme/log_solubility": "log of aqueous solubility",
48
+ }
49
+
50
+
51
+ def get_problem_type(row: Mapping[str, str]) -> str:
52
+ return row.get("problem_type") or row["type"]
53
+
54
+
55
+ def get_problem_category(problem_type: str | None) -> str:
56
+ return (problem_type or "").split("/", maxsplit=1)[0]
57
+
58
+
59
+ def get_problem_categories_from_datasets(*datasets: Dataset) -> Collection[str]:
60
+ return {
61
+ get_problem_category(pt)
62
+ for dataset in datasets
63
+ for pt in (dataset.hf_dataset if hasattr(dataset, "hf_dataset") else dataset)[
64
+ "problem_type"
65
+ ]
66
+ }
67
+
68
+
69
+ # Use this regex with findall to extract SMILES strings from text.
70
+ # Note this function currently fails on counterions e.g.
71
+ # Cc1ccc(-c2ccc3c(c2)c2ccccc2c[n+]3C)cc1.[Cl-]
72
+ SMILES_PATTERN = re.compile(
73
+ r"(?<!\w)(?:(?:Cl|Br|[BCNOPSFIC]|[cnops]|\[[^\]]+?\]|[0-9@+\-=#\\/()%])){4,}(?!\w)"
74
+ )
75
+
76
+
77
+ def make_sized_d2d(w: int = 400, h: int = 300) -> MolDraw2DCairo:
78
+ return MolDraw2DCairo(w, h)
79
+
80
+
81
+ def draw_molecule(
82
+ smiles: str, bg_opacity: float = 1.0, d2d: MolDraw2D | None = None
83
+ ) -> str:
84
+ """Draw a SMILES molecule and return the drawing string."""
85
+ mol = Chem.MolFromSmiles(smiles)
86
+ if mol is None:
87
+ raise ValueError(f"Failed to convert {smiles=} to a molecule.")
88
+ Compute2DCoords(mol)
89
+ StraightenDepiction(mol)
90
+ if d2d is None:
91
+ d2d = MolDraw2DSVG(-1, -1)
92
+ dopts = d2d.drawOptions()
93
+ dopts.useBWAtomPalette()
94
+ dopts.setBackgroundColour((*dopts.getBackgroundColour(), bg_opacity))
95
+ d2d.DrawMolecule(mol)
96
+ d2d.FinishDrawing()
97
+ return d2d.GetDrawingText()
98
+
99
+
100
+ def draw_reaction(
101
+ rxn_smiles: str, bg_opacity: float = 1.0, d2d: MolDraw2D | None = None
102
+ ) -> str:
103
+ rxn = ReactionFromSmarts(rxn_smiles, useSmiles=True)
104
+ if d2d is None:
105
+ d2d = MolDraw2DSVG(-1, -1)
106
+ dopts = d2d.drawOptions()
107
+ dopts.useBWAtomPalette()
108
+ dopts.setBackgroundColour((*dopts.getBackgroundColour(), bg_opacity))
109
+ d2d.DrawReaction(rxn)
110
+ d2d.FinishDrawing()
111
+ return d2d.GetDrawingText()
112
+
113
+
114
+ # Precompiled SMARTS patterns for protected bonds and ring atoms
115
+ _ring_db_pat = Chem.MolFromSmarts("[#6R,#16R]=[OR0,SR0,CR0,NR0]")
116
+ _ring_atom_pat = Chem.MolFromSmarts("[R]")
117
+
118
+
119
+ bloom_filters: dict[str, BloomFilter] = {}
120
+
121
+
122
+ def _get_bits(mol: Chem.Mol) -> set[str]:
123
+ """Get the fingerprint bits from a molecule."""
124
+ # the keys are the actual bits
125
+ bi: dict[int, tuple[tuple[int, int], ...]] = {}
126
+ GetMorganFingerprint(mol, 2, bitInfo=bi) # type: ignore[arg-type]
127
+ return {str(k) for k in bi}
128
+
129
+
130
+ ETHER0_DIR = Path(__file__).parent
131
+
132
+
133
+ def _get_bloom_filter(name: str) -> BloomFilter:
134
+ if name in bloom_filters:
135
+ return bloom_filters[name]
136
+ bloom_filters[name] = BloomFilter(str(ETHER0_DIR / f"{name}.bloom"))
137
+ return bloom_filters[name]
138
+
139
+
140
+ def get_ring_system(mol: Chem.Mol) -> list[str]:
141
+ """
142
+ Extracts ring systems from an RDKit molecule and returns a list of SMILES.
143
+ Bonds not in rings and not protected (e.g., ring carbonyls) are cleaved.
144
+
145
+ Source: https://github.com/PatWalters/useful_rdkit_utils/blob/edb126e3fd71870ae2d1c9440b904106e3ef97a2/useful_rdkit_utils/ring_systems.py#L13
146
+ Which has a MIT license, copyright 2021-2025 PatWalters.
147
+ """ # noqa: D205
148
+ # Copy to avoid mutating original
149
+ mol = Chem.Mol(mol)
150
+
151
+ # Tag protected bonds
152
+ for bond in mol.GetBonds():
153
+ bond.SetBoolProp("protected", False) # noqa: FBT003
154
+ for a1, a2 in mol.GetSubstructMatches(_ring_db_pat):
155
+ b = mol.GetBondBetweenAtoms(a1, a2)
156
+ b.SetBoolProp("protected", True) # noqa: FBT003
157
+
158
+ # Cleave linker bonds
159
+ cleave_idxs = [
160
+ b.GetIdx()
161
+ for b in mol.GetBonds()
162
+ if not b.IsInRing()
163
+ and not b.GetBoolProp("protected")
164
+ and b.GetBondType() == Chem.BondType.SINGLE
165
+ ]
166
+ if cleave_idxs:
167
+ frag_mol = Chem.FragmentOnBonds(mol, cleave_idxs)
168
+ Chem.SanitizeMol(frag_mol)
169
+ else:
170
+ frag_mol = mol
171
+
172
+ # Split into fragments and clean up
173
+ ring_smiles: list[str] = []
174
+ for frag in Chem.GetMolFrags(frag_mol, asMols=True):
175
+ if frag.HasSubstructMatch(_ring_atom_pat):
176
+ for atom in frag.GetAtoms():
177
+ if atom.GetAtomicNum() == 0:
178
+ atom.SetAtomicNum(1)
179
+ atom.SetIsotope(0)
180
+ frag = Chem.RemoveAllHs(frag) # noqa: PLW2901
181
+ # Fix stereo on terminal double bonds
182
+ for bd in frag.GetBonds():
183
+ if bd.GetBondType() == Chem.BondType.DOUBLE and (
184
+ 1 in {bd.GetBeginAtom().GetDegree(), bd.GetEndAtom().GetDegree()}
185
+ ):
186
+ bd.SetStereo(Chem.BondStereo.STEREONONE)
187
+ ring_smiles.append(Chem.MolToSmiles(frag))
188
+
189
+ return ring_smiles
190
+
191
+
192
+ def is_reasonable_ring_system(mol: Chem.Mol, ref_mol: Chem.Mol | None = None) -> bool:
193
+ """
194
+ Check if a molecule has a reasonable ring system.
195
+
196
+ Either no rings or the ring system is found in known rings.
197
+ If reference is provided, thsos are assumed valid.
198
+ """
199
+ bloom_filter = _get_bloom_filter("rings")
200
+ ring_systems = [canon(r) for r in get_ring_system(mol)]
201
+ # remove from consideration all rings in ref_mol, since we'll always assume they're correct
202
+ if ref_mol:
203
+ ref_ring_systems = [canon(r) for r in get_ring_system(ref_mol)]
204
+ ring_systems = [ring for ring in ring_systems if ring not in ref_ring_systems]
205
+ return all((r in bloom_filter) for r in ring_systems)
206
+
207
+
208
+ def is_reasonable_fp(mol: Chem.Mol, ref_mol: Chem.Mol | None = None) -> bool:
209
+ """
210
+ Check if a molecule has a reasonable fingerprint.
211
+
212
+ If reference is provided, those fingerprints are assumed valid.
213
+ """
214
+ bloom_filter = _get_bloom_filter("fingerprints")
215
+ bits: Collection[str] = _get_bits(mol)
216
+ # remove from consideration all rings in ref_mol, since we'll always assume they're correct
217
+ if ref_mol:
218
+ ref_bits = _get_bits(ref_mol)
219
+ bits = [bit for bit in bits if bit not in ref_bits]
220
+ return all((b in bloom_filter) for b in bits)
221
+
222
+
223
+ def mol_from_smiles(smiles: str, *args, **kwargs) -> Chem.Mol | None:
224
+ """MolFromSmiles is type-hinted to always return Mol, but can return None."""
225
+ return MolFromSmiles(smiles, *args, **kwargs)
src/ether0/fingerprints.bloom ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9267f986c512f26131eb0219170c12e07ded9c1574c75233da5e3d5f3f00278
3
+ size 12500018
src/ether0/model_prompts.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Prompts and utilities used for training the ether0 model."""
2
+
3
+ import re
4
+ from enum import Enum, StrEnum
5
+ from typing import assert_never
6
+
7
+ # Tokens to surround reasoning and answer in XML format
8
+ THINK_START = "<|think_start|>"
9
+ THINK_END = "<|think_end|>"
10
+ ANSWER_START = "<|answer_start|>"
11
+ ANSWER_END = "<|answer_end|>"
12
+
13
+
14
+ # Keys: True (reasoning + answer), False (answer only)
15
+ # Use strict regex for ether0 models, as we can SFT or RL the models into compliance
16
+ STRICT_XML_ANSWER_SPLIT_PATTERNS: dict[bool, re.Pattern] = {
17
+ True: re.compile(
18
+ rf"^\s?{re.escape(THINK_START)}\s*([\s\S]*?)\s*{re.escape(THINK_END)}([\s\S]*?){re.escape(ANSWER_START)}\s*([\s\S]*?)\s*{re.escape(ANSWER_END)}$"
19
+ ),
20
+ False: re.compile(
21
+ rf"^\s?{re.escape(ANSWER_START)}\s*(\S[\s\S]*?)\s*{re.escape(ANSWER_END)}$"
22
+ ),
23
+ }
24
+ # Use loose regex for other models because:
25
+ # 1. <think> may be out-of-distribution from the model's training data,
26
+ # so requiring thoughts may degrade performance.
27
+ # 2. We allow baseline models to add extra whitespace and/or preceding or trailing text
28
+ # around answer XML, again to maximize performance.
29
+ # 3. Similarly, we allow models to ramble for a bit mentioning <answer>,
30
+ # and then we just keep the last <answer> XML.
31
+ # 4. We want to avoid prompt engineering tricks to get around the previous items.
32
+ LOOSE_XML_ANSWER_LOOSE_PATTERN = r"<answer>\s*(\S[\s\S]*?)\s*<\/answer>"
33
+
34
+
35
+ class XMLAnswerPrompts(StrEnum):
36
+ """Enum of prompts to use ."""
37
+
38
+ REASONING_ANSWER = (
39
+ "A conversation between User and Assistant."
40
+ " The user asks a question, and the Assistant solves it."
41
+ " The assistant first thinks about the reasoning process"
42
+ " in the mind and then provides the user with the answer."
43
+ " The reasoning process and answer are enclosed within"
44
+ f" {THINK_START} {THINK_END} and {ANSWER_START} {ANSWER_END} tags,"
45
+ " respectively, i.e.,"
46
+ f" {THINK_START} reasoning process here {THINK_END}"
47
+ f"{ANSWER_START} answer here {ANSWER_END}"
48
+ )
49
+ ANSWER_ONLY = (
50
+ "A conversation between User and Assistant."
51
+ " The user asks a question, and the Assistant solves it."
52
+ " The assistant encloses its answer within"
53
+ f" {ANSWER_START} {ANSWER_END} tags, i.e.,"
54
+ f" {ANSWER_START} answer here {ANSWER_END}"
55
+ )
56
+
57
+ @property
58
+ def pattern(self) -> re.Pattern:
59
+ return STRICT_XML_ANSWER_SPLIT_PATTERNS[
60
+ self == XMLAnswerPrompts.REASONING_ANSWER
61
+ ]
62
+
63
+
64
+ class SysPrompt(Enum): # Use Enum over StrEnum for trl.TrlParser compatibility
65
+ """Possible system prompts for making a conversation to train upon."""
66
+
67
+ SCIENTIFIC_AI = "scientific_ai"
68
+
69
+ def get_sys_prompt(self) -> str:
70
+ match self:
71
+ case SysPrompt.SCIENTIFIC_AI:
72
+ return "You are a scientific reasoning AI assistant."
73
+ case _:
74
+ assert_never(self)
75
+
76
+
77
+ class ProblemPrompt(Enum): # Use Enum over StrEnum for trl.TrlParser compatibility
78
+ """Possible user prompts for making a conversation to train upon."""
79
+
80
+ NONE = "none"
81
+ THINK_ANSWER = "think_answer"
82
+ ANSWER = "answer"
83
+
84
+ def get_prompt(self) -> str:
85
+ match self:
86
+ case ProblemPrompt.NONE:
87
+ return ""
88
+ case ProblemPrompt.THINK_ANSWER:
89
+ return XMLAnswerPrompts.REASONING_ANSWER.value
90
+ case ProblemPrompt.ANSWER:
91
+ return XMLAnswerPrompts.ANSWER_ONLY.value
92
+ case _:
93
+ assert_never(self)
94
+
95
+
96
+ def extract_thought_answer_strict(
97
+ text: str, reasoning: bool
98
+ ) -> tuple[str | None, str | None]:
99
+ """Extract thought and answer from text using a strict XML pattern."""
100
+ # Use `maxsplit=1` to enforce just one match
101
+ matches = STRICT_XML_ANSWER_SPLIT_PATTERNS[reasoning].split(text, maxsplit=1)
102
+ try:
103
+ _, *inner, suffix = matches
104
+ except (IndexError, ValueError):
105
+ return None, None # Consider no answer or 2+ answers as a failure
106
+ if reasoning:
107
+ thought, inter, answer = inner
108
+ else:
109
+ thought, inter = None, None
110
+ (answer,) = inner
111
+ if (
112
+ THINK_START not in (thought or "")
113
+ and THINK_START not in (inter or "")
114
+ and ANSWER_START not in answer
115
+ and not suffix
116
+ ):
117
+ return thought, answer or None
118
+ return None, None # Consider nested answer as a failure
119
+
120
+
121
+ LOOSE_XML_ANSWER_USER_PROMPT = (
122
+ "When answering,"
123
+ " be sure to place the final answer as"
124
+ " SMILES notation into XML tags <answer></answer>."
125
+ " An example is <answer>CCO</answer>."
126
+ )
127
+
128
+
129
+ def extract_answer_loose(text: str | None) -> str:
130
+ """
131
+ Extract thought and answer from text using a loose XML pattern.
132
+
133
+ SEE: LOOSE_XML_ANSWER_LOOSE_PATTERN for when to use this.
134
+ """
135
+ matches = re.findall(LOOSE_XML_ANSWER_LOOSE_PATTERN, text or "")
136
+ try:
137
+ last_answer = matches[-1] # Last answer in the response
138
+ except IndexError:
139
+ return "" # Consider no answer as a failure
140
+ if "<answer>" not in last_answer:
141
+ return last_answer
142
+ return "" # Consider nested answer as a failure
src/ether0/models.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from collections.abc import Collection
3
+ from enum import StrEnum, auto
4
+ from typing import Any
5
+
6
+ from datasets import DatasetDict
7
+ from pydantic import BaseModel, Field, model_validator
8
+
9
+ from ether0.utils import TDataset
10
+
11
+ REWARD_REASON_KEY = "reward_reason" # Sentinel key
12
+
13
+
14
+ class RewardReason(StrEnum):
15
+ FORMAT_FAILED = auto()
16
+ INVALID_MOL = auto()
17
+ # Catch-all for invalid values that aren't a molecule or a reaction
18
+ INVALID_VALUE = auto()
19
+
20
+ # Oracle regression values
21
+ WRONG_NUMERICAL_ANSWER = auto()
22
+
23
+ # Reaction/retro-synthesis failures
24
+ INVALID_RXN = auto()
25
+ WRONG_PRODUCT = auto()
26
+ PRODUCT_IS_REACTANT = auto()
27
+ NOT_PURCHASABLE = auto()
28
+
29
+ # Molecule formula/functional group failures
30
+ WRONG_FORMULA = auto()
31
+ FAILED_CONSTRAINT = auto()
32
+
33
+ # Unreasonable molecules
34
+ FAILED_REOS_CHECK = auto()
35
+ FAILED_RING_CHECK = auto()
36
+ FAILED_COUNTERION_CHECK = auto()
37
+
38
+ # Really this is a bug, but we don't want to blow up training if a
39
+ # few bad examples slip through.
40
+ INVALID_GROUND_TRUTH = auto()
41
+
42
+ # Failover reason if we have an exception during a reward function.
43
+ # NOTE: not using "failed" or "error" since an unhandled exception
44
+ # may be something else
45
+ REWARD_FUNCTION_EXCEPTION = auto()
46
+
47
+ # These are automatically added if no other reason is given
48
+ WRONG_ANSWER = auto()
49
+ RIGHT_ANSWER = auto()
50
+
51
+ def set_reason(self, metadata: dict | None) -> None:
52
+ if metadata is not None:
53
+ metadata[REWARD_REASON_KEY] = self.value
54
+
55
+ @classmethod
56
+ def set_default_reason(cls, reward: float, metadata: dict | None) -> None:
57
+ if metadata is not None and REWARD_REASON_KEY not in metadata:
58
+ (cls.RIGHT_ANSWER if reward >= 1.0 else cls.WRONG_ANSWER).set_reason(
59
+ metadata
60
+ )
61
+
62
+
63
+ SOLUTION_DELIMITER = "!:!"
64
+
65
+
66
+ class RewardFunctionInfo(BaseModel):
67
+ """Metadata used by a reward function to evaluate a solution."""
68
+
69
+ fxn_name: str = Field(description="Name of the reward function to use.")
70
+ answer_info: str = Field(
71
+ description="Serialized metadata used by the reward function."
72
+ )
73
+ problem_type: str = Field(description="Problem type, for reference.")
74
+
75
+ @model_validator(mode="before")
76
+ @classmethod
77
+ def check_card_number_not_present(cls, data: Any) -> Any:
78
+ if isinstance(data, str):
79
+ # Deserialize from a string 3-tuple
80
+ fn, ainfo, pt = data.split(SOLUTION_DELIMITER, maxsplit=2)
81
+ return {"fxn_name": fn, "answer_info": ainfo, "problem_type": pt}
82
+ return data
83
+
84
+
85
+ class QAExample(BaseModel):
86
+ """Question-answer example with reward function info."""
87
+
88
+ id: str = Field(description="Unique identifier for this example.")
89
+ problem: str = Field(description="Problem to solve.")
90
+ problem_type: str = Field(description="Problem type, for reference or filtering.")
91
+ solution: RewardFunctionInfo = Field(
92
+ description="Metadata for the reward function."
93
+ )
94
+ ideal: str | None = Field(
95
+ description=(
96
+ "An optional ideal answer. This could be a candidate SMILES, a log10 of"
97
+ " water solubility, or None if having an ideal does not make sense."
98
+ )
99
+ )
100
+ unformatted: str | None = Field(
101
+ description=(
102
+ "Optional raw data used to generate the problem, used for traceability."
103
+ )
104
+ )
105
+
106
+
107
+ def filter_problem_types(
108
+ dataset: TDataset, problem_types: str | Collection[str] | None
109
+ ) -> TDataset:
110
+ """Filter a dataset by problem types.
111
+
112
+ Args:
113
+ dataset: The dataset to filter. Can be a single Dataset or a DatasetDict.
114
+ problem_types: A string or collection of strings specifying the problem
115
+ types to filter by.
116
+ - If None, the original dataset is returned.
117
+ - If a string or a collection of strings:
118
+ - Strings starting with "re:" are treated as regex patterns.
119
+ If a regex filter is provided, then it must be the only filter.
120
+ - Strings starting with "!" are treated as problem types to exclude.
121
+ - Other strings are treated as exact problem types to include.
122
+ - Mixing inclusion and exclusion rules (e.g. ["type_a", "!type_b"])
123
+ is not allowed.
124
+
125
+ Returns:
126
+ The filtered dataset.
127
+ """
128
+ if problem_types is None:
129
+ return dataset
130
+ if isinstance(problem_types, str): # Assume single problem type as a string
131
+ problem_types = [problem_types]
132
+ problem_types = {pt.strip() for pt in problem_types}
133
+
134
+ columns = (
135
+ next(iter(dataset.values())) if isinstance(dataset, DatasetDict) else dataset
136
+ ).column_names
137
+ # ether0-benchmark uses 'problem_type'; some variants may use 'type'
138
+ type_col = "problem_type" if "problem_type" in columns else "type"
139
+
140
+ if any(pt.startswith("re:") for pt in problem_types):
141
+ # A regex was passed in
142
+ if len(problem_types) != 1:
143
+ raise ValueError(
144
+ "If filtering by regex, only one filter is supported,"
145
+ f" passed {problem_types}."
146
+ )
147
+ regex = re.compile(next(iter(problem_types)).removeprefix("re:"))
148
+
149
+ def filter_func(x):
150
+ return regex.match(x[type_col]) is not None
151
+
152
+ else:
153
+ # Treat as exact string match
154
+ valid_problem_types = {pt for pt in problem_types if not pt.startswith("!")}
155
+ invalid_problem_types = {
156
+ pt.removeprefix("!") for pt in problem_types if pt.startswith("!")
157
+ }
158
+ if valid_problem_types:
159
+ if invalid_problem_types:
160
+ raise ValueError(
161
+ "Cannot specify both problem types to keep and to exclude,"
162
+ f" passed {problem_types}."
163
+ )
164
+
165
+ def filter_func(x):
166
+ return x[type_col] in valid_problem_types
167
+
168
+ else:
169
+
170
+ def filter_func(x):
171
+ return x[type_col] not in invalid_problem_types
172
+
173
+ return dataset.filter(filter_func, desc="Filtering problem types")
src/ether0/problem_prompts.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Prompt templates used for problems in the ether0 dataset."""
2
+
3
+ # ruff: noqa: E501, W505
4
+
5
+ NAME_IUPAC_PROMPTS = [
6
+ "What are the SMILES from the IUPAC name: {iupac}?",
7
+ "Could you please tell me the SMILES representation for {iupac}?",
8
+ "I have {iupac}. What would its SMILES be?",
9
+ "Provide the SMILES string for the molecule named {iupac}.",
10
+ "Convert this IUPAC name into a SMILES format: {iupac}.",
11
+ "Give me the SMILES notation for the compound {iupac}.",
12
+ "What SMILES corresponds to the IUPAC chemical name {iupac}?",
13
+ "How can I represent {iupac} as a SMILES string?",
14
+ "Generate the SMILES structure for this compound: {iupac}.",
15
+ "If the molecule is called {iupac}, what's the SMILES representation?",
16
+ ]
17
+
18
+ NAME_SMILES_PROMPTS = [
19
+ "What is the IUPAC name of this molecule: {smiles}?",
20
+ "Could you please tell me the IUPAC name for the compound represented by the SMILES string: {smiles}?",
21
+ "I have a molecule here with the SMILES notation {smiles}. What would its IUPAC name be?",
22
+ "I'm working with a chemical compound, and its SMILES representation is {smiles}. Can you help me determine its IUPAC name?",
23
+ "What is the correct IUPAC nomenclature for a molecule with the SMILES code {smiles}?",
24
+ "I'm trying to identify a compound. Its SMILES string is {smiles}. What's its IUPAC name?",
25
+ "If I provide you with the SMILES string of a molecule, which is {smiles}, could you generate its IUPAC name for me?",
26
+ "I've encountered a molecule in my research, denoted by the SMILES {smiles}. I'd appreciate it if you could tell me its IUPAC name.",
27
+ "Can you derive the IUPAC name from this SMILES representation: {smiles}?",
28
+ "For a compound with the structural representation given by the SMILES {smiles}, what is the corresponding IUPAC name?",
29
+ ]
30
+
31
+
32
+ REACTION_PROMPTS = [
33
+ "What is the product of this reaction? {rxn_smiles}",
34
+ "If {rxn_smiles} are allowed to react, what would the resulting molecule be in SMILES format?",
35
+ "Can you predict the outcome of this chemical reaction? {rxn_smiles}",
36
+ "I have a reaction scheme here: {rxn_smiles}. What will be generated as the product?",
37
+ "Assuming the reaction {rxn_smiles} goes to completion, what is the SMILES representation of the product?",
38
+ "In this reaction: {rxn_smiles}, what compound is formed?",
39
+ "Given the reactants and conditions in {rxn_smiles}, what is the expected product?",
40
+ "After performing the reaction {rxn_smiles}, what would I obtain as the main product?",
41
+ "If I mix these reactants under the specified conditions, {rxn_smiles}, what is the SMILES of the product?",
42
+ "Please provide the SMILES string of the product for this reaction: {rxn_smiles}",
43
+ ]
44
+
45
+ NAME_REACTION_PROMPTS = [
46
+ "What is the name of this reaction?\n{rxn_smiles}",
47
+ "I have a reaction here with {rxn_smiles}. Can you tell me its commonly used name?",
48
+ "Given the reaction represented by {rxn_smiles}, what is the standard name associated with it?",
49
+ "If I were to describe this reaction, {rxn_smiles}, in a textbook, what reaction name would I use?",
50
+ "I'm trying to identify this reaction: {rxn_smiles}. What is its well-known name?",
51
+ "The reaction {rxn_smiles} is taking place. What's the name of this type of transformation?",
52
+ "I came across this reaction pathway: {rxn_smiles}. Do you know the name it generally goes by?",
53
+ "In the context of organic chemistry, what is the established name for the reaction shown here: {rxn_smiles}?",
54
+ "I'm writing a lab report and need to name this reaction, {rxn_smiles}. What should I call it?",
55
+ "Can you identify the name of the reaction that follows this scheme: {rxn_smiles}?",
56
+ ]
57
+
58
+ COMPLETE_MOL_PROMPTS = [
59
+ "I have a partial molecule represented by the SMILES string {smiles}. What is a valid completion of this molecule, providing only the remaining characters in SMILES format?",
60
+ "Given the incomplete SMILES fragment {smiles}, can you suggest a realistic ending to complete the molecule? Please provide only the additional SMILES characters needed.",
61
+ "I'm working with a molecule that's partially described as {smiles}. What sequence of SMILES characters would you add to make it a complete, valid molecule?",
62
+ "The beginning of a molecule's SMILES representation is {smiles}. How would you finish this SMILES string to represent a viable chemical compound? Only provide the continuation of the SMILES.",
63
+ "Imagine you need to complete the SMILES string {smiles}. What's a plausible way to extend it to form a complete molecule, expressed as the remaining SMILES characters?",
64
+ "If I give you the partial SMILES {smiles}, what's a reasonable way to finish it off to create a valid molecule? Respond with just the additional SMILES characters.",
65
+ "I'm trying to construct a molecule, and I have the start of its SMILES: {smiles}. Could you provide a completion for it, ensuring the final molecule is realistic? Only give me the rest of the SMILES string.",
66
+ "Here's a fragment of a SMILES string: {smiles}. What would be a chemically sound way to complete it? Respond with the missing portion of the SMILES representation.",
67
+ "Suppose you have the incomplete molecular structure {smiles} in SMILES. How would you complete it to represent a real molecule, adding only the necessary SMILES characters?",
68
+ "I have an unfinished molecule represented by the SMILES fragment {smiles}. Can you help me complete it by suggesting the remaining SMILES characters needed to make it a valid chemical structure?",
69
+ ]
70
+
71
+ MOL_FORMULA_PROMPTS = [
72
+ "A compound with formula {formula} was isolated from {source}. What is a plausible SMILES for it given this organism?",
73
+ "{source} makes a compound with this formula: {formula}. What SMILES structure might correspond to it?",
74
+ "In {source}, I found a substance with formula {formula}. What biosynthetically plausible SMILES might this represent?",
75
+ "Analysis of {source} revealed a compound ({formula}). What SMILES structure aligns with this organism's metabolism?",
76
+ "The organism {source} contains a compound with formula {formula}. What's a likely SMILES based on its biochemistry?",
77
+ "A {formula} compound was extracted from {source}. Based on this organism, what's a probable SMILES structure?",
78
+ "What SMILES could have the formula {formula} and be isolated from {source}?",
79
+ "What would be a biologically relevant SMILES for a {formula} compound isolated from the organism {source}?",
80
+ "The organism {source} produced a compound with formula {formula}, what SMILES structure makes biosynthetic sense?",
81
+ "A {formula} metabolite from {source} was identified. What's a biologically plausible compound for this (as SMILES)?",
82
+ ]
83
+
84
+ FUNCTIONAL_GROUP_PROMPTS = [
85
+ "Propose a compound with molecular formula {formula} that contains the following functional groups: {functional_group}.",
86
+ "Suggest a SMILES structure for a molecule with formula {formula} and the following functional groups: {functional_group}.",
87
+ "Given that a compound has formula {formula}, propose SMILES for one that also contains these groups: {functional_group}",
88
+ "Provide a reasonable SMILES for a chemical with molecular formula {formula} and these groups: {functional_group}.",
89
+ "Generate a SMILES representation for a molecule containing groups: {functional_group}. It should also have formula {formula}.",
90
+ "Identify a plausible SMILES for a chemical compound with formula {formula} containing these groups: {functional_group}.",
91
+ ]
92
+
93
+ PROPERTY_TRIPLET_PROMPTS = [
94
+ "I have a molecule {smiles1} with a {property} of {value1}. Which of these similar molecules will most likely {change} this property?\n{options}",
95
+ "Given a molecule ({smiles1}) having a {property} of {value1}, select the modified molecule below that would {change} this property significantly:\n{options}",
96
+ "Molecule {smiles1} currently exhibits {property} of {value1}. Which modifications from the list below would effectively {change} it?\n{options}",
97
+ "If molecule {smiles1} has a {property} value of {value1}, which of the following options would best {change} this property?\n{options}",
98
+ "Considering {smiles1} has a measured {property} of {value1}, which candidate modification listed would most effectively {change} this property?\n{options}",
99
+ "Molecule {smiles1} demonstrates a {property} of {value1}. Which similar molecule below is best suited to {change} this characteristic?\n{options}",
100
+ "Given molecule {smiles1} with {property} at {value1}, identify which molecule among the following options would {change} it most effectively:\n{options}",
101
+ "Starting from molecule {smiles1}, which shows a {property} of {value1}, choose the structural change below that would notably {change} this property:\n{options}",
102
+ "The molecule {smiles1} has a {property} of {value1}. Which molecule listed would optimally {change} this value?\n{options}",
103
+ "Given a {property} of {value1} for molecule {smiles1}, pick the best molecule from below to {change} this property:\n{options}",
104
+ ]
105
+
106
+ # I have a molecule {smiles1} which is not a blood-brain barrier penetrating. Which of these similar molecules will most likely have this property?\n{options}",
107
+ PROPERTY_TRIPLET_PROMPTS_CAT = [
108
+ "I have a molecule {smiles1} which {rel} {property}. Which of these similar molecules will most likely {irel} this property?\n{options}",
109
+ "Given molecule {smiles1} that {rel} {property}, which molecule below is likely to {irel} this property?\n{options}",
110
+ "Molecule {smiles1} currently {rel} {property}. Choose from these similar molecules the one most likely to {irel} this property:\n{options}",
111
+ "Considering {smiles1} {rel} {property}, identify which of the following candidates will most likely {irel} the characteristic:\n{options}",
112
+ "Given that molecule {smiles1} {rel} {property}, select from below the molecule most expected to {irel} this characteristic:\n{options}",
113
+ "Starting from molecule {smiles1} which {rel} {property}, determine which listed molecule is most likely to {irel} this property:\n{options}",
114
+ "If molecule {smiles1} {rel} {property}, which of these related structures will most probably {irel} that property?\n{options}",
115
+ "Given molecule {smiles1} {rel} {property}, select the similar molecule listed below most likely to {irel} this property:\n{options}",
116
+ ]
117
+
118
+ # Which of the following options likely is a blood-brain barrier penetrating molecule?
119
+ # Which of the following options likely is not a blood-brain barrier penetrating molecule?
120
+ # Which of the following molecules is likely to not be blood-brain barrier penetrating?
121
+ PROPERTY_PROMPTS_CAT = [
122
+ "Which of the following options likely is{rel} a {property} molecule?\n{options}",
123
+ "Which of the following molecules is likely to{rel} be {property}?\n{options}",
124
+ "Identify the molecule below that likely is{rel} a {property} molecule:\n{options}",
125
+ "From the list below, select the molecule most likely to{rel} be {property}:\n{options}",
126
+ "Choose the molecule from the options below that most probably is{rel} {property}:\n{options}",
127
+ "Among the following, which molecule likely is{rel} considered {property}?\n{options}",
128
+ "Select the molecule below most expected to{rel} have {property} properties:\n{options}",
129
+ "From these molecules, identify the one most likely to{rel} possess {property}:\n{options}",
130
+ "Which candidate below most probably is{rel} classified as a {property} molecule?\n{options}",
131
+ ]
132
+
133
+
134
+ PROPERTY_PROMPTS = [
135
+ "Which of the following molecules likely has a {property} of {value}?\n{options}",
136
+ "Identify the molecule below expected to have a {property} around {value}:\n{options}",
137
+ "From these options, select the molecule most likely exhibiting {property} of {value}:\n{options}",
138
+ "Determine which of the following molecules likely shows a {property} near {value}:\n{options}",
139
+ "Choose the molecule that would most plausibly have a {property} of {value} from the list below:\n{options}",
140
+ "Among the following, which molecule is predicted to have a {property} close to {value}?\n{options}",
141
+ "Given the choices below, pick the molecule most likely to possess a {property} of {value}:\n{options}",
142
+ "Select the molecule from these candidates that probably has a {property} of {value}:\n{options}",
143
+ "Which molecule listed here is most likely to have a {property} approximately equal to {value}?\n{options}",
144
+ "Identify which of the following molecules will most likely have a {property} of {value}:\n{options}",
145
+ ]
146
+
147
+ RETRO_PROMPTS = [
148
+ "Propose a 1-step synthesis for the molecule {smiles} using likely purchasable reactants.",
149
+ "Given the molecule {smiles}, suggest a 1-step synthesis using commercially available starting materials.",
150
+ "What is a plausible 1-step reaction for the molecule {smiles} using common reactants?",
151
+ "Suggest a commercially feasible one-step route to synthesize {smiles}.",
152
+ "Outline a practical single-step synthetic method to prepare the molecule {smiles}.",
153
+ "Design a straightforward 1-step reaction scheme for synthesizing {smiles} using commercially available reagents.",
154
+ "Identify a likely accessible precursor and reaction for a single-step synthesis of {smiles}.",
155
+ "Provide a realistic single-step synthetic pathway to obtain {smiles} from common chemicals.",
156
+ "Propose a viable one-step synthetic route toward the molecule {smiles} starting from purchasable precursors.",
157
+ "Suggest one plausible reaction step to generate {smiles} using standard, commercially sourced reactants.",
158
+ ]
159
+
160
+ ORACLE_SOLUBILITY_PROMPTS = {
161
+ "tanimoto": [
162
+ "Propose a small change to {smiles} to {direction} its solubility by about 1 logS.",
163
+ "Suggest a minimal structural modification to {smiles} that would {direction} its solubility by approximately 1 logS unit.",
164
+ "What minor alteration could be made to {smiles} to {direction} its solubility by roughly 1 logS?",
165
+ "Design a small molecular change to {smiles} that would {direction} its solubility by about 1 logS while maintaining overall similarity.",
166
+ "Identify a small structural adjustment to {smiles} that would {direction} its aqueous solubility by approximately 1 logS unit.",
167
+ ],
168
+ "scaffold": [
169
+ "Change {smiles} to {direction} its solubility by about 1 logS, but keep its scaffold",
170
+ "Modify {smiles} to {direction} its solubility by approximately 1 logS while preserving the core scaffold structure.",
171
+ "Suggest alterations to {smiles} that would {direction} its solubility by about 1 logS unit without changing the molecular scaffold.",
172
+ "How could {smiles} be transformed to {direction} its solubility by roughly 1 logS while maintaining its scaffold?",
173
+ "Design a derivative of {smiles} with {direction}d solubility (by about 1 logS) that retains the same molecular scaffold.",
174
+ ],
175
+ "groups": [
176
+ "Adjust {smiles} to {direction} its solubility by about 1 logS, but keep the following groups intact: {pretty_groups}",
177
+ "Modify {smiles} to achieve a {direction} in solubility of approximately 1 logS while preserving these functional groups: {pretty_groups}",
178
+ "How would you alter {smiles} to {direction} its solubility by about 1 logS unit without changing these key groups: {pretty_groups}?",
179
+ "Suggest structural changes to {smiles} that would {direction} its solubility by roughly 1 logS while maintaining these groups: {pretty_groups}",
180
+ "Design a variant of {smiles} with {direction}d solubility (by about 1 logS) that retains all of these intact functional groups: {pretty_groups}",
181
+ ],
182
+ }
183
+
184
+ SMILES_FROM_FORMULA_PROMPTS = [
185
+ "Propose a molecule that has the following formula: {formula}.",
186
+ "Generate a SMILES representation for a compound with the formula {formula}.",
187
+ "What is a plausible SMILES for a compound with the formula {formula}?",
188
+ "Given the formula {formula}, can you suggest a possible SMILES structure?",
189
+ "Create a SMILES representation for a molecule that corresponds to the formula {formula}.",
190
+ "Identify a potential SMILES for a compound with the molecular formula {formula}.",
191
+ "What SMILES structure could correspond to the formula {formula}?",
192
+ "Generate a plausible SMILES for a compound with the formula {formula}.",
193
+ "Given the formula {formula}, what would be a reasonable SMILES representation?",
194
+ "Propose a SMILES structure for a molecule with the formula {formula}.",
195
+ "Generate a SMILES representation for a compound with the formula {formula}.",
196
+ ]
src/ether0/py.typed ADDED
File without changes
src/ether0/rewards.py ADDED
@@ -0,0 +1,753 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import logging
3
+ import operator
4
+ import re
5
+ import unicodedata
6
+ from collections.abc import Iterable, Mapping, Sequence
7
+ from typing import Protocol, cast
8
+
9
+ import exmol
10
+ from pydantic import JsonValue
11
+ from rdkit import Chem, DataStructs
12
+ from rdkit.Chem import GetMolFrags, SanitizeMol # pylint: disable=no-name-in-module
13
+ from rdkit.Chem.rdMolDescriptors import ( # pylint: disable=no-name-in-module
14
+ CalcMolFormula,
15
+ GetMorganFingerprintAsBitVect,
16
+ )
17
+ from rdkit.Chem.rdmolfiles import MolToSmiles # pylint: disable=no-name-in-module
18
+ from rdkit.rdBase import BlockLogs
19
+
20
+ from ether0.clients import fetch_forward_rxn, fetch_purchasable, fetch_solubility
21
+ from ether0.data import is_reasonable_fp, is_reasonable_ring_system, mol_from_smiles
22
+ from ether0.model_prompts import extract_answer_loose, extract_thought_answer_strict
23
+ from ether0.models import RewardFunctionInfo, RewardReason
24
+
25
+ block = BlockLogs()
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class RewardEvalFn(Protocol):
31
+ def __call__(
32
+ self,
33
+ yhat: str,
34
+ y: str,
35
+ soft: bool = False,
36
+ test: bool = False,
37
+ metadata: dict[str, JsonValue] | None = None,
38
+ ) -> float: ...
39
+
40
+
41
+ def formula_diff(formula1: str, formula2: str) -> float:
42
+ """Calculate l2 norm between two molecular formulas."""
43
+ # important = elements we care about in organic chem
44
+ important_elements = {"C", "H", "O", "N", "F", "Cl", "Br", "P", "S"}
45
+ pattern = re.compile(r"([A-Z][a-z]?)(\d*)")
46
+ counts1 = dict.fromkeys(important_elements, 0)
47
+ counts2 = dict.fromkeys(important_elements, 0)
48
+ for m in pattern.finditer(formula1):
49
+ element = m.group(1)
50
+ count = int(m.group(2)) if m.group(2) else 1
51
+ if element in important_elements:
52
+ counts1[element] += count
53
+ for m in pattern.finditer(formula2):
54
+ element = m.group(1)
55
+ count = int(m.group(2)) if m.group(2) else 1
56
+ if element in important_elements:
57
+ counts2[element] += count
58
+ d2 = sum((counts1[k] - counts2[k]) ** 2 for k in important_elements)
59
+ return d2**0.5
60
+
61
+
62
+ def format_reward(
63
+ completions,
64
+ reasoning: bool,
65
+ reward: float = 1.0,
66
+ **kwargs, # noqa: ARG001
67
+ ) -> list[float]:
68
+ """Reward function that checks if the completion has a specific format."""
69
+ if isinstance(completions[0], list):
70
+ completion_contents = [completion[0]["content"] for completion in completions]
71
+ else:
72
+ completion_contents = completions
73
+ # Note we check `answer is not None` since empty answer still counts as valid
74
+ # formatting.
75
+ return [
76
+ reward if answer is not None else 0.0
77
+ for answer in (
78
+ extract_thought_answer_strict(c, reasoning)[1] for c in completion_contents
79
+ )
80
+ ]
81
+
82
+
83
+ SUPERSCRIPT_PATTERN = re.compile(r"\^{([\d,]+)}")
84
+ ITALICS_PATTERN = re.compile(r"{([a-zA-Z])}")
85
+ # parentheses that aren't nested or contain hyphens
86
+ # https://regex101.com/r/6c8smX/1
87
+ USELESS_PARENTHESES = re.compile(r"([-\d])[\(\[{]([A-Za-z0-9]+)[\]\)}]-")
88
+
89
+
90
+ def normalize_iupac(s: str) -> str:
91
+ """Normalize an IUPAC name by removing special formatting and characters.
92
+
93
+ Args:
94
+ s: Original IUPAC name.
95
+
96
+ Returns:
97
+ A normalized IUPAC name without special characters.
98
+ """
99
+ s = s.strip().casefold()
100
+ # replace ^{n} with ^(n)
101
+ s = SUPERSCRIPT_PATTERN.sub(r"^(\1)", s)
102
+ # remove italicized pattern - but don't match ^{1,5} (by avoiding matching commas)
103
+ s = ITALICS_PATTERN.sub(r"\1", s)
104
+ # remove garbage
105
+ s = s.replace("$", "").replace("~", "") # noqa: FURB184
106
+ # remove parentheses that aren't nested or contain hyphens
107
+ s = USELESS_PARENTHESES.sub(r"\1\2-", s)
108
+ # ok to ignore carrots and hpyhens for comparison
109
+ return s.replace("^", "").replace(" ", "-") # noqa: FURB184
110
+
111
+
112
+ def normalize_unicodes(s: str) -> str:
113
+ """Normalize all Unicode dashes/hyphens to regular hyphen.
114
+
115
+ Args:
116
+ s: Input string with potential Unicode characters.
117
+
118
+ Returns:
119
+ Unicode-normalized string.
120
+ """
121
+ s = unicodedata.normalize("NFKC", s)
122
+ s = "".join("-" if unicodedata.category(c) in {"Pd", "Po"} else c for c in s)
123
+ return s.replace("-", "") # minus sign # noqa: FURB184
124
+
125
+
126
+ def is_reasonable_molecule(
127
+ mol: Chem.Mol,
128
+ metadata: dict[str, JsonValue] | None,
129
+ test: bool, # noqa: ARG001
130
+ ref_mol: Chem.Mol | None = None,
131
+ ) -> bool:
132
+ """Returns True if the molecule passes heuristics for being a reasonable molecule."""
133
+ # always check valence
134
+ try:
135
+ SanitizeMol(mol)
136
+ except Exception:
137
+ RewardReason.INVALID_MOL.set_reason(metadata)
138
+ return False
139
+
140
+ # We have decided that the convention will be to check the
141
+ # same at test time and train time.
142
+
143
+ # determine if we have counter-ions (which is fine), but we want to
144
+ # evaluate the largest molecule only. We only consider single molecules
145
+ # or single molecules + a counterion as valid responses
146
+ sorted_frags = sorted( # sort by size
147
+ GetMolFrags(mol, asMols=True), key=lambda m: m.GetNumAtoms(), reverse=True
148
+ )
149
+ if len(sorted_frags) > 2: # noqa: PLR2004
150
+ # not a counter-ion
151
+ RewardReason.FAILED_COUNTERION_CHECK.set_reason(metadata)
152
+ return False
153
+ if len(sorted_frags) == 2: # noqa: PLR2004
154
+ # If 2, assume first is counter-ion, and double check it's small
155
+ cmol = sorted_frags[1]
156
+ if cmol.GetNumHeavyAtoms() > 5: # noqa: PLR2004
157
+ RewardReason.FAILED_COUNTERION_CHECK.set_reason(metadata)
158
+ return False
159
+
160
+ mol = sorted_frags[0]
161
+
162
+ ring_status = is_reasonable_ring_system(mol, ref_mol)
163
+
164
+ if not ring_status:
165
+ RewardReason.FAILED_RING_CHECK.set_reason(metadata)
166
+ return False
167
+
168
+ failure = is_reasonable_fp(mol, ref_mol)
169
+ if not failure:
170
+ RewardReason.FAILED_REOS_CHECK.set_reason(metadata)
171
+ return False
172
+ return True
173
+
174
+
175
+ FULL_SMILES_KEY = "full_smiles"
176
+
177
+
178
+ def set_full_smiles(smiles: str, metadata: dict[str, JsonValue] | None) -> None:
179
+ if metadata is not None:
180
+ metadata[FULL_SMILES_KEY] = smiles
181
+
182
+
183
+ BAD_SMARTS_PATTERNS = [
184
+ "[#16]-[#16]-[#16]", # More than a thiol bond
185
+ "[#8]~[#8]", # Peroxides
186
+ "[#7]-[NH2]", # Hydrazines
187
+ "[#7]-[NH3]", # weird charged amine
188
+ "[#7]~[#7]~[#7]", # 3 or more amines
189
+ "[NX2](=[OX1])[O;$([X2]),$([X1-])]", # Nitrite
190
+ "[SX2][NX2]=[OX1]", # Thionitrite
191
+ "[$([NX3](=[OX1])(=[OX1])[O;$([X2]),$([X1-])]),$([NX3+]([OX1-])(=[OX1])[O;$([X2]),$([X1-])])]", # Nitrate # noqa: E501
192
+ "[$([NX3](=O)=O),$([NX3+](=O)[O-])][!#8]", # Nitro
193
+ "[NX2](=[OX1])[!#7;!#8]", # Nitroso
194
+ "[CX4]" + ("-[CX4]" * 6), # Long chain of carbons (7 or more)
195
+ ]
196
+
197
+
198
+ def contains_bad_substruct(mol: Chem.Mol) -> bool:
199
+ return any(
200
+ mol.HasSubstructMatch(Chem.MolFromSmarts(pat)) for pat in BAD_SMARTS_PATTERNS
201
+ )
202
+
203
+
204
+ def rxn_eval(
205
+ yhat: str,
206
+ y: str,
207
+ soft: bool = False, # noqa: ARG001
208
+ test: bool = False, # noqa: ARG001
209
+ metadata: dict[str, JsonValue] | None = None, # noqa: ARG001
210
+ ) -> float:
211
+ """Returns 1.0 if strings match (case-insensitive), otherwise 0.0."""
212
+ # some normalization for IUPAC names - shouldn't affect others
213
+
214
+ if normalize_iupac(yhat) == normalize_iupac(y):
215
+ return 1.0
216
+
217
+ # If that fails (would return 0), try normalizing further
218
+ return (
219
+ 1.0
220
+ if normalize_unicodes(normalize_iupac(yhat))
221
+ == normalize_unicodes(normalize_iupac(y))
222
+ else 0.0
223
+ )
224
+
225
+
226
+ def str_eval(
227
+ yhat: str,
228
+ y: str,
229
+ soft: bool = False, # noqa: ARG001
230
+ test: bool = False, # noqa: ARG001
231
+ metadata: dict[str, JsonValue] | None = None,
232
+ ) -> float:
233
+ """Returns 1.0 if strings match (case-insensitive), otherwise 0.0."""
234
+ set_full_smiles(yhat, metadata)
235
+ return 1.0 if normalize_iupac(yhat) == normalize_iupac(y) else 0.0
236
+
237
+
238
+ def valid_mol_eval(
239
+ yhat: str,
240
+ y: str,
241
+ soft: bool = False, # noqa: ARG001
242
+ test: bool = False,
243
+ metadata: dict[str, JsonValue] | None = None,
244
+ ) -> float:
245
+ """Validate if yhat is a valid SMILES string, when appended to y.
246
+
247
+ Args:
248
+ yhat: Model-predicted SMILES string or partial completion.
249
+ y: Base SMILES string (e.g. "O=C1CCC2=CC=C(O)C(OC)=C2C#CCC2=CC3=C4") to append
250
+ yhat and check validity.
251
+ test: unused
252
+ soft: unused
253
+ metadata: optional metadata dictionary
254
+
255
+ Returns:
256
+ 1.0 if `y + yhat` is a valid SMILES string, 0.0 otherwise.
257
+ """
258
+ if not yhat:
259
+ RewardReason.INVALID_MOL.set_reason(metadata)
260
+ return 0.0
261
+
262
+ # First attempt yhat alone (assuming full SMILES), then try y+yhat (assuming
263
+ # partial) if that fails
264
+ for smiles in (yhat, y + yhat):
265
+ if not smiles.startswith(y):
266
+ # only accept a solution containing the answer
267
+ continue
268
+ try:
269
+ mol = mol_from_smiles(smiles)
270
+ except Exception:
271
+ logger.exception(
272
+ f"Failed to construct molecule from SMILES string {yhat!r}."
273
+ )
274
+ continue
275
+ if mol is not None:
276
+ set_full_smiles(smiles, metadata)
277
+ if not is_reasonable_molecule(mol, metadata, test):
278
+ return 0.0
279
+ return 1.0
280
+
281
+ # Nothing worked - mark as invalid
282
+ RewardReason.INVALID_MOL.set_reason(metadata)
283
+ return 0.0
284
+
285
+
286
+ SMOOTH_THRESHOLD_TANIMOTO_SIMILARITY = 0.7 # close enough
287
+
288
+
289
+ def tanimoto_similarity(
290
+ m1: Chem.Mol | None, m2: Chem.Mol | None, atom_threshold: float = 10.0
291
+ ) -> float:
292
+ """Calculate Tanimoto similarity between two molecules.
293
+
294
+ The `atom_threshold` parameter is a relative fraction (e.g., `0.2` for 20%)
295
+ that sets a threshold for degenerate cases when the fingerprints are similar,
296
+ but there are many more atoms in one molecule.
297
+
298
+ Default is 10.0, which corresponds to a 1000% difference and has no practical effect.
299
+ """
300
+ if m1 is None or m2 is None:
301
+ return 0.0
302
+ fp1 = GetMorganFingerprintAsBitVect(m1, 2)
303
+ fp2 = GetMorganFingerprintAsBitVect(m2, 2)
304
+
305
+ # heavy atom threshold
306
+ atoms1 = m1.GetNumHeavyAtoms()
307
+ atoms2 = m2.GetNumHeavyAtoms()
308
+ if (denom := max(atoms1, atoms2)) > 0:
309
+ # Do not apply the atom diff check if there are no heavy atoms.
310
+ # This is always safe, since the only way to avoid
311
+ # this block is if m1=m2=H2, which would pass anyway.
312
+ atom_diff = abs(atoms1 - atoms2) / denom
313
+ if atom_diff > atom_threshold:
314
+ return 0.0
315
+ return DataStructs.TanimotoSimilarity(fp1, fp2)
316
+
317
+
318
+ def exact_mol_match(m1: Chem.Mol, m2: Chem.Mol) -> float:
319
+ s1 = MolToSmiles(m1, canonical=True, isomericSmiles=True) # noqa: FURB120
320
+ s2 = MolToSmiles(m2, canonical=True, isomericSmiles=True) # noqa: FURB120
321
+ return 1.0 if s1 == s2 else 0.0
322
+
323
+
324
+ def get_largest_mol(smiles: str) -> Chem.Mol | None:
325
+ parts = smiles.split(".")
326
+ # Filter out small fragments (removes counter-ions) and invalid SMILES
327
+ mols = [
328
+ mol_from_smiles(p)
329
+ for p in parts
330
+ if (len(p) > 3 and mol_from_smiles(p) is not None) # noqa: PLR2004
331
+ ]
332
+ if not mols:
333
+ return None
334
+ mols_atoms = []
335
+ for mol in mols:
336
+ n_atoms = None if mol is None else mol.GetNumAtoms()
337
+ if n_atoms is None:
338
+ raise NotImplementedError(f"Didn't handle {mol=} having None atoms.")
339
+ mols_atoms.append((mol, n_atoms))
340
+ return max(mols_atoms, key=operator.itemgetter(1))[0]
341
+
342
+
343
+ def product_eval(
344
+ yhat: str,
345
+ y: str,
346
+ soft: bool = False,
347
+ test: bool = False, # noqa: ARG001
348
+ metadata: dict[str, JsonValue] | None = None,
349
+ ) -> float:
350
+ """Computes the Tanimoto similarity of the largest fragments from two SMILES (if soft) or exact match (if not soft).
351
+
352
+ Returns:
353
+ Reward in [0, 1].
354
+ """ # noqa: E501,W505
355
+ m1 = get_largest_mol(yhat)
356
+ m2 = get_largest_mol(y)
357
+
358
+ if m1 is None:
359
+ RewardReason.INVALID_MOL.set_reason(metadata)
360
+ return 0.0
361
+ if m2 is None:
362
+ RewardReason.INVALID_GROUND_TRUTH.set_reason(metadata)
363
+ logger.warning(f"Invalid ground truth molecule {y!r}.")
364
+ return 0.0
365
+
366
+ # don't use yhat directly since it may have multiple molecules
367
+ set_full_smiles(MolToSmiles(m1), metadata)
368
+
369
+ if soft:
370
+ return tanimoto_similarity(m1, m2)
371
+
372
+ return exact_mol_match(m1, m2) # exact match for non-soft mode
373
+
374
+
375
+ def caption_eval(
376
+ yhat: str,
377
+ y: str,
378
+ soft: bool = False,
379
+ test: bool = False,
380
+ metadata: dict[str, JsonValue] | None = None,
381
+ ) -> float:
382
+ """Currently forwards to product_eval, but also stores Tanimoto in metadata."""
383
+ if metadata is not None:
384
+ m1 = get_largest_mol(yhat)
385
+ m2 = get_largest_mol(y)
386
+ metadata["tanimoto"] = (
387
+ tanimoto_similarity(m1, m2) if (m1 is not None and m2 is not None) else 0.0
388
+ )
389
+ return product_eval(yhat, y, soft, test, metadata)
390
+
391
+
392
+ def formula_eval(
393
+ yhat: str,
394
+ y: str,
395
+ soft: bool = False,
396
+ test: bool = False,
397
+ metadata: dict[str, JsonValue] | None = None,
398
+ ) -> float:
399
+ """Check correct formula and Tanimoto similarity, giving a reward in [0, 1]."""
400
+ set_full_smiles(yhat, metadata)
401
+ mhat = mol_from_smiles(yhat)
402
+ m = mol_from_smiles(y)
403
+ if mhat is None:
404
+ RewardReason.INVALID_MOL.set_reason(metadata)
405
+ return 0.0
406
+ if m is None:
407
+ RewardReason.INVALID_GROUND_TRUTH.set_reason(metadata)
408
+ logger.warning(f"Invalid ground truth molecule {y!r}.")
409
+ return 0.0
410
+
411
+ fhat = CalcMolFormula(mhat)
412
+ f = CalcMolFormula(m)
413
+ if fhat != f:
414
+ RewardReason.WRONG_FORMULA.set_reason(metadata)
415
+ return 0.0
416
+
417
+ if not is_reasonable_molecule(mhat, metadata, test, ref_mol=m):
418
+ return 0.0
419
+
420
+ return (
421
+ 1.0
422
+ if tanimoto_similarity(mhat, m) >= SMOOTH_THRESHOLD_TANIMOTO_SIMILARITY
423
+ # Give partial credit if soft=True and we got the right formula
424
+ else (0.5 if soft else 0.0)
425
+ )
426
+
427
+
428
+ def functional_group_eval(
429
+ yhat: str,
430
+ y: str,
431
+ soft: bool = False,
432
+ test: bool = False,
433
+ metadata: dict[str, JsonValue] | None = None,
434
+ ) -> float:
435
+ """Match functional group and formula, giving a reward in [0, 1]."""
436
+ set_full_smiles(yhat, metadata)
437
+ mhat = mol_from_smiles(yhat)
438
+ if mhat is None:
439
+ RewardReason.INVALID_MOL.set_reason(metadata)
440
+ return 0.0
441
+
442
+ y_args: tuple[str, list[str]] = ast.literal_eval(y)
443
+
444
+ formula = y_args[0]
445
+ groups = {g.lower() for g in y_args[1]}
446
+
447
+ fhat = CalcMolFormula(mhat)
448
+ if fhat != formula:
449
+ RewardReason.WRONG_FORMULA.set_reason(metadata)
450
+ return 0.0
451
+
452
+ groupshat: set[str] = {
453
+ f.lower() for f in exmol.get_functional_groups(mhat, return_all=True)
454
+ }
455
+
456
+ if not is_reasonable_molecule(mhat, metadata, test):
457
+ return 0.0
458
+
459
+ return (
460
+ 1.0
461
+ if groups <= groupshat
462
+ # Give partial credit if soft=True and we got the right formula
463
+ else (0.5 if soft else 0.0)
464
+ )
465
+
466
+
467
+ def oracle_solubility_eval(
468
+ yhat: str,
469
+ y: str,
470
+ soft: bool = False, # noqa: ARG001
471
+ test: bool = False,
472
+ metadata: dict[str, JsonValue] | None = None,
473
+ ) -> float:
474
+ """Evaluate solubility prediction using remote, giving a reward in [0, 1]."""
475
+ set_full_smiles(yhat, metadata)
476
+ # we only want single molecules
477
+ if "." in yhat:
478
+ return 0.0
479
+ mhat = mol_from_smiles(yhat)
480
+ if mhat is None:
481
+ RewardReason.INVALID_MOL.set_reason(metadata)
482
+ return 0.0
483
+
484
+ y_args: tuple[str, str | list[str], float | str, str] = ast.literal_eval(y)
485
+ constraint_type, constraint_data = y_args[:2]
486
+ target = float(y_args[2])
487
+ # Unused: direction = y_args[3] # noqa: ERA001
488
+
489
+ ref_mol: Chem.Mol | None = None
490
+
491
+ # first check constraint
492
+ if constraint_type == "scaffold":
493
+ ref_mol = mol_from_smiles(cast(str, constraint_data))
494
+ if ref_mol is None:
495
+ raise NotImplementedError(
496
+ f"Didn't handle when {constraint_data=} is invalid."
497
+ )
498
+ if not mhat.HasSubstructMatch(ref_mol):
499
+ RewardReason.FAILED_CONSTRAINT.set_reason(metadata)
500
+ return 0.0
501
+ elif constraint_type == "groups":
502
+ groups = [g.lower() for g in exmol.get_functional_groups(mhat, return_all=True)]
503
+ if not any(group.lower() in groups for group in constraint_data):
504
+ RewardReason.FAILED_CONSTRAINT.set_reason(metadata)
505
+ return 0.0
506
+ elif constraint_type == "tanimoto":
507
+ ref_mol = mol_from_smiles(cast(str, constraint_data))
508
+ if (
509
+ tanimoto_similarity(mhat, ref_mol, atom_threshold=0.2)
510
+ < SMOOTH_THRESHOLD_TANIMOTO_SIMILARITY
511
+ ):
512
+ RewardReason.FAILED_CONSTRAINT.set_reason(metadata)
513
+ return 0.0
514
+ else:
515
+ raise ValueError(f"Unknown constraint type: {constraint_type}")
516
+
517
+ if not is_reasonable_molecule(mhat, metadata, test, ref_mol=ref_mol):
518
+ return 0.0
519
+
520
+ # make sure we hit the target
521
+ result = fetch_solubility(yhat)
522
+ if "solubility" in result:
523
+ solubility = result["solubility"]
524
+ delta = solubility - target
525
+ # hard coded to typical solubility accuracies
526
+ # we subtract 0.01 because some questions ask for
527
+ # 0.5 change and we don't want restatements to
528
+ # be matches
529
+ if abs(delta) > (0.5 - 0.01):
530
+ RewardReason.WRONG_NUMERICAL_ANSWER.set_reason(metadata)
531
+ return 0.0
532
+ return 1.0
533
+ RewardReason.INVALID_MOL.set_reason(metadata)
534
+ return 0.0
535
+
536
+
537
+ def oracle_rxn_eval(
538
+ yhat: str,
539
+ y: str,
540
+ soft: bool = False,
541
+ test: bool = False, # noqa: ARG001
542
+ metadata: dict[str, JsonValue] | None = None,
543
+ ) -> float:
544
+ """Evaluate forward reaction prediction using remote, giving a reward in [0, 1]."""
545
+ if ">" not in yhat or "." not in yhat:
546
+ RewardReason.INVALID_RXN.set_reason(metadata)
547
+ return 0.0
548
+
549
+ # make sure there are not more than two angle brackets
550
+ if yhat.count(">") > 2: # noqa: PLR2004
551
+ RewardReason.INVALID_RXN.set_reason(metadata)
552
+ return 0.0
553
+
554
+ # ok now do real check on regex after heuristic checks
555
+ # adapted partly from https://gist.github.com/lsauer/1312860/264ae813c2bd2c27a769d261c8c6b38da34e22fb
556
+ # https://regex101.com/r/9bdE6H/1
557
+ # basically SMILES_THINGS>SMILES_THINGS | empty>
558
+ if not re.match(
559
+ r"^[^J][a-z0-9@+\-\[\]\(\)\\\/%=#$\.]{6,}>[a-z0-9@+\-\[\]\(\)\\\/%=#$\.]{0,}>",
560
+ yhat,
561
+ re.IGNORECASE, # lower = aromatic, which we're fine matching
562
+ ):
563
+ RewardReason.INVALID_RXN.set_reason(metadata)
564
+ return 0.0
565
+
566
+ ymol = mol_from_smiles(y)
567
+ if ymol is None:
568
+ RewardReason.INVALID_GROUND_TRUTH.set_reason(metadata)
569
+ logger.warning(f"Invalid ground truth molecule {y!r}.")
570
+ return 0.0
571
+
572
+ reactant_smi = yhat.split(">")[0].split(".")
573
+ reactants = [mol_from_smiles(r) for r in reactant_smi]
574
+ if not all(x is not None for x in reactants):
575
+ RewardReason.INVALID_MOL.set_reason(metadata)
576
+ return 0.0
577
+
578
+ reagents = [mol_from_smiles(r) for r in yhat.split(">")[1].split(".") if r.strip()]
579
+ if not all(x is not None for x in reagents):
580
+ RewardReason.INVALID_MOL.set_reason(metadata)
581
+ return 0.0
582
+
583
+ # check products, if present, contain the desired product
584
+ products = [mol_from_smiles(r) for r in yhat.split(">")[2].split(".") if r.strip()]
585
+ # notice we pass if there are no products
586
+ if products:
587
+ if not all(x is not None for x in products):
588
+ RewardReason.INVALID_MOL.set_reason(metadata)
589
+ return 0.0
590
+ if not any(exact_mol_match(m, ymol) == 1.0 for m in products): # type: ignore[arg-type]
591
+ RewardReason.INVALID_RXN.set_reason(metadata)
592
+ return 0.0
593
+
594
+ # Disallow products in the reactants or reagents
595
+ if any(exact_mol_match(m, ymol) == 1.0 for m in (reactants + reagents)): # type: ignore[arg-type]
596
+ RewardReason.PRODUCT_IS_REACTANT.set_reason(metadata)
597
+ return 0.0
598
+
599
+ # check that the reactants are purchasable
600
+
601
+ def is_small_so_probably_purchasable(smi: str) -> bool:
602
+ mol = mol_from_smiles(smi)
603
+ # Molecules with <= 4 heavy atoms are likely purchasable,
604
+ # since they include solvents and counterions
605
+ return mol is not None and mol.GetNumHeavyAtoms() <= 4 # noqa: PLR2004
606
+
607
+ purchasable_results = fetch_purchasable(reactant_smi)
608
+ if not all(
609
+ purchasable_results.get(r, False) or is_small_so_probably_purchasable(r)
610
+ for r in reactant_smi
611
+ ):
612
+ RewardReason.NOT_PURCHASABLE.set_reason(metadata)
613
+ return 0.0
614
+
615
+ result = fetch_forward_rxn(yhat)
616
+ if "product" in result:
617
+ product = result["product"]
618
+ pmol = mol_from_smiles(product)
619
+ if pmol is None:
620
+ RewardReason.INVALID_MOL.set_reason(metadata)
621
+ return 0.0
622
+ if soft:
623
+ return tanimoto_similarity(pmol, ymol)
624
+ if exact_mol_match(pmol, ymol) == 1.0:
625
+ return 1.0
626
+ RewardReason.WRONG_PRODUCT.set_reason(metadata)
627
+ return 0.0
628
+ RewardReason.INVALID_RXN.set_reason(metadata)
629
+ return 0.0
630
+
631
+
632
+ def valid_molecule_eval(
633
+ yhat: str,
634
+ y: str, # noqa: ARG001
635
+ soft: bool = False, # noqa: ARG001
636
+ test: bool = False, # noqa: ARG001
637
+ metadata: dict[str, JsonValue] | None = None, # noqa: ARG001
638
+ ) -> float:
639
+ """Evaluate if yhat is valid molecule."""
640
+ if not yhat:
641
+ return 0.0
642
+ mol = mol_from_smiles(yhat, sanitize=True)
643
+ return float(mol is not None)
644
+
645
+
646
+ EVAL_FUNCTIONS: Mapping[str, RewardEvalFn] = {
647
+ "str_eval": str_eval,
648
+ "valid_mol_eval": valid_mol_eval,
649
+ "caption_eval": caption_eval,
650
+ "product_eval": product_eval,
651
+ "rxn_eval": rxn_eval,
652
+ "formula_eval": formula_eval,
653
+ "functional_group_eval": functional_group_eval,
654
+ "sol_eval": oracle_solubility_eval,
655
+ "rxn_forward": oracle_rxn_eval,
656
+ "should_not_answer_eval": str_eval,
657
+ "should_answer_eval": valid_molecule_eval,
658
+ }
659
+
660
+
661
+ # These correspond to open-ended problems that do not have a
662
+ # unique molecule as answer.
663
+ APPLY_GOOD_MOLECULE_CHECK: set[str] = {
664
+ "valid_mol_eval",
665
+ "formula_eval",
666
+ "functional_group_eval",
667
+ "sol_eval",
668
+ }
669
+
670
+
671
+ def accuracy_reward(
672
+ completions: Sequence[list[Mapping[str, str]]] | Sequence[str],
673
+ solution: Iterable[str],
674
+ reasoning: bool,
675
+ metadata: list[dict[str, JsonValue]] | None = None,
676
+ soft: bool = False,
677
+ test: bool = False,
678
+ good_molecule_bonus: float = 0.0,
679
+ **kwargs, # noqa: ARG001
680
+ ) -> list[float]:
681
+ """Reward function that checks if the completion is the same as the ground truth."""
682
+ if isinstance(completions[0], list):
683
+ messages = cast(Sequence[list[Mapping[str, str]]], completions)
684
+ contents: Sequence[str] = [m[0]["content"] for m in messages]
685
+ else:
686
+ contents = cast(Sequence[str], completions)
687
+ if soft and test:
688
+ raise ValueError("Soft mode is not supported for test time accuracy reward.")
689
+ rewards = []
690
+ problem_types: list[str | None] = []
691
+
692
+ if metadata is None:
693
+ # Create empty metadata that we can use internal to this function
694
+ metadata = [{} for _ in contents]
695
+ else:
696
+ if metadata:
697
+ raise NotImplementedError(f"Received non-empty metadata {metadata}.")
698
+ metadata.extend([{} for _ in contents])
699
+
700
+ for content, info, meta in zip(contents, solution, metadata, strict=True):
701
+ reward = 0.0
702
+ reward_info = RewardFunctionInfo.model_validate(info)
703
+ fxn_name, answer_info, problem_type = tuple(reward_info.model_dump().values())
704
+ try:
705
+ answer: str | None = (
706
+ extract_answer_loose(content)
707
+ if test
708
+ else extract_thought_answer_strict(content, reasoning=reasoning)[1]
709
+ )
710
+ if answer is not None:
711
+ # During test time, see if full SMILES string was given as input
712
+ if problem_type == "valid_mol_eval" and test:
713
+ # If we're testing, we only allow full SMILES strings
714
+ reward = EVAL_FUNCTIONS[fxn_name](
715
+ answer, answer_info, test=test, metadata=meta
716
+ )
717
+ else:
718
+ reward = EVAL_FUNCTIONS[fxn_name](
719
+ answer, answer_info, soft=soft, metadata=meta
720
+ )
721
+ RewardReason.set_default_reason(reward, meta)
722
+
723
+ if reward == 1.0 and fxn_name in APPLY_GOOD_MOLECULE_CHECK:
724
+ if FULL_SMILES_KEY not in meta:
725
+ raise ValueError( # noqa: TRY301
726
+ f"Missing full SMILES key in metadata {meta}"
727
+ f" with reward function {fxn_name}."
728
+ )
729
+ full_smiles = cast(str, meta[FULL_SMILES_KEY])
730
+ mol = mol_from_smiles(full_smiles)
731
+ if mol is None:
732
+ raise ValueError( # noqa: TRY301
733
+ f"Invalid full SMILES {full_smiles}"
734
+ f" with reward function {fxn_name}."
735
+ )
736
+ meta["is_good_molecule"] = not contains_bad_substruct(mol)
737
+ if meta["is_good_molecule"]:
738
+ reward += good_molecule_bonus
739
+
740
+ else:
741
+ RewardReason.FORMAT_FAILED.set_reason(meta)
742
+ rewards.append(reward)
743
+ problem_types.append(problem_type)
744
+ except Exception:
745
+ logger.exception(
746
+ f"Unhandled exception in {fxn_name=} for {problem_type=}"
747
+ f" with inputs {content=}, {answer_info=} {soft=}, and {test=}."
748
+ )
749
+ RewardReason.REWARD_FUNCTION_EXCEPTION.set_reason(meta)
750
+ rewards.append(reward)
751
+ problem_types.append(None)
752
+
753
+ return rewards
src/ether0/rings.bloom ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1b81b5dc5e59d2068742aafb0c217067e951260f56527998e902d37ac177a25
3
+ size 12500018
src/ether0/utils.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+ from http import HTTPStatus
4
+ from typing import TypeVar
5
+
6
+ import regex
7
+ from datasets import Dataset, DatasetDict, Version, load_dataset
8
+ from datasets.exceptions import DatasetNotFoundError
9
+ from huggingface_hub.errors import HfHubHTTPError
10
+ from tenacity import (
11
+ before_sleep_log,
12
+ retry,
13
+ retry_if_exception,
14
+ stop_after_attempt,
15
+ wait_fixed,
16
+ )
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # pylint: disable-next=invalid-name
21
+ TDataset = TypeVar("TDataset", bound=Dataset | DatasetDict)
22
+
23
+
24
+ @retry(
25
+ retry=retry_if_exception(
26
+ lambda x: (
27
+ (
28
+ # On 2/11/2025 James kept seeing on the g3 server cluster:
29
+ # > huggingface_hub.errors.HfHubHTTPError: 504 Server Error: Gateway Time-out for
30
+ # > url: https://huggingface.co/api/datasets/org/repo/paths-info/abc123
31
+ # And on 3/14 James saw this on the g3 server cluster:
32
+ # > huggingface_hub.errors.HfHubHTTPError: 502 Server Error: Bad Gateway for
33
+ # > url: https://huggingface.co/api/datasets/org/repo/paths-info/abc123
34
+ isinstance(x, HfHubHTTPError)
35
+ and x.response.status_code
36
+ in {HTTPStatus.BAD_GATEWAY.value, HTTPStatus.GATEWAY_TIMEOUT.value}
37
+ )
38
+ # On 4/14/2025 James kept seeing on the g5 server cluster:
39
+ # > datasets.exceptions.DatasetNotFoundError:
40
+ # > Dataset 'org/repo' doesn't exist on the Hub or cannot be accessed.
41
+ or isinstance(x, DatasetNotFoundError)
42
+ )
43
+ ),
44
+ before_sleep=before_sleep_log(logger, logging.WARNING),
45
+ stop=stop_after_attempt(5),
46
+ wait=wait_fixed(5),
47
+ )
48
+ def load_dataset_retrying(
49
+ path: str,
50
+ revision: str | Version | None = None,
51
+ ) -> DatasetDict:
52
+ return load_dataset(path, revision=revision)
53
+
54
+
55
+ # SEE: https://www.compart.com/en/unicode/block/U+2070 for subscript letters
56
+ invalid_chars_regex = re.compile(
57
+ r"[^A-Za-z0-9Α-Ωα-ωₐₑₒₓₔₕₖₗₘₙₚₛₜ⁰¹²³⁴⁵⁶⁷⁸⁹₀₁₂₃₄₅₆₇₈₉×\s!\"#$%&±⁻'´ʻ‘’ʼ“”()*+⁺,\-—–‐‑‒―−⏤./:;«<≤=≡≈≆≥>›»⇌?@[\\\]^_`{|}~←⇐→➔➞➛➡➟➧➭⇨⇒⇛⟺⇔⟶…]" # noqa: RUF001
58
+ )
59
+ invalid_languages_regex = regex.compile(
60
+ r"[\p{"
61
+ + r"}\p{".join({
62
+ # SEE: https://jrgraphix.net/r/Unicode/
63
+ "Arabic",
64
+ "Armenian",
65
+ "Bengali",
66
+ "Braille_Patterns",
67
+ "Cyrillic",
68
+ "Devanagari",
69
+ "Ethiopic",
70
+ "Georgian",
71
+ "Gujarati",
72
+ "Gurmukhi",
73
+ "Han",
74
+ "Hangul",
75
+ "Hebrew",
76
+ "Hiragana",
77
+ "Kannada",
78
+ "Katakana",
79
+ "Khmer",
80
+ "Latin_Extended_A",
81
+ "Latin_Extended_Additional",
82
+ "Latin_Extended_B",
83
+ "Malayalam",
84
+ "Myanmar",
85
+ "Syriac",
86
+ "Tamil",
87
+ "Telugu",
88
+ "Thaana",
89
+ "Thai",
90
+ "Tifinagh",
91
+ })
92
+ + r"}]"
93
+ )
94
+
95
+
96
+ def contains_invalid(
97
+ text: str, chars: bool = False, languages: bool = False, threshold: int = 1
98
+ ) -> tuple[bool, list[str]]:
99
+ """Check if the text contains invalid characters or languages."""
100
+ if chars:
101
+ matches = invalid_chars_regex.findall(text)
102
+ if len(matches) >= threshold:
103
+ return True, sorted(matches)
104
+ if languages:
105
+ matches = invalid_languages_regex.findall(text)
106
+ if len(matches) >= threshold:
107
+ return True, sorted(matches)
108
+ return False, []
tests/conftest.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+
3
+ import pytest
4
+ from datasets import Dataset, load_dataset
5
+
6
+ TESTS_DIR = pathlib.Path(__file__).parent
7
+ REPO_ROOT_DIR = TESTS_DIR.parent
8
+
9
+
10
+ @pytest.fixture(name="ether0_benchmark_test", scope="session")
11
+ def fixture_ether0_benchmark_test() -> Dataset:
12
+ return load_dataset("futurehouse/ether0-benchmark", split="test")
tests/test_chat.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from ether0.chat import ChatArguments
4
+ from ether0.model_prompts import ProblemPrompt, SysPrompt
5
+
6
+
7
+ class TestChatArguments:
8
+ @pytest.mark.parametrize(
9
+ ("args", "row", "expected"),
10
+ [
11
+ (
12
+ ChatArguments(problem_prompt=ProblemPrompt.NONE),
13
+ {"problem": "stub problem"},
14
+ {"prompt": [{"content": "stub problem", "role": "user"}]},
15
+ ),
16
+ (
17
+ ChatArguments(problem_prompt=ProblemPrompt.NONE),
18
+ {"problem": ["stub problem", "stub problem 2"]},
19
+ {
20
+ "prompt": [
21
+ [{"content": "stub problem", "role": "user"}],
22
+ [{"content": "stub problem 2", "role": "user"}],
23
+ ]
24
+ },
25
+ ),
26
+ (
27
+ ChatArguments(
28
+ sys_prompt=SysPrompt.SCIENTIFIC_AI,
29
+ problem_prompt=ProblemPrompt.THINK_ANSWER,
30
+ ),
31
+ {"problem": "stub problem"},
32
+ {
33
+ "prompt": [
34
+ {
35
+ "role": "system",
36
+ "content": "You are a scientific reasoning AI assistant.",
37
+ },
38
+ {
39
+ "role": "user",
40
+ "content": (
41
+ "A conversation between User and Assistant. The user"
42
+ " asks a question, and the Assistant solves it. The"
43
+ " assistant first thinks about the reasoning process in"
44
+ " the mind and then provides the user with the answer."
45
+ " The reasoning process and answer are enclosed within"
46
+ " <|think_start|> <|think_end|>"
47
+ " and <|answer_start|> <|answer_end|> tags,"
48
+ " respectively, i.e., <|think_start|> reasoning process here"
49
+ " <|think_end|><|answer_start|> answer here <|answer_end|>"
50
+ "\n\nstub problem"
51
+ ),
52
+ },
53
+ ]
54
+ },
55
+ ),
56
+ ],
57
+ )
58
+ def test_rl_conversation(
59
+ self, args: ChatArguments, row: dict, expected: dict
60
+ ) -> None:
61
+ assert args.make_rl_conversation(row) == expected
tests/test_data.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Collection
2
+
3
+ import pytest
4
+ from datasets import Dataset
5
+ from pydantic import JsonValue
6
+
7
+ from ether0.data import (
8
+ SMILES_PATTERN,
9
+ get_problem_categories_from_datasets,
10
+ get_problem_category,
11
+ )
12
+ from ether0.models import RewardFunctionInfo
13
+ from ether0.rewards import EVAL_FUNCTIONS
14
+
15
+
16
+ def test_get_problem_categories_from_datasets(ether0_benchmark_test: Dataset) -> None:
17
+ assert get_problem_categories_from_datasets(ether0_benchmark_test) == {
18
+ "functional-group",
19
+ "molecule-completion",
20
+ "molecule-formula",
21
+ "molecule-name",
22
+ "oracle-solubility",
23
+ "property-cat-eve",
24
+ "property-cat-safety",
25
+ "property-cat-smell",
26
+ "property-regression-adme",
27
+ "property-regression-ld50",
28
+ "property-regression-pka",
29
+ "reaction-prediction",
30
+ "retro-synthesis",
31
+ "simple-formula",
32
+ }
33
+
34
+
35
+ UNVERIFIABLE_PROBLEM_CATEGORY_PREFIXES_TO_EXCLUDE: Collection[str] = {
36
+ "oracle-solubility", # 'ideal' is not actually an answer
37
+ "retro-synthesis", # 'ideal' is not actually an answer
38
+ }
39
+
40
+
41
+ def test_evals(ether0_benchmark_test: Dataset) -> None:
42
+ failures = []
43
+ for row in ether0_benchmark_test:
44
+ reward_info = RewardFunctionInfo.model_validate(row["solution"])
45
+ fxn_name, answer_info, problem_type = tuple(reward_info.model_dump().values())
46
+ problem_category = get_problem_category(problem_type)
47
+ if (
48
+ problem_category in UNVERIFIABLE_PROBLEM_CATEGORY_PREFIXES_TO_EXCLUDE
49
+ or problem_category
50
+ == "molecule-completion" # Molc had no 'ideal's when this was made
51
+ ):
52
+ continue
53
+ metadata: dict[str, JsonValue] = {}
54
+ try:
55
+ if problem_category.startswith("property"):
56
+ yhat = answer_info
57
+ else:
58
+ assert row["ideal"]
59
+ yhat = row["ideal"]
60
+ assert (
61
+ EVAL_FUNCTIONS[fxn_name](yhat=yhat, y=answer_info, metadata=metadata)
62
+ == 1.0
63
+ )
64
+ except AssertionError:
65
+ failures.append((problem_category, row["id"], metadata))
66
+ assert not failures
67
+
68
+
69
+ TEST_REASONING_TEXT = (
70
+ "Let's analyze the given molecules and try to predict their LD50 values. LD50"
71
+ " refers to the lethal dose at which 50% of the test organisms die. A lower LD50"
72
+ " means higher toxicity, and a higher LD50 indicates lower toxicity. We need to"
73
+ " identify structural features that relate to toxicity.\n\nThe question leaves open"
74
+ " the possibility that none of the compounds have an LD50 of 320 mg/kg. Let's"
75
+ " consider each molecule individually:\n\n1."
76
+ " ClC1=C(C=CC(=C1)Cl)C1(OCC(O1)COC1=CC=C(C=C1)N1CCN(CC1)C(C)=O)CN1C=NC=C1: This"
77
+ " molecule appears to be quite complex. It has a dichloro-substituted aromatic"
78
+ " ring, an ether linkage, a morpholine ring, a piperazine ring, and an imidazole"
79
+ " ring. The presence of two chlorine atoms on the phenyl ring could suggest some"
80
+ " interaction with biological targets. The molecule also has a morpholine and"
81
+ " piperazine moiety which could contribute to binding with receptors or enzymes."
82
+ " The presence of an amide group might indicate some polarity, but the overall"
83
+ " structure looks relatively lipophilic (nonpolar) given the aromatic rings and"
84
+ " alkyl chains.\n\n2."
85
+ " ClC1=C(C=CC(=C1)Cl)[C@]1(OC[C@@H](O1)COC1=CC=C(C=C1)N1CCN(CC1)C1=CC=C(C=C1)N1C(N(N=C1)[C@H](CC)C)=O)CN1N=CN=C1:" # noqa: E501
86
+ " This is a very complex molecule, with multiple rings, stereocenters, and"
87
+ " heteroatoms. It's a distinct structure and appears to be larger than the first"
88
+ " molecule. We can see a furan ring, a pyrazole ring, an amide group, and other"
89
+ " major differences. This change in the rings and other functional groups is likely"
90
+ " to significantly change the molecular properties compared to the first"
91
+ " molecule.\n\n3."
92
+ " [2H]C(C(=O)N1CCN(CC1)C1=CC=C(C=C1)OCC1O[C@@](OC1)(CN1C=NC=C1)C1=C(C=C(C=C1)Cl)Cl)([2H])[2H]:" # noqa: E501
93
+ " This molecule, labeled with deuterium, has multiple rings including a piperazine,"
94
+ " furan, a substituted imidazole, and a dichlorinated phenyl ring. It also includes"
95
+ " an ester group which is sometimes associated with higher toxicity compared to"
96
+ " simple ethers.\n\nThinking about general principles of toxicity, lipophilicity"
97
+ " (fat solubility) is often related to higher toxicity. A molecule with a marked"
98
+ " lipophilic character can often accumulate in fatty tissues and interact with the"
99
+ " cell membrane, affect cellular transport or receptor activity. This could lead to"
100
+ " higher toxicity by interfering with normal cellular function. Similarly, the"
101
+ " presence of chlorine atoms can sometimes contribute to toxicity due to possible"
102
+ " metabolic activation to reactive intermediates. However, the position and nature"
103
+ " of other substituents and functional groups can influence how chlorine"
104
+ " substitutions modulate toxicity. For example, some chlorinated compounds are"
105
+ " relatively non-toxic.\n\nConsidering the size and complexity of the molecules, we"
106
+ " should think about their potential metabolic pathways. Large molecules can be"
107
+ " metabolized through various pathways, potentially leading to reactive"
108
+ " intermediates that interact with biological molecules. Metabolites of these"
109
+ " compounds might be more or less toxic than the initial molecules, and the"
110
+ " metabolic pathways themselves might be quite different. Perhaps one of the"
111
+ " metabolites could be the reason for an LD50 of 320 mg/kg. Alternatively, a"
112
+ " compound might be relatively non-toxic in itself, but its presence can alter"
113
+ " enzyme activity or other metabolic processes and indirectly lead to cell"
114
+ " damage.\n\nComparing the three molecules. Molecules 1 and 2 share some structural"
115
+ " features like the dichloro-substituted aromatic ring and the presence of a"
116
+ " morpholine ring system. However, they also have distinct differences in the"
117
+ " connectivity and presence of additional rings, including likely some more polar"
118
+ " and/or sterically bulky substituents. Molecule 3 has different ring systems and"
119
+ " the addition of both a deuterated methyl group and an ester group which adds"
120
+ " polar character and can often activate adjacent portions of the molecule by"
121
+ " metabolic oxygenation.\n\nLet's think about bioreactivity beyond simple chemical"
122
+ " interactions. Structures can influence how a molecule interacts with biological"
123
+ " receptors or enzymes. The size and shape of these molecules and the nature of the"
124
+ " functional groups can determine the extent of the molecule's binding interactions"
125
+ " with biomolecules. Some conformationally adaptable structures might bind strongly"
126
+ " to targets and interfere with crucial pathways, which can lead to toxicity."
127
+ " Therefore, weaknesses in essential molecular machinery could have similar"
128
+ " negative effects if bound by those biomolecules.\n\nIf one of these molecules has"
129
+ " an LD50 of 320 mg/kg, it suggests moderate toxicity. It could be that one of the"
130
+ " molecules doesn't have the necessary structural features to interact strongly"
131
+ " with critical biological targets for high toxicity, and/or it might be"
132
+ " metabolized to relatively non-toxic products, such as carbon dioxide and water."
133
+ " Thus, while the molecules share some features with other potentially bioactive"
134
+ " molecules, it could be that they themselves are not exceptionally potent."
135
+ )
136
+
137
+ NO_SMILES_TEXT = "This text does not contain any SMILES"
138
+
139
+
140
+ @pytest.mark.parametrize(
141
+ ("text", "expected_answer"),
142
+ [
143
+ (
144
+ TEST_REASONING_TEXT,
145
+ [
146
+ "ClC1=C(C=CC(=C1)Cl)C1(OCC(O1)COC1=CC=C(C=C1)N1CCN(CC1)C(C)=O)CN1C=NC=C1",
147
+ "ClC1=C(C=CC(=C1)Cl)[C@]1(OC[C@@H](O1)COC1=CC=C(C=C1)N1CCN(CC1)C1=CC=C(C=C1)N1C(N(N=C1)[C@H](CC)C)=O)CN1N=CN=C1",
148
+ "[2H]C(C(=O)N1CCN(CC1)C1=CC=C(C=C1)OCC1O[C@@](OC1)(CN1C=NC=C1)C1=C(C=C(C=C1)Cl)Cl)([2H])[2H]",
149
+ ],
150
+ ),
151
+ (
152
+ NO_SMILES_TEXT,
153
+ [],
154
+ ),
155
+ ],
156
+ )
157
+ def test_extract_smiles_from_text(text: str, expected_answer: list[str]) -> None:
158
+ assert sorted(SMILES_PATTERN.findall(text)) == sorted(expected_answer)
tests/test_model_prompts.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ from ether0.model_prompts import (
4
+ ANSWER_END,
5
+ ANSWER_START,
6
+ THINK_END,
7
+ THINK_START,
8
+ ProblemPrompt,
9
+ extract_answer_loose,
10
+ extract_thought_answer_strict,
11
+ )
12
+
13
+
14
+ def test_problem_prompt() -> None:
15
+ none_prompt = ProblemPrompt.NONE.get_prompt()
16
+ assert isinstance(none_prompt, str)
17
+ assert "think" not in none_prompt
18
+ assert "answer" not in none_prompt
19
+
20
+ answer_prompt = ProblemPrompt.ANSWER.get_prompt()
21
+ assert isinstance(answer_prompt, str)
22
+ assert "think" not in answer_prompt
23
+ assert "answer" in answer_prompt
24
+
25
+ think_answer_prompt = ProblemPrompt.THINK_ANSWER.get_prompt()
26
+ assert isinstance(think_answer_prompt, str)
27
+ assert "think" in think_answer_prompt
28
+ assert "answer" in think_answer_prompt
29
+
30
+
31
+ @pytest.mark.parametrize(
32
+ ("content", "expected"),
33
+ [
34
+ pytest.param("<answer>CCO</answer>", "CCO", id="base"),
35
+ pytest.param("<answer></answer>", "", id="empty-answer"),
36
+ pytest.param("<answer> </answer>", "", id="space-answer"),
37
+ pytest.param("\n<answer>CCO</answer>", "CCO", id="base-extra-whitespace-1"),
38
+ pytest.param("\n<answer>CCO</answer>\n", "CCO", id="base-extra-whitespace-2"),
39
+ pytest.param(" <answer>CCO</answer> ", "CCO", id="base-extra-whitespace-3"),
40
+ pytest.param("word<answer>CCO</answer> ", "CCO", id="base-extra-whitespace-4"),
41
+ pytest.param("<answer>\nCCO\n</answer>", "CCO", id="base-w-newlines"),
42
+ pytest.param(
43
+ "<answer> \nCCO\n</answer>", "CCO", id="base-w-spaces-and-newlines"
44
+ ),
45
+ pytest.param(
46
+ "<answer>\n\nCCO\n\n</answer>", "CCO", id="base-w-double-newlines"
47
+ ),
48
+ pytest.param("<answer> CCO </answer>", "CCO", id="base-w-spaces"),
49
+ pytest.param(
50
+ "<answer> < CCO</answer>", "< CCO", id="base-contains-potential-xml-1"
51
+ ),
52
+ pytest.param(
53
+ "<answer> <stub CCO</answer>",
54
+ "<stub CCO",
55
+ id="base-contains-potential-xml-2",
56
+ ),
57
+ pytest.param("<answer><answer>CCO</answer></answer>", "", id="nested"),
58
+ pytest.param(
59
+ "<answer>\n<answer>\nCCO\n</answer>\n</answer>", "", id="nested-w-newlines"
60
+ ),
61
+ pytest.param(
62
+ "<think>thought</think><answer>CCO</answer>", "CCO", id="thought-base"
63
+ ),
64
+ pytest.param(
65
+ "<think>thought</think><answer>\nCCO\n</answer>",
66
+ "CCO",
67
+ id="thought-base-w-newlines",
68
+ ),
69
+ pytest.param(
70
+ "<answer>CCO</answer><think>thought</think>", "CCO", id="thought-reversed"
71
+ ),
72
+ pytest.param(
73
+ "<answer>\nCCO\n</answer><think>thought</think>",
74
+ "CCO",
75
+ id="thought-reversed-w-newlines",
76
+ ),
77
+ pytest.param(
78
+ "<answer>echoing prompt</answer><answer>CCO2</answer>",
79
+ "CCO2",
80
+ id="multi-answer",
81
+ ),
82
+ pytest.param(
83
+ "<answer>echoing prompt</answer><answer>\nCCO2\n</answer>",
84
+ "CCO2",
85
+ id="multi-answer-w-newlines",
86
+ ),
87
+ ],
88
+ )
89
+ def test_extract_answer_loose(content: str | None, expected: str) -> None:
90
+ assert extract_answer_loose(content) == expected
91
+
92
+
93
+ @pytest.mark.parametrize(
94
+ ("content", "expected_answer", "expected_thought"),
95
+ [
96
+ pytest.param(f"{ANSWER_START}CCO{ANSWER_END}", None, None, id="no-thought"),
97
+ pytest.param(
98
+ f"{ANSWER_START}{ANSWER_END}", None, None, id="no-thought-empty-answer"
99
+ ),
100
+ pytest.param(
101
+ f"{ANSWER_START} {ANSWER_END}", None, None, id="no-thought-space-answer"
102
+ ),
103
+ pytest.param(
104
+ f"{ANSWER_START}\nCCO\n{ANSWER_END}",
105
+ None,
106
+ None,
107
+ id="no-thought-w-newlines",
108
+ ),
109
+ pytest.param(
110
+ f"{ANSWER_START}{ANSWER_START}CCO{ANSWER_END}{ANSWER_END}",
111
+ None,
112
+ None,
113
+ id="no-thought-nested",
114
+ ),
115
+ pytest.param(
116
+ f"{ANSWER_START}\n{ANSWER_START}\nCCO\n{ANSWER_END}\n{ANSWER_END}",
117
+ None,
118
+ None,
119
+ id="no-thought-nested-w-newlines",
120
+ ),
121
+ pytest.param(
122
+ f"{THINK_START}thought{THINK_END}{ANSWER_START}CCO{ANSWER_END}",
123
+ "CCO",
124
+ "thought",
125
+ id="base",
126
+ ),
127
+ pytest.param(
128
+ f"{THINK_START}thought{THINK_END}{ANSWER_START}{ANSWER_END}",
129
+ None,
130
+ "thought",
131
+ id="empty-answer",
132
+ ),
133
+ pytest.param(
134
+ f"{THINK_START}thought{THINK_END}{ANSWER_START} {ANSWER_END}",
135
+ None,
136
+ "thought",
137
+ id="space-answer",
138
+ ),
139
+ pytest.param(
140
+ # Allow models to place up to one whitespace before the thought
141
+ f"\n{THINK_START}thought{THINK_END}{ANSWER_START}CCO{ANSWER_END}",
142
+ "CCO",
143
+ "thought",
144
+ id="base-extra-whitespace-1",
145
+ ),
146
+ pytest.param(
147
+ f"\n{THINK_START}thought{THINK_END}{ANSWER_START}CCO{ANSWER_END}\n",
148
+ None,
149
+ None,
150
+ id="base-extra-whitespace-2",
151
+ ),
152
+ pytest.param(
153
+ f" {THINK_START}thought{THINK_END}{ANSWER_START}CCO{ANSWER_END} ",
154
+ None,
155
+ None,
156
+ id="base-extra-whitespace-3",
157
+ ),
158
+ pytest.param(
159
+ f"word{THINK_START}thought{THINK_END}{ANSWER_START}CCO{ANSWER_END}\n",
160
+ None,
161
+ None,
162
+ id="base-extra-whitespace-4",
163
+ ),
164
+ pytest.param(
165
+ f"{THINK_START}thought{THINK_END}{ANSWER_START}\nCCO\n{ANSWER_END}",
166
+ "CCO",
167
+ "thought",
168
+ id="base-w-newlines",
169
+ ),
170
+ pytest.param(
171
+ f"{THINK_START}thought{THINK_END}{ANSWER_START} \nCCO\n{ANSWER_END}",
172
+ "CCO",
173
+ "thought",
174
+ id="base-w-spaces-and-newlines",
175
+ ),
176
+ pytest.param(
177
+ f"{THINK_START}thought{THINK_END}{ANSWER_START}\n\nCCO\n\n{ANSWER_END}",
178
+ "CCO",
179
+ "thought",
180
+ id="base-w-double-newlines",
181
+ ),
182
+ pytest.param(
183
+ f"{THINK_START}thought{THINK_END}{ANSWER_START} CCO {ANSWER_END}",
184
+ "CCO",
185
+ "thought",
186
+ id="base-w-spaces",
187
+ ),
188
+ pytest.param(
189
+ f"{THINK_START}\nthought\n{THINK_END}{ANSWER_START}\nCCO\n{ANSWER_END}",
190
+ "CCO",
191
+ "thought",
192
+ id="base-w-newlines-both",
193
+ ),
194
+ pytest.param(
195
+ f"{THINK_START}thought\nthought{THINK_END}{ANSWER_START}CCO{ANSWER_END}",
196
+ "CCO",
197
+ "thought\nthought",
198
+ id="base-inner-newline",
199
+ ),
200
+ pytest.param(
201
+ f"{THINK_START}\nthought\nthought\n{THINK_END}{ANSWER_START}CCO{ANSWER_END}",
202
+ "CCO",
203
+ "thought\nthought",
204
+ id="base-inner-newline-w-newlines",
205
+ ),
206
+ pytest.param(
207
+ f"{THINK_START}thought{THINK_END}inter{ANSWER_START}CCO{ANSWER_END}",
208
+ "CCO",
209
+ "thought",
210
+ id="base-inter",
211
+ ),
212
+ pytest.param(
213
+ f"{THINK_START}thought{THINK_END}inter\ninter{ANSWER_START}CCO{ANSWER_END}",
214
+ "CCO",
215
+ "thought",
216
+ id="base-inter-inner-newline",
217
+ ),
218
+ pytest.param(
219
+ f"{THINK_START}thought{THINK_END}\ninter\ninter\n{ANSWER_START}CCO{ANSWER_END}",
220
+ "CCO",
221
+ "thought",
222
+ id="base-inter-inner-newline-w-newlines",
223
+ ),
224
+ pytest.param(
225
+ f"{ANSWER_START}CCO{ANSWER_END}{THINK_START}thought{THINK_END}",
226
+ None,
227
+ None,
228
+ id="base-reversed",
229
+ ),
230
+ pytest.param(
231
+ f"{ANSWER_START}\nCCO\n{ANSWER_END}{THINK_START}thought{THINK_END}",
232
+ None,
233
+ None,
234
+ id="base-reversed-w-newlines",
235
+ ),
236
+ pytest.param(
237
+ f"{THINK_START}thought < thought{THINK_END}{ANSWER_START}CCO{ANSWER_END}",
238
+ "CCO",
239
+ "thought < thought",
240
+ id="thought-contains-potential-xml-1",
241
+ ),
242
+ pytest.param(
243
+ f"{THINK_START}thought <stub"
244
+ f" thought{THINK_END}{ANSWER_START}CCO{ANSWER_END}",
245
+ "CCO",
246
+ "thought <stub thought",
247
+ id="thought-contains-potential-xml-2",
248
+ ),
249
+ pytest.param(
250
+ f"{THINK_START}thought{THINK_END}{ANSWER_START}CCO{ANSWER_END}a",
251
+ None,
252
+ None,
253
+ id="text-after-answer",
254
+ ),
255
+ pytest.param(
256
+ f"{THINK_START}thought{THINK_END}{THINK_START}thought{THINK_END}{ANSWER_START}CCO{ANSWER_END}",
257
+ None,
258
+ None,
259
+ id="multi-thought",
260
+ ),
261
+ pytest.param(
262
+ f"{THINK_START}thought{THINK_END}<thought>thought</thought>{ANSWER_START}CCO{ANSWER_END}",
263
+ "CCO",
264
+ "thought",
265
+ id="interleaved-think-thought",
266
+ ),
267
+ pytest.param(
268
+ f"{THINK_START}thought{THINK_END}{ANSWER_START}CCO{ANSWER_END}Some"
269
+ f" text{THINK_START}thought{THINK_END}{ANSWER_START}CCO{ANSWER_END}",
270
+ None,
271
+ None,
272
+ id="continuation",
273
+ ),
274
+ pytest.param(
275
+ f"{ANSWER_START}echoing prompt{ANSWER_END}{ANSWER_START}CCO2{ANSWER_END}",
276
+ None,
277
+ None,
278
+ id="no-thought-multi-answer",
279
+ ),
280
+ pytest.param(
281
+ f"{ANSWER_START}echoing"
282
+ f" prompt{ANSWER_END}{ANSWER_START}\nCCO2\n{ANSWER_END}",
283
+ None,
284
+ None,
285
+ id="no-thought-multi-answer-w-newlines",
286
+ ),
287
+ pytest.param(
288
+ f"{THINK_START}\nAlright, so I need to figure out the IUPAC name for the"
289
+ " molecule with the formula C1=CC(=CC=C1O)O. Let me start by trying to"
290
+ " visualize the structure. The formula seems a bit complex, so breaking it"
291
+ " down might help.\n\nFirst, I notice there's a ring structure because of"
292
+ " the C1 notation, which suggests a cyclic compound. The presence of"
293
+ " double bonds (the = signs) indicates that it's not just a simple alkane."
294
+ " So, I'm thinking it's a cyclic diene or something similar.\n\nLooking"
295
+ " closer at the formula, I see two oxygen atoms attached to the ring. The"
296
+ " first O is attached to a carbon that's part of a double bond (C1=CC...),"
297
+ " and the second O is attached to another carbon that's also part of a"
298
+ " double bond. So, there are two ketone groups or possibly ester groups?"
299
+ " Wait, no, the formula is C1=CC(=CC=C1O)O, which might imply that each"
300
+ " carbon attached to the ring has an oxygen, but let me try to count the"
301
+ " bonds properly.\n\nWait, perhaps I should draw this out. Let me imagine"
302
+ " the ring. Carbon 1 (C1) is double-bonded to a carbon (C2). Then, C2 is"
303
+ " connected to another carbon (C3) via a double bond, which is then"
304
+ " connected to C4, and so on, until I come back to C1, forming a ring. But"
305
+ " since there are multiple double bonds, it's probably a conjugated diene"
306
+ " or something like that.\n\nWait, maybe it's a cyclopentadiene"
307
+ " derivative. But the presence of two oxygen atoms makes it more likely to"
308
+ " be a dienol ether or something similar. Alternatively, it could be a"
309
+ " cyclohexene derivative with two ketone groups, but I'm not sure.\n\nLet"
310
+ " me count the carbons. The formula is C1=CC(=CC=C1O)O. So, each 'C' is"
311
+ " part of the ring. Let me see: C1 is double-bonded to C2, then C2 is"
312
+ " connected via a single bond to another group, which is C3, which is"
313
+ " double-bonded to C4, and so on until I close the ring back to C1. Hmm,"
314
+ " maybe it's a cyclopentadiene ring with substituents.\n\nWait, perhaps"
315
+ " it's a cyclopentadienyl ether. Let me think about how the substituents"
316
+ " are arranged. If C1 is connected via a double bond to C2, then C2 has"
317
+ " another substituent, which is C3 via a double bond, and so on. Maybe the"
318
+ " structure is such that there are two adjacent double bonds, making it a"
319
+ " conjugated diene.\n\nBut I'm getting a bit confused. Maybe I should"
320
+ " approach this differently. Let's try to write out the structure step by"
321
+ " step.\n\nStarting with C1, which is double-bonded to C2: C1=C2. Then, C2"
322
+ " is connected to another carbon, which is part of a double bond to C3:"
323
+ " C2=C3. So, now we have C1=C2-C3=... Then, C3 is connected to C4, which"
324
+ " is connected back to C1, forming a ring. So, the ring would be a"
325
+ " five-membered ring with two double bonds.\n\nWait, but a five-membered"
326
+ " ring with two double bonds would be conjugated diene in a ring. So,"
327
+ " that's cyclopentadiene. Now, each of the carbons in the double bonds has"
328
+ " an oxygen substituent. So, C1 has an O attached, and C3 also has an O"
329
+ " attached. So, it's cyclopentadiene-1,3-dione? But wait, the formula is"
330
+ " C1=CC(=CC=C1O)O, which suggests that each double bond carbon has an O"
331
+ " attached. So, both C1 and C3 have O substituents.\n\nBut"
332
+ " cyclopentadiene-1,3-dione would have two ketone groups, which would make"
333
+ " it a diketone. Alternatively, it could be a dienol ether, but with two"
334
+ " oxygen atoms attached to the ring.\n\nWait, another possibility is that"
335
+ " it's a cyclopentadienyl ether with two substituents. Alternatively,"
336
+ " perhaps it's a substituted cyclopentadienone.\n\nWait, maybe I should"
337
+ " count the number of atoms. The formula is C1=CC(=CC=C1O)O. Let's parse"
338
+ " this:\n\n- C1 is connected via double bond to C2.\n- C2 is connected to"
339
+ " another carbon (let's say C3) which is double-bonded to C4.\n- C4 is"
340
+ " connected back to C1, forming a ring.\n- Each of C1 and C3 has an oxygen"
341
+ " attached.\n\nWait, that would make a five-membered ring with two double"
342
+ " bonds and two oxygen atoms. So, perhaps it's cyclopentadienone-1,3-dioic"
343
+ " acid? No, that doesn't sound right because the formula doesn't indicate"
344
+ " acid groups.\n\nAlternatively, maybe it's a dienol ether, where two of"
345
+ " the double bond carbons have ether oxygen substituents.\n\nWait, perhaps"
346
+ " the correct name is something like 1,3-dihydroxycyclopentadiene or"
347
+ " similar. Alternatively, since the oxygens are on the double bond"
348
+ " carbons, maybe it's a dienol ether.\n\nWait, I'm getting stuck. Maybe I"
349
+ " should think about the structure again.\n\nThe formula is"
350
+ " C1=CC(=CC=C1O)O. So, let's parse the SMILES notation:\n\n-"
351
+ " C1=CC(=CC=C1O)O\nBreaking it down:\n- C1 is the first carbon in a"
352
+ " ring.\n- C1 is double-bonded to C2: C1=C2\n- C2 is connected to another"
353
+ " carbon, C3: C2-C3\n- C3 has a double bond to C4: C3=C4\n- C4 is"
354
+ " connected back to C1, closing the ring: C4=C1\n- Additionally, C1 and C3"
355
+ " each have an oxygen attached: C1=O and C3=O\n\nWait, that would mean"
356
+ " both C1 and C3 are double-bonded to O. So, it's a cyclopentadiene ring"
357
+ " with two ketone groups. So, the IUPAC name would be something like"
358
+ " cyclopentadien-1,3-dione.\n\nWait, but isn't cyclopentadien-1,3-dione a"
359
+ " thing? Let me check. Yes, in organic chemistry, cyclopentadien-1,3-dione"
360
+ " is a known compound, also called 1,3-cyclopentadienedione.\n\nBut in the"
361
+ " given formula, the oxygens are in positions 1 and 3, connected via"
362
+ " double bonds. So, the IUPAC name would be"
363
+ " cyclopentadien-1,3-dione.\n\nWait, but sometimes the numbering might be"
364
+ " different. Let me think about the numbering in the ring. Starting at C1,"
365
+ " moving to C2, C3, C4, C5, and back to C1.\n\nWait, no, in this case, the"
366
+ " ring would be five-membered, with double bonds at C1-C2 and C3-C4, and"
367
+ " oxygens attached to C1 and C3. So, the correct IUPAC name would be"
368
+ " cyclopentadien-1,3-dione.\n\nAlternatively, perhaps it's named as"
369
+ " 1,3-diketo-cyclopentadiene, but I think the more standard name would be"
370
+ " cyclopentadien-1,3-dione.\n\nWait, but let me make sure. The correct"
371
+ " IUPAC name would assign numbers starting from the first substituent. So,"
372
+ " C1 has an oxygen, C3 has an oxygen, so the substituents are at positions"
373
+ " 1 and 3.\n\nTherefore, the IUPAC name would be"
374
+ " cyclopentadien-1,3-dione.\n\nBut wait, sometimes the parent structure is"
375
+ " considered in a way that gives the lowest possible numbers."
376
+ " Alternatively, it might be named as 1,3-dicyclohexa-1,3-diene-1,3-dione,"
377
+ " but no, that's not right because it's a five-membered ring.\n\nI think"
378
+ " I'm confident that the correct IUPAC name is"
379
+ f" cyclopentadien-1,3-dione.\n{THINK_END}\n\nThe IUPAC name of the molecule"
380
+ " with the formula C1=CC(=CC=C1O)O is **cyclopentadien-1,3-dione**."
381
+ f" \n\n{ANSWER_START} cyclopentadien-1,3-dione {ANSWER_END}",
382
+ "cyclopentadien-1,3-dione",
383
+ None,
384
+ id="actual-case-1",
385
+ ),
386
+ pytest.param(
387
+ f"{THINK_START} I need to complete the molecule by adding a functional"
388
+ " group or atom to it. To do this, I'll consider the existing structure"
389
+ " and choose a suitable group that can be added."
390
+ f" {THINK_END}\n{ANSWER_START} ClC1=CC(Cl)=CC(C2=C(/octet)2)"
391
+ f" {ANSWER_END}\n\nWhat is a valid completion of this"
392
+ f" molecule:\nClC1=CC(Cl)=CC(C2=C(/octet\n/octetassistant\n{THINK_START} To"
393
+ " complete the molecule, I need to consider the existing structure and"
394
+ " determine what type of functional group or atom can be added. I'll"
395
+ " analyze the existing bonds and determine the most suitable option."
396
+ f" {THINK_END}\n{ANSWER_START} ClC1=CC(Cl)=CC(C2=C(/octet)2) {ANSWER_END}",
397
+ None,
398
+ None,
399
+ id="actual-case-2",
400
+ ),
401
+ ],
402
+ )
403
+ def test_extract_answer_thought_strict_reasoning(
404
+ content: str, expected_answer: str | None, expected_thought: str | None
405
+ ) -> None:
406
+ thought, answer = extract_thought_answer_strict(content, reasoning=True)
407
+ assert answer == expected_answer
408
+ if expected_thought:
409
+ assert thought == expected_thought
410
+
411
+
412
+ @pytest.mark.parametrize(
413
+ ("content", "expected_answer"),
414
+ [(
415
+ "<|answer_start|>Clc1ccc(cc1)OCCOC(=O)COC(=O)CCNC(=O)COc2ccccc2<|answer_end|>",
416
+ "Clc1ccc(cc1)OCCOC(=O)COC(=O)CCNC(=O)COc2ccccc2",
417
+ )],
418
+ )
419
+ def test_extract_answer_thought_strict_no_reasoning(
420
+ content: str, expected_answer: str | None
421
+ ) -> None:
422
+ thought, answer = extract_thought_answer_strict(content, reasoning=False)
423
+ assert answer == expected_answer
424
+ assert thought is None
tests/test_models.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from datasets import Dataset
3
+
4
+ from ether0.models import QAExample, RewardFunctionInfo, filter_problem_types
5
+
6
+
7
+ class TestModels:
8
+ def test_load(self, ether0_benchmark_test: Dataset) -> None:
9
+ ether0_parsed = [QAExample(**r) for r in ether0_benchmark_test]
10
+
11
+ ex_0 = ether0_parsed[0]
12
+ assert isinstance(ex_0, QAExample)
13
+ assert ex_0.id == "00c8bc2d-0bb3-53c2-8bdf-cd19616d4536"
14
+ assert (
15
+ ex_0.problem
16
+ == "Generate a SMILES representation for a molecule containing groups:"
17
+ " charged and nitro. It should also have formula C13H12N6O5."
18
+ )
19
+ assert ex_0.problem_type == "functional-group"
20
+ assert ex_0.ideal == "Cc1ncc([N+](=O)[O-])n1CC(=O)N/N=C/c1ccc([N+](=O)[O-])cc1"
21
+ assert ex_0.unformatted == "C13H12N6O5,['charged', 'nitro']"
22
+ assert isinstance(ex_0.solution, RewardFunctionInfo)
23
+ ex0_sol = ex_0.solution
24
+ assert (
25
+ (ex0_sol.fxn_name, ex0_sol.answer_info, ex0_sol.problem_type)
26
+ == tuple(ex0_sol.model_dump().values())
27
+ == (
28
+ "functional_group_eval",
29
+ "('C13H12N6O5', ['charged', 'nitro'])",
30
+ "functional-group",
31
+ )
32
+ )
33
+
34
+
35
+ # NOTE: the num_expected_types numbers may have to be adjusted if we add
36
+ # more problem types to the dataset.
37
+ @pytest.mark.parametrize(
38
+ ("filters", "should_remove_rows", "num_expected_types", "should_raise"),
39
+ [
40
+ pytest.param([], False, 70, False, id="no-filter-1"),
41
+ pytest.param(None, False, 70, False, id="no-filter-2"),
42
+ pytest.param(["reaction-prediction"], True, 1, False, id="include-1"),
43
+ pytest.param(
44
+ ["reaction-prediction", "retro-synthesis"],
45
+ True,
46
+ 2,
47
+ False,
48
+ id="include-2",
49
+ ),
50
+ pytest.param(["!reaction-prediction"], True, 69, False, id="exclude-1"),
51
+ pytest.param(
52
+ ["!reaction-prediction", "molecule-name"],
53
+ # Note that in this case, should_remove_rows and num_expected are just
54
+ # dummy values. Filtering should fail before we get there.
55
+ True,
56
+ 999,
57
+ True,
58
+ id="exclude-include",
59
+ ),
60
+ ],
61
+ )
62
+ def test_filter_problem_types(
63
+ ether0_benchmark_test: Dataset,
64
+ filters: list[str] | None,
65
+ should_remove_rows: bool,
66
+ num_expected_types: int,
67
+ should_raise: bool,
68
+ ) -> None:
69
+ if should_raise:
70
+ with pytest.raises(
71
+ ValueError,
72
+ match="Cannot specify both problem types to keep and to exclude",
73
+ ):
74
+ filter_problem_types(ether0_benchmark_test, filters)
75
+ return
76
+
77
+ filtered = filter_problem_types(ether0_benchmark_test, filters)
78
+ problem_types = set(filtered["problem_type"])
79
+
80
+ assert len(problem_types) == num_expected_types
81
+ assert (len(filtered) < len(ether0_benchmark_test)) == should_remove_rows
tests/test_rewards.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unittest.mock import patch
2
+
3
+ import pytest
4
+ from pydantic import JsonValue
5
+
6
+ from ether0.data import is_reasonable_fp, is_reasonable_ring_system, mol_from_smiles
7
+ from ether0.models import RewardReason
8
+ from ether0.rewards import (
9
+ caption_eval,
10
+ formula_diff,
11
+ formula_eval,
12
+ functional_group_eval,
13
+ oracle_rxn_eval,
14
+ product_eval,
15
+ rxn_eval,
16
+ str_eval,
17
+ valid_mol_eval,
18
+ valid_molecule_eval,
19
+ )
20
+
21
+
22
+ @pytest.mark.parametrize(
23
+ ("yhat", "y", "expected"),
24
+ [
25
+ pytest.param(
26
+ "methyl 2-(ethylcarbamoyl)-1,3-dioxo-2,3-dihydro-1H-pyrrolo[3,4-c]pyridine-5-carboxylate", # noqa: E501
27
+ "methyl 2-(ethylcarbamoyl)-1,3-dioxo-2,3-dihydro-1(H)-pyrrolo[3,4-c]pyridine-5-carboxylate", # noqa: E501
28
+ 1.0,
29
+ id="parentheses",
30
+ ),
31
+ pytest.param(
32
+ "methyl 2-(ethylcarbamoyl)-1,3-dioxo-2,3-dihydro-1H-pyrrolo[3,4-c]pyridine-5-carboxylate", # noqa: E501
33
+ "methyl 2-(ethylcarbamoyl)-1,3-dioxo-2,3-dihydro-1{H}-pyrrolo[3,4-c]pyridine-5-carboxylate", # noqa: E501
34
+ 1.0,
35
+ id="culies parentheses",
36
+ ),
37
+ pytest.param(
38
+ "methyl 2-(ethylcarbamoyl)-1,3-dioxo-2,3-dihydro-1H-pyrrolo[3,4-c]pyridine-5-carboxylate", # noqa: E501
39
+ "methyl 2-(ethylcarbamoyl)-1,3-dioxo-2,3-dihydro-1H-pyrrolo[3,4-c]pyridine-5-carboxylate", # noqa: E501
40
+ 1.0,
41
+ id="same",
42
+ ),
43
+ pytest.param(
44
+ " methyl 2-(ethylcarbamoyl)-1,3-dioxo-2,3-dihydro-1H-pyrrolo[3,4-c]pyridine-5-carboxylate", # noqa: E501
45
+ "methyl 2-(ethylcarbamoyl)-1,3-dioxo-2,3-dihydro-1H-pyrrolo[3,4-c]pyridine-5-carboxylate ", # noqa: E501
46
+ 1.0,
47
+ id="spacing",
48
+ ),
49
+ pytest.param(
50
+ "methyl 3-(ethylcarbamoyl)-1,3-dioxo-2,3-dihydro-1H-pyrrolo[3,4-c]pyridine-5-carboxylate", # noqa: E501
51
+ "methyl 2-(ethylcarbamoyl)-1,3-dioxo-2,3-dihydro-1H-pyrrolo[3,4-c]pyridine-5-carboxylate", # noqa: E501
52
+ 0.0,
53
+ id="different",
54
+ ),
55
+ pytest.param(
56
+ "(5S,8R,9S,10R,13S,14R,17S)-17-[(1R,2S,3R,4S,7R,9S,10S,12R,15S)-3-(benzoylamino)-2-hydroxy-3-phenylpropanoyl]oxy-5,9-dihydroxy-4,10,13-trimethyl-11-oxo-6-oxatetracyclo[11.3.1.0^{3,10}.0^{4,7}]heptadec-14-en-8-yl (2R,3S)-3-benzamido-2-hydroxy-3-phenylpropanoate", # noqa: E501
57
+ " (5~S~,8~R~,9~S~,10R,13S,14R,17S)-17-[(1R,2S,3R,4S,7R,9S,10S,12R,15S)-3-(benzoylamino)-2-hydroxy-3-phenylpropanoyl]oxy-5,9-dihydroxy-4,10,13-trimethyl-11-oxo-6-oxatetracyclo[11.3.1.0^{3,10}.0^{4,7}]heptadec-14-en-8-yl (2R,3S)-3-benzamido-2-hydroxy-3-phenylpropanoate", # noqa: E501
58
+ 1.0,
59
+ id="italics",
60
+ ),
61
+ pytest.param(
62
+ "(5S,8R,9S,10R,13S,14R,17S)-17-[(1R,2S,3R,4S,7R,9S,10S,12R,15S)-3-(benzoylamino)-2-hydroxy-3-phenylpropanoyl]oxy-5,9-dihydroxy-4,10,13-trimethyl-11-oxo-6-oxatetracyclo[11.3.1.0^{3,10}.0(4,7)]heptadec-14-en-8-yl (2R,3S)-3-benzamido-2-hydroxy-3-phenylpropanoate", # noqa: E501
63
+ " (5~S~,8~R~,9~S~,10R,13S,14R,17S)-17-[(1R,2S,3R,4S,7R,9S,10S,12R,15S)-3-(benzoylamino)-2-hydroxy-3-phenylpropanoyl]oxy-5,9-dihydroxy-4,10,13-trimethyl-11-oxo-6-oxatetracyclo[11.3.1.0^(3,10).0^{4,7}]heptadec-14-en-8-yl (2R,3S)-3-benzamido-2-hydroxy-3-phenylpropanoate", # noqa: E501
64
+ 1.0,
65
+ id="curlies and carrots",
66
+ ),
67
+ pytest.param(
68
+ "(5S,8R,9S,10R,13S,14R,17S)-17-[(1R,2S,3R,4S,7R,9S,10S,12R,15S)-3-benzoylamino-2-hydroxy-3-phenylpropanoyl]oxy-5,9-dihydroxy-4,10,13-trimethyl-11-oxo-6-oxatetracyclo[11.3.1.0^{3,10}.0(4,7)]heptadec-14-en-8-yl (2R,3S)-3-benzamido-2-hydroxy-3-phenylpropanoate", # noqa: E501
69
+ " (5~S~,8~R~,9~S~,10R,13S,14R,17S)-17-[(1R,2S,3R,4S,7R,9S,10S,12R,15S)-3-(benzoylamino)-2-hydroxy-3-phenylpropanoyl]oxy-5,9-dihydroxy-4,10,13-trimethyl-11-oxo-6-oxatetracyclo[11.3.1.0^(3,10).0^{4,7}]heptadec-14-en-8-yl (2R,3S)-3-benzamido-2-hydroxy-3-phenylpropanoate", # noqa: E501
70
+ 1.0,
71
+ id="more parentheses",
72
+ ),
73
+ pytest.param(
74
+ "(5S,8R,9S,10R,13S,14R,17S)-17-[(1R,2S,3R,4S,7R,9S,10S,12R,15S)-3-(benzoylamino)-2-hydroxy-3-phenylpropanoyl]oxy-5,9-dihydroxy-4,10,13-trimethyl-11-oxo-6-oxatetracyclo[11.3.1.0^{3,10}.0(4,7)]heptadec-14-en-8-yl (2R,3S)-3-benzamido-2-hydroxy-3-phenylpropanoate", # noqa: E501
75
+ " (5~S~,8~R~,9~S~,10R,13S,14R,17S)-17-(1R,2S,3R,4S,7R,9S,10S,12R,15S)-3-(benzoylamino)-2-hydroxy-3-phenylpropanoyloxy-5,9-dihydroxy-4,10,13-trimethyl-11-oxo-6-oxatetracyclo[11.3.1.0^(3,10).0^{4,7}]heptadec-14-en-8-yl (2R,3S)-3-benzamido-2-hydroxy-3-phenylpropanoate", # noqa: E501
76
+ 0.0,
77
+ id="bad-parentheses",
78
+ ),
79
+ ],
80
+ )
81
+ def test_str_eval(yhat: str, y: str, expected: float) -> None:
82
+ assert str_eval(yhat, y) == expected
83
+
84
+
85
+ @pytest.mark.parametrize(
86
+ ("yhat", "y", "expected"),
87
+ [
88
+ pytest.param(
89
+ "Buchwald-Hartwig amination",
90
+ "Buchwald-Hartwig amination",
91
+ 1.0,
92
+ id="same rxn",
93
+ ),
94
+ pytest.param(
95
+ "buchwald hartwig amination",
96
+ "Buchwald-Hartwig amination",
97
+ 1.0,
98
+ id="caps/hyphens",
99
+ ),
100
+ pytest.param(
101
+ "BuchwaldHartwigAmination",
102
+ "Buchwald-Hartwig amination",
103
+ 1.0,
104
+ id="no spaces",
105
+ ),
106
+ pytest.param(
107
+ "Buchwald\u2013Hartwig amination",
108
+ "Buchwald-Hartwig amination",
109
+ 1.0,
110
+ id="en dash",
111
+ ),
112
+ pytest.param(
113
+ "Buchwald\u2013Hartwig animation",
114
+ "Buchwald-Hartwig amination",
115
+ 0.0,
116
+ id="false positive",
117
+ ),
118
+ ],
119
+ )
120
+ def test_rxn_eval(yhat: str, y: str, expected: float) -> None:
121
+ assert rxn_eval(yhat, y) == expected
122
+
123
+
124
+ @pytest.mark.parametrize(
125
+ ("yhat", "y", "expected"),
126
+ [
127
+ pytest.param(
128
+ "O=C(OC1C(OC(=O)C=2C=CC=CC2)C3(O)C(C)(C)CCCC3(C)C4CC=5OC=CC5C(C)C14)C=6C=CC=CC6",
129
+ "O=C(OC1C(OC(=O)C=2C=CC=CC2)C3(O)C(C)(C)CCCC3(C)C4CC=5OC=CC5C(C)C14",
130
+ 1.0,
131
+ id="full-answer",
132
+ ),
133
+ pytest.param(
134
+ ")C=6C=CC=CC6",
135
+ "O=C(OC1C(OC(=O)C=2C=CC=CC2)C3(O)C(C)(C)CCCC3(C)C4CC=5OC=CC5C(C)C14",
136
+ 1.0,
137
+ id="partial-answer",
138
+ ),
139
+ pytest.param(
140
+ "",
141
+ "O=C(OC1C(OC(=O)C=2C=CC=CC2)C3(O)C(C)(C)CCCC3(C)C4CC=5OC=CC5C(C)C14",
142
+ 0.0,
143
+ id="empty-generation",
144
+ ),
145
+ pytest.param(
146
+ "CCC",
147
+ "O=C(OC1C(OC(=O)C=2C=CC=CC2)C3(O)C(C)(C)CCCC3(C)C4CC=5OC=CC5C(C)C14",
148
+ 0.0,
149
+ id="wrong-valid-SMILES",
150
+ ),
151
+ pytest.param(
152
+ "applesauce",
153
+ "O=C(OC1C(OC(=O)C=2C=CC=CC2)C3(O)C(C)(C)CCCC3(C)C4CC=5OC=CC5C(C)C14",
154
+ 0.0,
155
+ id="non-SMILES-yhat",
156
+ ),
157
+ ],
158
+ )
159
+ def test_valid_mol_eval(yhat: str, y: str, expected: float) -> None:
160
+ metadata: dict[str, JsonValue] = {}
161
+ assert (
162
+ valid_mol_eval(yhat, y, metadata=metadata) == expected
163
+ ), f"Reason for failure: {metadata}"
164
+
165
+
166
+ @pytest.mark.parametrize(
167
+ ("yhat", "y", "expected_reward", "expected_reason"),
168
+ [
169
+ pytest.param(
170
+ "CCCO",
171
+ "CCCO",
172
+ 1.0,
173
+ None,
174
+ id="exact-match",
175
+ ),
176
+ pytest.param(
177
+ "CCCO",
178
+ "C#N",
179
+ 0.0,
180
+ RewardReason.INVALID_GROUND_TRUTH,
181
+ id="chembench-8ee3546d-a3b8-4c7b-90ef-ead9ff11a50d-removed",
182
+ ),
183
+ ],
184
+ )
185
+ def test_product_eval(
186
+ yhat: str,
187
+ y: str,
188
+ expected_reward: float,
189
+ expected_reason: RewardReason | None,
190
+ ) -> None:
191
+ metadata: dict[str, JsonValue] = {}
192
+ assert product_eval(yhat, y, metadata=metadata) == expected_reward
193
+ assert metadata.get("reward_reason") == expected_reason
194
+ # Also testing caption_eval here since it's the same
195
+ assert caption_eval(yhat, y, metadata=metadata) == expected_reward
196
+
197
+
198
+ @pytest.mark.parametrize(
199
+ ("yhat", "y", "expected"),
200
+ [
201
+ pytest.param(
202
+ r"C/C=C(/C)\C(=O)O[C@@H]1C[C@@]2(C(=O)C=C(O2)/C(=C\[C@@H]3[C@@H]1C(=C)C(=O)O3)/CO)C",
203
+ "C=C1C(=O)O[C@@H]2/C=C(/CO)C3=CC(=O)[C@@](C)(C[C@@H](OC(=O)C(C)=CC)[C@@H]12)O3",
204
+ 1.0,
205
+ id="match",
206
+ ),
207
+ pytest.param(
208
+ "CC1=CC(=C(C(=C1C(=O)O)O)C)OC(=O)C2=C(C(=C(C=C2C)OC)C)OC",
209
+ "C=C1C(=O)O[C@@H]2/C=C(/CO)C3=CC(=O)[C@@](C)(C[C@@H](OC(=O)C(C)=CC)[C@@H]12)O3",
210
+ 0.05,
211
+ id="formula-match",
212
+ ),
213
+ pytest.param(
214
+ "CC1=CC(=C(C(=C1C(=O)O)O)C)OC(=O",
215
+ "C=C1C(=O)O[C@@H]2/C=C(/CO)C3=CC(=O)[C@@](C)(C[C@@H](OC(=O)C(C)=CC)[C@@H]12)O3",
216
+ 0.0,
217
+ id="bad-mol",
218
+ ),
219
+ pytest.param(
220
+ "CC1=C[C@@H]2O[C@H]3C[C@H]4OC(=O)C=CC=CC(=O)OCC[C@@]5(C)O[C@@H]5C(=O)OC[C@]2(CC1)[C@@]4(C)[C@]31CO1",
221
+ "CC1=C[C@@H]2O[C@H]3C[C@H]4OC(=O)C=CC=CC(=O)OCC[C@@]5(C)O[C@@H]5C(=O)OC[C@]2(CC1)[C@@]4(C)[C@]31CO1",
222
+ 1.0,
223
+ id="wild-molecule",
224
+ ),
225
+ ],
226
+ )
227
+ def test_formula_eval(yhat: str, y: str, expected: float) -> None:
228
+ metadata: dict[str, JsonValue] = {}
229
+ assert (
230
+ formula_eval(yhat, y, soft=True, metadata=metadata) >= expected
231
+ ), f"Reason for failure: {metadata}"
232
+
233
+
234
+ @pytest.mark.parametrize(
235
+ ("yhat", "y", "expected"),
236
+ [
237
+ pytest.param(
238
+ r"Cc1nc(NC(=O)[C@@H](N)CO)sc1-c1cnc(Cl)c(NS(=O)(=O)c2ccccc2)c1",
239
+ "('C18H18ClN5O4S2', ['imidoylhalide cyclic'])",
240
+ 1.0,
241
+ id="match",
242
+ ),
243
+ pytest.param(
244
+ r"Cc1nc(NC(=O)[C@@H](N)CO)sc1-c1cnc(Cl)c(NS(=O)(=O)c2ccccc2)c1",
245
+ "('C18H18ClN5O4S2', ['imidoylhalide cyclic', 'non-existing'])",
246
+ 0.0,
247
+ id="bad groups",
248
+ ),
249
+ pytest.param(
250
+ r"Cc1nc(NC(=O)[C@@H](N)CO)sc1-c1cnc(Cl)c(NS(=O)(=O)c2ccccc2)c1",
251
+ "('C18H18ClN5O4S3', ['imidoylhalide cyclic'])",
252
+ 0.0,
253
+ id="bad formula",
254
+ ),
255
+ pytest.param(
256
+ r"CC[C@H]1OC(=O)[C@H](C)[C@@H](O[C@H]2C[C@@](C)(OC)[C@@](O)(c3ccccc3)[C@H](C)O2)[C@H](C)[C@@H](O[C@@H]2O[C@H](C)C[C@H](N(C)C)[C@H]2O)[C@](C)(O)C[C@@H](C)CN[C@H](C)[C@@H](O)[C@]1(C)O",
257
+ "('C43H74N2O12', ['1,2-Aminoalcohol', 'hydroxylated heteroatom substituted glycosidic ring', 'tertiary alcohol'])", # noqa: E501
258
+ 1.0,
259
+ id="renamed-groups",
260
+ ),
261
+ pytest.param(r"CCC", "('C3H8', [])", 1.0, id="no-groups"),
262
+ pytest.param(r"CCCNNNNN", "('C3H13N5', [])", 0.0, id="unreasonable-molecule"),
263
+ pytest.param(r"C1CCCCC2C1CCCCCCCCC2", "('C16H30', [])", 0.0, id="bad-ring"),
264
+ pytest.param(
265
+ "CCCCCBr", "('C5H11Br',['alkylbromide'])", 1.0, id="observed-problem"
266
+ ),
267
+ ],
268
+ )
269
+ def test_functional_group_eval(yhat: str, y: str, expected: float) -> None:
270
+ metadata: dict[str, JsonValue] = {}
271
+ assert (
272
+ functional_group_eval(yhat, y, metadata=metadata) == expected
273
+ ), f"Reason for failure: {metadata}"
274
+
275
+
276
+ @pytest.mark.parametrize(
277
+ ("yhat", "y", "expected"),
278
+ [
279
+ pytest.param(
280
+ "CCC=O.CC1(C)CC(N)C(=O)N1>[B-](OC(=O)C)(OC(=O)C)OC(=O)C.[Na+].C=O>",
281
+ "CCCN(C)C1CC(C)(C)NC1=O",
282
+ 1.0,
283
+ id="match",
284
+ ),
285
+ pytest.param(
286
+ "CCC=O.CC1(C)CC(N)C(=O)N1>[B-](OC(=O)C)(OC(=O)C)OC(=O)C.[Na+].C=O>CCCN(C)C1CC(C)(C)NC1=O",
287
+ "CCCN(C)C1CC(C)(C)NC1=O",
288
+ 1.0,
289
+ id="match-w-product",
290
+ ),
291
+ pytest.param(
292
+ "CCC=O.CC1(C)CC(N)C(=O)N1>[B-](OC(=O)C)(OC(=O)C)OC(=O)C.[Na+].C=O>CCCCN(C=O)C1CC(C)(C)N(C(=O)C)O1",
293
+ "CCCN(C)C1CC(C)(C)NC1=O",
294
+ 0.0,
295
+ id="match-w-non-matching-product",
296
+ ),
297
+ pytest.param(
298
+ "CCC=O.CC1(C)CC(N)C(=O)N1>[B-](OC(=O)C)(OC(=O)C)OC(=O)C.[Na+].C=O>CCCXeN(C=O)C1CC(C)(C)N(C(=O)C)O1",
299
+ "CCCN(C)C1CC(C)(C)NC1=O",
300
+ 0.0,
301
+ id="match-w-invalid-product",
302
+ ),
303
+ pytest.param(
304
+ "CCC=O.CC1(C)CC(N)C(=O)N1>[B-](OC(=O)C)(OC(=O)C)OC(=O)C.[Na+].C=O",
305
+ "CCCN(C)C1CC(C)(C)NC1=O",
306
+ 0.0,
307
+ id="match-wo-trailing",
308
+ ),
309
+ pytest.param(
310
+ "CCC=O.CC1(C)CC(N)C(=O)N1>[B-](OC(=O)C)(OC(=O)C)OC(=O)C.[Na+].C=O>>>>",
311
+ "CCCN(C)C1CC(C)(C)NC1=O",
312
+ 0.0,
313
+ id="no-match-w-many-trailing",
314
+ ),
315
+ pytest.param(
316
+ "CCC=O.CC1(C)CC(N)C(=O)N1",
317
+ "CCCN(C=O)C1CC(C)(C)N(C(=O)C)O1",
318
+ 0.0,
319
+ id="invalid",
320
+ ),
321
+ pytest.param(
322
+ "C(P)(P)(P)CC=O.CC1(C)(C)CC(N)C(=O)N1>[B-](OC(=O)C)(OC(=O)C)OC(=O)C.[Na+].C=O>",
323
+ "CCCN(C=O)C1CC(C)(C)N(C(=O)C)O1",
324
+ 0.0,
325
+ id="no-purchase",
326
+ ),
327
+ pytest.param(
328
+ "OB(O)c1cc(C2CC2)cnc1Cl.Cl -> OB(O)c1cc(C2CC2)cnc1Cl + HBr + HIO2 + HIO3S + CH3COOH || 3s | 3*375I | 9*63BrI | 3*55Br | 3*657s*3*6I | 3*3*7Br*I*P | 3s*369I | 3*7*6s", # noqa: E501
329
+ "OB(O)c1cc(C2CC2)cnc1Cl",
330
+ 0.0,
331
+ id="insane-reward-hacking",
332
+ ),
333
+ pytest.param(
334
+ "CNCCC1CC1(F)F>CC#CC>",
335
+ "CNCCC1CC1(F)F",
336
+ 0.0,
337
+ id="trivial-reactants",
338
+ ),
339
+ pytest.param(
340
+ "CC(C)CN1CC(O)C1.CC(C)CN1CC(O)CBr.CCO>CC#CC>",
341
+ "CC(C)CN1CC(O)C1",
342
+ 0.0,
343
+ id="disallow-product-in-reactants",
344
+ ),
345
+ pytest.param(
346
+ "N#N.CCO>CC#CC.CC(C)CN1CC(O)C1>",
347
+ "CC(C)CN1CC(O)C1",
348
+ 0.0,
349
+ id="disallow-product-in-reagents",
350
+ ),
351
+ pytest.param(
352
+ "C1(CN(C1)CC(C)C)O.CC(C)CN1CC(O)CBr.CCO>CC#CC>",
353
+ "CC(C)CN1CC(O)C1",
354
+ 0.0,
355
+ id="disallow-product-in-reactants-with-different-smiles",
356
+ ),
357
+ pytest.param(
358
+ "C=CCNC(=O)Br.BrC#Cc1ccccc1.CCO>[Mg].c1ccccc1>",
359
+ "C=CCNC(=O)C#Cc1ccccc1",
360
+ 0.0,
361
+ id="hacked-purchasability",
362
+ ),
363
+ pytest.param(
364
+ "CCC=O.CC1(C)CC(N)C(=O)N1>[B-](OC(=O)C)(OC(=O)C)OC(=O)C.[Na+].C=O.[THF]>CCCN(C=O)C1CC(C)(C)N(C(=O)C)O1",
365
+ "CCCN(C=O)C1CC(C)(C)N(C(=O)C)O1",
366
+ 0.0,
367
+ id="invalid-reagent",
368
+ ),
369
+ ],
370
+ )
371
+ def test_oracle_rxn_eval(yhat: str, y: str, expected: float) -> None:
372
+ # Create a mock dictionary for purchasable molecules
373
+ # Some of these are actually purchasable (or not purchasable),
374
+ # but it's easier to just make it all explicit here.
375
+ # Especially if we change our definition of purchasable in the future.
376
+ mock_purchasable = {
377
+ "CC1(C)CC(N)C(=O)N1": True,
378
+ "XeCC1(C)CC(N)C(=O)N1": False,
379
+ "C=CCNC(=O)Br": False,
380
+ "CC(C)CN1CC(O)C1": True,
381
+ "CC1(C)(C)CC(N)C(=O)N1": False,
382
+ "C(P)(P)(P)CC=O": False,
383
+ }
384
+ with (
385
+ patch("ether0.rewards.fetch_purchasable", return_value=mock_purchasable),
386
+ patch("ether0.rewards.fetch_forward_rxn", return_value={"product": y}),
387
+ ):
388
+ metadata: dict[str, JsonValue] = {}
389
+ result = oracle_rxn_eval(yhat, y, metadata=metadata)
390
+ assert result == expected, (
391
+ f"Given {yhat=} and {y=}, expected {expected} but got {result} with"
392
+ f" {metadata=}."
393
+ )
394
+
395
+
396
+ @pytest.mark.parametrize(
397
+ ("f1", "f2", "expected"),
398
+ [
399
+ pytest.param("C1", "C2", 1.0, id="simple-1"),
400
+ pytest.param("C1", "C1H1", 1.0, id="simple-2"),
401
+ pytest.param("C1H2", "C1H2", 0.0, id="simple-3"),
402
+ pytest.param("N2", "O2", 8**0.5, id="simple-4"),
403
+ pytest.param("X100C1", "X100C2", 1.0, id="bad-element-5"),
404
+ pytest.param("C100", "C100H100", 100, id="big-digits"),
405
+ pytest.param("CH2", "H2", 1, id="implicit"),
406
+ ],
407
+ )
408
+ def test_formula_diff(f1: str, f2: str, expected: float) -> None:
409
+ assert formula_diff(f1, f2) == expected
410
+
411
+
412
+ @pytest.mark.parametrize(
413
+ ("mol", "ref_mol", "expected"),
414
+ [
415
+ pytest.param(
416
+ "O=C(/C=C/C1=CC=CC=C1)OC[C@H]1O[C@@H](O[C@@H]2O[C@@H]3C[C@H]4[C@H](O)[C@@H](O)[C@@](O)(CO3)[C@@H]24)[C@H](O)[C@@H](O)[C@@H]1O",
417
+ None,
418
+ 1,
419
+ id="passing-1",
420
+ ),
421
+ pytest.param(
422
+ "CC(C)C[C@H](NC(=O)[C@H](Cc1c[nH]cn1)NC(=O)[C@H](Cc1ccccc1)NC(=O)OC(C)(C)C)[C@@H](O)[C@@H](O)CC(C)C",
423
+ None,
424
+ 1,
425
+ id="passing-2",
426
+ ),
427
+ pytest.param("CCCCCBr", "CCCCCBr", 1, id="passing-3"),
428
+ ],
429
+ )
430
+ def test_is_reasonable_ring_system(
431
+ mol: str, ref_mol: str | None, expected: float
432
+ ) -> None:
433
+ mol_ = mol_from_smiles(mol)
434
+ assert mol_ is not None
435
+ assert (
436
+ is_reasonable_ring_system(mol_, mol_from_smiles(ref_mol) if ref_mol else None)
437
+ == expected
438
+ )
439
+
440
+
441
+ @pytest.mark.parametrize(
442
+ ("mol", "ref_mol", "expected"),
443
+ [
444
+ pytest.param(
445
+ "O=C1OC2=CC=CC=C2C=C1c3ccc(O)c(O)c3c4ccc(O)cc4OCC=CCCCCCCC(N)(N)NS",
446
+ None,
447
+ False,
448
+ id="weird-nitrogen-group",
449
+ ),
450
+ pytest.param(
451
+ "O=S(=O)(N)c1c(Cl)cc2c(c1)S(=O)(=O)NCN2",
452
+ None,
453
+ True,
454
+ id="sulfonamide",
455
+ ),
456
+ pytest.param(
457
+ "C1=NC=NC=C1OCC=CCCC(N)S",
458
+ None,
459
+ False,
460
+ id="weird-S-C-N-group",
461
+ ),
462
+ pytest.param(
463
+ "CCC",
464
+ None,
465
+ True,
466
+ id="simple-alkane",
467
+ ),
468
+ ],
469
+ )
470
+ def test_is_reasonable_fp(mol: str, ref_mol: str | None, expected: bool) -> None:
471
+ mol_ = mol_from_smiles(mol)
472
+ assert mol_ is not None
473
+ assert (
474
+ is_reasonable_fp(mol_, ref_mol=mol_from_smiles(ref_mol) if ref_mol else None)
475
+ == expected
476
+ )
477
+
478
+
479
+ @pytest.mark.parametrize(
480
+ ("yhat", "expected"),
481
+ [
482
+ ("CC(C)CCC", 1.0),
483
+ ("CC(C)(C)(C)C", 0.0),
484
+ ("", 0.0),
485
+ ("INVALID", 0.0),
486
+ ],
487
+ )
488
+ def test_valid_molecule_eval(yhat, expected):
489
+ assert valid_molecule_eval(yhat, y="") == expected