|
from fastai.vision.all import * |
|
image_extensions.add('.webp') |
|
|
|
class MultiCaReClassifier(): |
|
|
|
def __init__(self, image_folder, models_root = 'MultiCaReClassifier/models', save_path = '', add_multiclass_columns = False): |
|
|
|
'''Class used to classify medical images considering their types (such as ultrasound or MRI), and the corresponding anatomical region and view (for radiology images only). |
|
image_folder (str): folder containing all the input images. |
|
models_root (str): folder containing the image classification models. |
|
save_path (str): path to save the inference table. |
|
add_multiclass_columns (bool): if True, multiclass columns will be added to the dataframe based on the multilabel column ('label_list').''' |
|
|
|
self.image_folder = os.path.join(image_folder, '') |
|
self.models_root = models_root |
|
self.save_path = save_path |
|
self.add_multiclass_columns = add_multiclass_columns |
|
|
|
|
|
self.label_dict = { |
|
"image_type:radiology~anatomical_region:axial_region": ["abdomen", "breast", "head", "neck", "pelvis", "thorax"], |
|
"image_type:radiology~anatomical_region:lower_limb": ["ankle", "foot", "hip", "knee", "lower_leg", "thigh"], |
|
"image_type:radiology~anatomical_view": ["axial", "frontal", "intravascular", "oblique", "occlusal", "panoramic", "periapical", "sagittal", "transabdominal", "transesophageal", "transthoracic", "transvaginal"], |
|
"image_type:endoscopy": ["airway_endoscopy", "arthroscopy", "ig_endoscopy", "other_endoscopy"], |
|
"image_type:electrography": ["eeg", "ekg", "emg"], |
|
"image_type:ophthalmic_imaging": ["autofluorescence", "b_scan", "fundus_photograph", "gonioscopy", "oct", "ophtalmic_angiography", "slit_lamp_photograph"], |
|
"image_type:radiology~anatomical_region:upper_limb": ["elbow", "forearm", "hand", "shoulder", "upper_arm", "wrist"], |
|
"image_type:radiology~anatomical_region": ["axial_region", "lower_limb", "upper_limb", "whole_body"], |
|
"image_type:radiology~main": ["ct", "mri", "pet", "scintigraphy", "spect", "tractography", "ultrasound", "x_ray"], |
|
"image_type:pathology": ["acid_fast", "alcian_blue", "congo_red", "fish", "giemsa", "gram", "h&e", "immunostaining", "masson_trichrome", "methenamine_silver", "methylene_blue", "papanicolaou", "pas", "van_gieson"], |
|
"image_type:radiology~anatomical_region:axial_region.thorax": ["cardiac_image", "other_thoracic_image"], |
|
"image_type:medical_photograph": ["oral_photograph", "other_medical_photograph", "skin_photograph"], |
|
"image_type": ["chart", "electrography", "endoscopy", "medical_photograph", "ophthalmic_imaging", "pathology", "radiology"] |
|
} |
|
|
|
|
|
self.image_paths = get_image_files(self.image_folder) |
|
self.data = pd.DataFrame(columns=[name for name in self.label_dict.keys() if os.path.isdir(os.path.join('models', name.replace(':', '_')))]) |
|
self.data['image_path'] = self.image_paths |
|
self.predict_image_classes() |
|
|
|
|
|
|
|
def predict_image_classes(self): |
|
|
|
'''Method used to get the predictions for each image.''' |
|
|
|
|
|
model_order = 1 |
|
while True: |
|
order_count = 0 |
|
for model_name in self.label_dict.keys(): |
|
if len(re.split(r'[:.]', model_name)) == model_order: |
|
self._add_predictions(model_name) |
|
order_count += 1 |
|
if order_count == 0: |
|
break |
|
model_order += 1 |
|
|
|
|
|
self.apply_postprocessing() |
|
if self.save_path: |
|
self.data.to_csv(self.save_path, index=None) |
|
|
|
def apply_postprocessing(self): |
|
|
|
'''Method used to postprocess the predictions.''' |
|
|
|
|
|
columns_to_flatten = [c for c in self.data.columns if c.startswith('image_type')] |
|
self.data['label_list'] = self.data[columns_to_flatten].values.tolist() |
|
self.data['label_list'] = self.data['label_list'].apply(lambda x: [item for item in x if isinstance(item, (str, np.str_))]) |
|
self.data.drop(columns_to_flatten, axis = 1, inplace = True) |
|
|
|
|
|
replacement_dict = {'transesophageal': 'ultrasound_view', 'transthoracic': 'ultrasound_view', 'transabdominal': 'ultrasound_view', |
|
'transvaginal': 'ultrasound_view', 'ophtalmic_angiography': 'ophthalmic_angiography', 'ig_endoscopy': 'gi_endoscopy'} |
|
|
|
self.data['label_list'] = self.data['label_list'].apply(lambda x: [replacement_dict.get(item, item) for item in x]) |
|
|
|
|
|
self.data['label_list'] = self.data['label_list'].apply(lambda x: self._add_compound_classes(x)) |
|
|
|
|
|
if self.add_multiclass_columns: |
|
self._generate_multiclass_columns() |
|
|
|
|
|
auxiliary_classes = ['axial_region', 'cardiac_image', 'other_thoracic_image', 'intravascular', 'ultrasound_view'] |
|
self.data['label_list'] = self.data['label_list'].apply(lambda x: [item for item in x if item not in auxiliary_classes]) |
|
|
|
|
|
|
|
def _identify_upper_model(self, model_name): |
|
|
|
'''Method used to identify the corresponding upper model of a given model.''' |
|
|
|
colon_index = self._search_last_match(model_name, ':') |
|
dot_index = self._search_last_match(model_name, '.') |
|
index = max(colon_index, dot_index) |
|
if index != -1: |
|
return model_name[:index] |
|
else: |
|
return None |
|
|
|
def _search_last_match(self, string, character): |
|
|
|
'''Method used to find the last mention of a character in a string.''' |
|
|
|
if character in string: |
|
return string.rindex(character) |
|
else: |
|
return -1 |
|
|
|
def _add_predictions(self, model_name): |
|
|
|
'''Method used to add all the predictions of a given model to the outcome dataframe.''' |
|
|
|
upper_model = self._identify_upper_model(model_name) |
|
|
|
|
|
if upper_model is not None: |
|
condition_class = model_name.split(':')[-1].split('~')[0].split('.')[-1] |
|
condition = self.data[model_name].isnull() & (self.data[upper_model] == condition_class) |
|
else: |
|
condition = self.data[model_name].isnull() |
|
imgs = self.data[condition].image_path.values |
|
|
|
labels = np.array(self.label_dict[model_name]) |
|
|
|
|
|
if len(imgs) > 0: |
|
device = 'cpu' |
|
|
|
checkpoint_file = os.path.join(model_name.replace(':', '_'), 'model') |
|
dls = ImageDataLoaders.from_path_func('', imgs, lambda x: '0', item_tfms=Resize((224,224), method='squish')) |
|
learn = vision_learner(dls, resnet50, n_out=len(labels)).to_fp16() |
|
learn.load(checkpoint_file, device=device, weights_only=False) |
|
test_dl = learn.dls.test_dl(imgs, device=device) |
|
probs, _ = learn.get_preds(dl=test_dl) |
|
self.data.loc[condition, model_name] = labels[probs.argmax(axis=1)] |
|
|
|
def _add_compound_classes(self, input_class_list): |
|
|
|
'''This method is used to add compound classes to the label list if the corresponding component classes are present.''' |
|
|
|
compound_class_dicts = [ |
|
{'compound_class': 'echocardiogram', 'components': ['ultrasound', 'cardiac_image']}, |
|
{'compound_class': 'ivus', 'components': ['ultrasound', 'intravascular']}, |
|
{'compound_class': 'mammography', 'components': ['x_ray', 'breast']} |
|
] |
|
|
|
for dct in compound_class_dicts: |
|
condition = True |
|
for cls in dct['components']: |
|
if cls not in input_class_list: |
|
condition = False |
|
break |
|
if condition: |
|
if dct['compound_class'] not in input_class_list: |
|
input_class_list.append(dct['compound_class']) |
|
|
|
return input_class_list |
|
|
|
def _generate_multiclass_columns(self): |
|
|
|
'''Method used to generate the multiclass columns based on the label list.''' |
|
|
|
image_types = ['chart', 'radiology', 'pathology', 'medical_photograph', 'ophthalmic_imaging', 'endoscopy', 'electrography'] |
|
self.data['image_type'] = self.data['label_list'].apply(lambda x: self._get_column_label(x, image_types)) |
|
|
|
image_subtypes = ['chart', |
|
'ct', 'mri', 'x_ray', 'pet', 'spect', 'scintigraphy', 'ultrasound', 'tractography', |
|
'acid_fast', 'alcian_blue', 'congo_red', 'fish', 'giemsa', 'gram', 'h&e', 'immunostaining', 'masson_trichrome', 'methenamine_silver', 'methylene_blue', 'papanicolaou', 'pas', 'van_gieson', |
|
'skin_photograph', 'oral_photograph', 'other_medical_photograph', |
|
'b_scan', 'autofluorescence', 'fundus_photograph', 'gonioscopy', 'oct', 'ophthalmic_angiography', 'slit_lamp_photograph', |
|
'gi_endoscopy', 'airway_endoscopy', 'other_endoscopy', 'arthroscopy', |
|
'eeg', 'emg', 'ekg'] |
|
|
|
self.data['image_subtype'] = self.data['label_list'].apply(lambda x: self._get_column_label(x, image_subtypes)) |
|
|
|
anatomical_regions = ['abdomen', 'breast', 'head', 'neck', 'pelvis', 'thorax', |
|
'lower_limb', 'upper_limb', 'whole_body'] |
|
|
|
self.data['radiology_region'] = self.data['label_list'].apply(lambda x: self._get_column_label(x, anatomical_regions)) |
|
|
|
granular_anatomical_regions = ['abdomen', 'breast', 'head', 'neck', 'pelvis', 'thorax', |
|
'ankle', 'foot', 'hip', 'knee', 'lower_leg', 'thigh', |
|
'elbow', 'forearm', 'hand', 'shoulder', 'upper_arm', 'wrist', |
|
'whole_body'] |
|
|
|
self.data['radiology_region_granular'] = self.data['label_list'].apply(lambda x: self._get_column_label(x, granular_anatomical_regions)) |
|
|
|
anatomical_view = ['axial', 'frontal', 'sagittal', 'oblique', |
|
'occlusal', 'panoramic', 'periapical', 'intravascular', 'ultrasound_view'] |
|
|
|
self.data['radiology_view'] = self.data['label_list'].apply(lambda x: self._get_column_label(x, anatomical_view)) |
|
|
|
def _get_column_label(self, column_list, label_list): |
|
|
|
'''Method used to get the label from a relevant list that is present in the predictions of a given image.''' |
|
|
|
label = '' |
|
for column in column_list: |
|
if column in label_list: |
|
label = column |
|
return label |