Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torchvision.transforms.v2 as T | |
| import torch.nn.functional as F | |
| from .utils import expand_mask | |
| class LoadCLIPSegModels: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": {}, | |
| } | |
| RETURN_TYPES = ("CLIP_SEG",) | |
| FUNCTION = "execute" | |
| CATEGORY = "essentials/segmentation" | |
| def execute(self): | |
| from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation | |
| processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") | |
| model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") | |
| return ((processor, model),) | |
| class ApplyCLIPSeg: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "clip_seg": ("CLIP_SEG",), | |
| "image": ("IMAGE",), | |
| "prompt": ("STRING", { "multiline": False, "default": "" }), | |
| "threshold": ("FLOAT", { "default": 0.4, "min": 0.0, "max": 1.0, "step": 0.05 }), | |
| "smooth": ("INT", { "default": 9, "min": 0, "max": 32, "step": 1 }), | |
| "dilate": ("INT", { "default": 0, "min": -32, "max": 32, "step": 1 }), | |
| "blur": ("INT", { "default": 0, "min": 0, "max": 64, "step": 1 }), | |
| }, | |
| } | |
| RETURN_TYPES = ("MASK",) | |
| FUNCTION = "execute" | |
| CATEGORY = "essentials/segmentation" | |
| def execute(self, image, clip_seg, prompt, threshold, smooth, dilate, blur): | |
| processor, model = clip_seg | |
| imagenp = image.mul(255).clamp(0, 255).byte().cpu().numpy() | |
| outputs = [] | |
| for i in imagenp: | |
| inputs = processor(text=prompt, images=[i], return_tensors="pt") | |
| out = model(**inputs) | |
| out = out.logits.unsqueeze(1) | |
| out = torch.sigmoid(out[0][0]) | |
| out = (out > threshold) | |
| outputs.append(out) | |
| del imagenp | |
| outputs = torch.stack(outputs, dim=0) | |
| if smooth > 0: | |
| if smooth % 2 == 0: | |
| smooth += 1 | |
| outputs = T.functional.gaussian_blur(outputs, smooth) | |
| outputs = outputs.float() | |
| if dilate != 0: | |
| outputs = expand_mask(outputs, dilate, True) | |
| if blur > 0: | |
| if blur % 2 == 0: | |
| blur += 1 | |
| outputs = T.functional.gaussian_blur(outputs, blur) | |
| # resize to original size | |
| outputs = F.interpolate(outputs.unsqueeze(1), size=(image.shape[1], image.shape[2]), mode='bicubic').squeeze(1) | |
| return (outputs,) | |
| SEG_CLASS_MAPPINGS = { | |
| "ApplyCLIPSeg+": ApplyCLIPSeg, | |
| "LoadCLIPSegModels+": LoadCLIPSegModels, | |
| } | |
| SEG_NAME_MAPPINGS = { | |
| "ApplyCLIPSeg+": "🔧 Apply CLIPSeg", | |
| "LoadCLIPSegModels+": "🔧 Load CLIPSeg Models", | |
| } |