CYF200127 commited on
Commit
32c244d
·
verified ·
1 Parent(s): 3dbdaf8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -11
app.py CHANGED
@@ -8,7 +8,9 @@ from rxn.reaction import Reaction
8
  from rdkit import Chem
9
  from rdkit.Chem import rdChemReactions
10
  from rdkit.Chem import Draw
11
-
 
 
12
 
13
  PROMPT_DIR = "prompts/"
14
  ckpt_path = "./rxn/model/model.ckpt"
@@ -19,7 +21,7 @@ PROMPT_NAMES = {
19
  "2_RxnOCR.txt": "Reaction Image Parsing Workflow",
20
  }
21
  example_diagram = "examples/exp.png"
22
- rdkit_image = "examples/image.webp"
23
 
24
  def list_prompt_files_with_names():
25
  """
@@ -106,10 +108,10 @@ def process_chem_image(image, selected_task):
106
  return "\n\n".join(detailed_reactions), smiles_output, combined_image_path, example_diagram, json_file_path
107
 
108
 
109
-
110
  prompts_with_names = list_prompt_files_with_names()
111
 
112
-
113
  examples = [
114
 
115
  ["examples/reaction1.png", "Reaction Image Parsing Workflow"],
@@ -167,7 +169,7 @@ with gr.Blocks() as demo:
167
  @gr.render(inputs = smiles_output) # 使用gr.render修饰器绑定输入和渲染逻辑
168
  def show_split(inputs): # 定义处理和展示分割文本的函数
169
  if not inputs or isinstance(inputs, str) and inputs.strip() == "": # 检查输入文本是否为空
170
- return gr.Textbox(label= "SMILES of Reaction i"), gr.Image(value=rdkit_image, label= "RDKit Image of Reaction i")
171
  else:
172
  # 假设输入是逗号分隔的 SMILES 字符串
173
  smiles_list = inputs.split(",")
@@ -175,14 +177,65 @@ with gr.Blocks() as demo:
175
  components = [] # 初始化一个组件列表,用于存放每个 SMILES 对应的 Textbox 组件
176
  for i, smiles in enumerate(smiles_list):
177
  smiles.replace('"', '').replace("'", "").replace("[", "").replace("]", "")
178
- reaction = rdChemReactions.ReactionFromSmarts(smiles)
179
- if reaction:
180
- img = Draw.ReactionToImage(reaction)
181
- components.append(gr.Textbox(value=smiles,label= f"SMILES of Reaction {i + 1} ", show_copy_button=True, interactive=False))
182
- components.append(gr.Image(value=img,label= f"RDKit Image of Reaction {i + 1} "))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  return components # 返回包含所有 SMILES Textbox 组件的列表
184
 
185
- download_json = gr.File(label="Download JSON File")
186
 
187
 
188
 
 
8
  from rdkit import Chem
9
  from rdkit.Chem import rdChemReactions
10
  from rdkit.Chem import Draw
11
+ from rdkit.Chem import AllChem
12
+ from rdkit.Chem.Draw import rdMolDraw2D
13
+ import cairosvg
14
 
15
  PROMPT_DIR = "prompts/"
16
  ckpt_path = "./rxn/model/model.ckpt"
 
21
  "2_RxnOCR.txt": "Reaction Image Parsing Workflow",
22
  }
23
  example_diagram = "examples/exp.png"
24
+ rdkit_image = "examples/rdkit.png"
25
 
26
  def list_prompt_files_with_names():
27
  """
 
108
  return "\n\n".join(detailed_reactions), smiles_output, combined_image_path, example_diagram, json_file_path
109
 
110
 
111
+ # 获取 prompts 和友好名字
112
  prompts_with_names = list_prompt_files_with_names()
113
 
114
+ # 示例数据:图像路径 + 任务选项
115
  examples = [
116
 
117
  ["examples/reaction1.png", "Reaction Image Parsing Workflow"],
 
169
  @gr.render(inputs = smiles_output) # 使用gr.render修饰器绑定输入和渲染逻辑
170
  def show_split(inputs): # 定义处理和展示分割文本的函数
171
  if not inputs or isinstance(inputs, str) and inputs.strip() == "": # 检查输入文本是否为空
172
+ return gr.Textbox(label= "SMILES of Reaction i"), gr.Image(value=rdkit_image, label= "RDKit Image of Reaction i",height=100)
173
  else:
174
  # 假设输入是逗号分隔的 SMILES 字符串
175
  smiles_list = inputs.split(",")
 
177
  components = [] # 初始化一个组件列表,用于存放每个 SMILES 对应的 Textbox 组件
178
  for i, smiles in enumerate(smiles_list):
179
  smiles.replace('"', '').replace("'", "").replace("[", "").replace("]", "")
180
+ rxn = rdChemReactions.ReactionFromSmarts(smiles, useSmiles=True)
181
+
182
+ if rxn:
183
+
184
+ new_rxn = AllChem.ChemicalReaction()
185
+ for mol in rxn.GetReactants():
186
+ mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))
187
+ new_rxn.AddReactantTemplate(mol)
188
+ for mol in rxn.GetProducts():
189
+ mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))
190
+ new_rxn.AddProductTemplate(mol)
191
+
192
+ rxn = new_rxn
193
+
194
+ def atom_mapping_remover(rxn):
195
+ for reactant in rxn.GetReactants():
196
+ for atom in reactant.GetAtoms():
197
+ atom.SetAtomMapNum(0)
198
+ for product in rxn.GetProducts():
199
+ for atom in product.GetAtoms():
200
+ atom.SetAtomMapNum(0)
201
+ return rxn
202
+
203
+ atom_mapping_remover(rxn)
204
+
205
+ reactant1 = rxn.GetReactantTemplate(0)
206
+ print(reactant1.GetNumBonds)
207
+ reactant2 = rxn.GetReactantTemplate(1) if rxn.GetNumReactantTemplates() > 1 else None
208
+
209
+ if reactant1.GetNumBonds() > 0:
210
+ bond_length_reference = Draw.MeanBondLength(reactant1)
211
+ elif reactant2 and reactant2.GetNumBonds() > 0:
212
+ bond_length_reference = Draw.MeanBondLength(reactant2)
213
+ else:
214
+ bond_length_reference = 1.0
215
+
216
+
217
+ drawer = rdMolDraw2D.MolDraw2DSVG(-1, -1)
218
+ dopts = drawer.drawOptions()
219
+ dopts.padding = 0.1
220
+ dopts.includeRadicals = True
221
+ Draw.SetACS1996Mode(dopts, bond_length_reference*0.55)
222
+ dopts.bondLineWidth = 1.5
223
+ drawer.DrawReaction(rxn)
224
+ drawer.FinishDrawing()
225
+ svg_content = drawer.GetDrawingText()
226
+ svg_file = f"reaction{i+1}.svg"
227
+ with open(svg_file, "w") as f:
228
+ f.write(svg_content)
229
+ png_file = f"reaction_{i+1}.png"
230
+ cairosvg.svg2png(url=svg_file, write_to=png_file)
231
+
232
+
233
+
234
+ components.append(gr.Textbox(value=smiles,label= f"Reaction {i + 1} SMILES", show_copy_button=True, interactive=False))
235
+ components.append(gr.Image(value=png_file,label= f"Reaction {i + 1} RDKit Image"))
236
  return components # 返回包含所有 SMILES Textbox 组件的列表
237
 
238
+ download_json = gr.File(label="Download JSON File",)
239
 
240
 
241