|
import os |
|
import argparse |
|
from typing import List |
|
import PIL |
|
import torch |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from matplotlib.backends.backend_agg import FigureCanvasAgg |
|
|
|
from .pix2seq import build_pix2seq_model |
|
from .tokenizer import get_tokenizer |
|
from .dataset import make_transforms |
|
from .data import postprocess_reactions, ReactionImageData |
|
|
|
from molscribe import MolScribe |
|
from huggingface_hub import hf_hub_download |
|
import easyocr |
|
|
|
|
|
class Reaction: |
|
|
|
def __init__(self, model_path, device=None): |
|
""" |
|
:param model_path: path of the model checkpoint. |
|
:param device: torch device, defaults to be CPU. |
|
""" |
|
args = self._get_args() |
|
args.format = 'reaction' |
|
states = torch.load(model_path, map_location=torch.device('cpu')) |
|
if device is None: |
|
device = torch.device('cpu') |
|
self.device = device |
|
self.tokenizer = get_tokenizer(args) |
|
self.model = self.get_model(args, self.tokenizer, self.device, states['state_dict']) |
|
self.transform = make_transforms('test', augment=False, debug=False) |
|
self.molscribe = self.get_molscribe() |
|
self.ocr_model = self.get_ocr_model() |
|
|
|
def _get_args(self): |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument('--backbone', default='resnet50', type=str, |
|
help="Name of the convolutional backbone to use") |
|
parser.add_argument('--dilation', action='store_true', |
|
help="If true, we replace stride with dilation in the last convolutional block (DC5)") |
|
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), |
|
help="Type of positional embedding to use on top of the image features") |
|
|
|
parser.add_argument('--enc_layers', default=6, type=int, help="Number of encoding layers in the transformer") |
|
parser.add_argument('--dec_layers', default=6, type=int, help="Number of decoding layers in the transformer") |
|
parser.add_argument('--dim_feedforward', default=1024, type=int, |
|
help="Intermediate size of the feedforward layers in the transformer blocks") |
|
parser.add_argument('--hidden_dim', default=256, type=int, |
|
help="Size of the embeddings (dimension of the transformer)") |
|
parser.add_argument('--dropout', default=0.1, type=float, help="Dropout applied in the transformer") |
|
parser.add_argument('--nheads', default=8, type=int, |
|
help="Number of attention heads inside the transformer's attentions") |
|
parser.add_argument('--pre_norm', action='store_true') |
|
|
|
parser.add_argument('--format', type=str, default='reaction') |
|
parser.add_argument('--input_size', type=int, default=1333) |
|
|
|
args = parser.parse_args([]) |
|
args.pix2seq = True |
|
args.pix2seq_ckpt = None |
|
args.pred_eos = True |
|
return args |
|
|
|
def get_model(self, args, tokenizer, device, model_states): |
|
def remove_prefix(state_dict): |
|
return {k.replace('model.', ''): v for k, v in state_dict.items()} |
|
|
|
model = build_pix2seq_model(args, tokenizer[args.format]) |
|
model.load_state_dict(remove_prefix(model_states), strict=False) |
|
model.to(device) |
|
model.eval() |
|
return model |
|
|
|
def get_molscribe(self): |
|
ckpt_path = hf_hub_download("yujieq/MolScribe", "swin_base_char_aux_1m.pth") |
|
molscribe = MolScribe(ckpt_path, device=self.device) |
|
return molscribe |
|
|
|
def get_ocr_model(self): |
|
reader = easyocr.Reader(['en'], gpu=(self.device.type == 'cuda')) |
|
return reader |
|
|
|
def predict_images(self, input_images: List, batch_size=16, molscribe=False, ocr=False): |
|
|
|
device = self.device |
|
tokenizer = self.tokenizer['reaction'] |
|
predictions = [] |
|
for idx in range(0, len(input_images), batch_size): |
|
batch_images = input_images[idx:idx+batch_size] |
|
images, refs = zip(*[self.transform(image) for image in batch_images]) |
|
images = torch.stack(images, dim=0).to(device) |
|
with torch.no_grad(): |
|
pred_seqs, pred_scores = self.model(images, max_len=tokenizer.max_len) |
|
for i, (seqs, scores) in enumerate(zip(pred_seqs, pred_scores)): |
|
reactions = tokenizer.sequence_to_data(seqs.tolist(), scores.tolist(), scale=refs[i]['scale']) |
|
reactions = postprocess_reactions( |
|
reactions, |
|
image=input_images[i], |
|
molscribe=self.molscribe if molscribe else None, |
|
ocr=self.ocr_model if ocr else None |
|
) |
|
predictions.append(reactions) |
|
return predictions |
|
|
|
def predict_image(self, image, **kwargs): |
|
predictions = self.predict_images([image], **kwargs) |
|
return predictions[0] |
|
|
|
def predict_image_files(self, image_files: List, **kwargs): |
|
input_images = [] |
|
for path in image_files: |
|
image = PIL.Image.open(path).convert("RGB") |
|
input_images.append(image) |
|
return self.predict_images(input_images, **kwargs) |
|
|
|
def predict_image_file(self, image_file: str, **kwargs): |
|
predictions = self.predict_image_files([image_file], **kwargs) |
|
return predictions[0] |
|
|
|
def draw_predictions(self, predictions, image=None, image_file=None): |
|
results = [] |
|
assert image or image_file |
|
data = ReactionImageData(predictions=predictions, image=image, image_file=image_file) |
|
h, w = np.array([data.height, data.width]) * 10 / max(data.height, data.width) |
|
for r in data.pred_reactions: |
|
fig, ax = plt.subplots(figsize=(w, h)) |
|
fig.tight_layout() |
|
canvas = FigureCanvasAgg(fig) |
|
ax.imshow(data.image) |
|
ax.axis('off') |
|
r.draw(ax) |
|
canvas.draw() |
|
buf = canvas.buffer_rgba() |
|
results.append(np.asarray(buf)) |
|
plt.close(fig) |
|
return results |
|
|
|
def draw_predictions_combined(self, predictions, image=None, image_file=None): |
|
assert image or image_file |
|
data = ReactionImageData(predictions=predictions, image=image, image_file=image_file) |
|
h, w = np.array([data.height, data.width]) * 10 / max(data.height, data.width) |
|
n = len(data.pred_reactions) |
|
fig, axes = plt.subplots(n, 1, figsize=(w, h * n)) |
|
if n == 1: |
|
axes = [axes] |
|
fig.tight_layout(rect=(0.02, 0.02, 0.99, 0.99)) |
|
canvas = FigureCanvasAgg(fig) |
|
for i, r in enumerate(data.pred_reactions): |
|
ax = axes[i] |
|
ax.imshow(data.image) |
|
ax.set_xticks([]) |
|
ax.set_yticks([]) |
|
ax.set_title(f'reaction # {i}', fontdict={'fontweight': 'bold', 'fontsize': 14}) |
|
r.draw(ax) |
|
canvas.draw() |
|
buf = canvas.buffer_rgba() |
|
result_image = np.asarray(buf) |
|
plt.close(fig) |
|
return result_image |
|
|