CYF200127 commited on
Commit
ca5fbc9
·
verified ·
1 Parent(s): f1f2574

Update getReaction.py

Browse files
Files changed (1) hide show
  1. getReaction.py +79 -78
getReaction.py CHANGED
@@ -1,78 +1,79 @@
1
- import sys
2
- sys.path.append('./rxn/')
3
- import torch
4
- from rxn.reaction import Reaction
5
- import json
6
- from matplotlib import pyplot as plt
7
- import numpy as np
8
-
9
- ckpt_path = "./rxn/model/model.ckpt"
10
- model = Reaction(ckpt_path, device=torch.device('cpu'))
11
- device = torch.device('cpu')
12
-
13
- def get_reaction(image_path: str) -> list:
14
- '''Returns a list of reactions extracted from the image.'''
15
- image_file = image_path
16
- return json.dumps(model.predict_image_file(image_file, molscribe=True, ocr=True))
17
-
18
-
19
-
20
- def generate_combined_image(predictions, image_file):
21
- """
22
- 将预测的图像整合到一个对称的布局中输出。
23
- """
24
- output = model.draw_predictions(predictions, image_file=image_file)
25
- n_images = len(output)
26
- if n_images == 1:
27
- n_cols = 1
28
- elif n_images == 2:
29
- n_cols = 2
30
- else:
31
- n_cols = 3
32
- n_rows = (n_images + n_cols - 1) // n_cols # 计算需要的行数
33
-
34
- # 确保每张图像符合要求
35
- processed_images = []
36
- for img in output:
37
- if len(img.shape) == 2: # 灰度图像
38
- img = np.stack([img] * 3, axis=-1) # 转换为 RGB 格式
39
- elif img.shape[2] > 3: # RGBA 图像
40
- img = img[:, :, :3] # 只保留 RGB 通道
41
- if img.dtype == np.float32 or img.dtype == np.float64:
42
- img = (img * 255).astype(np.uint8) # 转换为 uint8
43
- processed_images.append(img)
44
- output = processed_images
45
-
46
- # 为不足的子图位置添加占位图
47
- if n_images < n_rows * n_cols:
48
- blank_image = np.ones_like(output[0]) * 255 # 生成一个白色占位图
49
- while len(output) < n_rows * n_cols:
50
- output.append(blank_image)
51
-
52
- # 创建子图画布
53
- fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))
54
-
55
- # 确保 axes 是一维数组
56
- if isinstance(axes, np.ndarray):
57
- axes = axes.flatten()
58
- else:
59
- axes = [axes] # 单个子图的情况
60
-
61
- # 绘制每张图像
62
- for idx, img in enumerate(output):
63
- ax = axes[idx]
64
- ax.imshow(img)
65
- ax.axis('off')
66
- if idx < n_images:
67
- ax.set_title(f"Reaction {idx + 1}")
68
-
69
- # 删除多余的子图
70
- for idx in range(n_images, len(axes)):
71
- fig.delaxes(axes[idx])
72
-
73
- # 保存整合图像
74
- combined_image_path = "combined_output.png"
75
- plt.tight_layout()
76
- plt.savefig(combined_image_path)
77
- plt.close(fig)
78
- return combined_image_path
 
 
1
+ import sys
2
+ sys.path.append('./rxn/')
3
+ import torch
4
+ from rxn.reaction import Reaction
5
+ import json
6
+ from matplotlib import pyplot as plt
7
+ import numpy as np
8
+
9
+ ckpt_path = "./rxn/model/model.ckpt"
10
+ model = Reaction(ckpt_path, device=torch.device('cpu'))
11
+ device = torch.device('cpu')
12
+
13
+ def get_reaction(image_path: str) -> list:
14
+ '''Returns a list of reactions extracted from the image.'''
15
+ image_file = image_path
16
+ return json.dumps(model.predict_image_file(image_file, molscribe=True, ocr=True))
17
+
18
+
19
+
20
+ def generate_combined_image(predictions, image_file):
21
+ """
22
+ 将预测的图像整合到一个对称的布局中输出。
23
+ """
24
+ output = model.draw_predictions(predictions, image_file=image_file)
25
+ n_images = len(output)
26
+ # if n_images == 1:
27
+ # n_cols = 1
28
+ # elif n_images == 2:
29
+ # n_cols = 2
30
+ # else:
31
+ # n_cols = 3
32
+ n_cols = 1
33
+ n_rows = (n_images + n_cols - 1) // n_cols # 计算需要的行数
34
+
35
+ # 确保每张图像符合要求
36
+ processed_images = []
37
+ for img in output:
38
+ if len(img.shape) == 2: # 灰度图像
39
+ img = np.stack([img] * 3, axis=-1) # 转换为 RGB 格式
40
+ elif img.shape[2] > 3: # RGBA 图像
41
+ img = img[:, :, :3] # 只保留 RGB 通道
42
+ if img.dtype == np.float32 or img.dtype == np.float64:
43
+ img = (img * 255).astype(np.uint8) # 转换为 uint8
44
+ processed_images.append(img)
45
+ output = processed_images
46
+
47
+ # 为不足的子图位置添加占位图
48
+ if n_images < n_rows * n_cols:
49
+ blank_image = np.ones_like(output[0]) * 255 # 生成一个白色占位图
50
+ while len(output) < n_rows * n_cols:
51
+ output.append(blank_image)
52
+
53
+ # 创建子图画布
54
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 12 * n_rows))
55
+
56
+ # 确保 axes 是一维数组
57
+ if isinstance(axes, np.ndarray):
58
+ axes = axes.flatten()
59
+ else:
60
+ axes = [axes] # 单个子图的情况
61
+
62
+ # 绘制每张图像
63
+ for idx, img in enumerate(output):
64
+ ax = axes[idx]
65
+ ax.imshow(img)
66
+ ax.axis('off')
67
+ if idx < n_images:
68
+ ax.set_title(f"Reaction {idx + 1}",fontsize=42)
69
+
70
+ # 删除多余的子图
71
+ for idx in range(n_images, len(axes)):
72
+ fig.delaxes(axes[idx])
73
+
74
+ # 保存整合图像
75
+ combined_image_path = "combined_output.png"
76
+ plt.tight_layout()
77
+ plt.savefig(combined_image_path)
78
+ plt.close(fig)
79
+ return combined_image_path