jamino30's picture
Upload folder using huggingface_hub
28ac920 verified
import torch
import torch.nn as nn
from torchvision import transforms
from safetensors.torch import load_file
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from model import U2Net
if torch.cuda.is_available(): device = 'cuda'
elif torch.backends.mps.is_available(): device = 'mps'
else: device = 'cpu'
device = torch.device(device)
def preprocess_image(image_path):
img = Image.open(image_path).convert('RGB')
preprocess = transforms.Compose([
transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img = preprocess(img).unsqueeze(0).to(device)
return img
def run_inference(model, image_path, threshold=None):
input_img = preprocess_image(image_path)
with torch.no_grad():
d1, *_ = model(input_img)
pred = torch.sigmoid(d1)
pred = pred[0, :, :].cpu().numpy()
pred = (pred - pred.min()) / (pred.max() - pred.min())
if threshold is not None:
pred = (pred > threshold).astype(np.uint8) * 255
else:
pred = (pred * 255).astype(np.uint8)
return pred
def overlay_segmentation(original_image, binary_mask, alpha=0.5):
original_image = Image.open(original_image).convert('RGB').resize((512, 512), Image.BILINEAR)
original_image_np = np.array(original_image)
overlay = np.zeros_like(original_image_np)
overlay[:, :, 0] = binary_mask
overlay_image = (1 - alpha) * original_image_np + alpha * overlay
overlay_image = overlay_image.astype(np.uint8)
return overlay_image
if __name__ == '__main__':
# ---
model_path = '../testing/u2net-duts-msra.safetensors'
filename = input('Filename: ')
image_path = f'../content_images/{filename}'
# ---
model = U2Net().to(device)
model = nn.DataParallel(model)
model.load_state_dict(load_file(model_path, device=device.type))
model.eval()
mask = run_inference(model, image_path)
mask_with_threshold = run_inference(model, image_path, threshold=0.7)
fig = plt.figure(figsize=(10, 10))
gs = GridSpec(2, 2, figure=fig, wspace=0, hspace=0)
images = [
Image.open(image_path).resize((512, 512)),
mask,
overlay_segmentation(image_path, mask_with_threshold),
mask_with_threshold
]
for i, img in enumerate(images):
ax = fig.add_subplot(gs[i // 2, i % 2])
ax.imshow(img, cmap='gray' if i % 2 != 0 else None)
ax.axis('off')
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
plt.savefig('../testing/inference-output.jpg', format='jpg', bbox_inches='tight', pad_inches=0)