|
|
|
import os |
|
import pandas as pd |
|
import torch |
|
from model import MicrographCleaner |
|
from dataset import InferenceMicrographDataset, decode_array |
|
from inference_utils import sliding_window_inference |
|
import matplotlib.pyplot as plt |
|
import tqdm |
|
|
|
|
|
def main(): |
|
|
|
os.makedirs('predictions', exist_ok=True) |
|
|
|
|
|
WINDOW_SIZE = 512 |
|
THRESHOLD = 0.5 |
|
OVERLAP = 0.5 |
|
|
|
|
|
model = MicrographCleaner.load_from_checkpoint('final_checkpoint.ckpt', map_location='cpu') |
|
model.eval() |
|
|
|
|
|
test_df = pd.read_csv('test.csv') |
|
test_dataset = InferenceMicrographDataset(test_df, window_size=WINDOW_SIZE) |
|
|
|
|
|
unique_ids = set() |
|
model.eval() |
|
with torch.inference_mode(): |
|
for idx in tqdm.tqdm(range(len(test_dataset))): |
|
image, image_id, (pad_h, pad_w) = test_dataset[idx] |
|
|
|
|
|
if image_id in unique_ids: |
|
continue |
|
unique_ids.add(image_id) |
|
|
|
|
|
pred = sliding_window_inference( |
|
model, |
|
image, |
|
window_size=WINDOW_SIZE, |
|
overlap=OVERLAP |
|
) |
|
|
|
|
|
if pad_h > 0: |
|
pred = pred[..., :-pad_h, :] |
|
if pad_w > 0: |
|
pred = pred[..., :-pad_w] |
|
|
|
|
|
pred_mask = (pred > THRESHOLD).cpu().numpy()[0] |
|
|
|
|
|
orig_image = decode_array(test_df.iloc[idx]['image']) |
|
|
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5)) |
|
ax1.imshow(orig_image, cmap='gray') |
|
ax1.set_title('Original Image') |
|
ax1.axis('off') |
|
|
|
ax2.imshow(pred_mask, cmap='gray') |
|
ax2.set_title('Predicted Mask') |
|
ax2.axis('off') |
|
|
|
plt.tight_layout() |
|
plt.savefig(f'predictions/{image_id}_prediction.png') |
|
plt.close() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |