Upload 52 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- gradio_demo/app.py +545 -0
- gradio_demo/app_generateOne.py +529 -0
- gradio_demo/character_template.py +62 -0
- src/efficientvit/__init__.py +0 -0
- src/efficientvit/apps/__init__.py +0 -0
- src/efficientvit/apps/data_provider/__init__.py +7 -0
- src/efficientvit/apps/data_provider/augment/__init__.py +6 -0
- src/efficientvit/apps/data_provider/augment/bbox.py +30 -0
- src/efficientvit/apps/data_provider/augment/color_aug.py +84 -0
- src/efficientvit/apps/data_provider/base.py +223 -0
- src/efficientvit/apps/data_provider/random_resolution/__init__.py +7 -0
- src/efficientvit/apps/data_provider/random_resolution/_data_loader.py +1598 -0
- src/efficientvit/apps/data_provider/random_resolution/_data_worker.py +377 -0
- src/efficientvit/apps/data_provider/random_resolution/controller.py +94 -0
- src/efficientvit/apps/setup.py +141 -0
- src/efficientvit/apps/trainer/__init__.py +6 -0
- src/efficientvit/apps/trainer/base.py +297 -0
- src/efficientvit/apps/trainer/run_config.py +121 -0
- src/efficientvit/apps/utils/__init__.py +12 -0
- src/efficientvit/apps/utils/dist.py +73 -0
- src/efficientvit/apps/utils/ema.py +50 -0
- src/efficientvit/apps/utils/export.py +47 -0
- src/efficientvit/apps/utils/init.py +68 -0
- src/efficientvit/apps/utils/lr.py +48 -0
- src/efficientvit/apps/utils/metric.py +37 -0
- src/efficientvit/apps/utils/misc.py +111 -0
- src/efficientvit/apps/utils/opt.py +31 -0
- src/efficientvit/models/__init__.py +0 -0
- src/efficientvit/models/efficientvit/__init__.py +8 -0
- src/efficientvit/models/efficientvit/backbone.py +372 -0
- src/efficientvit/models/efficientvit/cls.py +174 -0
- src/efficientvit/models/efficientvit/sam.py +653 -0
- src/efficientvit/models/efficientvit/seg.py +355 -0
- src/efficientvit/models/nn/__init__.py +8 -0
- src/efficientvit/models/nn/act.py +30 -0
- src/efficientvit/models/nn/drop.py +98 -0
- src/efficientvit/models/nn/norm.py +157 -0
- src/efficientvit/models/nn/ops.py +585 -0
- src/efficientvit/models/utils/__init__.py +7 -0
- src/efficientvit/models/utils/list.py +57 -0
- src/efficientvit/models/utils/network.py +77 -0
- src/efficientvit/models/utils/random.py +73 -0
- src/efficientvit/sam_model_zoo.py +53 -0
- src/ip_adapter/attention_processor.py +424 -0
- src/ip_adapter/resampler.py +120 -0
- src/ip_adapter/utils.py +5 -0
- src/pipelines/instantid_pipeline.py +720 -0
- src/pipelines/instantid_single_pieline.py +772 -0
- src/pipelines/lora_pipeline.py +681 -0
- src/prompt_attention/p2p_attention.py +148 -0
gradio_demo/app.py
ADDED
|
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('./')
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import random
|
| 5 |
+
import numpy as np
|
| 6 |
+
from gradio_demo.character_template import character_man, lorapath_man
|
| 7 |
+
from gradio_demo.character_template import character_woman, lorapath_woman
|
| 8 |
+
from gradio_demo.character_template import styles, lorapath_styles
|
| 9 |
+
import torch
|
| 10 |
+
import os
|
| 11 |
+
from typing import Tuple, List
|
| 12 |
+
import copy
|
| 13 |
+
import argparse
|
| 14 |
+
from diffusers.utils import load_image
|
| 15 |
+
import cv2
|
| 16 |
+
from PIL import Image
|
| 17 |
+
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
|
| 18 |
+
from controlnet_aux import OpenposeDetector
|
| 19 |
+
from controlnet_aux.open_pose.body import Body
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
from inference.models import YOLOWorld
|
| 23 |
+
from src.efficientvit.models.efficientvit.sam import EfficientViTSamPredictor
|
| 24 |
+
from src.efficientvit.sam_model_zoo import create_sam_model
|
| 25 |
+
import supervision as sv
|
| 26 |
+
except:
|
| 27 |
+
print("YoloWorld can not be load")
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
from groundingdino.models import build_model
|
| 31 |
+
from groundingdino.util import box_ops
|
| 32 |
+
from groundingdino.util.slconfig import SLConfig
|
| 33 |
+
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
|
| 34 |
+
from groundingdino.util.inference import annotate, predict
|
| 35 |
+
from segment_anything import build_sam, SamPredictor
|
| 36 |
+
import groundingdino.datasets.transforms as T
|
| 37 |
+
except:
|
| 38 |
+
print("groundingdino can not be load")
|
| 39 |
+
|
| 40 |
+
from src.pipelines.lora_pipeline import LoraMultiConceptPipeline
|
| 41 |
+
from src.prompt_attention.p2p_attention import AttentionReplace
|
| 42 |
+
from diffusers import ControlNetModel, StableDiffusionXLPipeline
|
| 43 |
+
from src.pipelines.lora_pipeline import revise_regionally_controlnet_forward
|
| 44 |
+
|
| 45 |
+
CHARACTER_MAN_NAMES = list(character_man.keys())
|
| 46 |
+
CHARACTER_WOMAN_NAMES = list(character_woman.keys())
|
| 47 |
+
STYLE_NAMES = list(styles.keys())
|
| 48 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 49 |
+
|
| 50 |
+
### Description
|
| 51 |
+
title = r"""
|
| 52 |
+
<h1 align="center">OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models</h1>
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
description = r"""
|
| 56 |
+
<b>Official 🤗 Gradio demo</b> for <a href='https://github.com/' target='_blank'><b>OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models</b></a>.<br>
|
| 57 |
+
|
| 58 |
+
How to use:<br>
|
| 59 |
+
1. Select two characters.
|
| 60 |
+
2. Enter a text prompt as done in normal text-to-image models.
|
| 61 |
+
3. Click the <b>Submit</b> button to start customizing.
|
| 62 |
+
4. Enjoy the generated image😊!
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
article = r"""
|
| 66 |
+
---
|
| 67 |
+
📝 **Citation**
|
| 68 |
+
<br>
|
| 69 |
+
If our work is helpful for your research or applications, please cite us via:
|
| 70 |
+
```bibtex
|
| 71 |
+
@article{,
|
| 72 |
+
title={OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models},
|
| 73 |
+
author={},
|
| 74 |
+
journal={},
|
| 75 |
+
year={}
|
| 76 |
+
}
|
| 77 |
+
```
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
tips = r"""
|
| 81 |
+
### Usage tips of OMG
|
| 82 |
+
1. Input text prompts to describe a man and a woman
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
css = '''
|
| 86 |
+
.gradio-container {width: 85% !important}
|
| 87 |
+
'''
|
| 88 |
+
|
| 89 |
+
def sample_image(pipe,
|
| 90 |
+
input_prompt,
|
| 91 |
+
input_neg_prompt=None,
|
| 92 |
+
generator=None,
|
| 93 |
+
concept_models=None,
|
| 94 |
+
num_inference_steps=50,
|
| 95 |
+
guidance_scale=7.5,
|
| 96 |
+
controller=None,
|
| 97 |
+
stage=None,
|
| 98 |
+
region_masks=None,
|
| 99 |
+
lora_list = None,
|
| 100 |
+
styleL=None,
|
| 101 |
+
**extra_kargs
|
| 102 |
+
):
|
| 103 |
+
|
| 104 |
+
spatial_condition = extra_kargs.pop('spatial_condition')
|
| 105 |
+
if spatial_condition is not None:
|
| 106 |
+
spatial_condition_input = [spatial_condition] * len(input_prompt)
|
| 107 |
+
else:
|
| 108 |
+
spatial_condition_input = None
|
| 109 |
+
|
| 110 |
+
images = pipe(
|
| 111 |
+
prompt=input_prompt,
|
| 112 |
+
concept_models=concept_models,
|
| 113 |
+
negative_prompt=input_neg_prompt,
|
| 114 |
+
generator=generator,
|
| 115 |
+
guidance_scale=guidance_scale,
|
| 116 |
+
num_inference_steps=num_inference_steps,
|
| 117 |
+
cross_attention_kwargs={"scale": 0.8},
|
| 118 |
+
controller=controller,
|
| 119 |
+
stage=stage,
|
| 120 |
+
region_masks=region_masks,
|
| 121 |
+
lora_list=lora_list,
|
| 122 |
+
styleL=styleL,
|
| 123 |
+
image=spatial_condition_input,
|
| 124 |
+
**extra_kargs).images
|
| 125 |
+
|
| 126 |
+
return images
|
| 127 |
+
|
| 128 |
+
def load_image_yoloworld(image_source) -> Tuple[np.array, torch.Tensor]:
|
| 129 |
+
image = np.asarray(image_source)
|
| 130 |
+
return image
|
| 131 |
+
|
| 132 |
+
def load_image_dino(image_source) -> Tuple[np.array, torch.Tensor]:
|
| 133 |
+
transform = T.Compose(
|
| 134 |
+
[
|
| 135 |
+
T.RandomResize([800], max_size=1333),
|
| 136 |
+
T.ToTensor(),
|
| 137 |
+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
| 138 |
+
]
|
| 139 |
+
)
|
| 140 |
+
image = np.asarray(image_source)
|
| 141 |
+
image_transformed, _ = transform(image_source, None)
|
| 142 |
+
return image, image_transformed
|
| 143 |
+
|
| 144 |
+
def predict_mask(segmentmodel, sam, image, TEXT_PROMPT, segmentType, confidence = 0.2, threshold = 0.5):
|
| 145 |
+
if segmentType=='GroundingDINO':
|
| 146 |
+
image_source, image = load_image_dino(image)
|
| 147 |
+
boxes, logits, phrases = predict(
|
| 148 |
+
model=segmentmodel,
|
| 149 |
+
image=image,
|
| 150 |
+
caption=TEXT_PROMPT,
|
| 151 |
+
box_threshold=0.3,
|
| 152 |
+
text_threshold=0.25
|
| 153 |
+
)
|
| 154 |
+
sam.set_image(image_source)
|
| 155 |
+
H, W, _ = image_source.shape
|
| 156 |
+
boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
|
| 157 |
+
|
| 158 |
+
transformed_boxes = sam.transform.apply_boxes_torch(boxes_xyxy, image_source.shape[:2]).cuda()
|
| 159 |
+
masks, _, _ = sam.predict_torch(
|
| 160 |
+
point_coords=None,
|
| 161 |
+
point_labels=None,
|
| 162 |
+
boxes=transformed_boxes,
|
| 163 |
+
multimask_output=False,
|
| 164 |
+
)
|
| 165 |
+
masks=masks[0].squeeze(0)
|
| 166 |
+
else:
|
| 167 |
+
image_source = load_image_yoloworld(image)
|
| 168 |
+
segmentmodel.set_classes([TEXT_PROMPT])
|
| 169 |
+
results = segmentmodel.infer(image_source, confidence=confidence)
|
| 170 |
+
detections = sv.Detections.from_inference(results).with_nms(
|
| 171 |
+
class_agnostic=True, threshold=threshold
|
| 172 |
+
)
|
| 173 |
+
masks = None
|
| 174 |
+
if len(detections) != 0:
|
| 175 |
+
print(TEXT_PROMPT + " detected!")
|
| 176 |
+
sam.set_image(image_source, image_format="RGB")
|
| 177 |
+
masks, _, _ = sam.predict(box=detections.xyxy[0], multimask_output=False)
|
| 178 |
+
masks = torch.from_numpy(masks.squeeze())
|
| 179 |
+
|
| 180 |
+
return masks
|
| 181 |
+
|
| 182 |
+
def prepare_text(prompt, region_prompts):
|
| 183 |
+
'''
|
| 184 |
+
Args:
|
| 185 |
+
prompt_entity: [subject1]-*-[attribute1]-*-[Location1]|[subject2]-*-[attribute2]-*-[Location2]|[global text]
|
| 186 |
+
Returns:
|
| 187 |
+
full_prompt: subject1, attribute1 and subject2, attribute2, global text
|
| 188 |
+
context_prompt: subject1 and subject2, global text
|
| 189 |
+
entity_collection: [(subject1, attribute1), Location1]
|
| 190 |
+
'''
|
| 191 |
+
region_collection = []
|
| 192 |
+
|
| 193 |
+
regions = region_prompts.split('|')
|
| 194 |
+
|
| 195 |
+
for region in regions:
|
| 196 |
+
if region == '':
|
| 197 |
+
break
|
| 198 |
+
prompt_region, neg_prompt_region = region.split('-*-')
|
| 199 |
+
prompt_region = prompt_region.replace('[', '').replace(']', '')
|
| 200 |
+
neg_prompt_region = neg_prompt_region.replace('[', '').replace(']', '')
|
| 201 |
+
|
| 202 |
+
region_collection.append((prompt_region, neg_prompt_region))
|
| 203 |
+
return (prompt, region_collection)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def build_model_sd(pretrained_model, controlnet_path, device, prompts):
|
| 207 |
+
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16).to(device)
|
| 208 |
+
pipe = LoraMultiConceptPipeline.from_pretrained(
|
| 209 |
+
pretrained_model, controlnet=controlnet, torch_dtype=torch.float16, variant="fp16").to(device)
|
| 210 |
+
controller = AttentionReplace(prompts, 50, cross_replace_steps={"default_": 1.}, self_replace_steps=0.4, tokenizer=pipe.tokenizer, device=device, dtype=torch.float16, width=1024//32, height=1024//32)
|
| 211 |
+
revise_regionally_controlnet_forward(pipe.unet, controller)
|
| 212 |
+
pipe_concept = StableDiffusionXLPipeline.from_pretrained(pretrained_model, torch_dtype=torch.float16,
|
| 213 |
+
variant="fp16").to(device)
|
| 214 |
+
return pipe, controller, pipe_concept
|
| 215 |
+
|
| 216 |
+
def build_model_lora(pipe_concept, lora_paths, style_path, condition, args):
|
| 217 |
+
pipe_list = []
|
| 218 |
+
if condition == "Human pose":
|
| 219 |
+
controlnet = ControlNetModel.from_pretrained(args.openpose_checkpoint, torch_dtype=torch.float16).to(device)
|
| 220 |
+
pipe_concept.controlnet = controlnet
|
| 221 |
+
elif condition == "Canny Edge":
|
| 222 |
+
controlnet = ControlNetModel.from_pretrained(args.canny_checkpoint, torch_dtype=torch.float16).to(device)
|
| 223 |
+
pipe_concept.controlnet = controlnet
|
| 224 |
+
elif condition == "Depth":
|
| 225 |
+
controlnet = ControlNetModel.from_pretrained(args.depth_checkpoint, torch_dtype=torch.float16).to(device)
|
| 226 |
+
pipe_concept.controlnet = controlnet
|
| 227 |
+
|
| 228 |
+
if style_path is not None and os.path.exists(style_path):
|
| 229 |
+
pipe_concept.load_lora_weights(style_path, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
|
| 230 |
+
|
| 231 |
+
for lora_path in lora_paths.split('|'):
|
| 232 |
+
adapter_name = lora_path.split('/')[-1].split('.')[0]
|
| 233 |
+
pipe_concept.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name=adapter_name)
|
| 234 |
+
pipe_concept.enable_xformers_memory_efficient_attention()
|
| 235 |
+
pipe_list.append(adapter_name)
|
| 236 |
+
return pipe_list
|
| 237 |
+
|
| 238 |
+
def build_yolo_segment_model(sam_path, device):
|
| 239 |
+
yolo_world = YOLOWorld(model_id="yolo_world/l")
|
| 240 |
+
sam = EfficientViTSamPredictor(
|
| 241 |
+
create_sam_model(name="xl1", weight_url=sam_path).to(device).eval()
|
| 242 |
+
)
|
| 243 |
+
return yolo_world, sam
|
| 244 |
+
|
| 245 |
+
def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
|
| 246 |
+
args = SLConfig.fromfile(ckpt_config_filename)
|
| 247 |
+
model = build_model(args)
|
| 248 |
+
args.device = device
|
| 249 |
+
|
| 250 |
+
checkpoint = torch.load(os.path.join(repo_id, filename), map_location='cpu')
|
| 251 |
+
log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
|
| 252 |
+
print("Model loaded from {} \n => {}".format(filename, log))
|
| 253 |
+
_ = model.eval()
|
| 254 |
+
return model
|
| 255 |
+
|
| 256 |
+
def build_dino_segment_model(ckpt_repo_id, sam_checkpoint):
|
| 257 |
+
ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
|
| 258 |
+
ckpt_config_filename = os.path.join(ckpt_repo_id, "GroundingDINO_SwinB.cfg.py")
|
| 259 |
+
groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename)
|
| 260 |
+
sam = build_sam(checkpoint=sam_checkpoint)
|
| 261 |
+
sam.cuda()
|
| 262 |
+
sam_predictor = SamPredictor(sam)
|
| 263 |
+
return groundingdino_model, sam_predictor
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def main(device, segment_type):
|
| 268 |
+
pipe, controller, pipe_concept = build_model_sd(args.pretrained_sdxl_model, args.openpose_checkpoint, device, prompts_tmp)
|
| 269 |
+
|
| 270 |
+
if segment_type == 'GroundingDINO':
|
| 271 |
+
detect_model, sam = build_dino_segment_model(args.dino_checkpoint, args.sam_checkpoint)
|
| 272 |
+
else:
|
| 273 |
+
detect_model, sam = build_yolo_segment_model(args.efficientViT_checkpoint, device)
|
| 274 |
+
|
| 275 |
+
resolution_list = ["1440*728",
|
| 276 |
+
"1344*768",
|
| 277 |
+
"1216*832",
|
| 278 |
+
"1152*896",
|
| 279 |
+
"1024*1024",
|
| 280 |
+
"896*1152",
|
| 281 |
+
"832*1216",
|
| 282 |
+
"768*1344",
|
| 283 |
+
"728*1440"]
|
| 284 |
+
|
| 285 |
+
condition_list = ["None",
|
| 286 |
+
"Human pose",
|
| 287 |
+
"Canny Edge",
|
| 288 |
+
"Depth"]
|
| 289 |
+
|
| 290 |
+
depth_estimator = DPTForDepthEstimation.from_pretrained(args.dpt_checkpoint).to("cuda")
|
| 291 |
+
feature_extractor = DPTFeatureExtractor.from_pretrained(args.dpt_checkpoint)
|
| 292 |
+
body_model = Body(args.pose_detector_checkpoint)
|
| 293 |
+
openpose = OpenposeDetector(body_model)
|
| 294 |
+
|
| 295 |
+
def remove_tips():
|
| 296 |
+
return gr.update(visible=False)
|
| 297 |
+
|
| 298 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
| 299 |
+
if randomize_seed:
|
| 300 |
+
seed = random.randint(0, MAX_SEED)
|
| 301 |
+
return seed
|
| 302 |
+
|
| 303 |
+
def get_humanpose(img):
|
| 304 |
+
openpose_image = openpose(img)
|
| 305 |
+
return openpose_image
|
| 306 |
+
|
| 307 |
+
def get_cannyedge(image):
|
| 308 |
+
image = np.array(image)
|
| 309 |
+
image = cv2.Canny(image, 100, 200)
|
| 310 |
+
image = image[:, :, None]
|
| 311 |
+
image = np.concatenate([image, image, image], axis=2)
|
| 312 |
+
canny_image = Image.fromarray(image)
|
| 313 |
+
return canny_image
|
| 314 |
+
|
| 315 |
+
def get_depth(image):
|
| 316 |
+
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
|
| 317 |
+
with torch.no_grad(), torch.autocast("cuda"):
|
| 318 |
+
depth_map = depth_estimator(image).predicted_depth
|
| 319 |
+
|
| 320 |
+
depth_map = torch.nn.functional.interpolate(
|
| 321 |
+
depth_map.unsqueeze(1),
|
| 322 |
+
size=(1024, 1024),
|
| 323 |
+
mode="bicubic",
|
| 324 |
+
align_corners=False,
|
| 325 |
+
)
|
| 326 |
+
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
|
| 327 |
+
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
|
| 328 |
+
depth_map = (depth_map - depth_min) / (depth_max - depth_min)
|
| 329 |
+
image = torch.cat([depth_map] * 3, dim=1)
|
| 330 |
+
image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
|
| 331 |
+
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
|
| 332 |
+
return image
|
| 333 |
+
|
| 334 |
+
def generate_image(prompt1, prompt2, prompt3, prompt4, negative_prompt, man, woman, resolution, local_prompt1, local_prompt2, seed, condition, condition_img1, condition_img2, condition_img3, condition_img4, style):
|
| 335 |
+
try:
|
| 336 |
+
path1 = lorapath_man[man]
|
| 337 |
+
path2 = lorapath_woman[woman]
|
| 338 |
+
pipe_concept.unload_lora_weights()
|
| 339 |
+
pipe_list = build_model_lora(pipe_concept, path1 + "|" + path2, lorapath_styles[style], condition, args)
|
| 340 |
+
|
| 341 |
+
if lorapath_styles[style] is not None and os.path.exists(lorapath_styles[style]):
|
| 342 |
+
styleL = True
|
| 343 |
+
else:
|
| 344 |
+
styleL = False
|
| 345 |
+
|
| 346 |
+
input_list = [prompt1, prompt2, prompt3, prompt4]
|
| 347 |
+
condition_list = [condition_img1, condition_img2, condition_img3, condition_img4]
|
| 348 |
+
output_list = []
|
| 349 |
+
|
| 350 |
+
width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
|
| 351 |
+
|
| 352 |
+
kwargs = {
|
| 353 |
+
'height': height,
|
| 354 |
+
'width': width,
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
for prompt, condition_img in zip(input_list, condition_list):
|
| 358 |
+
if prompt!='':
|
| 359 |
+
input_prompt = []
|
| 360 |
+
p = '{prompt}, 35mm photograph, film, professional, 4k, highly detailed.'
|
| 361 |
+
if styleL:
|
| 362 |
+
p = styles[style] + p
|
| 363 |
+
input_prompt.append([p.replace("{prompt}", prompt), p.replace("{prompt}", prompt)])
|
| 364 |
+
input_prompt.append([(styles[style] + local_prompt1, character_man.get(man)[1]), (styles[style] + local_prompt2, character_woman.get(woman)[1])])
|
| 365 |
+
|
| 366 |
+
if condition == 'Human pose' and condition_img is not None:
|
| 367 |
+
spatial_condition = get_humanpose(condition_img).resize((width, height))
|
| 368 |
+
elif condition == 'Canny Edge' and condition_img is not None:
|
| 369 |
+
spatial_condition = get_cannyedge(condition_img).resize((width, height))
|
| 370 |
+
elif condition == 'Depth' and condition_img is not None:
|
| 371 |
+
spatial_condition = get_depth(condition_img).resize((width, height))
|
| 372 |
+
else:
|
| 373 |
+
spatial_condition = None
|
| 374 |
+
|
| 375 |
+
kwargs['spatial_condition'] = spatial_condition
|
| 376 |
+
|
| 377 |
+
controller.reset()
|
| 378 |
+
image = sample_image(
|
| 379 |
+
pipe,
|
| 380 |
+
input_prompt=input_prompt,
|
| 381 |
+
concept_models=pipe_concept,
|
| 382 |
+
input_neg_prompt=[negative_prompt] * len(input_prompt),
|
| 383 |
+
generator=torch.Generator(device).manual_seed(seed),
|
| 384 |
+
controller=controller,
|
| 385 |
+
stage=1,
|
| 386 |
+
lora_list=pipe_list,
|
| 387 |
+
styleL=styleL,
|
| 388 |
+
**kwargs)
|
| 389 |
+
|
| 390 |
+
controller.reset()
|
| 391 |
+
if pipe.tokenizer("man")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
|
| 392 |
+
mask1 = predict_mask(detect_model, sam, image[0], 'man', args.segment_type, confidence=0.15,
|
| 393 |
+
threshold=0.5)
|
| 394 |
+
else:
|
| 395 |
+
mask1 = None
|
| 396 |
+
|
| 397 |
+
if pipe.tokenizer("woman")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
|
| 398 |
+
mask2 = predict_mask(detect_model, sam, image[0], 'woman', args.segment_type, confidence=0.15,
|
| 399 |
+
threshold=0.5)
|
| 400 |
+
else:
|
| 401 |
+
mask2 = None
|
| 402 |
+
|
| 403 |
+
if mask1 is None and mask2 is None:
|
| 404 |
+
output_list.append(image[1])
|
| 405 |
+
else:
|
| 406 |
+
image = sample_image(
|
| 407 |
+
pipe,
|
| 408 |
+
input_prompt=input_prompt,
|
| 409 |
+
concept_models=pipe_concept,
|
| 410 |
+
input_neg_prompt=[negative_prompt] * len(input_prompt),
|
| 411 |
+
generator=torch.Generator(device).manual_seed(seed),
|
| 412 |
+
controller=controller,
|
| 413 |
+
stage=2,
|
| 414 |
+
region_masks=[mask1, mask2],
|
| 415 |
+
lora_list=pipe_list,
|
| 416 |
+
styleL=styleL,
|
| 417 |
+
**kwargs)
|
| 418 |
+
output_list.append(image[1])
|
| 419 |
+
else:
|
| 420 |
+
output_list.append(None)
|
| 421 |
+
return output_list
|
| 422 |
+
except:
|
| 423 |
+
print("error")
|
| 424 |
+
return None, None, None, None
|
| 425 |
+
|
| 426 |
+
def get_local_value_man(input):
|
| 427 |
+
return character_man[input][0]
|
| 428 |
+
|
| 429 |
+
def get_local_value_woman(input):
|
| 430 |
+
return character_woman[input][0]
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
with gr.Blocks(css=css) as demo:
|
| 434 |
+
# description
|
| 435 |
+
gr.Markdown(title)
|
| 436 |
+
gr.Markdown(description)
|
| 437 |
+
|
| 438 |
+
with gr.Row():
|
| 439 |
+
gallery = gr.Image(label="Generated Images", height=512, width=512)
|
| 440 |
+
gallery2 = gr.Image(label="Generated Images", height=512, width=512)
|
| 441 |
+
gallery3 = gr.Image(label="Generated Images", height=512, width=512)
|
| 442 |
+
gallery4 = gr.Image(label="Generated Images", height=512, width=512)
|
| 443 |
+
usage_tips = gr.Markdown(label="Usage tips of OMG", value=tips, visible=False)
|
| 444 |
+
|
| 445 |
+
with gr.Row():
|
| 446 |
+
condition_img1 = gr.Image(label="Input condition", height=128, width=128)
|
| 447 |
+
condition_img2 = gr.Image(label="Input condition", height=128, width=128)
|
| 448 |
+
condition_img3 = gr.Image(label="Input condition", height=128, width=128)
|
| 449 |
+
condition_img4 = gr.Image(label="Input condition", height=128, width=128)
|
| 450 |
+
|
| 451 |
+
# character choose
|
| 452 |
+
with gr.Row():
|
| 453 |
+
man = gr.Dropdown(label="Character 1 selection", choices=CHARACTER_MAN_NAMES, value="Harry Potter (identifier: Harry Potter)")
|
| 454 |
+
woman = gr.Dropdown(label="Character 2 selection", choices=CHARACTER_WOMAN_NAMES, value="Hermione Granger (identifier: Hermione Granger)")
|
| 455 |
+
resolution = gr.Dropdown(label="Image Resolution (width*height)", choices=resolution_list, value="1024*1024")
|
| 456 |
+
condition = gr.Dropdown(label="Input condition type", choices=condition_list, value="None")
|
| 457 |
+
style = gr.Dropdown(label="style", choices=STYLE_NAMES, value="None")
|
| 458 |
+
|
| 459 |
+
with gr.Row():
|
| 460 |
+
local_prompt1 = gr.Textbox(label="Character1_prompt",
|
| 461 |
+
info="Describe the Character 1, this prompt should include the identifier of character 1",
|
| 462 |
+
value="Close-up photo of the Harry Potter, 35mm photograph, film, professional, 4k, highly detailed.")
|
| 463 |
+
local_prompt2 = gr.Textbox(label="Character2_prompt",
|
| 464 |
+
info="Describe the Character 2, this prompt should include the identifier of character2",
|
| 465 |
+
value="Close-up photo of the Hermione Granger, 35mm photograph, film, professional, 4k, highly detailed.")
|
| 466 |
+
|
| 467 |
+
man.change(get_local_value_man, man, local_prompt1)
|
| 468 |
+
woman.change(get_local_value_woman, woman, local_prompt2)
|
| 469 |
+
|
| 470 |
+
# prompt
|
| 471 |
+
with gr.Column():
|
| 472 |
+
prompt = gr.Textbox(label="Prompt 1",
|
| 473 |
+
info="Give a simple prompt to describe the first image content",
|
| 474 |
+
placeholder="Required",
|
| 475 |
+
value="close-up shot, photography, the cool man and beautiful woman as they accidentally discover a mysterious island while on vacation by the sea, facing the camera smiling")
|
| 476 |
+
prompt2 = gr.Textbox(label="Prompt 2",
|
| 477 |
+
info="Give a simple prompt to describe the second image content",
|
| 478 |
+
placeholder="optional",
|
| 479 |
+
value="")
|
| 480 |
+
prompt3 = gr.Textbox(label="Prompt 3",
|
| 481 |
+
info="Give a simple prompt to describe the third image content",
|
| 482 |
+
placeholder="optional",
|
| 483 |
+
value="")
|
| 484 |
+
prompt4 = gr.Textbox(label="Prompt 4",
|
| 485 |
+
info="Give a simple prompt to describe the fourth image content",
|
| 486 |
+
placeholder="optional",
|
| 487 |
+
value="")
|
| 488 |
+
|
| 489 |
+
with gr.Accordion(open=False, label="Advanced Options"):
|
| 490 |
+
seed = gr.Slider(
|
| 491 |
+
label="Seed",
|
| 492 |
+
minimum=0,
|
| 493 |
+
maximum=MAX_SEED,
|
| 494 |
+
step=1,
|
| 495 |
+
value=42,
|
| 496 |
+
)
|
| 497 |
+
negative_prompt = gr.Textbox(label="Negative Prompt",
|
| 498 |
+
placeholder="noisy, blurry, soft, deformed, ugly",
|
| 499 |
+
value="noisy, blurry, soft, deformed, ugly")
|
| 500 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
| 501 |
+
|
| 502 |
+
submit = gr.Button("Submit", variant="primary")
|
| 503 |
+
|
| 504 |
+
submit.click(
|
| 505 |
+
fn=remove_tips,
|
| 506 |
+
outputs=usage_tips,
|
| 507 |
+
).then(
|
| 508 |
+
fn=randomize_seed_fn,
|
| 509 |
+
inputs=[seed, randomize_seed],
|
| 510 |
+
outputs=seed,
|
| 511 |
+
queue=False,
|
| 512 |
+
api_name=False,
|
| 513 |
+
).then(
|
| 514 |
+
fn=generate_image,
|
| 515 |
+
inputs=[prompt, prompt2, prompt3, prompt4, negative_prompt, man, woman, resolution, local_prompt1, local_prompt2, seed, condition, condition_img1, condition_img2, condition_img3, condition_img4, style],
|
| 516 |
+
outputs=[gallery, gallery2, gallery3, gallery4]
|
| 517 |
+
)
|
| 518 |
+
demo.launch(server_name='0.0.0.0',server_port=7861, debug=True)
|
| 519 |
+
|
| 520 |
+
def parse_args():
|
| 521 |
+
parser = argparse.ArgumentParser('', add_help=False)
|
| 522 |
+
parser.add_argument('--pretrained_sdxl_model', default='./checkpoint/stable-diffusion-xl-base-1.0', type=str)
|
| 523 |
+
parser.add_argument('--openpose_checkpoint', default='./checkpoint/controlnet-openpose-sdxl-1.0', type=str)
|
| 524 |
+
parser.add_argument('--canny_checkpoint', default='./checkpoint/controlnet-canny-sdxl-1.0', type=str)
|
| 525 |
+
parser.add_argument('--depth_checkpoint', default='./checkpoint/controlnet-depth-sdxl-1.0', type=str)
|
| 526 |
+
parser.add_argument('--efficientViT_checkpoint', default='./checkpoint/sam/xl1.pt', type=str)
|
| 527 |
+
parser.add_argument('--dino_checkpoint', default='./checkpoint/GroundingDINO', type=str)
|
| 528 |
+
parser.add_argument('--sam_checkpoint', default='./checkpoint/sam/sam_vit_h_4b8939.pth', type=str)
|
| 529 |
+
parser.add_argument('--dpt_checkpoint', default='./checkpoint/dpt-hybrid-midas', type=str)
|
| 530 |
+
parser.add_argument('--pose_detector_checkpoint', default='./checkpoint/ControlNet/annotator/ckpts/body_pose_model.pth', type=str)
|
| 531 |
+
parser.add_argument('--prompt', default='Close-up photo of the cool man and beautiful woman in surprised expressions as they accidentally discover a mysterious island while on vacation by the sea, 35mm photograph, film, professional, 4k, highly detailed.', type=str)
|
| 532 |
+
parser.add_argument('--negative_prompt', default='noisy, blurry, soft, deformed, ugly', type=str)
|
| 533 |
+
parser.add_argument('--seed', default=22, type=int)
|
| 534 |
+
parser.add_argument('--suffix', default='', type=str)
|
| 535 |
+
parser.add_argument('--segment_type', default='yoloworld', help='GroundingDINO or yoloworld', type=str)
|
| 536 |
+
return parser.parse_args()
|
| 537 |
+
|
| 538 |
+
if __name__ == '__main__':
|
| 539 |
+
args = parse_args()
|
| 540 |
+
|
| 541 |
+
prompts = [args.prompt]*2
|
| 542 |
+
prompts_tmp = copy.deepcopy(prompts)
|
| 543 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
| 544 |
+
|
| 545 |
+
main(device, args.segment_type)
|
gradio_demo/app_generateOne.py
ADDED
|
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('./')
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import random
|
| 5 |
+
import numpy as np
|
| 6 |
+
from gradio_demo.character_template import character_man, lorapath_man
|
| 7 |
+
from gradio_demo.character_template import character_woman, lorapath_woman
|
| 8 |
+
from gradio_demo.character_template import styles, lorapath_styles
|
| 9 |
+
import torch
|
| 10 |
+
import os
|
| 11 |
+
from typing import Tuple, List
|
| 12 |
+
import copy
|
| 13 |
+
import argparse
|
| 14 |
+
from diffusers.utils import load_image
|
| 15 |
+
import cv2
|
| 16 |
+
from PIL import Image
|
| 17 |
+
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
|
| 18 |
+
from controlnet_aux import OpenposeDetector
|
| 19 |
+
from controlnet_aux.open_pose.body import Body
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
from inference.models import YOLOWorld
|
| 23 |
+
from src.efficientvit.models.efficientvit.sam import EfficientViTSamPredictor
|
| 24 |
+
from src.efficientvit.sam_model_zoo import create_sam_model
|
| 25 |
+
import supervision as sv
|
| 26 |
+
except:
|
| 27 |
+
print("YoloWorld can not be load")
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
from groundingdino.models import build_model
|
| 31 |
+
from groundingdino.util import box_ops
|
| 32 |
+
from groundingdino.util.slconfig import SLConfig
|
| 33 |
+
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
|
| 34 |
+
from groundingdino.util.inference import annotate, predict
|
| 35 |
+
from segment_anything import build_sam, SamPredictor
|
| 36 |
+
import groundingdino.datasets.transforms as T
|
| 37 |
+
except:
|
| 38 |
+
print("groundingdino can not be load")
|
| 39 |
+
|
| 40 |
+
from src.pipelines.lora_pipeline import LoraMultiConceptPipeline
|
| 41 |
+
from src.prompt_attention.p2p_attention import AttentionReplace
|
| 42 |
+
from diffusers import ControlNetModel, StableDiffusionXLPipeline
|
| 43 |
+
from src.pipelines.lora_pipeline import revise_regionally_controlnet_forward
|
| 44 |
+
|
| 45 |
+
CHARACTER_MAN_NAMES = list(character_man.keys())
|
| 46 |
+
CHARACTER_WOMAN_NAMES = list(character_woman.keys())
|
| 47 |
+
STYLE_NAMES = list(styles.keys())
|
| 48 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 49 |
+
|
| 50 |
+
### Description
|
| 51 |
+
title = r"""
|
| 52 |
+
<h1 align="center">OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models</h1>
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
description = r"""
|
| 56 |
+
<b>Official 🤗 Gradio demo</b> for <a href='https://github.com/' target='_blank'><b>OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models</b></a>.<br>
|
| 57 |
+
|
| 58 |
+
How to use:<br>
|
| 59 |
+
1. Select two characters.
|
| 60 |
+
2. Enter a text prompt as done in normal text-to-image models.
|
| 61 |
+
3. Click the <b>Submit</b> button to start customizing.
|
| 62 |
+
4. Enjoy the generated image😊!
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
article = r"""
|
| 66 |
+
---
|
| 67 |
+
📝 **Citation**
|
| 68 |
+
<br>
|
| 69 |
+
If our work is helpful for your research or applications, please cite us via:
|
| 70 |
+
```bibtex
|
| 71 |
+
@article{,
|
| 72 |
+
title={OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models},
|
| 73 |
+
author={},
|
| 74 |
+
journal={},
|
| 75 |
+
year={}
|
| 76 |
+
}
|
| 77 |
+
```
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
tips = r"""
|
| 81 |
+
### Usage tips of OMG
|
| 82 |
+
1. Input text prompts to describe a man and a woman
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
css = '''
|
| 86 |
+
.gradio-container {width: 85% !important}
|
| 87 |
+
'''
|
| 88 |
+
|
| 89 |
+
def sample_image(pipe,
|
| 90 |
+
input_prompt,
|
| 91 |
+
input_neg_prompt=None,
|
| 92 |
+
generator=None,
|
| 93 |
+
concept_models=None,
|
| 94 |
+
num_inference_steps=50,
|
| 95 |
+
guidance_scale=7.5,
|
| 96 |
+
controller=None,
|
| 97 |
+
stage=None,
|
| 98 |
+
region_masks=None,
|
| 99 |
+
lora_list = None,
|
| 100 |
+
styleL=None,
|
| 101 |
+
**extra_kargs
|
| 102 |
+
):
|
| 103 |
+
|
| 104 |
+
spatial_condition = extra_kargs.pop('spatial_condition')
|
| 105 |
+
if spatial_condition is not None:
|
| 106 |
+
spatial_condition_input = [spatial_condition] * len(input_prompt)
|
| 107 |
+
else:
|
| 108 |
+
spatial_condition_input = None
|
| 109 |
+
|
| 110 |
+
images = pipe(
|
| 111 |
+
prompt=input_prompt,
|
| 112 |
+
concept_models=concept_models,
|
| 113 |
+
negative_prompt=input_neg_prompt,
|
| 114 |
+
generator=generator,
|
| 115 |
+
guidance_scale=guidance_scale,
|
| 116 |
+
num_inference_steps=num_inference_steps,
|
| 117 |
+
cross_attention_kwargs={"scale": 0.8},
|
| 118 |
+
controller=controller,
|
| 119 |
+
stage=stage,
|
| 120 |
+
region_masks=region_masks,
|
| 121 |
+
lora_list=lora_list,
|
| 122 |
+
styleL=styleL,
|
| 123 |
+
image=spatial_condition_input,
|
| 124 |
+
**extra_kargs).images
|
| 125 |
+
|
| 126 |
+
return images
|
| 127 |
+
|
| 128 |
+
def load_image_yoloworld(image_source) -> Tuple[np.array, torch.Tensor]:
|
| 129 |
+
image = np.asarray(image_source)
|
| 130 |
+
return image
|
| 131 |
+
|
| 132 |
+
def load_image_dino(image_source) -> Tuple[np.array, torch.Tensor]:
|
| 133 |
+
transform = T.Compose(
|
| 134 |
+
[
|
| 135 |
+
T.RandomResize([800], max_size=1333),
|
| 136 |
+
T.ToTensor(),
|
| 137 |
+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
| 138 |
+
]
|
| 139 |
+
)
|
| 140 |
+
image = np.asarray(image_source)
|
| 141 |
+
image_transformed, _ = transform(image_source, None)
|
| 142 |
+
return image, image_transformed
|
| 143 |
+
|
| 144 |
+
def predict_mask(segmentmodel, sam, image, TEXT_PROMPT, segmentType, confidence = 0.2, threshold = 0.5):
|
| 145 |
+
if segmentType=='GroundingDINO':
|
| 146 |
+
image_source, image = load_image_dino(image)
|
| 147 |
+
boxes, logits, phrases = predict(
|
| 148 |
+
model=segmentmodel,
|
| 149 |
+
image=image,
|
| 150 |
+
caption=TEXT_PROMPT,
|
| 151 |
+
box_threshold=0.3,
|
| 152 |
+
text_threshold=0.25
|
| 153 |
+
)
|
| 154 |
+
sam.set_image(image_source)
|
| 155 |
+
H, W, _ = image_source.shape
|
| 156 |
+
boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
|
| 157 |
+
|
| 158 |
+
transformed_boxes = sam.transform.apply_boxes_torch(boxes_xyxy, image_source.shape[:2]).cuda()
|
| 159 |
+
masks, _, _ = sam.predict_torch(
|
| 160 |
+
point_coords=None,
|
| 161 |
+
point_labels=None,
|
| 162 |
+
boxes=transformed_boxes,
|
| 163 |
+
multimask_output=False,
|
| 164 |
+
)
|
| 165 |
+
masks=masks[0].squeeze(0)
|
| 166 |
+
else:
|
| 167 |
+
image_source = load_image_yoloworld(image)
|
| 168 |
+
segmentmodel.set_classes([TEXT_PROMPT])
|
| 169 |
+
results = segmentmodel.infer(image_source, confidence=confidence)
|
| 170 |
+
detections = sv.Detections.from_inference(results).with_nms(
|
| 171 |
+
class_agnostic=True, threshold=threshold
|
| 172 |
+
)
|
| 173 |
+
masks = None
|
| 174 |
+
if len(detections) != 0:
|
| 175 |
+
print(TEXT_PROMPT + " detected!")
|
| 176 |
+
sam.set_image(image_source, image_format="RGB")
|
| 177 |
+
masks, _, _ = sam.predict(box=detections.xyxy[0], multimask_output=False)
|
| 178 |
+
masks = torch.from_numpy(masks.squeeze())
|
| 179 |
+
|
| 180 |
+
return masks
|
| 181 |
+
|
| 182 |
+
def prepare_text(prompt, region_prompts):
|
| 183 |
+
'''
|
| 184 |
+
Args:
|
| 185 |
+
prompt_entity: [subject1]-*-[attribute1]-*-[Location1]|[subject2]-*-[attribute2]-*-[Location2]|[global text]
|
| 186 |
+
Returns:
|
| 187 |
+
full_prompt: subject1, attribute1 and subject2, attribute2, global text
|
| 188 |
+
context_prompt: subject1 and subject2, global text
|
| 189 |
+
entity_collection: [(subject1, attribute1), Location1]
|
| 190 |
+
'''
|
| 191 |
+
region_collection = []
|
| 192 |
+
|
| 193 |
+
regions = region_prompts.split('|')
|
| 194 |
+
|
| 195 |
+
for region in regions:
|
| 196 |
+
if region == '':
|
| 197 |
+
break
|
| 198 |
+
prompt_region, neg_prompt_region = region.split('-*-')
|
| 199 |
+
prompt_region = prompt_region.replace('[', '').replace(']', '')
|
| 200 |
+
neg_prompt_region = neg_prompt_region.replace('[', '').replace(']', '')
|
| 201 |
+
|
| 202 |
+
region_collection.append((prompt_region, neg_prompt_region))
|
| 203 |
+
return (prompt, region_collection)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def build_model_sd(pretrained_model, controlnet_path, device, prompts):
|
| 207 |
+
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16).to(device)
|
| 208 |
+
pipe = LoraMultiConceptPipeline.from_pretrained(
|
| 209 |
+
pretrained_model, controlnet=controlnet, torch_dtype=torch.float16, variant="fp16").to(device)
|
| 210 |
+
controller = AttentionReplace(prompts, 50, cross_replace_steps={"default_": 1.}, self_replace_steps=0.4, tokenizer=pipe.tokenizer, device=device, dtype=torch.float16, width=1024//32, height=1024//32)
|
| 211 |
+
revise_regionally_controlnet_forward(pipe.unet, controller)
|
| 212 |
+
pipe_concept = StableDiffusionXLPipeline.from_pretrained(pretrained_model, torch_dtype=torch.float16,
|
| 213 |
+
variant="fp16").to(device)
|
| 214 |
+
return pipe, controller, pipe_concept
|
| 215 |
+
|
| 216 |
+
def build_model_lora(pipe_concept, lora_paths, style_path, condition, args):
|
| 217 |
+
pipe_list = []
|
| 218 |
+
if condition == "Human pose":
|
| 219 |
+
controlnet = ControlNetModel.from_pretrained(args.openpose_checkpoint, torch_dtype=torch.float16).to(device)
|
| 220 |
+
pipe_concept.controlnet = controlnet
|
| 221 |
+
elif condition == "Canny Edge":
|
| 222 |
+
controlnet = ControlNetModel.from_pretrained(args.canny_checkpoint, torch_dtype=torch.float16).to(device)
|
| 223 |
+
pipe_concept.controlnet = controlnet
|
| 224 |
+
elif condition == "Depth":
|
| 225 |
+
controlnet = ControlNetModel.from_pretrained(args.depth_checkpoint, torch_dtype=torch.float16).to(device)
|
| 226 |
+
pipe_concept.controlnet = controlnet
|
| 227 |
+
|
| 228 |
+
if style_path is not None and os.path.exists(style_path):
|
| 229 |
+
pipe_concept.load_lora_weights(style_path, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
|
| 230 |
+
|
| 231 |
+
for lora_path in lora_paths.split('|'):
|
| 232 |
+
adapter_name = lora_path.split('/')[-1].split('.')[0]
|
| 233 |
+
pipe_concept.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name=adapter_name)
|
| 234 |
+
pipe_concept.enable_xformers_memory_efficient_attention()
|
| 235 |
+
pipe_list.append(adapter_name)
|
| 236 |
+
return pipe_list
|
| 237 |
+
|
| 238 |
+
def build_yolo_segment_model(sam_path, device):
|
| 239 |
+
yolo_world = YOLOWorld(model_id="yolo_world/l")
|
| 240 |
+
sam = EfficientViTSamPredictor(
|
| 241 |
+
create_sam_model(name="xl1", weight_url=sam_path).to(device).eval()
|
| 242 |
+
)
|
| 243 |
+
return yolo_world, sam
|
| 244 |
+
|
| 245 |
+
def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
|
| 246 |
+
args = SLConfig.fromfile(ckpt_config_filename)
|
| 247 |
+
model = build_model(args)
|
| 248 |
+
args.device = device
|
| 249 |
+
|
| 250 |
+
checkpoint = torch.load(os.path.join(repo_id, filename), map_location='cpu')
|
| 251 |
+
log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
|
| 252 |
+
print("Model loaded from {} \n => {}".format(filename, log))
|
| 253 |
+
_ = model.eval()
|
| 254 |
+
return model
|
| 255 |
+
|
| 256 |
+
def build_dino_segment_model(ckpt_repo_id, sam_checkpoint):
|
| 257 |
+
ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
|
| 258 |
+
ckpt_config_filename = os.path.join(ckpt_repo_id, "GroundingDINO_SwinB.cfg.py")
|
| 259 |
+
groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename)
|
| 260 |
+
sam = build_sam(checkpoint=sam_checkpoint)
|
| 261 |
+
sam.cuda()
|
| 262 |
+
sam_predictor = SamPredictor(sam)
|
| 263 |
+
return groundingdino_model, sam_predictor
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def main(device, segment_type):
|
| 268 |
+
pipe, controller, pipe_concept = build_model_sd(args.pretrained_sdxl_model, args.openpose_checkpoint, device, prompts_tmp)
|
| 269 |
+
|
| 270 |
+
if segment_type == 'GroundingDINO':
|
| 271 |
+
detect_model, sam = build_dino_segment_model(args.dino_checkpoint, args.sam_checkpoint)
|
| 272 |
+
else:
|
| 273 |
+
detect_model, sam = build_yolo_segment_model(args.efficientViT_checkpoint, device)
|
| 274 |
+
|
| 275 |
+
resolution_list = ["1440*728",
|
| 276 |
+
"1344*768",
|
| 277 |
+
"1216*832",
|
| 278 |
+
"1152*896",
|
| 279 |
+
"1024*1024",
|
| 280 |
+
"896*1152",
|
| 281 |
+
"832*1216",
|
| 282 |
+
"768*1344",
|
| 283 |
+
"728*1440"]
|
| 284 |
+
|
| 285 |
+
condition_list = ["None",
|
| 286 |
+
"Human pose",
|
| 287 |
+
"Canny Edge",
|
| 288 |
+
"Depth"]
|
| 289 |
+
|
| 290 |
+
depth_estimator = DPTForDepthEstimation.from_pretrained(args.dpt_checkpoint).to("cuda")
|
| 291 |
+
feature_extractor = DPTFeatureExtractor.from_pretrained(args.dpt_checkpoint)
|
| 292 |
+
body_model = Body(args.pose_detector_checkpoint)
|
| 293 |
+
openpose = OpenposeDetector(body_model)
|
| 294 |
+
|
| 295 |
+
def remove_tips():
|
| 296 |
+
return gr.update(visible=False)
|
| 297 |
+
|
| 298 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
| 299 |
+
if randomize_seed:
|
| 300 |
+
seed = random.randint(0, MAX_SEED)
|
| 301 |
+
return seed
|
| 302 |
+
|
| 303 |
+
def get_humanpose(img):
|
| 304 |
+
openpose_image = openpose(img)
|
| 305 |
+
return openpose_image
|
| 306 |
+
|
| 307 |
+
def get_cannyedge(image):
|
| 308 |
+
image = np.array(image)
|
| 309 |
+
image = cv2.Canny(image, 100, 200)
|
| 310 |
+
image = image[:, :, None]
|
| 311 |
+
image = np.concatenate([image, image, image], axis=2)
|
| 312 |
+
canny_image = Image.fromarray(image)
|
| 313 |
+
return canny_image
|
| 314 |
+
|
| 315 |
+
def get_depth(image):
|
| 316 |
+
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
|
| 317 |
+
with torch.no_grad(), torch.autocast("cuda"):
|
| 318 |
+
depth_map = depth_estimator(image).predicted_depth
|
| 319 |
+
|
| 320 |
+
depth_map = torch.nn.functional.interpolate(
|
| 321 |
+
depth_map.unsqueeze(1),
|
| 322 |
+
size=(1024, 1024),
|
| 323 |
+
mode="bicubic",
|
| 324 |
+
align_corners=False,
|
| 325 |
+
)
|
| 326 |
+
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
|
| 327 |
+
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
|
| 328 |
+
depth_map = (depth_map - depth_min) / (depth_max - depth_min)
|
| 329 |
+
image = torch.cat([depth_map] * 3, dim=1)
|
| 330 |
+
image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
|
| 331 |
+
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
|
| 332 |
+
return image
|
| 333 |
+
|
| 334 |
+
def generate_image(prompt1, negative_prompt, man, woman, resolution, local_prompt1, local_prompt2, seed, condition, condition_img1, style):
|
| 335 |
+
try:
|
| 336 |
+
path1 = lorapath_man[man]
|
| 337 |
+
path2 = lorapath_woman[woman]
|
| 338 |
+
pipe_concept.unload_lora_weights()
|
| 339 |
+
pipe_list = build_model_lora(pipe_concept, path1 + "|" + path2, lorapath_styles[style], condition, args)
|
| 340 |
+
|
| 341 |
+
if lorapath_styles[style] is not None and os.path.exists(lorapath_styles[style]):
|
| 342 |
+
styleL = True
|
| 343 |
+
else:
|
| 344 |
+
styleL = False
|
| 345 |
+
|
| 346 |
+
input_list = [prompt1]
|
| 347 |
+
condition_list = [condition_img1]
|
| 348 |
+
output_list = []
|
| 349 |
+
|
| 350 |
+
width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
|
| 351 |
+
|
| 352 |
+
kwargs = {
|
| 353 |
+
'height': height,
|
| 354 |
+
'width': width,
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
for prompt, condition_img in zip(input_list, condition_list):
|
| 358 |
+
if prompt!='':
|
| 359 |
+
input_prompt = []
|
| 360 |
+
p = '{prompt}, 35mm photograph, film, professional, 4k, highly detailed.'
|
| 361 |
+
if styleL:
|
| 362 |
+
p = styles[style] + p
|
| 363 |
+
input_prompt.append([p.replace("{prompt}", prompt), p.replace("{prompt}", prompt)])
|
| 364 |
+
input_prompt.append([(styles[style] + local_prompt1, character_man.get(man)[1]), (styles[style] + local_prompt2, character_woman.get(woman)[1])])
|
| 365 |
+
|
| 366 |
+
if condition == 'Human pose' and condition_img is not None:
|
| 367 |
+
spatial_condition = get_humanpose(condition_img).resize((width, height))
|
| 368 |
+
elif condition == 'Canny Edge' and condition_img is not None:
|
| 369 |
+
spatial_condition = get_cannyedge(condition_img).resize((width, height))
|
| 370 |
+
elif condition == 'Depth' and condition_img is not None:
|
| 371 |
+
spatial_condition = get_depth(condition_img).resize((width, height))
|
| 372 |
+
else:
|
| 373 |
+
spatial_condition = None
|
| 374 |
+
|
| 375 |
+
kwargs['spatial_condition'] = spatial_condition
|
| 376 |
+
controller.reset()
|
| 377 |
+
image = sample_image(
|
| 378 |
+
pipe,
|
| 379 |
+
input_prompt=input_prompt,
|
| 380 |
+
concept_models=pipe_concept,
|
| 381 |
+
input_neg_prompt=[negative_prompt] * len(input_prompt),
|
| 382 |
+
generator=torch.Generator(device).manual_seed(seed),
|
| 383 |
+
controller=controller,
|
| 384 |
+
stage=1,
|
| 385 |
+
lora_list=pipe_list,
|
| 386 |
+
styleL=styleL,
|
| 387 |
+
**kwargs)
|
| 388 |
+
|
| 389 |
+
controller.reset()
|
| 390 |
+
if pipe.tokenizer("man")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
|
| 391 |
+
mask1 = predict_mask(detect_model, sam, image[0], 'man', args.segment_type, confidence=0.15,
|
| 392 |
+
threshold=0.5)
|
| 393 |
+
else:
|
| 394 |
+
mask1 = None
|
| 395 |
+
|
| 396 |
+
if pipe.tokenizer("woman")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
|
| 397 |
+
mask2 = predict_mask(detect_model, sam, image[0], 'woman', args.segment_type, confidence=0.15,
|
| 398 |
+
threshold=0.5)
|
| 399 |
+
else:
|
| 400 |
+
mask2 = None
|
| 401 |
+
|
| 402 |
+
if mask1 is None and mask2 is None:
|
| 403 |
+
output_list.append(image[1])
|
| 404 |
+
else:
|
| 405 |
+
image = sample_image(
|
| 406 |
+
pipe,
|
| 407 |
+
input_prompt=input_prompt,
|
| 408 |
+
concept_models=pipe_concept,
|
| 409 |
+
input_neg_prompt=[negative_prompt] * len(input_prompt),
|
| 410 |
+
generator=torch.Generator(device).manual_seed(seed),
|
| 411 |
+
controller=controller,
|
| 412 |
+
stage=2,
|
| 413 |
+
region_masks=[mask1, mask2],
|
| 414 |
+
lora_list=pipe_list,
|
| 415 |
+
styleL=styleL,
|
| 416 |
+
**kwargs)
|
| 417 |
+
output_list.append(image[1])
|
| 418 |
+
else:
|
| 419 |
+
output_list.append(None)
|
| 420 |
+
output_list.append(spatial_condition)
|
| 421 |
+
return output_list
|
| 422 |
+
except:
|
| 423 |
+
print("error")
|
| 424 |
+
return
|
| 425 |
+
|
| 426 |
+
def get_local_value_man(input):
|
| 427 |
+
return character_man[input][0]
|
| 428 |
+
|
| 429 |
+
def get_local_value_woman(input):
|
| 430 |
+
return character_woman[input][0]
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
with gr.Blocks(css=css) as demo:
|
| 434 |
+
# description
|
| 435 |
+
gr.Markdown(title)
|
| 436 |
+
gr.Markdown(description)
|
| 437 |
+
|
| 438 |
+
with gr.Row():
|
| 439 |
+
gallery = gr.Image(label="Generated Images", height=512, width=512)
|
| 440 |
+
gen_condition = gr.Image(label="Spatial Condition", height=512, width=512)
|
| 441 |
+
usage_tips = gr.Markdown(label="Usage tips of OMG", value=tips, visible=False)
|
| 442 |
+
|
| 443 |
+
with gr.Row():
|
| 444 |
+
condition_img1 = gr.Image(label="Input an RGB image for condition", height=128, width=128)
|
| 445 |
+
|
| 446 |
+
# character choose
|
| 447 |
+
with gr.Row():
|
| 448 |
+
man = gr.Dropdown(label="Character 1 selection", choices=CHARACTER_MAN_NAMES, value="Harry Potter (identifier: Harry Potter)")
|
| 449 |
+
woman = gr.Dropdown(label="Character 2 selection", choices=CHARACTER_WOMAN_NAMES, value="Hermione Granger (identifier: Hermione Granger)")
|
| 450 |
+
resolution = gr.Dropdown(label="Image Resolution (width*height)", choices=resolution_list, value="1024*1024")
|
| 451 |
+
condition = gr.Dropdown(label="Input condition type", choices=condition_list, value="None")
|
| 452 |
+
style = gr.Dropdown(label="style", choices=STYLE_NAMES, value="None")
|
| 453 |
+
|
| 454 |
+
with gr.Row():
|
| 455 |
+
local_prompt1 = gr.Textbox(label="Character1_prompt",
|
| 456 |
+
info="Describe the Character 1, this prompt should include the identifier of character 1",
|
| 457 |
+
value="Close-up photo of the Harry Potter, 35mm photograph, film, professional, 4k, highly detailed.")
|
| 458 |
+
local_prompt2 = gr.Textbox(label="Character2_prompt",
|
| 459 |
+
info="Describe the Character 2, this prompt should include the identifier of character2",
|
| 460 |
+
value="Close-up photo of the Hermione Granger, 35mm photograph, film, professional, 4k, highly detailed.")
|
| 461 |
+
|
| 462 |
+
man.change(get_local_value_man, man, local_prompt1)
|
| 463 |
+
woman.change(get_local_value_woman, woman, local_prompt2)
|
| 464 |
+
|
| 465 |
+
# prompt
|
| 466 |
+
with gr.Column():
|
| 467 |
+
prompt = gr.Textbox(label="Prompt 1",
|
| 468 |
+
info="Give a simple prompt to describe the first image content",
|
| 469 |
+
placeholder="Required",
|
| 470 |
+
value="close-up shot, photography, the cool man and beautiful woman as they accidentally discover a mysterious island while on vacation by the sea, facing the camera smiling")
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
with gr.Accordion(open=False, label="Advanced Options"):
|
| 474 |
+
seed = gr.Slider(
|
| 475 |
+
label="Seed",
|
| 476 |
+
minimum=0,
|
| 477 |
+
maximum=MAX_SEED,
|
| 478 |
+
step=1,
|
| 479 |
+
value=42,
|
| 480 |
+
)
|
| 481 |
+
negative_prompt = gr.Textbox(label="Negative Prompt",
|
| 482 |
+
placeholder="noisy, blurry, soft, deformed, ugly",
|
| 483 |
+
value="noisy, blurry, soft, deformed, ugly")
|
| 484 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
| 485 |
+
|
| 486 |
+
submit = gr.Button("Submit", variant="primary")
|
| 487 |
+
|
| 488 |
+
submit.click(
|
| 489 |
+
fn=remove_tips,
|
| 490 |
+
outputs=usage_tips,
|
| 491 |
+
).then(
|
| 492 |
+
fn=randomize_seed_fn,
|
| 493 |
+
inputs=[seed, randomize_seed],
|
| 494 |
+
outputs=seed,
|
| 495 |
+
queue=False,
|
| 496 |
+
api_name=False,
|
| 497 |
+
).then(
|
| 498 |
+
fn=generate_image,
|
| 499 |
+
inputs=[prompt, negative_prompt, man, woman, resolution, local_prompt1, local_prompt2, seed, condition, condition_img1, style],
|
| 500 |
+
outputs=[gallery, gen_condition]
|
| 501 |
+
)
|
| 502 |
+
demo.launch(server_name='0.0.0.0',server_port=7861, debug=True)
|
| 503 |
+
|
| 504 |
+
def parse_args():
|
| 505 |
+
parser = argparse.ArgumentParser('', add_help=False)
|
| 506 |
+
parser.add_argument('--pretrained_sdxl_model', default='./checkpoint/stable-diffusion-xl-base-1.0', type=str)
|
| 507 |
+
parser.add_argument('--openpose_checkpoint', default='./checkpoint/controlnet-openpose-sdxl-1.0', type=str)
|
| 508 |
+
parser.add_argument('--canny_checkpoint', default='./checkpoint/controlnet-canny-sdxl-1.0', type=str)
|
| 509 |
+
parser.add_argument('--depth_checkpoint', default='./checkpoint/controlnet-depth-sdxl-1.0', type=str)
|
| 510 |
+
parser.add_argument('--efficientViT_checkpoint', default='./checkpoint/sam/xl1.pt', type=str)
|
| 511 |
+
parser.add_argument('--dino_checkpoint', default='./checkpoint/GroundingDINO', type=str)
|
| 512 |
+
parser.add_argument('--sam_checkpoint', default='./checkpoint/sam/sam_vit_h_4b8939.pth', type=str)
|
| 513 |
+
parser.add_argument('--dpt_checkpoint', default='./checkpoint/dpt-hybrid-midas', type=str)
|
| 514 |
+
parser.add_argument('--pose_detector_checkpoint', default='./checkpoint/ControlNet/annotator/ckpts/body_pose_model.pth', type=str)
|
| 515 |
+
parser.add_argument('--prompt', default='Close-up photo of the cool man and beautiful woman in surprised expressions as they accidentally discover a mysterious island while on vacation by the sea, 35mm photograph, film, professional, 4k, highly detailed.', type=str)
|
| 516 |
+
parser.add_argument('--negative_prompt', default='noisy, blurry, soft, deformed, ugly', type=str)
|
| 517 |
+
parser.add_argument('--seed', default=22, type=int)
|
| 518 |
+
parser.add_argument('--suffix', default='', type=str)
|
| 519 |
+
parser.add_argument('--segment_type', default='yoloworld', help='GroundingDINO or yoloworld', type=str)
|
| 520 |
+
return parser.parse_args()
|
| 521 |
+
|
| 522 |
+
if __name__ == '__main__':
|
| 523 |
+
args = parse_args()
|
| 524 |
+
|
| 525 |
+
prompts = [args.prompt]*2
|
| 526 |
+
prompts_tmp = copy.deepcopy(prompts)
|
| 527 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
| 528 |
+
|
| 529 |
+
main(device, args.segment_type)
|
gradio_demo/character_template.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
character_list_man = [
|
| 2 |
+
{
|
| 3 |
+
"name": "Harry Potter (identifier: Harry Potter)",
|
| 4 |
+
"prompt": "Close-up photo of the Harry Potter, 35mm photograph, film, professional, 4k, highly detailed.",
|
| 5 |
+
"negative_prompt": "noisy, blurry, soft, deformed, ugly",
|
| 6 |
+
"path": "./checkpoint/lora/Harry_Potter.safetensors",
|
| 7 |
+
},
|
| 8 |
+
{
|
| 9 |
+
"name": "Chris Evans (identifier: Chris Evans)",
|
| 10 |
+
"prompt": "Close-up photo of the Chris Evans, 35mm photograph, film, professional, 4k, highly detailed.",
|
| 11 |
+
"negative_prompt": "noisy, blurry, soft, deformed, ugly",
|
| 12 |
+
"path": "./checkpoint/lora/chris-evans.safetensors",
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"name": "Jordan Torres (identifier: jordan_torres)",
|
| 16 |
+
"prompt": "Close-up photo of the jordan_torres man, 35mm photograph, film, professional, 4k, highly detailed.",
|
| 17 |
+
"negative_prompt": "noisy, blurry, soft, deformed, ugly",
|
| 18 |
+
"path": "./checkpoint/lora/jordan_torres_v2_xl.safetensors",
|
| 19 |
+
},
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
character_list_woman = [
|
| 23 |
+
{
|
| 24 |
+
"name": "Hermione Granger (identifier: Hermione Granger)",
|
| 25 |
+
"prompt": "Close-up photo of the Hermione Granger, 35mm photograph, film, professional, 4k, highly detailed.",
|
| 26 |
+
"negative_prompt": "noisy, blurry, soft, deformed, ugly",
|
| 27 |
+
"path": "./checkpoint/lora/Hermione_Granger.safetensors",
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"name": "Taylor Swift (identifier: TaylorSwift)",
|
| 31 |
+
"prompt": "Close-up photo of the TaylorSwift, 35mm photograph, film, professional, 4k, highly detailed.",
|
| 32 |
+
"negative_prompt": "noisy, blurry, soft, deformed, ugly",
|
| 33 |
+
"path": "./checkpoint/lora/TaylorSwiftSDXL.safetensors",
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"name": "Keira Knightley (identifier: ohwx woman)",
|
| 37 |
+
"prompt": "Close-up photo of the ohwx woman, 35mm photograph, film, professional, 4k, highly detailed.",
|
| 38 |
+
"negative_prompt": "noisy, blurry, soft, deformed, ugly",
|
| 39 |
+
"path": "./checkpoint/lora/keira_lora_sdxl_v1-000008.safetensors",
|
| 40 |
+
},
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
style_list = [
|
| 44 |
+
{
|
| 45 |
+
"name": "None",
|
| 46 |
+
"prompt": "",
|
| 47 |
+
"path": "",
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"name": "Anime sketch style",
|
| 51 |
+
"prompt": "Pencil_Sketch:1.2, messy lines, greyscale, traditional media, sketch, ",
|
| 52 |
+
"path": "./checkpoint/style/Anime_Sketch_SDXL.safetensors",
|
| 53 |
+
}
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
character_man = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in character_list_man}
|
| 57 |
+
character_woman = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in character_list_woman}
|
| 58 |
+
styles = {k["name"]: (k["prompt"]) for k in style_list}
|
| 59 |
+
|
| 60 |
+
lorapath_man = {k["name"]: (k["path"]) for k in character_list_man}
|
| 61 |
+
lorapath_woman = {k["name"]: (k["path"]) for k in character_list_woman}
|
| 62 |
+
lorapath_styles = {k["name"]: (k["path"]) for k in style_list}
|
src/efficientvit/__init__.py
ADDED
|
File without changes
|
src/efficientvit/apps/__init__.py
ADDED
|
File without changes
|
src/efficientvit/apps/data_provider/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
from .augment import *
|
| 6 |
+
from .base import *
|
| 7 |
+
from .random_resolution import *
|
src/efficientvit/apps/data_provider/augment/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
from .bbox import *
|
| 6 |
+
from .color_aug import *
|
src/efficientvit/apps/data_provider/augment/bbox.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
__all__ = ["rand_bbox"]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def rand_bbox(
|
| 11 |
+
h: int,
|
| 12 |
+
w: int,
|
| 13 |
+
lam: float,
|
| 14 |
+
rand_func: callable = np.random.uniform,
|
| 15 |
+
) -> tuple[int, int, int, int]:
|
| 16 |
+
"""randomly sample bbox, used in cutmix"""
|
| 17 |
+
cut_rat = np.sqrt(1.0 - lam)
|
| 18 |
+
cut_w = w * cut_rat
|
| 19 |
+
cut_h = h * cut_rat
|
| 20 |
+
|
| 21 |
+
# uniform
|
| 22 |
+
cx = rand_func(0, w)
|
| 23 |
+
cy = rand_func(0, h)
|
| 24 |
+
|
| 25 |
+
bbx1 = int(np.clip(cx - cut_w / 2, 0, w))
|
| 26 |
+
bby1 = int(np.clip(cy - cut_h / 2, 0, h))
|
| 27 |
+
bbx2 = int(np.clip(cx + cut_w / 2, 0, w))
|
| 28 |
+
bby2 = int(np.clip(cy + cut_h / 2, 0, h))
|
| 29 |
+
|
| 30 |
+
return bbx1, bby1, bbx2, bby2
|
src/efficientvit/apps/data_provider/augment/color_aug.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torchvision.transforms as transforms
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from timm.data.auto_augment import rand_augment_transform
|
| 9 |
+
|
| 10 |
+
__all__ = ["ColorAug", "RandAug"]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ImageAug:
|
| 14 |
+
def aug_image(self, image: Image.Image) -> Image.Image:
|
| 15 |
+
raise NotImplementedError
|
| 16 |
+
|
| 17 |
+
def __call__(
|
| 18 |
+
self, feed_dict: dict or np.ndarray or Image.Image
|
| 19 |
+
) -> dict or np.ndarray or Image.Image:
|
| 20 |
+
if isinstance(feed_dict, dict):
|
| 21 |
+
output_dict = feed_dict
|
| 22 |
+
image = feed_dict[self.key]
|
| 23 |
+
else:
|
| 24 |
+
output_dict = None
|
| 25 |
+
image = feed_dict
|
| 26 |
+
is_ndarray = isinstance(image, np.ndarray)
|
| 27 |
+
if is_ndarray:
|
| 28 |
+
image = Image.fromarray(image)
|
| 29 |
+
|
| 30 |
+
image = self.aug_image(image)
|
| 31 |
+
|
| 32 |
+
if is_ndarray:
|
| 33 |
+
image = np.array(image)
|
| 34 |
+
|
| 35 |
+
if output_dict is None:
|
| 36 |
+
return image
|
| 37 |
+
else:
|
| 38 |
+
output_dict[self.key] = image
|
| 39 |
+
return output_dict
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ColorAug(transforms.ColorJitter, ImageAug):
|
| 43 |
+
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, key="data"):
|
| 44 |
+
super().__init__(
|
| 45 |
+
brightness=brightness,
|
| 46 |
+
contrast=contrast,
|
| 47 |
+
saturation=saturation,
|
| 48 |
+
hue=hue,
|
| 49 |
+
)
|
| 50 |
+
self.key = key
|
| 51 |
+
|
| 52 |
+
def aug_image(self, image: Image.Image) -> Image.Image:
|
| 53 |
+
return transforms.ColorJitter.forward(self, image)
|
| 54 |
+
|
| 55 |
+
def forward(
|
| 56 |
+
self, feed_dict: dict or np.ndarray or Image.Image
|
| 57 |
+
) -> dict or np.ndarray or Image.Image:
|
| 58 |
+
return ImageAug.__call__(self, feed_dict)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class RandAug(ImageAug):
|
| 62 |
+
def __init__(
|
| 63 |
+
self, config: dict[str, any], mean: tuple[float, float, float], key="data"
|
| 64 |
+
):
|
| 65 |
+
n = config.get("n", 2)
|
| 66 |
+
m = config.get("m", 9)
|
| 67 |
+
mstd = config.get("mstd", 1.0)
|
| 68 |
+
inc = config.get("inc", 1)
|
| 69 |
+
tpct = config.get("tpct", 0.45)
|
| 70 |
+
config_str = f"rand-n{n}-m{m}-mstd{mstd}-inc{inc}"
|
| 71 |
+
|
| 72 |
+
aa_params = dict(
|
| 73 |
+
translate_pct=tpct,
|
| 74 |
+
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
|
| 75 |
+
interpolation=Image.BICUBIC,
|
| 76 |
+
)
|
| 77 |
+
self.aug_op = rand_augment_transform(config_str, aa_params)
|
| 78 |
+
self.key = key
|
| 79 |
+
|
| 80 |
+
def aug_image(self, image: Image.Image) -> Image.Image:
|
| 81 |
+
return self.aug_op(image)
|
| 82 |
+
|
| 83 |
+
def __repr__(self):
|
| 84 |
+
return self.aug_op.__repr__()
|
src/efficientvit/apps/data_provider/base.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
import copy
|
| 6 |
+
import warnings
|
| 7 |
+
|
| 8 |
+
import torch.utils.data
|
| 9 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 10 |
+
|
| 11 |
+
from src.efficientvit.apps.data_provider.random_resolution import RRSController
|
| 12 |
+
from src.efficientvit.models.utils import val2tuple
|
| 13 |
+
|
| 14 |
+
__all__ = ["parse_image_size", "random_drop_data", "DataProvider"]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def parse_image_size(size: int or str) -> tuple[int, int]:
|
| 18 |
+
if isinstance(size, str):
|
| 19 |
+
size = [int(val) for val in size.split("-")]
|
| 20 |
+
return size[0], size[1]
|
| 21 |
+
else:
|
| 22 |
+
return val2tuple(size, 2)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def random_drop_data(dataset, drop_size: int, seed: int, keys=("samples",)):
|
| 26 |
+
g = torch.Generator()
|
| 27 |
+
g.manual_seed(seed) # set random seed before sampling validation set
|
| 28 |
+
rand_indexes = torch.randperm(len(dataset), generator=g).tolist()
|
| 29 |
+
|
| 30 |
+
dropped_indexes = rand_indexes[:drop_size]
|
| 31 |
+
remaining_indexes = rand_indexes[drop_size:]
|
| 32 |
+
|
| 33 |
+
dropped_dataset = copy.deepcopy(dataset)
|
| 34 |
+
for key in keys:
|
| 35 |
+
setattr(
|
| 36 |
+
dropped_dataset,
|
| 37 |
+
key,
|
| 38 |
+
[getattr(dropped_dataset, key)[idx] for idx in dropped_indexes],
|
| 39 |
+
)
|
| 40 |
+
setattr(dataset, key, [getattr(dataset, key)[idx] for idx in remaining_indexes])
|
| 41 |
+
return dataset, dropped_dataset
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class DataProvider:
|
| 45 |
+
data_keys = ("samples",)
|
| 46 |
+
mean_std = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}
|
| 47 |
+
SUB_SEED = 937162211 # random seed for sampling subset
|
| 48 |
+
VALID_SEED = 2147483647 # random seed for the validation set
|
| 49 |
+
|
| 50 |
+
name: str
|
| 51 |
+
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
train_batch_size: int,
|
| 55 |
+
test_batch_size: int or None,
|
| 56 |
+
valid_size: int or float or None,
|
| 57 |
+
n_worker: int,
|
| 58 |
+
image_size: int or list[int] or str or list[str],
|
| 59 |
+
num_replicas: int or None = None,
|
| 60 |
+
rank: int or None = None,
|
| 61 |
+
train_ratio: float or None = None,
|
| 62 |
+
drop_last: bool = False,
|
| 63 |
+
):
|
| 64 |
+
warnings.filterwarnings("ignore")
|
| 65 |
+
super().__init__()
|
| 66 |
+
|
| 67 |
+
# batch_size & valid_size
|
| 68 |
+
self.train_batch_size = train_batch_size
|
| 69 |
+
self.test_batch_size = test_batch_size or self.train_batch_size
|
| 70 |
+
self.valid_size = valid_size
|
| 71 |
+
|
| 72 |
+
# image size
|
| 73 |
+
if isinstance(image_size, list):
|
| 74 |
+
self.image_size = [parse_image_size(size) for size in image_size]
|
| 75 |
+
self.image_size.sort() # e.g., 160 -> 224
|
| 76 |
+
RRSController.IMAGE_SIZE_LIST = copy.deepcopy(self.image_size)
|
| 77 |
+
self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size[-1]
|
| 78 |
+
else:
|
| 79 |
+
self.image_size = parse_image_size(image_size)
|
| 80 |
+
RRSController.IMAGE_SIZE_LIST = [self.image_size]
|
| 81 |
+
self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size
|
| 82 |
+
|
| 83 |
+
# distributed configs
|
| 84 |
+
self.num_replicas = num_replicas
|
| 85 |
+
self.rank = rank
|
| 86 |
+
|
| 87 |
+
# build datasets
|
| 88 |
+
train_dataset, val_dataset, test_dataset = self.build_datasets()
|
| 89 |
+
|
| 90 |
+
if train_ratio is not None and train_ratio < 1.0:
|
| 91 |
+
assert 0 < train_ratio < 1
|
| 92 |
+
_, train_dataset = random_drop_data(
|
| 93 |
+
train_dataset,
|
| 94 |
+
int(train_ratio * len(train_dataset)),
|
| 95 |
+
self.SUB_SEED,
|
| 96 |
+
self.data_keys,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# build data loader
|
| 100 |
+
self.train = self.build_dataloader(
|
| 101 |
+
train_dataset, train_batch_size, n_worker, drop_last=drop_last, train=True
|
| 102 |
+
)
|
| 103 |
+
self.valid = self.build_dataloader(
|
| 104 |
+
val_dataset, test_batch_size, n_worker, drop_last=False, train=False
|
| 105 |
+
)
|
| 106 |
+
self.test = self.build_dataloader(
|
| 107 |
+
test_dataset, test_batch_size, n_worker, drop_last=False, train=False
|
| 108 |
+
)
|
| 109 |
+
if self.valid is None:
|
| 110 |
+
self.valid = self.test
|
| 111 |
+
self.sub_train = None
|
| 112 |
+
|
| 113 |
+
@property
|
| 114 |
+
def data_shape(self) -> tuple[int, ...]:
|
| 115 |
+
return 3, self.active_image_size[0], self.active_image_size[1]
|
| 116 |
+
|
| 117 |
+
def build_valid_transform(self, image_size: tuple[int, int] or None = None) -> any:
|
| 118 |
+
raise NotImplementedError
|
| 119 |
+
|
| 120 |
+
def build_train_transform(self, image_size: tuple[int, int] or None = None) -> any:
|
| 121 |
+
raise NotImplementedError
|
| 122 |
+
|
| 123 |
+
def build_datasets(self) -> tuple[any, any, any]:
|
| 124 |
+
raise NotImplementedError
|
| 125 |
+
|
| 126 |
+
def build_dataloader(
|
| 127 |
+
self,
|
| 128 |
+
dataset: any or None,
|
| 129 |
+
batch_size: int,
|
| 130 |
+
n_worker: int,
|
| 131 |
+
drop_last: bool,
|
| 132 |
+
train: bool,
|
| 133 |
+
):
|
| 134 |
+
if dataset is None:
|
| 135 |
+
return None
|
| 136 |
+
if isinstance(self.image_size, list) and train:
|
| 137 |
+
from efficientvit.apps.data_provider.random_resolution._data_loader import \
|
| 138 |
+
RRSDataLoader
|
| 139 |
+
|
| 140 |
+
dataloader_class = RRSDataLoader
|
| 141 |
+
else:
|
| 142 |
+
dataloader_class = torch.utils.data.DataLoader
|
| 143 |
+
if self.num_replicas is None:
|
| 144 |
+
return dataloader_class(
|
| 145 |
+
dataset=dataset,
|
| 146 |
+
batch_size=batch_size,
|
| 147 |
+
shuffle=True,
|
| 148 |
+
num_workers=n_worker,
|
| 149 |
+
pin_memory=True,
|
| 150 |
+
drop_last=drop_last,
|
| 151 |
+
)
|
| 152 |
+
else:
|
| 153 |
+
sampler = DistributedSampler(dataset, self.num_replicas, self.rank)
|
| 154 |
+
return dataloader_class(
|
| 155 |
+
dataset=dataset,
|
| 156 |
+
batch_size=batch_size,
|
| 157 |
+
sampler=sampler,
|
| 158 |
+
num_workers=n_worker,
|
| 159 |
+
pin_memory=True,
|
| 160 |
+
drop_last=drop_last,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
def set_epoch(self, epoch: int) -> None:
|
| 164 |
+
RRSController.set_epoch(epoch, len(self.train))
|
| 165 |
+
if isinstance(self.train.sampler, DistributedSampler):
|
| 166 |
+
self.train.sampler.set_epoch(epoch)
|
| 167 |
+
|
| 168 |
+
def assign_active_image_size(self, new_size: int or tuple[int, int]) -> None:
|
| 169 |
+
self.active_image_size = val2tuple(new_size, 2)
|
| 170 |
+
new_transform = self.build_valid_transform(self.active_image_size)
|
| 171 |
+
# change the transform of the valid and test set
|
| 172 |
+
self.valid.dataset.transform = self.test.dataset.transform = new_transform
|
| 173 |
+
|
| 174 |
+
def sample_val_dataset(self, train_dataset, valid_transform) -> tuple[any, any]:
|
| 175 |
+
if self.valid_size is not None:
|
| 176 |
+
if 0 < self.valid_size < 1:
|
| 177 |
+
valid_size = int(self.valid_size * len(train_dataset))
|
| 178 |
+
else:
|
| 179 |
+
assert self.valid_size >= 1
|
| 180 |
+
valid_size = int(self.valid_size)
|
| 181 |
+
train_dataset, val_dataset = random_drop_data(
|
| 182 |
+
train_dataset,
|
| 183 |
+
valid_size,
|
| 184 |
+
self.VALID_SEED,
|
| 185 |
+
self.data_keys,
|
| 186 |
+
)
|
| 187 |
+
val_dataset.transform = valid_transform
|
| 188 |
+
else:
|
| 189 |
+
val_dataset = None
|
| 190 |
+
return train_dataset, val_dataset
|
| 191 |
+
|
| 192 |
+
def build_sub_train_loader(self, n_samples: int, batch_size: int) -> any:
|
| 193 |
+
# used for resetting BN running statistics
|
| 194 |
+
if self.sub_train is None:
|
| 195 |
+
self.sub_train = {}
|
| 196 |
+
if self.active_image_size in self.sub_train:
|
| 197 |
+
return self.sub_train[self.active_image_size]
|
| 198 |
+
|
| 199 |
+
# construct dataset and dataloader
|
| 200 |
+
train_dataset = copy.deepcopy(self.train.dataset)
|
| 201 |
+
if n_samples < len(train_dataset):
|
| 202 |
+
_, train_dataset = random_drop_data(
|
| 203 |
+
train_dataset,
|
| 204 |
+
n_samples,
|
| 205 |
+
self.SUB_SEED,
|
| 206 |
+
self.data_keys,
|
| 207 |
+
)
|
| 208 |
+
RRSController.ACTIVE_SIZE = self.active_image_size
|
| 209 |
+
train_dataset.transform = self.build_train_transform(
|
| 210 |
+
image_size=self.active_image_size
|
| 211 |
+
)
|
| 212 |
+
data_loader = self.build_dataloader(
|
| 213 |
+
train_dataset, batch_size, self.train.num_workers, True, False
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# pre-fetch data
|
| 217 |
+
self.sub_train[self.active_image_size] = [
|
| 218 |
+
data
|
| 219 |
+
for data in data_loader
|
| 220 |
+
for _ in range(max(1, n_samples // len(train_dataset)))
|
| 221 |
+
]
|
| 222 |
+
|
| 223 |
+
return self.sub_train[self.active_image_size]
|
src/efficientvit/apps/data_provider/random_resolution/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Random resolution data loader compatible with multi-processing and distributed training.
|
| 2 |
+
|
| 3 |
+
Replace Pytorch's DataLoader with RRSDataLoader to support random resolution
|
| 4 |
+
at the training time, resolution sampling is controlled by RRSController
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .controller import *
|
src/efficientvit/apps/data_provider/random_resolution/_data_loader.py
ADDED
|
@@ -0,0 +1,1598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r"""This file is based on torch/utils/data/data_loader.py
|
| 2 |
+
|
| 3 |
+
Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter
|
| 4 |
+
|
| 5 |
+
To support these two classes, in `./_utils` we define many utility methods and
|
| 6 |
+
functions to be run in multiprocessing. E.g., the data loading worker loop is
|
| 7 |
+
in `./_utils/worker.py`.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import functools
|
| 11 |
+
import itertools
|
| 12 |
+
import logging
|
| 13 |
+
import multiprocessing as python_multiprocessing
|
| 14 |
+
import os
|
| 15 |
+
import queue
|
| 16 |
+
import threading
|
| 17 |
+
import warnings
|
| 18 |
+
from typing import (Any, Callable, Generic, Iterable, List, Optional, Sequence,
|
| 19 |
+
TypeVar, Union)
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.distributed as dist
|
| 23 |
+
import torch.multiprocessing as multiprocessing
|
| 24 |
+
import torch.utils.data.graph_settings
|
| 25 |
+
from torch._utils import ExceptionWrapper
|
| 26 |
+
from torch.utils.data import (BatchSampler, Dataset, IterableDataset,
|
| 27 |
+
IterDataPipe, MapDataPipe, RandomSampler,
|
| 28 |
+
Sampler, SequentialSampler, _utils)
|
| 29 |
+
from torch.utils.data.datapipes.datapipe import (
|
| 30 |
+
_IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper)
|
| 31 |
+
|
| 32 |
+
from ._data_worker import _worker_loop
|
| 33 |
+
|
| 34 |
+
__all__ = ["RRSDataLoader"]
|
| 35 |
+
|
| 36 |
+
T_co = TypeVar("T_co", covariant=True)
|
| 37 |
+
T = TypeVar("T")
|
| 38 |
+
_worker_init_fn_t = Callable[[int], None]
|
| 39 |
+
|
| 40 |
+
# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
|
| 41 |
+
# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
|
| 42 |
+
# See https://github.com/python/mypy/issues/3737.
|
| 43 |
+
_collate_fn_t = Callable[[List[T]], Any]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# These functions used to be defined in this file. However, it was moved to
|
| 47 |
+
# _utils/collate.py. Although it is rather hard to access this from user land
|
| 48 |
+
# (one has to explicitly directly `import torch.utils.data.dataloader`), there
|
| 49 |
+
# probably is user code out there using it. This aliasing maintains BC in this
|
| 50 |
+
# aspect.
|
| 51 |
+
default_collate: _collate_fn_t = _utils.collate.default_collate
|
| 52 |
+
default_convert = _utils.collate.default_convert
|
| 53 |
+
|
| 54 |
+
get_worker_info = _utils.worker.get_worker_info
|
| 55 |
+
|
| 56 |
+
logger = logging.getLogger(__name__)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class _DatasetKind:
|
| 60 |
+
Map = 0
|
| 61 |
+
Iterable = 1
|
| 62 |
+
|
| 63 |
+
@staticmethod
|
| 64 |
+
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
|
| 65 |
+
if kind == _DatasetKind.Map:
|
| 66 |
+
return _utils.fetch._MapDatasetFetcher(
|
| 67 |
+
dataset, auto_collation, collate_fn, drop_last
|
| 68 |
+
)
|
| 69 |
+
else:
|
| 70 |
+
return _utils.fetch._IterableDatasetFetcher(
|
| 71 |
+
dataset, auto_collation, collate_fn, drop_last
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class _InfiniteConstantSampler(Sampler):
|
| 76 |
+
r"""Analogous to ``itertools.repeat(None, None)``.
|
| 77 |
+
Used as sampler for :class:`~torch.utils.data.IterableDataset`.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
data_source (Dataset): dataset to sample from
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def __init__(self):
|
| 84 |
+
super().__init__(None)
|
| 85 |
+
|
| 86 |
+
def __iter__(self):
|
| 87 |
+
while True:
|
| 88 |
+
yield None
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _get_distributed_settings():
|
| 92 |
+
if dist.is_available() and dist.is_initialized():
|
| 93 |
+
return dist.get_world_size(), dist.get_rank()
|
| 94 |
+
else:
|
| 95 |
+
return 1, 0
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id):
|
| 99 |
+
global_worker_id = worker_id
|
| 100 |
+
info = torch.utils.data.get_worker_info()
|
| 101 |
+
assert info is not None
|
| 102 |
+
total_workers = info.num_workers
|
| 103 |
+
datapipe = info.dataset
|
| 104 |
+
assert isinstance(datapipe, (IterDataPipe, MapDataPipe))
|
| 105 |
+
# To distribute elements across distributed process evenly, we should shard data on distributed
|
| 106 |
+
# processes first then shard on worker processes
|
| 107 |
+
total_workers *= world_size
|
| 108 |
+
global_worker_id = global_worker_id * world_size + rank_id
|
| 109 |
+
# For BC, use default SHARDING_PRIORITIES
|
| 110 |
+
torch.utils.data.graph_settings.apply_sharding(
|
| 111 |
+
datapipe, total_workers, global_worker_id
|
| 112 |
+
)
|
| 113 |
+
if worker_init_fn is not None:
|
| 114 |
+
worker_init_fn(worker_id)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _share_dist_seed(generator, pg):
|
| 118 |
+
_shared_seed = torch.empty((), dtype=torch.int64).random_(generator=generator)
|
| 119 |
+
if isinstance(pg, dist.ProcessGroup):
|
| 120 |
+
dist.broadcast(_shared_seed, src=0, group=pg)
|
| 121 |
+
return _shared_seed.item()
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class RRSDataLoader(Generic[T_co]):
|
| 125 |
+
r"""
|
| 126 |
+
Data loader. Combines a dataset and a sampler, and provides an iterable over
|
| 127 |
+
the given dataset.
|
| 128 |
+
|
| 129 |
+
The :class:`~torch.utils.data.DataLoader` supports both map-style and
|
| 130 |
+
iterable-style datasets with single- or multi-process loading, customizing
|
| 131 |
+
loading order and optional automatic batching (collation) and memory pinning.
|
| 132 |
+
|
| 133 |
+
See :py:mod:`torch.utils.data` documentation page for more details.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
dataset (Dataset): dataset from which to load the data.
|
| 137 |
+
batch_size (int, optional): how many samples per batch to load
|
| 138 |
+
(default: ``1``).
|
| 139 |
+
shuffle (bool, optional): set to ``True`` to have the data reshuffled
|
| 140 |
+
at every epoch (default: ``False``).
|
| 141 |
+
sampler (Sampler or Iterable, optional): defines the strategy to draw
|
| 142 |
+
samples from the dataset. Can be any ``Iterable`` with ``__len__``
|
| 143 |
+
implemented. If specified, :attr:`shuffle` must not be specified.
|
| 144 |
+
batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
|
| 145 |
+
returns a batch of indices at a time. Mutually exclusive with
|
| 146 |
+
:attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
|
| 147 |
+
and :attr:`drop_last`.
|
| 148 |
+
num_workers (int, optional): how many subprocesses to use for data
|
| 149 |
+
loading. ``0`` means that the data will be loaded in the main process.
|
| 150 |
+
(default: ``0``)
|
| 151 |
+
collate_fn (Callable, optional): merges a list of samples to form a
|
| 152 |
+
mini-batch of Tensor(s). Used when using batched loading from a
|
| 153 |
+
map-style dataset.
|
| 154 |
+
pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
|
| 155 |
+
into device/CUDA pinned memory before returning them. If your data elements
|
| 156 |
+
are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
|
| 157 |
+
see the example below.
|
| 158 |
+
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
|
| 159 |
+
if the dataset size is not divisible by the batch size. If ``False`` and
|
| 160 |
+
the size of dataset is not divisible by the batch size, then the last batch
|
| 161 |
+
will be smaller. (default: ``False``)
|
| 162 |
+
timeout (numeric, optional): if positive, the timeout value for collecting a batch
|
| 163 |
+
from workers. Should always be non-negative. (default: ``0``)
|
| 164 |
+
worker_init_fn (Callable, optional): If not ``None``, this will be called on each
|
| 165 |
+
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
|
| 166 |
+
input, after seeding and before data loading. (default: ``None``)
|
| 167 |
+
generator (torch.Generator, optional): If not ``None``, this RNG will be used
|
| 168 |
+
by RandomSampler to generate random indexes and multiprocessing to generate
|
| 169 |
+
`base_seed` for workers. (default: ``None``)
|
| 170 |
+
prefetch_factor (int, optional, keyword-only arg): Number of batches loaded
|
| 171 |
+
in advance by each worker. ``2`` means there will be a total of
|
| 172 |
+
2 * num_workers batches prefetched across all workers. (default value depends
|
| 173 |
+
on the set value for num_workers. If value of num_workers=0 default is ``None``.
|
| 174 |
+
Otherwise if value of num_workers>0 default is ``2``).
|
| 175 |
+
persistent_workers (bool, optional): If ``True``, the data loader will not shutdown
|
| 176 |
+
the worker processes after a dataset has been consumed once. This allows to
|
| 177 |
+
maintain the workers `Dataset` instances alive. (default: ``False``)
|
| 178 |
+
pin_memory_device (str, optional): the data loader will copy Tensors
|
| 179 |
+
into device pinned memory before returning them if pin_memory is set to true.
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
.. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
|
| 183 |
+
cannot be an unpicklable object, e.g., a lambda function. See
|
| 184 |
+
:ref:`multiprocessing-best-practices` on more details related
|
| 185 |
+
to multiprocessing in PyTorch.
|
| 186 |
+
|
| 187 |
+
.. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
|
| 188 |
+
When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
|
| 189 |
+
it instead returns an estimate based on ``len(dataset) / batch_size``, with proper
|
| 190 |
+
rounding depending on :attr:`drop_last`, regardless of multi-process loading
|
| 191 |
+
configurations. This represents the best guess PyTorch can make because PyTorch
|
| 192 |
+
trusts user :attr:`dataset` code in correctly handling multi-process
|
| 193 |
+
loading to avoid duplicate data.
|
| 194 |
+
|
| 195 |
+
However, if sharding results in multiple workers having incomplete last batches,
|
| 196 |
+
this estimate can still be inaccurate, because (1) an otherwise complete batch can
|
| 197 |
+
be broken into multiple ones and (2) more than one batch worth of samples can be
|
| 198 |
+
dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such
|
| 199 |
+
cases in general.
|
| 200 |
+
|
| 201 |
+
See `Dataset Types`_ for more details on these two types of datasets and how
|
| 202 |
+
:class:`~torch.utils.data.IterableDataset` interacts with
|
| 203 |
+
`Multi-process data loading`_.
|
| 204 |
+
|
| 205 |
+
.. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and
|
| 206 |
+
:ref:`data-loading-randomness` notes for random seed related questions.
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
dataset: Dataset[T_co]
|
| 210 |
+
batch_size: Optional[int]
|
| 211 |
+
num_workers: int
|
| 212 |
+
pin_memory: bool
|
| 213 |
+
drop_last: bool
|
| 214 |
+
timeout: float
|
| 215 |
+
sampler: Union[Sampler, Iterable]
|
| 216 |
+
pin_memory_device: str
|
| 217 |
+
prefetch_factor: Optional[int]
|
| 218 |
+
_iterator: Optional["_BaseDataLoaderIter"]
|
| 219 |
+
__initialized = False
|
| 220 |
+
|
| 221 |
+
def __init__(
|
| 222 |
+
self,
|
| 223 |
+
dataset: Dataset[T_co],
|
| 224 |
+
batch_size: Optional[int] = 1,
|
| 225 |
+
shuffle: Optional[bool] = None,
|
| 226 |
+
sampler: Union[Sampler, Iterable, None] = None,
|
| 227 |
+
batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None,
|
| 228 |
+
num_workers: int = 0,
|
| 229 |
+
collate_fn: Optional[_collate_fn_t] = None,
|
| 230 |
+
pin_memory: bool = False,
|
| 231 |
+
drop_last: bool = False,
|
| 232 |
+
timeout: float = 0,
|
| 233 |
+
worker_init_fn: Optional[_worker_init_fn_t] = None,
|
| 234 |
+
multiprocessing_context=None,
|
| 235 |
+
generator=None,
|
| 236 |
+
*,
|
| 237 |
+
prefetch_factor: Optional[int] = None,
|
| 238 |
+
persistent_workers: bool = False,
|
| 239 |
+
pin_memory_device: str = ""
|
| 240 |
+
):
|
| 241 |
+
torch._C._log_api_usage_once("python.data_loader")
|
| 242 |
+
|
| 243 |
+
if num_workers < 0:
|
| 244 |
+
raise ValueError(
|
| 245 |
+
"num_workers option should be non-negative; "
|
| 246 |
+
"use num_workers=0 to disable multiprocessing."
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
if timeout < 0:
|
| 250 |
+
raise ValueError("timeout option should be non-negative")
|
| 251 |
+
|
| 252 |
+
if num_workers == 0 and prefetch_factor is not None:
|
| 253 |
+
raise ValueError(
|
| 254 |
+
"prefetch_factor option could only be specified in multiprocessing."
|
| 255 |
+
"let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None."
|
| 256 |
+
)
|
| 257 |
+
elif num_workers > 0 and prefetch_factor is None:
|
| 258 |
+
prefetch_factor = 2
|
| 259 |
+
elif prefetch_factor is not None and prefetch_factor < 0:
|
| 260 |
+
raise ValueError("prefetch_factor option should be non-negative")
|
| 261 |
+
|
| 262 |
+
if persistent_workers and num_workers == 0:
|
| 263 |
+
raise ValueError("persistent_workers option needs num_workers > 0")
|
| 264 |
+
|
| 265 |
+
self.dataset = dataset
|
| 266 |
+
self.num_workers = num_workers
|
| 267 |
+
self.prefetch_factor = prefetch_factor
|
| 268 |
+
self.pin_memory = pin_memory
|
| 269 |
+
self.pin_memory_device = pin_memory_device
|
| 270 |
+
self.timeout = timeout
|
| 271 |
+
self.worker_init_fn = worker_init_fn
|
| 272 |
+
self.multiprocessing_context = multiprocessing_context
|
| 273 |
+
|
| 274 |
+
# Adds forward compatibilities so classic DataLoader can work with DataPipes:
|
| 275 |
+
# _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler
|
| 276 |
+
if isinstance(self.dataset, IterDataPipe):
|
| 277 |
+
self.dataset = _IterDataPipeSerializationWrapper(self.dataset)
|
| 278 |
+
elif isinstance(self.dataset, MapDataPipe):
|
| 279 |
+
self.dataset = _MapDataPipeSerializationWrapper(self.dataset)
|
| 280 |
+
|
| 281 |
+
# Arg-check dataset related before checking samplers because we want to
|
| 282 |
+
# tell users that iterable-style datasets are incompatible with custom
|
| 283 |
+
# samplers first, so that they don't learn that this combo doesn't work
|
| 284 |
+
# after spending time fixing the custom sampler errors.
|
| 285 |
+
if isinstance(dataset, IterableDataset):
|
| 286 |
+
self._dataset_kind = _DatasetKind.Iterable
|
| 287 |
+
# NOTE [ Custom Samplers and IterableDataset ]
|
| 288 |
+
#
|
| 289 |
+
# `IterableDataset` does not support custom `batch_sampler` or
|
| 290 |
+
# `sampler` since the key is irrelevant (unless we support
|
| 291 |
+
# generator-style dataset one day...).
|
| 292 |
+
#
|
| 293 |
+
# For `sampler`, we always create a dummy sampler. This is an
|
| 294 |
+
# infinite sampler even when the dataset may have an implemented
|
| 295 |
+
# finite `__len__` because in multi-process data loading, naive
|
| 296 |
+
# settings will return duplicated data (which may be desired), and
|
| 297 |
+
# thus using a sampler with length matching that of dataset will
|
| 298 |
+
# cause data lost (you may have duplicates of the first couple
|
| 299 |
+
# batches, but never see anything afterwards). Therefore,
|
| 300 |
+
# `Iterabledataset` always uses an infinite sampler, an instance of
|
| 301 |
+
# `_InfiniteConstantSampler` defined above.
|
| 302 |
+
#
|
| 303 |
+
# A custom `batch_sampler` essentially only controls the batch size.
|
| 304 |
+
# However, it is unclear how useful it would be since an iterable-style
|
| 305 |
+
# dataset can handle that within itself. Moreover, it is pointless
|
| 306 |
+
# in multi-process data loading as the assignment order of batches
|
| 307 |
+
# to workers is an implementation detail so users can not control
|
| 308 |
+
# how to batchify each worker's iterable. Thus, we disable this
|
| 309 |
+
# option. If this turns out to be useful in future, we can re-enable
|
| 310 |
+
# this, and support custom samplers that specify the assignments to
|
| 311 |
+
# specific workers.
|
| 312 |
+
if isinstance(dataset, IterDataPipe):
|
| 313 |
+
if shuffle is not None:
|
| 314 |
+
dataset = torch.utils.data.graph_settings.apply_shuffle_settings(
|
| 315 |
+
dataset, shuffle=shuffle
|
| 316 |
+
)
|
| 317 |
+
# We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default.
|
| 318 |
+
elif shuffle not in {False, None}:
|
| 319 |
+
raise ValueError(
|
| 320 |
+
"DataLoader with IterableDataset: expected unspecified "
|
| 321 |
+
"shuffle option, but got shuffle={}".format(shuffle)
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
if sampler is not None:
|
| 325 |
+
# See NOTE [ Custom Samplers and IterableDataset ]
|
| 326 |
+
raise ValueError(
|
| 327 |
+
"DataLoader with IterableDataset: expected unspecified "
|
| 328 |
+
"sampler option, but got sampler={}".format(sampler)
|
| 329 |
+
)
|
| 330 |
+
elif batch_sampler is not None:
|
| 331 |
+
# See NOTE [ Custom Samplers and IterableDataset ]
|
| 332 |
+
raise ValueError(
|
| 333 |
+
"DataLoader with IterableDataset: expected unspecified "
|
| 334 |
+
"batch_sampler option, but got batch_sampler={}".format(
|
| 335 |
+
batch_sampler
|
| 336 |
+
)
|
| 337 |
+
)
|
| 338 |
+
else:
|
| 339 |
+
shuffle = bool(shuffle)
|
| 340 |
+
self._dataset_kind = _DatasetKind.Map
|
| 341 |
+
|
| 342 |
+
if sampler is not None and shuffle:
|
| 343 |
+
raise ValueError("sampler option is mutually exclusive with " "shuffle")
|
| 344 |
+
|
| 345 |
+
if batch_sampler is not None:
|
| 346 |
+
# auto_collation with custom batch_sampler
|
| 347 |
+
if batch_size != 1 or shuffle or sampler is not None or drop_last:
|
| 348 |
+
raise ValueError(
|
| 349 |
+
"batch_sampler option is mutually exclusive "
|
| 350 |
+
"with batch_size, shuffle, sampler, and "
|
| 351 |
+
"drop_last"
|
| 352 |
+
)
|
| 353 |
+
batch_size = None
|
| 354 |
+
drop_last = False
|
| 355 |
+
elif batch_size is None:
|
| 356 |
+
# no auto_collation
|
| 357 |
+
if drop_last:
|
| 358 |
+
raise ValueError(
|
| 359 |
+
"batch_size=None option disables auto-batching "
|
| 360 |
+
"and is mutually exclusive with drop_last"
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
if sampler is None: # give default samplers
|
| 364 |
+
if self._dataset_kind == _DatasetKind.Iterable:
|
| 365 |
+
# See NOTE [ Custom Samplers and IterableDataset ]
|
| 366 |
+
sampler = _InfiniteConstantSampler()
|
| 367 |
+
else: # map-style
|
| 368 |
+
if shuffle:
|
| 369 |
+
sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type]
|
| 370 |
+
else:
|
| 371 |
+
sampler = SequentialSampler(dataset) # type: ignore[arg-type]
|
| 372 |
+
|
| 373 |
+
if batch_size is not None and batch_sampler is None:
|
| 374 |
+
# auto_collation without custom batch_sampler
|
| 375 |
+
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
|
| 376 |
+
|
| 377 |
+
self.batch_size = batch_size
|
| 378 |
+
self.drop_last = drop_last
|
| 379 |
+
self.sampler = sampler
|
| 380 |
+
self.batch_sampler = batch_sampler
|
| 381 |
+
self.generator = generator
|
| 382 |
+
|
| 383 |
+
if collate_fn is None:
|
| 384 |
+
if self._auto_collation:
|
| 385 |
+
collate_fn = _utils.collate.default_collate
|
| 386 |
+
else:
|
| 387 |
+
collate_fn = _utils.collate.default_convert
|
| 388 |
+
|
| 389 |
+
self.collate_fn = collate_fn
|
| 390 |
+
self.persistent_workers = persistent_workers
|
| 391 |
+
|
| 392 |
+
self.__initialized = True
|
| 393 |
+
self._IterableDataset_len_called = (
|
| 394 |
+
None # See NOTE [ IterableDataset and __len__ ]
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
self._iterator = None
|
| 398 |
+
|
| 399 |
+
self.check_worker_number_rationality()
|
| 400 |
+
|
| 401 |
+
torch.set_vital("Dataloader", "enabled", "True") # type: ignore[attr-defined]
|
| 402 |
+
|
| 403 |
+
def _get_iterator(self) -> "_BaseDataLoaderIter":
|
| 404 |
+
if self.num_workers == 0:
|
| 405 |
+
return _SingleProcessDataLoaderIter(self)
|
| 406 |
+
else:
|
| 407 |
+
self.check_worker_number_rationality()
|
| 408 |
+
return _MultiProcessingDataLoaderIter(self)
|
| 409 |
+
|
| 410 |
+
@property
|
| 411 |
+
def multiprocessing_context(self):
|
| 412 |
+
return self.__multiprocessing_context
|
| 413 |
+
|
| 414 |
+
@multiprocessing_context.setter
|
| 415 |
+
def multiprocessing_context(self, multiprocessing_context):
|
| 416 |
+
if multiprocessing_context is not None:
|
| 417 |
+
if self.num_workers > 0:
|
| 418 |
+
if isinstance(multiprocessing_context, str):
|
| 419 |
+
valid_start_methods = multiprocessing.get_all_start_methods()
|
| 420 |
+
if multiprocessing_context not in valid_start_methods:
|
| 421 |
+
raise ValueError(
|
| 422 |
+
(
|
| 423 |
+
"multiprocessing_context option "
|
| 424 |
+
"should specify a valid start method in {!r}, but got "
|
| 425 |
+
"multiprocessing_context={!r}"
|
| 426 |
+
).format(valid_start_methods, multiprocessing_context)
|
| 427 |
+
)
|
| 428 |
+
multiprocessing_context = multiprocessing.get_context(
|
| 429 |
+
multiprocessing_context
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
if not isinstance(
|
| 433 |
+
multiprocessing_context, python_multiprocessing.context.BaseContext
|
| 434 |
+
):
|
| 435 |
+
raise TypeError(
|
| 436 |
+
(
|
| 437 |
+
"multiprocessing_context option should be a valid context "
|
| 438 |
+
"object or a string specifying the start method, but got "
|
| 439 |
+
"multiprocessing_context={}"
|
| 440 |
+
).format(multiprocessing_context)
|
| 441 |
+
)
|
| 442 |
+
else:
|
| 443 |
+
raise ValueError(
|
| 444 |
+
(
|
| 445 |
+
"multiprocessing_context can only be used with "
|
| 446 |
+
"multi-process loading (num_workers > 0), but got "
|
| 447 |
+
"num_workers={}"
|
| 448 |
+
).format(self.num_workers)
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
self.__multiprocessing_context = multiprocessing_context
|
| 452 |
+
|
| 453 |
+
def __setattr__(self, attr, val):
|
| 454 |
+
if self.__initialized and attr in (
|
| 455 |
+
"batch_size",
|
| 456 |
+
"batch_sampler",
|
| 457 |
+
"sampler",
|
| 458 |
+
"drop_last",
|
| 459 |
+
"dataset",
|
| 460 |
+
"persistent_workers",
|
| 461 |
+
):
|
| 462 |
+
raise ValueError(
|
| 463 |
+
"{} attribute should not be set after {} is "
|
| 464 |
+
"initialized".format(attr, self.__class__.__name__)
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
super().__setattr__(attr, val)
|
| 468 |
+
|
| 469 |
+
# We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
|
| 470 |
+
# since '_BaseDataLoaderIter' references 'DataLoader'.
|
| 471 |
+
def __iter__(self) -> "_BaseDataLoaderIter":
|
| 472 |
+
# When using a single worker the returned iterator should be
|
| 473 |
+
# created everytime to avoid reseting its state
|
| 474 |
+
# However, in the case of a multiple workers iterator
|
| 475 |
+
# the iterator is only created once in the lifetime of the
|
| 476 |
+
# DataLoader object so that workers can be reused
|
| 477 |
+
if self.persistent_workers and self.num_workers > 0:
|
| 478 |
+
if self._iterator is None:
|
| 479 |
+
self._iterator = self._get_iterator()
|
| 480 |
+
else:
|
| 481 |
+
self._iterator._reset(self)
|
| 482 |
+
return self._iterator
|
| 483 |
+
else:
|
| 484 |
+
return self._get_iterator()
|
| 485 |
+
|
| 486 |
+
@property
|
| 487 |
+
def _auto_collation(self):
|
| 488 |
+
return self.batch_sampler is not None
|
| 489 |
+
|
| 490 |
+
@property
|
| 491 |
+
def _index_sampler(self):
|
| 492 |
+
# The actual sampler used for generating indices for `_DatasetFetcher`
|
| 493 |
+
# (see _utils/fetch.py) to read data at each time. This would be
|
| 494 |
+
# `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
|
| 495 |
+
# We can't change `.sampler` and `.batch_sampler` attributes for BC
|
| 496 |
+
# reasons.
|
| 497 |
+
if self._auto_collation:
|
| 498 |
+
return self.batch_sampler
|
| 499 |
+
else:
|
| 500 |
+
return self.sampler
|
| 501 |
+
|
| 502 |
+
def __len__(self) -> int:
|
| 503 |
+
if self._dataset_kind == _DatasetKind.Iterable:
|
| 504 |
+
# NOTE [ IterableDataset and __len__ ]
|
| 505 |
+
#
|
| 506 |
+
# For `IterableDataset`, `__len__` could be inaccurate when one naively
|
| 507 |
+
# does multi-processing data loading, since the samples will be duplicated.
|
| 508 |
+
# However, no real use case should be actually using that behavior, so
|
| 509 |
+
# it should count as a user error. We should generally trust user
|
| 510 |
+
# code to do the proper thing (e.g., configure each replica differently
|
| 511 |
+
# in `__iter__`), and give us the correct `__len__` if they choose to
|
| 512 |
+
# implement it (this will still throw if the dataset does not implement
|
| 513 |
+
# a `__len__`).
|
| 514 |
+
#
|
| 515 |
+
# To provide a further warning, we track if `__len__` was called on the
|
| 516 |
+
# `DataLoader`, save the returned value in `self._len_called`, and warn
|
| 517 |
+
# if the iterator ends up yielding more than this number of samples.
|
| 518 |
+
|
| 519 |
+
# Cannot statically verify that dataset is Sized
|
| 520 |
+
length = self._IterableDataset_len_called = len(self.dataset) # type: ignore[assignment, arg-type]
|
| 521 |
+
if (
|
| 522 |
+
self.batch_size is not None
|
| 523 |
+
): # IterableDataset doesn't allow custom sampler or batch_sampler
|
| 524 |
+
from math import ceil
|
| 525 |
+
|
| 526 |
+
if self.drop_last:
|
| 527 |
+
length = length // self.batch_size
|
| 528 |
+
else:
|
| 529 |
+
length = ceil(length / self.batch_size)
|
| 530 |
+
return length
|
| 531 |
+
else:
|
| 532 |
+
return len(self._index_sampler)
|
| 533 |
+
|
| 534 |
+
def check_worker_number_rationality(self):
|
| 535 |
+
# This function check whether the dataloader's worker number is rational based on
|
| 536 |
+
# current system's resource. Current rule is that if the number of workers this
|
| 537 |
+
# Dataloader will create is bigger than the number of logical cpus that is allowed to
|
| 538 |
+
# use, than we will pop up a warning to let user pay attention.
|
| 539 |
+
#
|
| 540 |
+
# eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2
|
| 541 |
+
# threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current
|
| 542 |
+
# DataLoader process can use half of them which is 32, then the rational max number of
|
| 543 |
+
# worker that initiated from this process is 32.
|
| 544 |
+
# Now, let's say the created DataLoader has num_works = 40, which is bigger than 32.
|
| 545 |
+
# So the warning message is triggered to notify the user to lower the worker number if
|
| 546 |
+
# necessary.
|
| 547 |
+
#
|
| 548 |
+
#
|
| 549 |
+
# [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is
|
| 550 |
+
# available (available in most of Linux system, but not OSX and Windows).
|
| 551 |
+
# When os.sched_getaffinity is not available, os.cpu_count() is called instead, but
|
| 552 |
+
# it doesn't repect cpuset.
|
| 553 |
+
# We don't take threading into account since each worker process is single threaded
|
| 554 |
+
# at this time.
|
| 555 |
+
#
|
| 556 |
+
# We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc)
|
| 557 |
+
# other than `torch.set_num_threads` to 1 in the worker process, if the passing
|
| 558 |
+
# in functions use 3rd party modules that rely on those threading flags to determine
|
| 559 |
+
# how many thread to create (eg. numpy, etc), then it is caller's responsibility to
|
| 560 |
+
# set those flags correctly.
|
| 561 |
+
def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked):
|
| 562 |
+
|
| 563 |
+
suggested_max_worker_msg = (
|
| 564 |
+
(
|
| 565 |
+
(
|
| 566 |
+
"Our suggested max number of worker in current system is {}{}, which is smaller "
|
| 567 |
+
"than what this DataLoader is going to create."
|
| 568 |
+
).format(
|
| 569 |
+
num_worker_suggest,
|
| 570 |
+
(
|
| 571 |
+
""
|
| 572 |
+
if cpuset_checked
|
| 573 |
+
else " (`cpuset` is not taken into account)"
|
| 574 |
+
),
|
| 575 |
+
)
|
| 576 |
+
)
|
| 577 |
+
if num_worker_suggest is not None
|
| 578 |
+
else (
|
| 579 |
+
"DataLoader is not able to compute a suggested max number of worker in current system."
|
| 580 |
+
)
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
warn_msg = (
|
| 584 |
+
"This DataLoader will create {} worker processes in total. {} "
|
| 585 |
+
"Please be aware that excessive worker creation might get DataLoader running slow or even freeze, "
|
| 586 |
+
"lower the worker number to avoid potential slowness/freeze if necessary."
|
| 587 |
+
).format(num_worker_created, suggested_max_worker_msg)
|
| 588 |
+
return warn_msg
|
| 589 |
+
|
| 590 |
+
if not self.num_workers or self.num_workers == 0:
|
| 591 |
+
return
|
| 592 |
+
|
| 593 |
+
# try to compute a suggested max number of worker based on system's resource
|
| 594 |
+
max_num_worker_suggest = None
|
| 595 |
+
cpuset_checked = False
|
| 596 |
+
if hasattr(os, "sched_getaffinity"):
|
| 597 |
+
try:
|
| 598 |
+
max_num_worker_suggest = len(os.sched_getaffinity(0))
|
| 599 |
+
cpuset_checked = True
|
| 600 |
+
except Exception:
|
| 601 |
+
pass
|
| 602 |
+
if max_num_worker_suggest is None:
|
| 603 |
+
# os.cpu_count() could return Optional[int]
|
| 604 |
+
# get cpu count first and check None in order to satify mypy check
|
| 605 |
+
cpu_count = os.cpu_count()
|
| 606 |
+
if cpu_count is not None:
|
| 607 |
+
max_num_worker_suggest = cpu_count
|
| 608 |
+
|
| 609 |
+
if max_num_worker_suggest is None:
|
| 610 |
+
warnings.warn(
|
| 611 |
+
_create_warning_msg(
|
| 612 |
+
max_num_worker_suggest, self.num_workers, cpuset_checked
|
| 613 |
+
)
|
| 614 |
+
)
|
| 615 |
+
return
|
| 616 |
+
|
| 617 |
+
if self.num_workers > max_num_worker_suggest:
|
| 618 |
+
warnings.warn(
|
| 619 |
+
_create_warning_msg(
|
| 620 |
+
max_num_worker_suggest, self.num_workers, cpuset_checked
|
| 621 |
+
)
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
class _BaseDataLoaderIter:
|
| 626 |
+
def __init__(self, loader: RRSDataLoader) -> None:
|
| 627 |
+
self._dataset = loader.dataset
|
| 628 |
+
self._shared_seed = None
|
| 629 |
+
self._pg = None
|
| 630 |
+
if isinstance(self._dataset, IterDataPipe):
|
| 631 |
+
if dist.is_available() and dist.is_initialized():
|
| 632 |
+
self._pg = dist.new_group(backend="gloo")
|
| 633 |
+
self._shared_seed = _share_dist_seed(loader.generator, self._pg)
|
| 634 |
+
shared_rng = torch.Generator()
|
| 635 |
+
shared_rng.manual_seed(self._shared_seed)
|
| 636 |
+
self._dataset = torch.utils.data.graph_settings.apply_random_seed(
|
| 637 |
+
self._dataset, shared_rng
|
| 638 |
+
)
|
| 639 |
+
self._dataset_kind = loader._dataset_kind
|
| 640 |
+
self._IterableDataset_len_called = loader._IterableDataset_len_called
|
| 641 |
+
self._auto_collation = loader._auto_collation
|
| 642 |
+
self._drop_last = loader.drop_last
|
| 643 |
+
self._index_sampler = loader._index_sampler
|
| 644 |
+
self._num_workers = loader.num_workers
|
| 645 |
+
ws, rank = _get_distributed_settings()
|
| 646 |
+
self._world_size = ws
|
| 647 |
+
self._rank = rank
|
| 648 |
+
# for other backends, pin_memory_device need to set. if not set
|
| 649 |
+
# default behaviour is CUDA device. if pin_memory_device is selected
|
| 650 |
+
# and pin_memory is not set, the default behaviour false.
|
| 651 |
+
if len(loader.pin_memory_device) == 0:
|
| 652 |
+
self._pin_memory = loader.pin_memory and torch.cuda.is_available()
|
| 653 |
+
self._pin_memory_device = None
|
| 654 |
+
else:
|
| 655 |
+
if not loader.pin_memory:
|
| 656 |
+
warn_msg = (
|
| 657 |
+
"pin memory device is set and pin_memory flag is not used then device pinned memory won't be used"
|
| 658 |
+
"please set pin_memory to true, if you need to use the device pin memory"
|
| 659 |
+
)
|
| 660 |
+
warnings.warn(warn_msg)
|
| 661 |
+
|
| 662 |
+
self._pin_memory = loader.pin_memory
|
| 663 |
+
self._pin_memory_device = loader.pin_memory_device
|
| 664 |
+
self._timeout = loader.timeout
|
| 665 |
+
self._collate_fn = loader.collate_fn
|
| 666 |
+
self._sampler_iter = iter(self._index_sampler)
|
| 667 |
+
self._base_seed = (
|
| 668 |
+
torch.empty((), dtype=torch.int64)
|
| 669 |
+
.random_(generator=loader.generator)
|
| 670 |
+
.item()
|
| 671 |
+
)
|
| 672 |
+
self._persistent_workers = loader.persistent_workers
|
| 673 |
+
self._num_yielded = 0
|
| 674 |
+
self._profile_name = "enumerate(DataLoader)#{}.__next__".format(
|
| 675 |
+
self.__class__.__name__
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
def __iter__(self) -> "_BaseDataLoaderIter":
|
| 679 |
+
return self
|
| 680 |
+
|
| 681 |
+
def _reset(self, loader, first_iter=False):
|
| 682 |
+
self._sampler_iter = iter(self._index_sampler)
|
| 683 |
+
self._num_yielded = 0
|
| 684 |
+
self._IterableDataset_len_called = loader._IterableDataset_len_called
|
| 685 |
+
if isinstance(self._dataset, IterDataPipe):
|
| 686 |
+
self._shared_seed = _share_dist_seed(loader.generator, self._pg)
|
| 687 |
+
shared_rng = torch.Generator()
|
| 688 |
+
shared_rng.manual_seed(self._shared_seed)
|
| 689 |
+
self._dataset = torch.utils.data.graph_settings.apply_random_seed(
|
| 690 |
+
self._dataset, shared_rng
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
def _next_index(self):
|
| 694 |
+
return next(self._sampler_iter) # may raise StopIteration
|
| 695 |
+
|
| 696 |
+
def _next_data(self):
|
| 697 |
+
raise NotImplementedError
|
| 698 |
+
|
| 699 |
+
def __next__(self) -> Any:
|
| 700 |
+
with torch.autograd.profiler.record_function(self._profile_name):
|
| 701 |
+
if self._sampler_iter is None:
|
| 702 |
+
self._reset() # type: ignore[call-arg]
|
| 703 |
+
data = self._next_data()
|
| 704 |
+
self._num_yielded += 1
|
| 705 |
+
if (
|
| 706 |
+
self._dataset_kind == _DatasetKind.Iterable
|
| 707 |
+
and self._IterableDataset_len_called is not None
|
| 708 |
+
and self._num_yielded > self._IterableDataset_len_called
|
| 709 |
+
):
|
| 710 |
+
warn_msg = (
|
| 711 |
+
"Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
|
| 712 |
+
"samples have been fetched. "
|
| 713 |
+
).format(
|
| 714 |
+
self._dataset, self._IterableDataset_len_called, self._num_yielded
|
| 715 |
+
)
|
| 716 |
+
if self._num_workers > 0:
|
| 717 |
+
warn_msg += (
|
| 718 |
+
"For multiprocessing data-loading, this could be caused by not properly configuring the "
|
| 719 |
+
"IterableDataset replica at each worker. Please see "
|
| 720 |
+
"https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples."
|
| 721 |
+
)
|
| 722 |
+
warnings.warn(warn_msg)
|
| 723 |
+
return data
|
| 724 |
+
|
| 725 |
+
def __len__(self) -> int:
|
| 726 |
+
return len(self._index_sampler)
|
| 727 |
+
|
| 728 |
+
def __getstate__(self):
|
| 729 |
+
# across multiple threads for HOGWILD.
|
| 730 |
+
# Probably the best way to do this is by moving the sample pushing
|
| 731 |
+
# to a separate thread and then just sharing the data queue
|
| 732 |
+
# but signalling the end is tricky without a non-blocking API
|
| 733 |
+
raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
|
| 734 |
+
|
| 735 |
+
|
| 736 |
+
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
|
| 737 |
+
def __init__(self, loader):
|
| 738 |
+
super().__init__(loader)
|
| 739 |
+
assert self._timeout == 0
|
| 740 |
+
assert self._num_workers == 0
|
| 741 |
+
|
| 742 |
+
# Adds forward compatibilities so classic DataLoader can work with DataPipes:
|
| 743 |
+
# Taking care of distributed sharding
|
| 744 |
+
if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
|
| 745 |
+
# For BC, use default SHARDING_PRIORITIES
|
| 746 |
+
torch.utils.data.graph_settings.apply_sharding(
|
| 747 |
+
self._dataset, self._world_size, self._rank
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
self._dataset_fetcher = _DatasetKind.create_fetcher(
|
| 751 |
+
self._dataset_kind,
|
| 752 |
+
self._dataset,
|
| 753 |
+
self._auto_collation,
|
| 754 |
+
self._collate_fn,
|
| 755 |
+
self._drop_last,
|
| 756 |
+
)
|
| 757 |
+
|
| 758 |
+
def _next_data(self):
|
| 759 |
+
index = self._next_index() # may raise StopIteration
|
| 760 |
+
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
|
| 761 |
+
if self._pin_memory:
|
| 762 |
+
data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
|
| 763 |
+
return data
|
| 764 |
+
|
| 765 |
+
|
| 766 |
+
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
| 767 |
+
r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
|
| 768 |
+
|
| 769 |
+
# NOTE [ Data Loader Multiprocessing Shutdown Logic ]
|
| 770 |
+
#
|
| 771 |
+
# Preliminary:
|
| 772 |
+
#
|
| 773 |
+
# Our data model looks like this (queues are indicated with curly brackets):
|
| 774 |
+
#
|
| 775 |
+
# main process ||
|
| 776 |
+
# | ||
|
| 777 |
+
# {index_queue} ||
|
| 778 |
+
# | ||
|
| 779 |
+
# worker processes || DATA
|
| 780 |
+
# | ||
|
| 781 |
+
# {worker_result_queue} || FLOW
|
| 782 |
+
# | ||
|
| 783 |
+
# pin_memory_thread of main process || DIRECTION
|
| 784 |
+
# | ||
|
| 785 |
+
# {data_queue} ||
|
| 786 |
+
# | ||
|
| 787 |
+
# data output \/
|
| 788 |
+
#
|
| 789 |
+
# P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
|
| 790 |
+
# `pin_memory=False`.
|
| 791 |
+
#
|
| 792 |
+
#
|
| 793 |
+
# Terminating multiprocessing logic requires very careful design. In
|
| 794 |
+
# particular, we need to make sure that
|
| 795 |
+
#
|
| 796 |
+
# 1. The iterator gracefully exits the workers when its last reference is
|
| 797 |
+
# gone or it is depleted.
|
| 798 |
+
#
|
| 799 |
+
# In this case, the workers should be gracefully exited because the
|
| 800 |
+
# main process may still need to continue to run, and we want cleaning
|
| 801 |
+
# up code in the workers to be executed (e.g., releasing GPU memory).
|
| 802 |
+
# Naturally, we implement the shutdown logic in `__del__` of
|
| 803 |
+
# DataLoaderIterator.
|
| 804 |
+
#
|
| 805 |
+
# We delay the discussion on the logic in this case until later.
|
| 806 |
+
#
|
| 807 |
+
# 2. The iterator exits the workers when the loader process and/or worker
|
| 808 |
+
# processes exits normally or with error.
|
| 809 |
+
#
|
| 810 |
+
# We set all workers and `pin_memory_thread` to have `daemon=True`.
|
| 811 |
+
#
|
| 812 |
+
# You may ask, why can't we make the workers non-daemonic, and
|
| 813 |
+
# gracefully exit using the same logic as we have in `__del__` when the
|
| 814 |
+
# iterator gets deleted (see 1 above)?
|
| 815 |
+
#
|
| 816 |
+
# First of all, `__del__` is **not** guaranteed to be called when
|
| 817 |
+
# interpreter exits. Even if it is called, by the time it executes,
|
| 818 |
+
# many Python core library resources may alreay be freed, and even
|
| 819 |
+
# simple things like acquiring an internal lock of a queue may hang.
|
| 820 |
+
# Therefore, in this case, we actually need to prevent `__del__` from
|
| 821 |
+
# being executed, and rely on the automatic termination of daemonic
|
| 822 |
+
# children.
|
| 823 |
+
#
|
| 824 |
+
# Thus, we register an `atexit` hook that sets a global flag
|
| 825 |
+
# `_utils.python_exit_status`. Since `atexit` hooks are executed in the
|
| 826 |
+
# reverse order of registration, we are guaranteed that this flag is
|
| 827 |
+
# set before library resources we use are freed (which, at least in
|
| 828 |
+
# CPython, is done via an `atexit` handler defined in
|
| 829 |
+
# `multiprocessing/util.py`
|
| 830 |
+
# https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362
|
| 831 |
+
# registered when an object requiring this mechanism is first
|
| 832 |
+
# created, e.g., `mp.Queue`
|
| 833 |
+
# https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103
|
| 834 |
+
# https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29
|
| 835 |
+
# )
|
| 836 |
+
#
|
| 837 |
+
# So in `__del__`, we check if `_utils.python_exit_status` is set or
|
| 838 |
+
# `None` (freed), and perform no-op if so.
|
| 839 |
+
#
|
| 840 |
+
# However, simply letting library clean-up codes run can also be bad,
|
| 841 |
+
# because such codes (i.e., `multiprocessing.util._exit_function()`)
|
| 842 |
+
# include join putting threads for `mp.Queue`, which can be blocking.
|
| 843 |
+
# Hence, the main process putting threads are called with
|
| 844 |
+
# `cancel_join_thread` at creation. See later section
|
| 845 |
+
# [ 3b. A process won't hang when putting into a queue; ]
|
| 846 |
+
# for more details.
|
| 847 |
+
#
|
| 848 |
+
# Here are two example cases where library clean-up codes can run
|
| 849 |
+
# before `__del__` is called:
|
| 850 |
+
#
|
| 851 |
+
# 1. If we hold onto a reference to the iterator, it more often
|
| 852 |
+
# than not tries to do `multiprocessing` library cleaning before
|
| 853 |
+
# clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666)
|
| 854 |
+
# and thus prevents our cleaning-up code to run first.
|
| 855 |
+
#
|
| 856 |
+
# 2. A similar issue araises when a `DataLoader` is used in a subprocess.
|
| 857 |
+
# When a process ends, it shuts the all its daemonic children
|
| 858 |
+
# down with a SIGTERM (instead of joining them without a timeout).
|
| 859 |
+
# Simiarly for threads, but by a different mechanism. This fact,
|
| 860 |
+
# together with a few implementation details of multiprocessing, forces
|
| 861 |
+
# us to make workers daemonic. All of our problems arise when a
|
| 862 |
+
# DataLoader is used in a subprocess, and are caused by multiprocessing
|
| 863 |
+
# code which looks more or less like this:
|
| 864 |
+
#
|
| 865 |
+
# try:
|
| 866 |
+
# your_function_using_a_dataloader()
|
| 867 |
+
# finally:
|
| 868 |
+
# multiprocessing.util._exit_function()
|
| 869 |
+
#
|
| 870 |
+
# The joining/termination mentioned above happens inside
|
| 871 |
+
# `_exit_function()`. Now, if `your_function_using_a_dataloader()`
|
| 872 |
+
# throws, the stack trace stored in the exception will prevent the
|
| 873 |
+
# frame which uses `DataLoaderIter` to be freed. If the frame has any
|
| 874 |
+
# reference to the `DataLoaderIter` (e.g., in a method of the iter),
|
| 875 |
+
# its `__del__`, which starts the shutdown procedure, will not be
|
| 876 |
+
# called. That, in turn, means that workers aren't notified. Attempting
|
| 877 |
+
# to join in `_exit_function` will then result in a hang.
|
| 878 |
+
#
|
| 879 |
+
# For context, `_exit_function` is also registered as an `atexit` call.
|
| 880 |
+
# So it is unclear to me (@ssnl) why this is needed in a finally block.
|
| 881 |
+
# The code dates back to 2008 and there is no comment on the original
|
| 882 |
+
# PEP 371 or patch https://bugs.python.org/issue3050 (containing both
|
| 883 |
+
# the finally block and the `atexit` registration) that explains this.
|
| 884 |
+
#
|
| 885 |
+
#
|
| 886 |
+
# Finally, another choice is to just shutdown workers with logic in 1
|
| 887 |
+
# above whenever we see an error in `next`. This isn't ideal because
|
| 888 |
+
# a. It prevents users from using try-catch to resume data loading.
|
| 889 |
+
# b. It doesn't prevent hanging if users have references to the
|
| 890 |
+
# iterator.
|
| 891 |
+
#
|
| 892 |
+
# 3. All processes exit if any of them die unexpectedly by fatal signals.
|
| 893 |
+
#
|
| 894 |
+
# As shown above, the workers are set as daemonic children of the main
|
| 895 |
+
# process. However, automatic cleaning-up of such child processes only
|
| 896 |
+
# happens if the parent process exits gracefully (e.g., not via fatal
|
| 897 |
+
# signals like SIGKILL). So we must ensure that each process will exit
|
| 898 |
+
# even the process that should send/receive data to/from it were
|
| 899 |
+
# killed, i.e.,
|
| 900 |
+
#
|
| 901 |
+
# a. A process won't hang when getting from a queue.
|
| 902 |
+
#
|
| 903 |
+
# Even with carefully designed data dependencies (i.e., a `put()`
|
| 904 |
+
# always corresponding to a `get()`), hanging on `get()` can still
|
| 905 |
+
# happen when data in queue is corrupted (e.g., due to
|
| 906 |
+
# `cancel_join_thread` or unexpected exit).
|
| 907 |
+
#
|
| 908 |
+
# For child exit, we set a timeout whenever we try to get data
|
| 909 |
+
# from `data_queue`, and check the workers' status on each timeout
|
| 910 |
+
# and error.
|
| 911 |
+
# See `_DataLoaderiter._get_batch()` and
|
| 912 |
+
# `_DataLoaderiter._try_get_data()` for details.
|
| 913 |
+
#
|
| 914 |
+
# Additionally, for child exit on non-Windows platforms, we also
|
| 915 |
+
# register a SIGCHLD handler (which is supported on Windows) on
|
| 916 |
+
# the main process, which checks if any of the workers fail in the
|
| 917 |
+
# (Python) handler. This is more efficient and faster in detecting
|
| 918 |
+
# worker failures, compared to only using the above mechanism.
|
| 919 |
+
# See `DataLoader.cpp` and `_utils/signal_handling.py` for details.
|
| 920 |
+
#
|
| 921 |
+
# For `.get()` calls where the sender(s) is not the workers, we
|
| 922 |
+
# guard them with timeouts, and check the status of the sender
|
| 923 |
+
# when timeout happens:
|
| 924 |
+
# + in the workers, the `_utils.worker.ManagerWatchdog` class
|
| 925 |
+
# checks the status of the main process.
|
| 926 |
+
# + if `pin_memory=True`, when getting from `pin_memory_thread`,
|
| 927 |
+
# check `pin_memory_thread` status periodically until `.get()`
|
| 928 |
+
# returns or see that `pin_memory_thread` died.
|
| 929 |
+
#
|
| 930 |
+
# b. A process won't hang when putting into a queue;
|
| 931 |
+
#
|
| 932 |
+
# We use `mp.Queue` which has a separate background thread to put
|
| 933 |
+
# objects from an unbounded buffer array. The background thread is
|
| 934 |
+
# daemonic and usually automatically joined when the process
|
| 935 |
+
# *exits*.
|
| 936 |
+
#
|
| 937 |
+
# In case that the receiver has ended abruptly while
|
| 938 |
+
# reading from the pipe, the join will hang forever. The usual
|
| 939 |
+
# solution for this in Python is calling `q.cancel_join_thread`,
|
| 940 |
+
# which prevents automatically joining it when finalizing
|
| 941 |
+
# (exiting).
|
| 942 |
+
#
|
| 943 |
+
# Nonetheless, `cancel_join_thread` must only be called when the
|
| 944 |
+
# queue is **not** going to be read from or write into by another
|
| 945 |
+
# process, because it may hold onto a lock or leave corrupted data
|
| 946 |
+
# in the queue, leading other readers/writers to hang.
|
| 947 |
+
#
|
| 948 |
+
# Hence,
|
| 949 |
+
# + For worker processes, we only do so (for their output
|
| 950 |
+
# queues, i.e., `worker_result_queue`) before exiting.
|
| 951 |
+
# + For `pin_memory_thread`, its output queue `data_queue` is a
|
| 952 |
+
# `queue.Queue` that does blocking `put` if the queue is full.
|
| 953 |
+
# So there is no above problem, but as a result, in
|
| 954 |
+
# `_pin_memory_loop`, we do need to wrap the `put` in a loop
|
| 955 |
+
# that breaks not only upon success, but also when the main
|
| 956 |
+
# process stops reading, i.e., is shutting down.
|
| 957 |
+
# + For loader process, we `cancel_join_thread()` for all
|
| 958 |
+
# `_index_queues` because the whole purpose of workers and
|
| 959 |
+
# `pin_memory_thread` is to serve the loader process. If
|
| 960 |
+
# loader process is already exiting, we don't really care if
|
| 961 |
+
# the queues are corrupted.
|
| 962 |
+
#
|
| 963 |
+
#
|
| 964 |
+
# Now let's get back to 1:
|
| 965 |
+
# how we gracefully exit the workers when the last reference to the
|
| 966 |
+
# iterator is gone.
|
| 967 |
+
#
|
| 968 |
+
# To achieve this, we implement the following logic along with the design
|
| 969 |
+
# choices mentioned above:
|
| 970 |
+
#
|
| 971 |
+
# `workers_done_event`:
|
| 972 |
+
# A `multiprocessing.Event` shared among the main process and all worker
|
| 973 |
+
# processes. This is used to signal the workers that the iterator is
|
| 974 |
+
# shutting down. After it is set, they will not send processed data to
|
| 975 |
+
# queues anymore, and only wait for the final `None` before exiting.
|
| 976 |
+
# `done_event` isn't strictly needed. I.e., we can just check for `None`
|
| 977 |
+
# from the input queue, but it allows us to skip wasting resources
|
| 978 |
+
# processing data if we are already shutting down.
|
| 979 |
+
#
|
| 980 |
+
# `pin_memory_thread_done_event`:
|
| 981 |
+
# A `threading.Event` for a similar purpose to that of
|
| 982 |
+
# `workers_done_event`, but is for the `pin_memory_thread`. The reason
|
| 983 |
+
# that separate events are needed is that `pin_memory_thread` reads from
|
| 984 |
+
# the output queue of the workers. But the workers, upon seeing that
|
| 985 |
+
# `workers_done_event` is set, only wants to see the final `None`, and is
|
| 986 |
+
# not required to flush all data in the output queue (e.g., it may call
|
| 987 |
+
# `cancel_join_thread` on that queue if its `IterableDataset` iterator
|
| 988 |
+
# happens to exhaust coincidentally, which is out of the control of the
|
| 989 |
+
# main process). Thus, since we will exit `pin_memory_thread` before the
|
| 990 |
+
# workers (see below), two separete events are used.
|
| 991 |
+
#
|
| 992 |
+
# NOTE: In short, the protocol is that the main process will set these
|
| 993 |
+
# `done_event`s and then the corresponding processes/threads a `None`,
|
| 994 |
+
# and that they may exit at any time after receiving the `None`.
|
| 995 |
+
#
|
| 996 |
+
# NOTE: Using `None` as the final signal is valid, since normal data will
|
| 997 |
+
# always be a 2-tuple with the 1st element being the index of the data
|
| 998 |
+
# transferred (different from dataset index/key), and the 2nd being
|
| 999 |
+
# either the dataset key or the data sample (depending on which part
|
| 1000 |
+
# of the data model the queue is at).
|
| 1001 |
+
#
|
| 1002 |
+
# [ worker processes ]
|
| 1003 |
+
# While loader process is alive:
|
| 1004 |
+
# Get from `index_queue`.
|
| 1005 |
+
# If get anything else,
|
| 1006 |
+
# Check `workers_done_event`.
|
| 1007 |
+
# If set, continue to next iteration
|
| 1008 |
+
# i.e., keep getting until see the `None`, then exit.
|
| 1009 |
+
# Otherwise, process data:
|
| 1010 |
+
# If is fetching from an `IterableDataset` and the iterator
|
| 1011 |
+
# is exhausted, send an `_IterableDatasetStopIteration`
|
| 1012 |
+
# object to signal iteration end. The main process, upon
|
| 1013 |
+
# receiving such an object, will send `None` to this
|
| 1014 |
+
# worker and not use the corresponding `index_queue`
|
| 1015 |
+
# anymore.
|
| 1016 |
+
# If timed out,
|
| 1017 |
+
# No matter `workers_done_event` is set (still need to see `None`)
|
| 1018 |
+
# or not, must continue to next iteration.
|
| 1019 |
+
# (outside loop)
|
| 1020 |
+
# If `workers_done_event` is set, (this can be False with `IterableDataset`)
|
| 1021 |
+
# `data_queue.cancel_join_thread()`. (Everything is ending here:
|
| 1022 |
+
# main process won't read from it;
|
| 1023 |
+
# other workers will also call
|
| 1024 |
+
# `cancel_join_thread`.)
|
| 1025 |
+
#
|
| 1026 |
+
# [ pin_memory_thread ]
|
| 1027 |
+
# # No need to check main thread. If this thread is alive, the main loader
|
| 1028 |
+
# # thread must be alive, because this thread is set as daemonic.
|
| 1029 |
+
# While `pin_memory_thread_done_event` is not set:
|
| 1030 |
+
# Get from `index_queue`.
|
| 1031 |
+
# If timed out, continue to get in the next iteration.
|
| 1032 |
+
# Otherwise, process data.
|
| 1033 |
+
# While `pin_memory_thread_done_event` is not set:
|
| 1034 |
+
# Put processed data to `data_queue` (a `queue.Queue` with blocking put)
|
| 1035 |
+
# If timed out, continue to put in the next iteration.
|
| 1036 |
+
# Otherwise, break, i.e., continuing to the out loop.
|
| 1037 |
+
#
|
| 1038 |
+
# NOTE: we don't check the status of the main thread because
|
| 1039 |
+
# 1. if the process is killed by fatal signal, `pin_memory_thread`
|
| 1040 |
+
# ends.
|
| 1041 |
+
# 2. in other cases, either the cleaning-up in __del__ or the
|
| 1042 |
+
# automatic exit of daemonic thread will take care of it.
|
| 1043 |
+
# This won't busy-wait either because `.get(timeout)` does not
|
| 1044 |
+
# busy-wait.
|
| 1045 |
+
#
|
| 1046 |
+
# [ main process ]
|
| 1047 |
+
# In the DataLoader Iter's `__del__`
|
| 1048 |
+
# b. Exit `pin_memory_thread`
|
| 1049 |
+
# i. Set `pin_memory_thread_done_event`.
|
| 1050 |
+
# ii Put `None` in `worker_result_queue`.
|
| 1051 |
+
# iii. Join the `pin_memory_thread`.
|
| 1052 |
+
# iv. `worker_result_queue.cancel_join_thread()`.
|
| 1053 |
+
#
|
| 1054 |
+
# c. Exit the workers.
|
| 1055 |
+
# i. Set `workers_done_event`.
|
| 1056 |
+
# ii. Put `None` in each worker's `index_queue`.
|
| 1057 |
+
# iii. Join the workers.
|
| 1058 |
+
# iv. Call `.cancel_join_thread()` on each worker's `index_queue`.
|
| 1059 |
+
#
|
| 1060 |
+
# NOTE: (c) is better placed after (b) because it may leave corrupted
|
| 1061 |
+
# data in `worker_result_queue`, which `pin_memory_thread`
|
| 1062 |
+
# reads from, in which case the `pin_memory_thread` can only
|
| 1063 |
+
# happen at timeing out, which is slow. Nonetheless, same thing
|
| 1064 |
+
# happens if a worker is killed by signal at unfortunate times,
|
| 1065 |
+
# but in other cases, we are better off having a non-corrupted
|
| 1066 |
+
# `worker_result_queue` for `pin_memory_thread`.
|
| 1067 |
+
#
|
| 1068 |
+
# NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
|
| 1069 |
+
# can be omitted
|
| 1070 |
+
#
|
| 1071 |
+
# NB: `done_event`s isn't strictly needed. E.g., we can just check for
|
| 1072 |
+
# `None` from `index_queue`, but it allows us to skip wasting resources
|
| 1073 |
+
# processing indices already in `index_queue` if we are already shutting
|
| 1074 |
+
# down.
|
| 1075 |
+
|
| 1076 |
+
def __init__(self, loader):
|
| 1077 |
+
super().__init__(loader)
|
| 1078 |
+
|
| 1079 |
+
self._prefetch_factor = loader.prefetch_factor
|
| 1080 |
+
|
| 1081 |
+
assert self._num_workers > 0
|
| 1082 |
+
assert self._prefetch_factor > 0
|
| 1083 |
+
|
| 1084 |
+
if loader.multiprocessing_context is None:
|
| 1085 |
+
multiprocessing_context = multiprocessing
|
| 1086 |
+
else:
|
| 1087 |
+
multiprocessing_context = loader.multiprocessing_context
|
| 1088 |
+
|
| 1089 |
+
self._worker_init_fn = loader.worker_init_fn
|
| 1090 |
+
|
| 1091 |
+
# Adds forward compatibilities so classic DataLoader can work with DataPipes:
|
| 1092 |
+
# Additional worker init function will take care of sharding in MP and Distributed
|
| 1093 |
+
if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
|
| 1094 |
+
self._worker_init_fn = functools.partial(
|
| 1095 |
+
_sharding_worker_init_fn,
|
| 1096 |
+
self._worker_init_fn,
|
| 1097 |
+
self._world_size,
|
| 1098 |
+
self._rank,
|
| 1099 |
+
)
|
| 1100 |
+
|
| 1101 |
+
# No certainty which module multiprocessing_context is
|
| 1102 |
+
self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
|
| 1103 |
+
self._worker_pids_set = False
|
| 1104 |
+
self._shutdown = False
|
| 1105 |
+
self._workers_done_event = multiprocessing_context.Event()
|
| 1106 |
+
|
| 1107 |
+
self._index_queues = []
|
| 1108 |
+
self._workers = []
|
| 1109 |
+
for i in range(self._num_workers):
|
| 1110 |
+
# No certainty which module multiprocessing_context is
|
| 1111 |
+
index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
|
| 1112 |
+
# Need to `cancel_join_thread` here!
|
| 1113 |
+
# See sections (2) and (3b) above.
|
| 1114 |
+
index_queue.cancel_join_thread()
|
| 1115 |
+
w = multiprocessing_context.Process(
|
| 1116 |
+
target=_worker_loop,
|
| 1117 |
+
args=(
|
| 1118 |
+
self._dataset_kind,
|
| 1119 |
+
self._dataset,
|
| 1120 |
+
index_queue,
|
| 1121 |
+
self._worker_result_queue,
|
| 1122 |
+
self._workers_done_event,
|
| 1123 |
+
self._auto_collation,
|
| 1124 |
+
self._collate_fn,
|
| 1125 |
+
self._drop_last,
|
| 1126 |
+
self._base_seed,
|
| 1127 |
+
self._worker_init_fn,
|
| 1128 |
+
i,
|
| 1129 |
+
self._num_workers,
|
| 1130 |
+
self._persistent_workers,
|
| 1131 |
+
self._shared_seed,
|
| 1132 |
+
),
|
| 1133 |
+
)
|
| 1134 |
+
w.daemon = True
|
| 1135 |
+
# NB: Process.start() actually take some time as it needs to
|
| 1136 |
+
# start a process and pass the arguments over via a pipe.
|
| 1137 |
+
# Therefore, we only add a worker to self._workers list after
|
| 1138 |
+
# it started, so that we do not call .join() if program dies
|
| 1139 |
+
# before it starts, and __del__ tries to join but will get:
|
| 1140 |
+
# AssertionError: can only join a started process.
|
| 1141 |
+
w.start()
|
| 1142 |
+
self._index_queues.append(index_queue)
|
| 1143 |
+
self._workers.append(w)
|
| 1144 |
+
|
| 1145 |
+
if self._pin_memory:
|
| 1146 |
+
self._pin_memory_thread_done_event = threading.Event()
|
| 1147 |
+
|
| 1148 |
+
# Queue is not type-annotated
|
| 1149 |
+
self._data_queue = queue.Queue() # type: ignore[var-annotated]
|
| 1150 |
+
if self._pin_memory_device == "xpu":
|
| 1151 |
+
current_device = torch.xpu.current_device() # type: ignore[attr-defined]
|
| 1152 |
+
else:
|
| 1153 |
+
current_device = torch.cuda.current_device() # choose cuda for default
|
| 1154 |
+
pin_memory_thread = threading.Thread(
|
| 1155 |
+
target=_utils.pin_memory._pin_memory_loop,
|
| 1156 |
+
args=(
|
| 1157 |
+
self._worker_result_queue,
|
| 1158 |
+
self._data_queue,
|
| 1159 |
+
current_device,
|
| 1160 |
+
self._pin_memory_thread_done_event,
|
| 1161 |
+
self._pin_memory_device,
|
| 1162 |
+
),
|
| 1163 |
+
)
|
| 1164 |
+
pin_memory_thread.daemon = True
|
| 1165 |
+
pin_memory_thread.start()
|
| 1166 |
+
# Similar to workers (see comment above), we only register
|
| 1167 |
+
# pin_memory_thread once it is started.
|
| 1168 |
+
self._pin_memory_thread = pin_memory_thread
|
| 1169 |
+
else:
|
| 1170 |
+
self._data_queue = self._worker_result_queue
|
| 1171 |
+
|
| 1172 |
+
# In some rare cases, persistent workers (daemonic processes)
|
| 1173 |
+
# would be terminated before `__del__` of iterator is invoked
|
| 1174 |
+
# when main process exits
|
| 1175 |
+
# It would cause failure when pin_memory_thread tries to read
|
| 1176 |
+
# corrupted data from worker_result_queue
|
| 1177 |
+
# atexit is used to shutdown thread and child processes in the
|
| 1178 |
+
# right sequence before main process exits
|
| 1179 |
+
if self._persistent_workers and self._pin_memory:
|
| 1180 |
+
import atexit
|
| 1181 |
+
|
| 1182 |
+
for w in self._workers:
|
| 1183 |
+
atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
|
| 1184 |
+
|
| 1185 |
+
# .pid can be None only before process is spawned (not the case, so ignore)
|
| 1186 |
+
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
|
| 1187 |
+
_utils.signal_handling._set_SIGCHLD_handler()
|
| 1188 |
+
self._worker_pids_set = True
|
| 1189 |
+
self._reset(loader, first_iter=True)
|
| 1190 |
+
|
| 1191 |
+
def _reset(self, loader, first_iter=False):
|
| 1192 |
+
super()._reset(loader, first_iter)
|
| 1193 |
+
self._send_idx = 0 # idx of the next task to be sent to workers
|
| 1194 |
+
self._rcvd_idx = 0 # idx of the next task to be returned in __next__
|
| 1195 |
+
# information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
|
| 1196 |
+
# map: task idx => - (worker_id,) if data isn't fetched (outstanding)
|
| 1197 |
+
# \ (worker_id, data) if data is already fetched (out-of-order)
|
| 1198 |
+
self._task_info = {}
|
| 1199 |
+
self._tasks_outstanding = (
|
| 1200 |
+
0 # always equal to count(v for v in task_info.values() if len(v) == 1)
|
| 1201 |
+
)
|
| 1202 |
+
# A list of booleans representing whether each worker still has work to
|
| 1203 |
+
# do, i.e., not having exhausted its iterable dataset object. It always
|
| 1204 |
+
# contains all `True`s if not using an iterable-style dataset
|
| 1205 |
+
# (i.e., if kind != Iterable).
|
| 1206 |
+
# Not that this indicates that a worker still has work to do *for this epoch*.
|
| 1207 |
+
# It does not mean that a worker is dead. In case of `_persistent_workers`,
|
| 1208 |
+
# the worker will be reset to available in the next epoch.
|
| 1209 |
+
self._workers_status = [True for i in range(self._num_workers)]
|
| 1210 |
+
# Reset the worker queue cycle so it resumes next epoch at worker 0
|
| 1211 |
+
self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
|
| 1212 |
+
# We resume the prefetching in case it was enabled
|
| 1213 |
+
if not first_iter:
|
| 1214 |
+
for idx in range(self._num_workers):
|
| 1215 |
+
self._index_queues[idx].put(
|
| 1216 |
+
_utils.worker._ResumeIteration(self._shared_seed)
|
| 1217 |
+
)
|
| 1218 |
+
resume_iteration_cnt = self._num_workers
|
| 1219 |
+
while resume_iteration_cnt > 0:
|
| 1220 |
+
return_idx, return_data = self._get_data()
|
| 1221 |
+
if isinstance(return_idx, _utils.worker._ResumeIteration):
|
| 1222 |
+
assert return_data is None
|
| 1223 |
+
resume_iteration_cnt -= 1
|
| 1224 |
+
# prime the prefetch loop
|
| 1225 |
+
for _ in range(self._prefetch_factor * self._num_workers):
|
| 1226 |
+
self._try_put_index()
|
| 1227 |
+
|
| 1228 |
+
def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
|
| 1229 |
+
# Tries to fetch data from `self._data_queue` once for a given timeout.
|
| 1230 |
+
# This can also be used as inner loop of fetching without timeout, with
|
| 1231 |
+
# the sender status as the loop condition.
|
| 1232 |
+
#
|
| 1233 |
+
# This raises a `RuntimeError` if any worker died expectedly. This error
|
| 1234 |
+
# can come from either the SIGCHLD handler in `_utils/signal_handling.py`
|
| 1235 |
+
# (only for non-Windows platforms), or the manual check below on errors
|
| 1236 |
+
# and timeouts.
|
| 1237 |
+
#
|
| 1238 |
+
# Returns a 2-tuple:
|
| 1239 |
+
# (bool: whether successfully get data, any: data if successful else None)
|
| 1240 |
+
try:
|
| 1241 |
+
data = self._data_queue.get(timeout=timeout)
|
| 1242 |
+
return (True, data)
|
| 1243 |
+
except Exception as e:
|
| 1244 |
+
# At timeout and error, we manually check whether any worker has
|
| 1245 |
+
# failed. Note that this is the only mechanism for Windows to detect
|
| 1246 |
+
# worker failures.
|
| 1247 |
+
failed_workers = []
|
| 1248 |
+
for worker_id, w in enumerate(self._workers):
|
| 1249 |
+
if self._workers_status[worker_id] and not w.is_alive():
|
| 1250 |
+
failed_workers.append(w)
|
| 1251 |
+
self._mark_worker_as_unavailable(worker_id)
|
| 1252 |
+
if len(failed_workers) > 0:
|
| 1253 |
+
pids_str = ", ".join(str(w.pid) for w in failed_workers)
|
| 1254 |
+
raise RuntimeError(
|
| 1255 |
+
"DataLoader worker (pid(s) {}) exited unexpectedly".format(pids_str)
|
| 1256 |
+
) from e
|
| 1257 |
+
if isinstance(e, queue.Empty):
|
| 1258 |
+
return (False, None)
|
| 1259 |
+
import errno
|
| 1260 |
+
import tempfile
|
| 1261 |
+
|
| 1262 |
+
try:
|
| 1263 |
+
# Raise an exception if we are this close to the FDs limit.
|
| 1264 |
+
# Apparently, trying to open only one file is not a sufficient
|
| 1265 |
+
# test.
|
| 1266 |
+
# See NOTE [ DataLoader on Linux and open files limit ]
|
| 1267 |
+
fds_limit_margin = 10
|
| 1268 |
+
fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)]
|
| 1269 |
+
except OSError as e:
|
| 1270 |
+
if e.errno == errno.EMFILE:
|
| 1271 |
+
raise RuntimeError(
|
| 1272 |
+
"Too many open files. Communication with the"
|
| 1273 |
+
" workers is no longer possible. Please increase the"
|
| 1274 |
+
" limit using `ulimit -n` in the shell or change the"
|
| 1275 |
+
" sharing strategy by calling"
|
| 1276 |
+
" `torch.multiprocessing.set_sharing_strategy('file_system')`"
|
| 1277 |
+
" at the beginning of your code"
|
| 1278 |
+
) from None
|
| 1279 |
+
raise
|
| 1280 |
+
|
| 1281 |
+
# NOTE [ DataLoader on Linux and open files limit ]
|
| 1282 |
+
#
|
| 1283 |
+
# On Linux when DataLoader is used with multiprocessing we pass the data between
|
| 1284 |
+
# the root process and the workers through SHM files. We remove those files from
|
| 1285 |
+
# the filesystem as soon as they are created and keep them alive by
|
| 1286 |
+
# passing around their file descriptors through AF_UNIX sockets. (See
|
| 1287 |
+
# docs/source/multiprocessing.rst and 'Multiprocessing Technical Notes` in
|
| 1288 |
+
# the wiki (https://github.com/pytorch/pytorch/wiki).)
|
| 1289 |
+
#
|
| 1290 |
+
# This sometimes leads us to exceeding the open files limit. When that happens,
|
| 1291 |
+
# and the offending file descriptor is coming over a socket, the `socket` Python
|
| 1292 |
+
# package silently strips the file descriptor from the message, setting only the
|
| 1293 |
+
# `MSG_CTRUNC` flag (which might be a bit misleading since the manpage says that
|
| 1294 |
+
# it _indicates that some control data were discarded due to lack of space in
|
| 1295 |
+
# the buffer for ancillary data_). This might reflect the C implementation of
|
| 1296 |
+
# AF_UNIX sockets.
|
| 1297 |
+
#
|
| 1298 |
+
# This behaviour can be reproduced with the script and instructions at the
|
| 1299 |
+
# bottom of this note.
|
| 1300 |
+
#
|
| 1301 |
+
# When that happens, the standard Python `multiprocessing` (and not
|
| 1302 |
+
# `torch.multiprocessing`) raises a `RuntimeError: received 0 items of ancdata`
|
| 1303 |
+
#
|
| 1304 |
+
# Sometimes, instead of the FD being stripped, you may get an `OSError:
|
| 1305 |
+
# Too many open files`, both in the script below and in DataLoader. However,
|
| 1306 |
+
# this is rare and seems to be nondeterministic.
|
| 1307 |
+
#
|
| 1308 |
+
#
|
| 1309 |
+
# #!/usr/bin/env python3
|
| 1310 |
+
# import sys
|
| 1311 |
+
# import socket
|
| 1312 |
+
# import os
|
| 1313 |
+
# import array
|
| 1314 |
+
# import shutil
|
| 1315 |
+
# import socket
|
| 1316 |
+
#
|
| 1317 |
+
#
|
| 1318 |
+
# if len(sys.argv) != 4:
|
| 1319 |
+
# print("Usage: ", sys.argv[0], " tmp_dirname iteration (send|recv)")
|
| 1320 |
+
# sys.exit(1)
|
| 1321 |
+
#
|
| 1322 |
+
# if __name__ == '__main__':
|
| 1323 |
+
# dirname = sys.argv[1]
|
| 1324 |
+
# sock_path = dirname + "/sock"
|
| 1325 |
+
# iterations = int(sys.argv[2])
|
| 1326 |
+
# def dummy_path(i):
|
| 1327 |
+
# return dirname + "/" + str(i) + ".dummy"
|
| 1328 |
+
#
|
| 1329 |
+
#
|
| 1330 |
+
# if sys.argv[3] == 'send':
|
| 1331 |
+
# while not os.path.exists(sock_path):
|
| 1332 |
+
# pass
|
| 1333 |
+
# client = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
|
| 1334 |
+
# client.connect(sock_path)
|
| 1335 |
+
# for i in range(iterations):
|
| 1336 |
+
# fd = os.open(dummy_path(i), os.O_WRONLY | os.O_CREAT)
|
| 1337 |
+
# ancdata = array.array('i', [fd])
|
| 1338 |
+
# msg = bytes([i % 256])
|
| 1339 |
+
# print("Sending fd ", fd, " (iteration #", i, ")")
|
| 1340 |
+
# client.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, ancdata)])
|
| 1341 |
+
#
|
| 1342 |
+
#
|
| 1343 |
+
# else:
|
| 1344 |
+
# assert sys.argv[3] == 'recv'
|
| 1345 |
+
#
|
| 1346 |
+
# if os.path.exists(dirname):
|
| 1347 |
+
# raise Exception("Directory exists")
|
| 1348 |
+
#
|
| 1349 |
+
# os.mkdir(dirname)
|
| 1350 |
+
#
|
| 1351 |
+
# print("Opening socket...")
|
| 1352 |
+
# server = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
|
| 1353 |
+
# server.bind(sock_path)
|
| 1354 |
+
#
|
| 1355 |
+
# print("Listening...")
|
| 1356 |
+
# for i in range(iterations):
|
| 1357 |
+
# a = array.array('i')
|
| 1358 |
+
# msg, ancdata, flags, addr = server.recvmsg(1, socket.CMSG_SPACE(a.itemsize))
|
| 1359 |
+
# assert(len(ancdata) == 1)
|
| 1360 |
+
# cmsg_level, cmsg_type, cmsg_data = ancdata[0]
|
| 1361 |
+
# a.frombytes(cmsg_data)
|
| 1362 |
+
# print("Received fd ", a[0], " (iteration #", i, ")")
|
| 1363 |
+
#
|
| 1364 |
+
# shutil.rmtree(dirname)
|
| 1365 |
+
#
|
| 1366 |
+
# Steps to reproduce:
|
| 1367 |
+
#
|
| 1368 |
+
# 1. Run two shells and set lower file descriptor limit in the receiving one:
|
| 1369 |
+
# (shell1) ulimit -n 1020
|
| 1370 |
+
# (shell2) ulimit -n 1022
|
| 1371 |
+
#
|
| 1372 |
+
# 2. Run the script above with the `recv` option in the first shell
|
| 1373 |
+
# (shell1) ./test_socket.py sock_tmp 1017 recv
|
| 1374 |
+
#
|
| 1375 |
+
# 3. Run the script with the `send` option in the second shell:
|
| 1376 |
+
# (shell2) ./test_socket.py sock_tmp 1017 send
|
| 1377 |
+
|
| 1378 |
+
def _get_data(self):
|
| 1379 |
+
# Fetches data from `self._data_queue`.
|
| 1380 |
+
#
|
| 1381 |
+
# We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
|
| 1382 |
+
# which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
|
| 1383 |
+
# in a loop. This is the only mechanism to detect worker failures for
|
| 1384 |
+
# Windows. For other platforms, a SIGCHLD handler is also used for
|
| 1385 |
+
# worker failure detection.
|
| 1386 |
+
#
|
| 1387 |
+
# If `pin_memory=True`, we also need check if `pin_memory_thread` had
|
| 1388 |
+
# died at timeouts.
|
| 1389 |
+
if self._timeout > 0:
|
| 1390 |
+
success, data = self._try_get_data(self._timeout)
|
| 1391 |
+
if success:
|
| 1392 |
+
return data
|
| 1393 |
+
else:
|
| 1394 |
+
raise RuntimeError(
|
| 1395 |
+
"DataLoader timed out after {} seconds".format(self._timeout)
|
| 1396 |
+
)
|
| 1397 |
+
elif self._pin_memory:
|
| 1398 |
+
while self._pin_memory_thread.is_alive():
|
| 1399 |
+
success, data = self._try_get_data()
|
| 1400 |
+
if success:
|
| 1401 |
+
return data
|
| 1402 |
+
else:
|
| 1403 |
+
# while condition is false, i.e., pin_memory_thread died.
|
| 1404 |
+
raise RuntimeError("Pin memory thread exited unexpectedly")
|
| 1405 |
+
# In this case, `self._data_queue` is a `queue.Queue`,. But we don't
|
| 1406 |
+
# need to call `.task_done()` because we don't use `.join()`.
|
| 1407 |
+
else:
|
| 1408 |
+
while True:
|
| 1409 |
+
success, data = self._try_get_data()
|
| 1410 |
+
if success:
|
| 1411 |
+
return data
|
| 1412 |
+
|
| 1413 |
+
def _next_data(self):
|
| 1414 |
+
while True:
|
| 1415 |
+
# If the worker responsible for `self._rcvd_idx` has already ended
|
| 1416 |
+
# and was unable to fulfill this task (due to exhausting an `IterableDataset`),
|
| 1417 |
+
# we try to advance `self._rcvd_idx` to find the next valid index.
|
| 1418 |
+
#
|
| 1419 |
+
# This part needs to run in the loop because both the `self._get_data()`
|
| 1420 |
+
# call and `_IterableDatasetStopIteration` check below can mark
|
| 1421 |
+
# extra worker(s) as dead.
|
| 1422 |
+
while self._rcvd_idx < self._send_idx:
|
| 1423 |
+
info = self._task_info[self._rcvd_idx]
|
| 1424 |
+
worker_id = info[0]
|
| 1425 |
+
if (
|
| 1426 |
+
len(info) == 2 or self._workers_status[worker_id]
|
| 1427 |
+
): # has data or is still active
|
| 1428 |
+
break
|
| 1429 |
+
del self._task_info[self._rcvd_idx]
|
| 1430 |
+
self._rcvd_idx += 1
|
| 1431 |
+
else:
|
| 1432 |
+
# no valid `self._rcvd_idx` is found (i.e., didn't break)
|
| 1433 |
+
if not self._persistent_workers:
|
| 1434 |
+
self._shutdown_workers()
|
| 1435 |
+
raise StopIteration
|
| 1436 |
+
|
| 1437 |
+
# Now `self._rcvd_idx` is the batch index we want to fetch
|
| 1438 |
+
|
| 1439 |
+
# Check if the next sample has already been generated
|
| 1440 |
+
if len(self._task_info[self._rcvd_idx]) == 2:
|
| 1441 |
+
data = self._task_info.pop(self._rcvd_idx)[1]
|
| 1442 |
+
return self._process_data(data)
|
| 1443 |
+
|
| 1444 |
+
assert not self._shutdown and self._tasks_outstanding > 0
|
| 1445 |
+
idx, data = self._get_data()
|
| 1446 |
+
self._tasks_outstanding -= 1
|
| 1447 |
+
if self._dataset_kind == _DatasetKind.Iterable:
|
| 1448 |
+
# Check for _IterableDatasetStopIteration
|
| 1449 |
+
if isinstance(data, _utils.worker._IterableDatasetStopIteration):
|
| 1450 |
+
if self._persistent_workers:
|
| 1451 |
+
self._workers_status[data.worker_id] = False
|
| 1452 |
+
else:
|
| 1453 |
+
self._mark_worker_as_unavailable(data.worker_id)
|
| 1454 |
+
self._try_put_index()
|
| 1455 |
+
continue
|
| 1456 |
+
|
| 1457 |
+
if idx != self._rcvd_idx:
|
| 1458 |
+
# store out-of-order samples
|
| 1459 |
+
self._task_info[idx] += (data,)
|
| 1460 |
+
else:
|
| 1461 |
+
del self._task_info[idx]
|
| 1462 |
+
return self._process_data(data)
|
| 1463 |
+
|
| 1464 |
+
def _try_put_index(self):
|
| 1465 |
+
assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
|
| 1466 |
+
|
| 1467 |
+
try:
|
| 1468 |
+
index = self._next_index()
|
| 1469 |
+
except StopIteration:
|
| 1470 |
+
return
|
| 1471 |
+
for _ in range(self._num_workers): # find the next active worker, if any
|
| 1472 |
+
worker_queue_idx = next(self._worker_queue_idx_cycle)
|
| 1473 |
+
if self._workers_status[worker_queue_idx]:
|
| 1474 |
+
break
|
| 1475 |
+
else:
|
| 1476 |
+
# not found (i.e., didn't break)
|
| 1477 |
+
return
|
| 1478 |
+
|
| 1479 |
+
self._index_queues[worker_queue_idx].put((self._send_idx, index))
|
| 1480 |
+
self._task_info[self._send_idx] = (worker_queue_idx,)
|
| 1481 |
+
self._tasks_outstanding += 1
|
| 1482 |
+
self._send_idx += 1
|
| 1483 |
+
|
| 1484 |
+
def _process_data(self, data):
|
| 1485 |
+
self._rcvd_idx += 1
|
| 1486 |
+
self._try_put_index()
|
| 1487 |
+
if isinstance(data, ExceptionWrapper):
|
| 1488 |
+
data.reraise()
|
| 1489 |
+
return data
|
| 1490 |
+
|
| 1491 |
+
def _mark_worker_as_unavailable(self, worker_id, shutdown=False):
|
| 1492 |
+
# Mark a worker as having finished its work e.g., due to
|
| 1493 |
+
# exhausting an `IterableDataset`. This should be used only when this
|
| 1494 |
+
# `_MultiProcessingDataLoaderIter` is going to continue running.
|
| 1495 |
+
|
| 1496 |
+
assert self._workers_status[worker_id] or (
|
| 1497 |
+
self._persistent_workers and shutdown
|
| 1498 |
+
)
|
| 1499 |
+
|
| 1500 |
+
# Signal termination to that specific worker.
|
| 1501 |
+
q = self._index_queues[worker_id]
|
| 1502 |
+
# Indicate that no more data will be put on this queue by the current
|
| 1503 |
+
# process.
|
| 1504 |
+
q.put(None)
|
| 1505 |
+
|
| 1506 |
+
# Note that we don't actually join the worker here, nor do we remove the
|
| 1507 |
+
# worker's pid from C side struct because (1) joining may be slow, and
|
| 1508 |
+
# (2) since we don't join, the worker may still raise error, and we
|
| 1509 |
+
# prefer capturing those, rather than ignoring them, even though they
|
| 1510 |
+
# are raised after the worker has finished its job.
|
| 1511 |
+
# Joinning is deferred to `_shutdown_workers`, which it is called when
|
| 1512 |
+
# all workers finish their jobs (e.g., `IterableDataset` replicas) or
|
| 1513 |
+
# when this iterator is garbage collected.
|
| 1514 |
+
|
| 1515 |
+
self._workers_status[worker_id] = False
|
| 1516 |
+
|
| 1517 |
+
assert self._workers_done_event.is_set() == shutdown
|
| 1518 |
+
|
| 1519 |
+
def _shutdown_workers(self):
|
| 1520 |
+
# Called when shutting down this `_MultiProcessingDataLoaderIter`.
|
| 1521 |
+
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
|
| 1522 |
+
# the logic of this function.
|
| 1523 |
+
if (
|
| 1524 |
+
_utils is None
|
| 1525 |
+
or _utils.python_exit_status is True
|
| 1526 |
+
or _utils.python_exit_status is None
|
| 1527 |
+
):
|
| 1528 |
+
# See (2) of the note. If Python is shutting down, do no-op.
|
| 1529 |
+
return
|
| 1530 |
+
# Normal exit when last reference is gone / iterator is depleted.
|
| 1531 |
+
# See (1) and the second half of the note.
|
| 1532 |
+
if not self._shutdown:
|
| 1533 |
+
self._shutdown = True
|
| 1534 |
+
try:
|
| 1535 |
+
# Normal exit when last reference is gone / iterator is depleted.
|
| 1536 |
+
# See (1) and the second half of the note.
|
| 1537 |
+
|
| 1538 |
+
# Exit `pin_memory_thread` first because exiting workers may leave
|
| 1539 |
+
# corrupted data in `worker_result_queue` which `pin_memory_thread`
|
| 1540 |
+
# reads from.
|
| 1541 |
+
if hasattr(self, "_pin_memory_thread"):
|
| 1542 |
+
# Use hasattr in case error happens before we set the attribute.
|
| 1543 |
+
self._pin_memory_thread_done_event.set()
|
| 1544 |
+
# Send something to pin_memory_thread in case it is waiting
|
| 1545 |
+
# so that it can wake up and check `pin_memory_thread_done_event`
|
| 1546 |
+
self._worker_result_queue.put((None, None))
|
| 1547 |
+
self._pin_memory_thread.join()
|
| 1548 |
+
self._worker_result_queue.cancel_join_thread()
|
| 1549 |
+
self._worker_result_queue.close()
|
| 1550 |
+
|
| 1551 |
+
# Exit workers now.
|
| 1552 |
+
self._workers_done_event.set()
|
| 1553 |
+
for worker_id in range(len(self._workers)):
|
| 1554 |
+
# Get number of workers from `len(self._workers)` instead of
|
| 1555 |
+
# `self._num_workers` in case we error before starting all
|
| 1556 |
+
# workers.
|
| 1557 |
+
# If we are using workers_status with persistent_workers
|
| 1558 |
+
# we have to shut it down because the worker is paused
|
| 1559 |
+
if self._persistent_workers or self._workers_status[worker_id]:
|
| 1560 |
+
self._mark_worker_as_unavailable(worker_id, shutdown=True)
|
| 1561 |
+
for w in self._workers:
|
| 1562 |
+
# We should be able to join here, but in case anything went
|
| 1563 |
+
# wrong, we set a timeout and if the workers fail to join,
|
| 1564 |
+
# they are killed in the `finally` block.
|
| 1565 |
+
w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
|
| 1566 |
+
for q in self._index_queues:
|
| 1567 |
+
q.cancel_join_thread()
|
| 1568 |
+
q.close()
|
| 1569 |
+
finally:
|
| 1570 |
+
# Even though all this function does is putting into queues that
|
| 1571 |
+
# we have called `cancel_join_thread` on, weird things can
|
| 1572 |
+
# happen when a worker is killed by a signal, e.g., hanging in
|
| 1573 |
+
# `Event.set()`. So we need to guard this with SIGCHLD handler,
|
| 1574 |
+
# and remove pids from the C side data structure only at the
|
| 1575 |
+
# end.
|
| 1576 |
+
#
|
| 1577 |
+
if self._worker_pids_set:
|
| 1578 |
+
_utils.signal_handling._remove_worker_pids(id(self))
|
| 1579 |
+
self._worker_pids_set = False
|
| 1580 |
+
for w in self._workers:
|
| 1581 |
+
if w.is_alive():
|
| 1582 |
+
# Existing mechanisms try to make the workers exit
|
| 1583 |
+
# peacefully, but in case that we unfortunately reach
|
| 1584 |
+
# here, which we shouldn't, (e.g., pytorch/pytorch#39570),
|
| 1585 |
+
# we kill the worker.
|
| 1586 |
+
w.terminate()
|
| 1587 |
+
|
| 1588 |
+
# staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter`
|
| 1589 |
+
@staticmethod
|
| 1590 |
+
def _clean_up_worker(w):
|
| 1591 |
+
try:
|
| 1592 |
+
w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
|
| 1593 |
+
finally:
|
| 1594 |
+
if w.is_alive():
|
| 1595 |
+
w.terminate()
|
| 1596 |
+
|
| 1597 |
+
def __del__(self):
|
| 1598 |
+
self._shutdown_workers()
|
src/efficientvit/apps/data_provider/random_resolution/_data_worker.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
r""""This file is based on torch/utils/data/_utils/worker.py
|
| 2 |
+
|
| 3 |
+
Contains definitions of the methods used by the _BaseDataLoaderIter workers.
|
| 4 |
+
These **needs** to be in global scope since Py2 doesn't support serializing
|
| 5 |
+
static methods.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import queue
|
| 10 |
+
import random
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from typing import TYPE_CHECKING, Optional, Union
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from torch._utils import ExceptionWrapper
|
| 16 |
+
from torch.utils.data._utils import (HAS_NUMPY, IS_WINDOWS,
|
| 17 |
+
MP_STATUS_CHECK_INTERVAL, signal_handling)
|
| 18 |
+
|
| 19 |
+
if TYPE_CHECKING:
|
| 20 |
+
from torch.utils.data import Dataset
|
| 21 |
+
|
| 22 |
+
from .controller import RRSController
|
| 23 |
+
|
| 24 |
+
if IS_WINDOWS:
|
| 25 |
+
import ctypes
|
| 26 |
+
from ctypes.wintypes import BOOL, DWORD, HANDLE
|
| 27 |
+
|
| 28 |
+
# On Windows, the parent ID of the worker process remains unchanged when the manager process
|
| 29 |
+
# is gone, and the only way to check it through OS is to let the worker have a process handle
|
| 30 |
+
# of the manager and ask if the process status has changed.
|
| 31 |
+
class ManagerWatchdog:
|
| 32 |
+
def __init__(self):
|
| 33 |
+
self.manager_pid = os.getppid()
|
| 34 |
+
|
| 35 |
+
# mypy cannot detect this code is windows only
|
| 36 |
+
self.kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore[attr-defined]
|
| 37 |
+
self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
|
| 38 |
+
self.kernel32.OpenProcess.restype = HANDLE
|
| 39 |
+
self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
|
| 40 |
+
self.kernel32.WaitForSingleObject.restype = DWORD
|
| 41 |
+
|
| 42 |
+
# Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
|
| 43 |
+
SYNCHRONIZE = 0x00100000
|
| 44 |
+
self.manager_handle = self.kernel32.OpenProcess(
|
| 45 |
+
SYNCHRONIZE, 0, self.manager_pid
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
if not self.manager_handle:
|
| 49 |
+
raise ctypes.WinError(ctypes.get_last_error()) # type: ignore[attr-defined]
|
| 50 |
+
|
| 51 |
+
self.manager_dead = False
|
| 52 |
+
|
| 53 |
+
def is_alive(self):
|
| 54 |
+
if not self.manager_dead:
|
| 55 |
+
# Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
|
| 56 |
+
self.manager_dead = (
|
| 57 |
+
self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
|
| 58 |
+
)
|
| 59 |
+
return not self.manager_dead
|
| 60 |
+
|
| 61 |
+
else:
|
| 62 |
+
|
| 63 |
+
class ManagerWatchdog: # type: ignore[no-redef]
|
| 64 |
+
def __init__(self):
|
| 65 |
+
self.manager_pid = os.getppid()
|
| 66 |
+
self.manager_dead = False
|
| 67 |
+
|
| 68 |
+
def is_alive(self):
|
| 69 |
+
if not self.manager_dead:
|
| 70 |
+
self.manager_dead = os.getppid() != self.manager_pid
|
| 71 |
+
return not self.manager_dead
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
_worker_info = None
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class WorkerInfo:
|
| 78 |
+
id: int
|
| 79 |
+
num_workers: int
|
| 80 |
+
seed: int
|
| 81 |
+
dataset: "Dataset"
|
| 82 |
+
__initialized = False
|
| 83 |
+
|
| 84 |
+
def __init__(self, **kwargs):
|
| 85 |
+
for k, v in kwargs.items():
|
| 86 |
+
setattr(self, k, v)
|
| 87 |
+
self.__keys = tuple(kwargs.keys())
|
| 88 |
+
self.__initialized = True
|
| 89 |
+
|
| 90 |
+
def __setattr__(self, key, val):
|
| 91 |
+
if self.__initialized:
|
| 92 |
+
raise RuntimeError(
|
| 93 |
+
"Cannot assign attributes to {} objects".format(self.__class__.__name__)
|
| 94 |
+
)
|
| 95 |
+
return super().__setattr__(key, val)
|
| 96 |
+
|
| 97 |
+
def __repr__(self):
|
| 98 |
+
items = []
|
| 99 |
+
for k in self.__keys:
|
| 100 |
+
items.append("{}={}".format(k, getattr(self, k)))
|
| 101 |
+
return "{}({})".format(self.__class__.__name__, ", ".join(items))
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def get_worker_info() -> Optional[WorkerInfo]:
|
| 105 |
+
r"""Returns the information about the current
|
| 106 |
+
:class:`~torch.utils.data.DataLoader` iterator worker process.
|
| 107 |
+
|
| 108 |
+
When called in a worker, this returns an object guaranteed to have the
|
| 109 |
+
following attributes:
|
| 110 |
+
|
| 111 |
+
* :attr:`id`: the current worker id.
|
| 112 |
+
* :attr:`num_workers`: the total number of workers.
|
| 113 |
+
* :attr:`seed`: the random seed set for the current worker. This value is
|
| 114 |
+
determined by main process RNG and the worker id. See
|
| 115 |
+
:class:`~torch.utils.data.DataLoader`'s documentation for more details.
|
| 116 |
+
* :attr:`dataset`: the copy of the dataset object in **this** process. Note
|
| 117 |
+
that this will be a different object in a different process than the one
|
| 118 |
+
in the main process.
|
| 119 |
+
|
| 120 |
+
When called in the main process, this returns ``None``.
|
| 121 |
+
|
| 122 |
+
.. note::
|
| 123 |
+
When used in a :attr:`worker_init_fn` passed over to
|
| 124 |
+
:class:`~torch.utils.data.DataLoader`, this method can be useful to
|
| 125 |
+
set up each worker process differently, for instance, using ``worker_id``
|
| 126 |
+
to configure the ``dataset`` object to only read a specific fraction of a
|
| 127 |
+
sharded dataset, or use ``seed`` to seed other libraries used in dataset
|
| 128 |
+
code.
|
| 129 |
+
"""
|
| 130 |
+
return _worker_info
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
r"""Dummy class used to signal the end of an IterableDataset"""
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@dataclass(frozen=True)
|
| 137 |
+
class _IterableDatasetStopIteration:
|
| 138 |
+
worker_id: int
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
r"""Dummy class used to resume the fetching when worker reuse is enabled"""
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@dataclass(frozen=True)
|
| 145 |
+
class _ResumeIteration:
|
| 146 |
+
seed: Optional[int] = None
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# The function `_generate_state` is adapted from `numpy.random.SeedSequence`
|
| 150 |
+
# from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx
|
| 151 |
+
# It's MIT licensed, here is the copyright:
|
| 152 |
+
|
| 153 |
+
# Copyright (c) 2015 Melissa E. O'Neill
|
| 154 |
+
# Copyright (c) 2019 NumPy Developers
|
| 155 |
+
#
|
| 156 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 157 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 158 |
+
# in the Software without restriction, including without limitation the rights
|
| 159 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 160 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 161 |
+
# furnished to do so, subject to the following conditions:
|
| 162 |
+
#
|
| 163 |
+
# The above copyright notice and this permission notice shall be included in
|
| 164 |
+
# all copies or substantial portions of the Software.
|
| 165 |
+
#
|
| 166 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 167 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 168 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 169 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 170 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 171 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 172 |
+
# SOFTWARE.
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# This function generates an array of int32 as the seed for
|
| 176 |
+
# `numpy.random`, in order to prevent state collision due to same
|
| 177 |
+
# seed and algorithm for `numpy.random` and `random` modules.
|
| 178 |
+
def _generate_state(base_seed, worker_id):
|
| 179 |
+
INIT_A = 0x43B0D7E5
|
| 180 |
+
MULT_A = 0x931E8875
|
| 181 |
+
INIT_B = 0x8B51F9DD
|
| 182 |
+
MULT_B = 0x58F38DED
|
| 183 |
+
MIX_MULT_L = 0xCA01F9DD
|
| 184 |
+
MIX_MULT_R = 0x4973F715
|
| 185 |
+
XSHIFT = 4 * 8 // 2
|
| 186 |
+
MASK32 = 0xFFFFFFFF
|
| 187 |
+
|
| 188 |
+
entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0]
|
| 189 |
+
pool = [0] * 4
|
| 190 |
+
|
| 191 |
+
hash_const_A = INIT_A
|
| 192 |
+
|
| 193 |
+
def hash(value):
|
| 194 |
+
nonlocal hash_const_A
|
| 195 |
+
value = (value ^ hash_const_A) & MASK32
|
| 196 |
+
hash_const_A = (hash_const_A * MULT_A) & MASK32
|
| 197 |
+
value = (value * hash_const_A) & MASK32
|
| 198 |
+
value = (value ^ (value >> XSHIFT)) & MASK32
|
| 199 |
+
return value
|
| 200 |
+
|
| 201 |
+
def mix(x, y):
|
| 202 |
+
result_x = (MIX_MULT_L * x) & MASK32
|
| 203 |
+
result_y = (MIX_MULT_R * y) & MASK32
|
| 204 |
+
result = (result_x - result_y) & MASK32
|
| 205 |
+
result = (result ^ (result >> XSHIFT)) & MASK32
|
| 206 |
+
return result
|
| 207 |
+
|
| 208 |
+
# Add in the entropy to the pool.
|
| 209 |
+
for i in range(len(pool)):
|
| 210 |
+
pool[i] = hash(entropy[i])
|
| 211 |
+
|
| 212 |
+
# Mix all bits together so late bits can affect earlier bits.
|
| 213 |
+
for i_src in range(len(pool)):
|
| 214 |
+
for i_dst in range(len(pool)):
|
| 215 |
+
if i_src != i_dst:
|
| 216 |
+
pool[i_dst] = mix(pool[i_dst], hash(pool[i_src]))
|
| 217 |
+
|
| 218 |
+
hash_const_B = INIT_B
|
| 219 |
+
state = []
|
| 220 |
+
for i_dst in range(4):
|
| 221 |
+
data_val = pool[i_dst]
|
| 222 |
+
data_val = (data_val ^ hash_const_B) & MASK32
|
| 223 |
+
hash_const_B = (hash_const_B * MULT_B) & MASK32
|
| 224 |
+
data_val = (data_val * hash_const_B) & MASK32
|
| 225 |
+
data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32
|
| 226 |
+
state.append(data_val)
|
| 227 |
+
return state
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def _worker_loop(
|
| 231 |
+
dataset_kind,
|
| 232 |
+
dataset,
|
| 233 |
+
index_queue,
|
| 234 |
+
data_queue,
|
| 235 |
+
done_event,
|
| 236 |
+
auto_collation,
|
| 237 |
+
collate_fn,
|
| 238 |
+
drop_last,
|
| 239 |
+
base_seed,
|
| 240 |
+
init_fn,
|
| 241 |
+
worker_id,
|
| 242 |
+
num_workers,
|
| 243 |
+
persistent_workers,
|
| 244 |
+
shared_seed,
|
| 245 |
+
):
|
| 246 |
+
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
|
| 247 |
+
# logic of this function.
|
| 248 |
+
|
| 249 |
+
try:
|
| 250 |
+
# Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
|
| 251 |
+
# module's handlers are executed after Python returns from C low-level
|
| 252 |
+
# handlers, likely when the same fatal signal had already happened
|
| 253 |
+
# again.
|
| 254 |
+
# https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
|
| 255 |
+
signal_handling._set_worker_signal_handlers()
|
| 256 |
+
|
| 257 |
+
torch.set_num_threads(1)
|
| 258 |
+
seed = base_seed + worker_id
|
| 259 |
+
random.seed(seed)
|
| 260 |
+
torch.manual_seed(seed)
|
| 261 |
+
if HAS_NUMPY:
|
| 262 |
+
np_seed = _generate_state(base_seed, worker_id)
|
| 263 |
+
import numpy as np
|
| 264 |
+
|
| 265 |
+
np.random.seed(np_seed)
|
| 266 |
+
|
| 267 |
+
from torch.utils.data import IterDataPipe
|
| 268 |
+
from torch.utils.data.graph_settings import apply_random_seed
|
| 269 |
+
|
| 270 |
+
shared_rng = torch.Generator()
|
| 271 |
+
if isinstance(dataset, IterDataPipe):
|
| 272 |
+
assert shared_seed is not None
|
| 273 |
+
shared_rng.manual_seed(shared_seed)
|
| 274 |
+
dataset = apply_random_seed(dataset, shared_rng)
|
| 275 |
+
|
| 276 |
+
global _worker_info
|
| 277 |
+
_worker_info = WorkerInfo(
|
| 278 |
+
id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
from torch.utils.data import _DatasetKind
|
| 282 |
+
|
| 283 |
+
init_exception = None
|
| 284 |
+
|
| 285 |
+
try:
|
| 286 |
+
if init_fn is not None:
|
| 287 |
+
init_fn(worker_id)
|
| 288 |
+
|
| 289 |
+
fetcher = _DatasetKind.create_fetcher(
|
| 290 |
+
dataset_kind, dataset, auto_collation, collate_fn, drop_last
|
| 291 |
+
)
|
| 292 |
+
except Exception:
|
| 293 |
+
init_exception = ExceptionWrapper(
|
| 294 |
+
where="in DataLoader worker process {}".format(worker_id)
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
# When using Iterable mode, some worker can exit earlier than others due
|
| 298 |
+
# to the IterableDataset behaving differently for different workers.
|
| 299 |
+
# When such things happen, an `_IterableDatasetStopIteration` object is
|
| 300 |
+
# sent over to the main process with the ID of this worker, so that the
|
| 301 |
+
# main process won't send more tasks to this worker, and will send
|
| 302 |
+
# `None` to this worker to properly exit it.
|
| 303 |
+
#
|
| 304 |
+
# Note that we cannot set `done_event` from a worker as it is shared
|
| 305 |
+
# among all processes. Instead, we set the `iteration_end` flag to
|
| 306 |
+
# signify that the iterator is exhausted. When either `done_event` or
|
| 307 |
+
# `iteration_end` is set, we skip all processing step and just wait for
|
| 308 |
+
# `None`.
|
| 309 |
+
iteration_end = False
|
| 310 |
+
|
| 311 |
+
watchdog = ManagerWatchdog()
|
| 312 |
+
|
| 313 |
+
while watchdog.is_alive():
|
| 314 |
+
try:
|
| 315 |
+
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
|
| 316 |
+
except queue.Empty:
|
| 317 |
+
continue
|
| 318 |
+
if isinstance(r, _ResumeIteration):
|
| 319 |
+
# Acknowledge the main process
|
| 320 |
+
data_queue.put((r, None))
|
| 321 |
+
iteration_end = False
|
| 322 |
+
|
| 323 |
+
if isinstance(dataset, IterDataPipe):
|
| 324 |
+
assert r.seed is not None
|
| 325 |
+
shared_rng.manual_seed(r.seed)
|
| 326 |
+
dataset = apply_random_seed(dataset, shared_rng)
|
| 327 |
+
|
| 328 |
+
# Recreate the fetcher for worker-reuse policy
|
| 329 |
+
fetcher = _DatasetKind.create_fetcher(
|
| 330 |
+
dataset_kind, dataset, auto_collation, collate_fn, drop_last
|
| 331 |
+
)
|
| 332 |
+
continue
|
| 333 |
+
elif r is None:
|
| 334 |
+
# Received the final signal
|
| 335 |
+
assert done_event.is_set() or iteration_end
|
| 336 |
+
break
|
| 337 |
+
elif done_event.is_set() or iteration_end:
|
| 338 |
+
# `done_event` is set. But I haven't received the final signal
|
| 339 |
+
# (None) yet. I will keep continuing until get it, and skip the
|
| 340 |
+
# processing steps.
|
| 341 |
+
continue
|
| 342 |
+
idx, index = r
|
| 343 |
+
""" Added """
|
| 344 |
+
RRSController.sample_resolution(batch_id=idx)
|
| 345 |
+
""" Added """
|
| 346 |
+
data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
|
| 347 |
+
if init_exception is not None:
|
| 348 |
+
data = init_exception
|
| 349 |
+
init_exception = None
|
| 350 |
+
else:
|
| 351 |
+
try:
|
| 352 |
+
data = fetcher.fetch(index)
|
| 353 |
+
except Exception as e:
|
| 354 |
+
if (
|
| 355 |
+
isinstance(e, StopIteration)
|
| 356 |
+
and dataset_kind == _DatasetKind.Iterable
|
| 357 |
+
):
|
| 358 |
+
data = _IterableDatasetStopIteration(worker_id)
|
| 359 |
+
# Set `iteration_end`
|
| 360 |
+
# (1) to save future `next(...)` calls, and
|
| 361 |
+
# (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
|
| 362 |
+
iteration_end = True
|
| 363 |
+
else:
|
| 364 |
+
# It is important that we don't store exc_info in a variable.
|
| 365 |
+
# `ExceptionWrapper` does the correct thing.
|
| 366 |
+
# See NOTE [ Python Traceback Reference Cycle Problem ]
|
| 367 |
+
data = ExceptionWrapper(
|
| 368 |
+
where="in DataLoader worker process {}".format(worker_id)
|
| 369 |
+
)
|
| 370 |
+
data_queue.put((idx, data))
|
| 371 |
+
del data, idx, index, r # save memory
|
| 372 |
+
except KeyboardInterrupt:
|
| 373 |
+
# Main process will raise KeyboardInterrupt anyways.
|
| 374 |
+
pass
|
| 375 |
+
if done_event.is_set():
|
| 376 |
+
data_queue.cancel_join_thread()
|
| 377 |
+
data_queue.close()
|
src/efficientvit/apps/data_provider/random_resolution/controller.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
import copy
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torchvision.transforms as transforms
|
| 9 |
+
import torchvision.transforms.functional as F
|
| 10 |
+
|
| 11 |
+
from src.efficientvit.models.utils import torch_random_choices
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
"RRSController",
|
| 15 |
+
"get_interpolate",
|
| 16 |
+
"MyRandomResizedCrop",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class RRSController:
|
| 21 |
+
ACTIVE_SIZE = (224, 224)
|
| 22 |
+
IMAGE_SIZE_LIST = [(224, 224)]
|
| 23 |
+
|
| 24 |
+
CHOICE_LIST = None
|
| 25 |
+
|
| 26 |
+
@staticmethod
|
| 27 |
+
def get_candidates() -> list[tuple[int, int]]:
|
| 28 |
+
return copy.deepcopy(RRSController.IMAGE_SIZE_LIST)
|
| 29 |
+
|
| 30 |
+
@staticmethod
|
| 31 |
+
def sample_resolution(batch_id: int) -> None:
|
| 32 |
+
RRSController.ACTIVE_SIZE = RRSController.CHOICE_LIST[batch_id]
|
| 33 |
+
|
| 34 |
+
@staticmethod
|
| 35 |
+
def set_epoch(epoch: int, batch_per_epoch: int) -> None:
|
| 36 |
+
g = torch.Generator()
|
| 37 |
+
g.manual_seed(epoch)
|
| 38 |
+
RRSController.CHOICE_LIST = torch_random_choices(
|
| 39 |
+
RRSController.get_candidates(),
|
| 40 |
+
g,
|
| 41 |
+
batch_per_epoch,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_interpolate(name: str) -> F.InterpolationMode:
|
| 46 |
+
mapping = {
|
| 47 |
+
"nearest": F.InterpolationMode.NEAREST,
|
| 48 |
+
"bilinear": F.InterpolationMode.BILINEAR,
|
| 49 |
+
"bicubic": F.InterpolationMode.BICUBIC,
|
| 50 |
+
"box": F.InterpolationMode.BOX,
|
| 51 |
+
"hamming": F.InterpolationMode.HAMMING,
|
| 52 |
+
"lanczos": F.InterpolationMode.LANCZOS,
|
| 53 |
+
}
|
| 54 |
+
if name in mapping:
|
| 55 |
+
return mapping[name]
|
| 56 |
+
elif name == "random":
|
| 57 |
+
return torch_random_choices(
|
| 58 |
+
[
|
| 59 |
+
F.InterpolationMode.NEAREST,
|
| 60 |
+
F.InterpolationMode.BILINEAR,
|
| 61 |
+
F.InterpolationMode.BICUBIC,
|
| 62 |
+
F.InterpolationMode.BOX,
|
| 63 |
+
F.InterpolationMode.HAMMING,
|
| 64 |
+
F.InterpolationMode.LANCZOS,
|
| 65 |
+
],
|
| 66 |
+
)
|
| 67 |
+
else:
|
| 68 |
+
raise NotImplementedError
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class MyRandomResizedCrop(transforms.RandomResizedCrop):
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
scale=(0.08, 1.0),
|
| 75 |
+
ratio=(3.0 / 4.0, 4.0 / 3.0),
|
| 76 |
+
interpolation: str = "random",
|
| 77 |
+
):
|
| 78 |
+
super(MyRandomResizedCrop, self).__init__(224, scale, ratio)
|
| 79 |
+
self.interpolation = interpolation
|
| 80 |
+
|
| 81 |
+
def forward(self, img: torch.Tensor) -> torch.Tensor:
|
| 82 |
+
i, j, h, w = self.get_params(img, list(self.scale), list(self.ratio))
|
| 83 |
+
target_size = RRSController.ACTIVE_SIZE
|
| 84 |
+
return F.resized_crop(
|
| 85 |
+
img, i, j, h, w, list(target_size), get_interpolate(self.interpolation)
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def __repr__(self) -> str:
|
| 89 |
+
format_string = self.__class__.__name__
|
| 90 |
+
format_string += f"(\n\tsize={RRSController.get_candidates()},\n"
|
| 91 |
+
format_string += f"\tscale={tuple(round(s, 4) for s in self.scale)},\n"
|
| 92 |
+
format_string += f"\tratio={tuple(round(r, 4) for r in self.ratio)},\n"
|
| 93 |
+
format_string += f"\tinterpolation={self.interpolation})"
|
| 94 |
+
return format_string
|
src/efficientvit/apps/setup.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
|
| 9 |
+
import torch.backends.cudnn
|
| 10 |
+
import torch.distributed
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
from src.efficientvit.apps.data_provider import DataProvider
|
| 14 |
+
from src.efficientvit.apps.trainer.run_config import RunConfig
|
| 15 |
+
from src.efficientvit.apps.utils import (dist_init, dump_config,
|
| 16 |
+
get_dist_local_rank, get_dist_rank,
|
| 17 |
+
get_dist_size, init_modules, is_master,
|
| 18 |
+
load_config, partial_update_config,
|
| 19 |
+
zero_last_gamma)
|
| 20 |
+
from src.efficientvit.models.utils import (build_kwargs_from_config,
|
| 21 |
+
load_state_dict_from_file)
|
| 22 |
+
|
| 23 |
+
__all__ = [
|
| 24 |
+
"save_exp_config",
|
| 25 |
+
"setup_dist_env",
|
| 26 |
+
"setup_seed",
|
| 27 |
+
"setup_exp_config",
|
| 28 |
+
"setup_data_provider",
|
| 29 |
+
"setup_run_config",
|
| 30 |
+
"init_model",
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def save_exp_config(exp_config: dict, path: str, name="config.yaml") -> None:
|
| 35 |
+
if not is_master():
|
| 36 |
+
return
|
| 37 |
+
dump_config(exp_config, os.path.join(path, name))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def setup_dist_env(gpu: str or None = None) -> None:
|
| 41 |
+
if gpu is not None:
|
| 42 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = gpu
|
| 43 |
+
if not torch.distributed.is_initialized():
|
| 44 |
+
dist_init()
|
| 45 |
+
torch.backends.cudnn.benchmark = True
|
| 46 |
+
torch.cuda.set_device(get_dist_local_rank())
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def setup_seed(manual_seed: int, resume: bool) -> None:
|
| 50 |
+
if resume:
|
| 51 |
+
manual_seed = int(time.time())
|
| 52 |
+
manual_seed = get_dist_rank() + manual_seed
|
| 53 |
+
torch.manual_seed(manual_seed)
|
| 54 |
+
torch.cuda.manual_seed_all(manual_seed)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def setup_exp_config(
|
| 58 |
+
config_path: str, recursive=True, opt_args: dict or None = None
|
| 59 |
+
) -> dict:
|
| 60 |
+
# load config
|
| 61 |
+
if not os.path.isfile(config_path):
|
| 62 |
+
raise ValueError(config_path)
|
| 63 |
+
|
| 64 |
+
fpaths = [config_path]
|
| 65 |
+
if recursive:
|
| 66 |
+
extension = os.path.splitext(config_path)[1]
|
| 67 |
+
while os.path.dirname(config_path) != config_path:
|
| 68 |
+
config_path = os.path.dirname(config_path)
|
| 69 |
+
fpath = os.path.join(config_path, "default" + extension)
|
| 70 |
+
if os.path.isfile(fpath):
|
| 71 |
+
fpaths.append(fpath)
|
| 72 |
+
fpaths = fpaths[::-1]
|
| 73 |
+
|
| 74 |
+
default_config = load_config(fpaths[0])
|
| 75 |
+
exp_config = deepcopy(default_config)
|
| 76 |
+
for fpath in fpaths[1:]:
|
| 77 |
+
partial_update_config(exp_config, load_config(fpath))
|
| 78 |
+
# update config via args
|
| 79 |
+
if opt_args is not None:
|
| 80 |
+
partial_update_config(exp_config, opt_args)
|
| 81 |
+
|
| 82 |
+
return exp_config
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def setup_data_provider(
|
| 86 |
+
exp_config: dict,
|
| 87 |
+
data_provider_classes: list[type[DataProvider]],
|
| 88 |
+
is_distributed: bool = True,
|
| 89 |
+
) -> DataProvider:
|
| 90 |
+
dp_config = exp_config["data_provider"]
|
| 91 |
+
dp_config["num_replicas"] = get_dist_size() if is_distributed else None
|
| 92 |
+
dp_config["rank"] = get_dist_rank() if is_distributed else None
|
| 93 |
+
dp_config["test_batch_size"] = (
|
| 94 |
+
dp_config.get("test_batch_size", None) or dp_config["base_batch_size"] * 2
|
| 95 |
+
)
|
| 96 |
+
dp_config["batch_size"] = dp_config["train_batch_size"] = dp_config[
|
| 97 |
+
"base_batch_size"
|
| 98 |
+
]
|
| 99 |
+
|
| 100 |
+
data_provider_lookup = {
|
| 101 |
+
provider.name: provider for provider in data_provider_classes
|
| 102 |
+
}
|
| 103 |
+
data_provider_class = data_provider_lookup[dp_config["dataset"]]
|
| 104 |
+
|
| 105 |
+
data_provider_kwargs = build_kwargs_from_config(dp_config, data_provider_class)
|
| 106 |
+
data_provider = data_provider_class(**data_provider_kwargs)
|
| 107 |
+
return data_provider
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def setup_run_config(exp_config: dict, run_config_cls: type[RunConfig]) -> RunConfig:
|
| 111 |
+
exp_config["run_config"]["init_lr"] = (
|
| 112 |
+
exp_config["run_config"]["base_lr"] * get_dist_size()
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
run_config = run_config_cls(**exp_config["run_config"])
|
| 116 |
+
|
| 117 |
+
return run_config
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def init_model(
|
| 121 |
+
network: nn.Module,
|
| 122 |
+
init_from: str or None = None,
|
| 123 |
+
backbone_init_from: str or None = None,
|
| 124 |
+
rand_init="trunc_normal",
|
| 125 |
+
last_gamma=None,
|
| 126 |
+
) -> None:
|
| 127 |
+
# initialization
|
| 128 |
+
init_modules(network, init_type=rand_init)
|
| 129 |
+
# zero gamma of last bn in each block
|
| 130 |
+
if last_gamma is not None:
|
| 131 |
+
zero_last_gamma(network, last_gamma)
|
| 132 |
+
|
| 133 |
+
# load weight
|
| 134 |
+
if init_from is not None and os.path.isfile(init_from):
|
| 135 |
+
network.load_state_dict(load_state_dict_from_file(init_from))
|
| 136 |
+
print(f"Loaded init from {init_from}")
|
| 137 |
+
elif backbone_init_from is not None and os.path.isfile(backbone_init_from):
|
| 138 |
+
network.backbone.load_state_dict(load_state_dict_from_file(backbone_init_from))
|
| 139 |
+
print(f"Loaded backbone init from {backbone_init_from}")
|
| 140 |
+
else:
|
| 141 |
+
print(f"Random init ({rand_init}) with last gamma {last_gamma}")
|
src/efficientvit/apps/trainer/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
from .base import *
|
| 6 |
+
from .run_config import *
|
src/efficientvit/apps/trainer/base.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
from src.efficientvit.apps.data_provider import DataProvider, parse_image_size
|
| 11 |
+
from src.efficientvit.apps.trainer.run_config import RunConfig
|
| 12 |
+
from src.efficientvit.apps.utils import (EMA, dist_barrier, get_dist_local_rank,
|
| 13 |
+
is_master)
|
| 14 |
+
from src.efficientvit.models.nn.norm import reset_bn
|
| 15 |
+
from src.efficientvit.models.utils import is_parallel, load_state_dict_from_file
|
| 16 |
+
|
| 17 |
+
__all__ = ["Trainer"]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Trainer:
|
| 21 |
+
def __init__(self, path: str, model: nn.Module, data_provider: DataProvider):
|
| 22 |
+
self.path = os.path.realpath(os.path.expanduser(path))
|
| 23 |
+
self.model = model.cuda()
|
| 24 |
+
self.data_provider = data_provider
|
| 25 |
+
|
| 26 |
+
self.ema = None
|
| 27 |
+
|
| 28 |
+
self.checkpoint_path = os.path.join(self.path, "checkpoint")
|
| 29 |
+
self.logs_path = os.path.join(self.path, "logs")
|
| 30 |
+
for path in [self.path, self.checkpoint_path, self.logs_path]:
|
| 31 |
+
os.makedirs(path, exist_ok=True)
|
| 32 |
+
|
| 33 |
+
self.best_val = 0.0
|
| 34 |
+
self.start_epoch = 0
|
| 35 |
+
|
| 36 |
+
@property
|
| 37 |
+
def network(self) -> nn.Module:
|
| 38 |
+
return self.model.module if is_parallel(self.model) else self.model
|
| 39 |
+
|
| 40 |
+
@property
|
| 41 |
+
def eval_network(self) -> nn.Module:
|
| 42 |
+
if self.ema is None:
|
| 43 |
+
model = self.model
|
| 44 |
+
else:
|
| 45 |
+
model = self.ema.shadows
|
| 46 |
+
model = model.module if is_parallel(model) else model
|
| 47 |
+
return model
|
| 48 |
+
|
| 49 |
+
def write_log(self, log_str, prefix="valid", print_log=True, mode="a") -> None:
|
| 50 |
+
if is_master():
|
| 51 |
+
fout = open(os.path.join(self.logs_path, f"{prefix}.log"), mode)
|
| 52 |
+
fout.write(log_str + "\n")
|
| 53 |
+
fout.flush()
|
| 54 |
+
fout.close()
|
| 55 |
+
if print_log:
|
| 56 |
+
print(log_str)
|
| 57 |
+
|
| 58 |
+
def save_model(
|
| 59 |
+
self,
|
| 60 |
+
checkpoint=None,
|
| 61 |
+
only_state_dict=True,
|
| 62 |
+
epoch=0,
|
| 63 |
+
model_name=None,
|
| 64 |
+
) -> None:
|
| 65 |
+
if is_master():
|
| 66 |
+
if checkpoint is None:
|
| 67 |
+
if only_state_dict:
|
| 68 |
+
checkpoint = {"state_dict": self.network.state_dict()}
|
| 69 |
+
else:
|
| 70 |
+
checkpoint = {
|
| 71 |
+
"state_dict": self.network.state_dict(),
|
| 72 |
+
"epoch": epoch,
|
| 73 |
+
"best_val": self.best_val,
|
| 74 |
+
"optimizer": self.optimizer.state_dict(),
|
| 75 |
+
"lr_scheduler": self.lr_scheduler.state_dict(),
|
| 76 |
+
"ema": self.ema.state_dict() if self.ema is not None else None,
|
| 77 |
+
"scaler": self.scaler.state_dict() if self.fp16 else None,
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
model_name = model_name or "checkpoint.pt"
|
| 81 |
+
|
| 82 |
+
latest_fname = os.path.join(self.checkpoint_path, "latest.txt")
|
| 83 |
+
model_path = os.path.join(self.checkpoint_path, model_name)
|
| 84 |
+
with open(latest_fname, "w") as _fout:
|
| 85 |
+
_fout.write(model_path + "\n")
|
| 86 |
+
torch.save(checkpoint, model_path)
|
| 87 |
+
|
| 88 |
+
def load_model(self, model_fname=None) -> None:
|
| 89 |
+
latest_fname = os.path.join(self.checkpoint_path, "latest.txt")
|
| 90 |
+
if model_fname is None and os.path.exists(latest_fname):
|
| 91 |
+
with open(latest_fname, "r") as fin:
|
| 92 |
+
model_fname = fin.readline()
|
| 93 |
+
if len(model_fname) > 0 and model_fname[-1] == "\n":
|
| 94 |
+
model_fname = model_fname[:-1]
|
| 95 |
+
try:
|
| 96 |
+
if model_fname is None:
|
| 97 |
+
model_fname = f"{self.checkpoint_path}/checkpoint.pt"
|
| 98 |
+
elif not os.path.exists(model_fname):
|
| 99 |
+
model_fname = f"{self.checkpoint_path}/{os.path.basename(model_fname)}"
|
| 100 |
+
if not os.path.exists(model_fname):
|
| 101 |
+
model_fname = f"{self.checkpoint_path}/checkpoint.pt"
|
| 102 |
+
print(f"=> loading checkpoint {model_fname}")
|
| 103 |
+
checkpoint = load_state_dict_from_file(model_fname, False)
|
| 104 |
+
except Exception:
|
| 105 |
+
self.write_log(f"fail to load checkpoint from {self.checkpoint_path}")
|
| 106 |
+
return
|
| 107 |
+
|
| 108 |
+
# load checkpoint
|
| 109 |
+
self.network.load_state_dict(checkpoint["state_dict"], strict=False)
|
| 110 |
+
log = []
|
| 111 |
+
if "epoch" in checkpoint:
|
| 112 |
+
self.start_epoch = checkpoint["epoch"] + 1
|
| 113 |
+
self.run_config.update_global_step(self.start_epoch)
|
| 114 |
+
log.append(f"epoch={self.start_epoch - 1}")
|
| 115 |
+
if "best_val" in checkpoint:
|
| 116 |
+
self.best_val = checkpoint["best_val"]
|
| 117 |
+
log.append(f"best_val={self.best_val:.2f}")
|
| 118 |
+
if "optimizer" in checkpoint:
|
| 119 |
+
self.optimizer.load_state_dict(checkpoint["optimizer"])
|
| 120 |
+
log.append("optimizer")
|
| 121 |
+
if "lr_scheduler" in checkpoint:
|
| 122 |
+
self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
|
| 123 |
+
log.append("lr_scheduler")
|
| 124 |
+
if "ema" in checkpoint and self.ema is not None:
|
| 125 |
+
self.ema.load_state_dict(checkpoint["ema"])
|
| 126 |
+
log.append("ema")
|
| 127 |
+
if "scaler" in checkpoint and self.fp16:
|
| 128 |
+
self.scaler.load_state_dict(checkpoint["scaler"])
|
| 129 |
+
log.append("scaler")
|
| 130 |
+
self.write_log("Loaded: " + ", ".join(log))
|
| 131 |
+
|
| 132 |
+
""" validate """
|
| 133 |
+
|
| 134 |
+
def reset_bn(
|
| 135 |
+
self,
|
| 136 |
+
network: nn.Module or None = None,
|
| 137 |
+
subset_size: int = 16000,
|
| 138 |
+
subset_batch_size: int = 100,
|
| 139 |
+
data_loader=None,
|
| 140 |
+
progress_bar=False,
|
| 141 |
+
) -> None:
|
| 142 |
+
network = network or self.network
|
| 143 |
+
if data_loader is None:
|
| 144 |
+
data_loader = []
|
| 145 |
+
for data in self.data_provider.build_sub_train_loader(
|
| 146 |
+
subset_size, subset_batch_size
|
| 147 |
+
):
|
| 148 |
+
if isinstance(data, list):
|
| 149 |
+
data_loader.append(data[0])
|
| 150 |
+
elif isinstance(data, dict):
|
| 151 |
+
data_loader.append(data["data"])
|
| 152 |
+
elif isinstance(data, torch.Tensor):
|
| 153 |
+
data_loader.append(data)
|
| 154 |
+
else:
|
| 155 |
+
raise NotImplementedError
|
| 156 |
+
|
| 157 |
+
network.eval()
|
| 158 |
+
reset_bn(
|
| 159 |
+
network,
|
| 160 |
+
data_loader,
|
| 161 |
+
sync=True,
|
| 162 |
+
progress_bar=progress_bar,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
def _validate(self, model, data_loader, epoch) -> dict[str, any]:
|
| 166 |
+
raise NotImplementedError
|
| 167 |
+
|
| 168 |
+
def validate(
|
| 169 |
+
self, model=None, data_loader=None, is_test=True, epoch=0
|
| 170 |
+
) -> dict[str, any]:
|
| 171 |
+
model = model or self.eval_network
|
| 172 |
+
if data_loader is None:
|
| 173 |
+
if is_test:
|
| 174 |
+
data_loader = self.data_provider.test
|
| 175 |
+
else:
|
| 176 |
+
data_loader = self.data_provider.valid
|
| 177 |
+
|
| 178 |
+
model.eval()
|
| 179 |
+
return self._validate(model, data_loader, epoch)
|
| 180 |
+
|
| 181 |
+
def multires_validate(
|
| 182 |
+
self,
|
| 183 |
+
model=None,
|
| 184 |
+
data_loader=None,
|
| 185 |
+
is_test=True,
|
| 186 |
+
epoch=0,
|
| 187 |
+
eval_image_size=None,
|
| 188 |
+
) -> dict[str, dict[str, any]]:
|
| 189 |
+
eval_image_size = eval_image_size or self.run_config.eval_image_size
|
| 190 |
+
eval_image_size = eval_image_size or self.data_provider.image_size
|
| 191 |
+
model = model or self.eval_network
|
| 192 |
+
|
| 193 |
+
if not isinstance(eval_image_size, list):
|
| 194 |
+
eval_image_size = [eval_image_size]
|
| 195 |
+
|
| 196 |
+
output_dict = {}
|
| 197 |
+
for r in eval_image_size:
|
| 198 |
+
self.data_provider.assign_active_image_size(parse_image_size(r))
|
| 199 |
+
if self.run_config.reset_bn:
|
| 200 |
+
self.reset_bn(
|
| 201 |
+
network=model,
|
| 202 |
+
subset_size=self.run_config.reset_bn_size,
|
| 203 |
+
subset_batch_size=self.run_config.reset_bn_batch_size,
|
| 204 |
+
progress_bar=True,
|
| 205 |
+
)
|
| 206 |
+
output_dict[f"r{r}"] = self.validate(model, data_loader, is_test, epoch)
|
| 207 |
+
return output_dict
|
| 208 |
+
|
| 209 |
+
""" training """
|
| 210 |
+
|
| 211 |
+
def prep_for_training(
|
| 212 |
+
self, run_config: RunConfig, ema_decay: float or None = None, fp16=False
|
| 213 |
+
) -> None:
|
| 214 |
+
self.run_config = run_config
|
| 215 |
+
self.model = nn.parallel.DistributedDataParallel(
|
| 216 |
+
self.model.cuda(),
|
| 217 |
+
device_ids=[get_dist_local_rank()],
|
| 218 |
+
static_graph=True,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
self.run_config.global_step = 0
|
| 222 |
+
self.run_config.batch_per_epoch = len(self.data_provider.train)
|
| 223 |
+
assert self.run_config.batch_per_epoch > 0, "Training set is empty"
|
| 224 |
+
|
| 225 |
+
# build optimizer
|
| 226 |
+
self.optimizer, self.lr_scheduler = self.run_config.build_optimizer(self.model)
|
| 227 |
+
|
| 228 |
+
if ema_decay is not None:
|
| 229 |
+
self.ema = EMA(self.network, ema_decay)
|
| 230 |
+
|
| 231 |
+
# fp16
|
| 232 |
+
self.fp16 = fp16
|
| 233 |
+
self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)
|
| 234 |
+
|
| 235 |
+
def sync_model(self):
|
| 236 |
+
print("Sync model")
|
| 237 |
+
self.save_model(model_name="sync.pt")
|
| 238 |
+
dist_barrier()
|
| 239 |
+
checkpoint = torch.load(
|
| 240 |
+
os.path.join(self.checkpoint_path, "sync.pt"), map_location="cpu"
|
| 241 |
+
)
|
| 242 |
+
dist_barrier()
|
| 243 |
+
if is_master():
|
| 244 |
+
os.remove(os.path.join(self.checkpoint_path, "sync.pt"))
|
| 245 |
+
dist_barrier()
|
| 246 |
+
|
| 247 |
+
# load checkpoint
|
| 248 |
+
self.network.load_state_dict(checkpoint["state_dict"], strict=False)
|
| 249 |
+
if "optimizer" in checkpoint:
|
| 250 |
+
self.optimizer.load_state_dict(checkpoint["optimizer"])
|
| 251 |
+
if "lr_scheduler" in checkpoint:
|
| 252 |
+
self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
|
| 253 |
+
if "ema" in checkpoint and self.ema is not None:
|
| 254 |
+
self.ema.load_state_dict(checkpoint["ema"])
|
| 255 |
+
if "scaler" in checkpoint and self.fp16:
|
| 256 |
+
self.scaler.load_state_dict(checkpoint["scaler"])
|
| 257 |
+
|
| 258 |
+
def before_step(self, feed_dict: dict[str, any]) -> dict[str, any]:
|
| 259 |
+
for key in feed_dict:
|
| 260 |
+
if isinstance(feed_dict[key], torch.Tensor):
|
| 261 |
+
feed_dict[key] = feed_dict[key].cuda()
|
| 262 |
+
return feed_dict
|
| 263 |
+
|
| 264 |
+
def run_step(self, feed_dict: dict[str, any]) -> dict[str, any]:
|
| 265 |
+
raise NotImplementedError
|
| 266 |
+
|
| 267 |
+
def after_step(self) -> None:
|
| 268 |
+
self.scaler.unscale_(self.optimizer)
|
| 269 |
+
# gradient clip
|
| 270 |
+
if self.run_config.grad_clip is not None:
|
| 271 |
+
torch.nn.utils.clip_grad_value_(
|
| 272 |
+
self.model.parameters(), self.run_config.grad_clip
|
| 273 |
+
)
|
| 274 |
+
# update
|
| 275 |
+
self.scaler.step(self.optimizer)
|
| 276 |
+
self.scaler.update()
|
| 277 |
+
|
| 278 |
+
self.lr_scheduler.step()
|
| 279 |
+
self.run_config.step()
|
| 280 |
+
# update ema
|
| 281 |
+
if self.ema is not None:
|
| 282 |
+
self.ema.step(self.network, self.run_config.global_step)
|
| 283 |
+
|
| 284 |
+
def _train_one_epoch(self, epoch: int) -> dict[str, any]:
|
| 285 |
+
raise NotImplementedError
|
| 286 |
+
|
| 287 |
+
def train_one_epoch(self, epoch: int) -> dict[str, any]:
|
| 288 |
+
self.model.train()
|
| 289 |
+
|
| 290 |
+
self.data_provider.set_epoch(epoch)
|
| 291 |
+
|
| 292 |
+
train_info_dict = self._train_one_epoch(epoch)
|
| 293 |
+
|
| 294 |
+
return train_info_dict
|
| 295 |
+
|
| 296 |
+
def train(self) -> None:
|
| 297 |
+
raise NotImplementedError
|
src/efficientvit/apps/trainer/run_config.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
from src.efficientvit.apps.utils import CosineLRwithWarmup, build_optimizer
|
| 11 |
+
|
| 12 |
+
__all__ = ["Scheduler", "RunConfig"]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Scheduler:
|
| 16 |
+
PROGRESS = 0
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class RunConfig:
|
| 20 |
+
n_epochs: int
|
| 21 |
+
init_lr: float
|
| 22 |
+
warmup_epochs: int
|
| 23 |
+
warmup_lr: float
|
| 24 |
+
lr_schedule_name: str
|
| 25 |
+
lr_schedule_param: dict
|
| 26 |
+
optimizer_name: str
|
| 27 |
+
optimizer_params: dict
|
| 28 |
+
weight_decay: float
|
| 29 |
+
no_wd_keys: list
|
| 30 |
+
grad_clip: float # allow none to turn off grad clipping
|
| 31 |
+
reset_bn: bool
|
| 32 |
+
reset_bn_size: int
|
| 33 |
+
reset_bn_batch_size: int
|
| 34 |
+
eval_image_size: list # allow none to use image_size in data_provider
|
| 35 |
+
|
| 36 |
+
@property
|
| 37 |
+
def none_allowed(self):
|
| 38 |
+
return ["grad_clip", "eval_image_size"]
|
| 39 |
+
|
| 40 |
+
def __init__(self, **kwargs): # arguments must be passed as kwargs
|
| 41 |
+
for k, val in kwargs.items():
|
| 42 |
+
setattr(self, k, val)
|
| 43 |
+
|
| 44 |
+
# check that all relevant configs are there
|
| 45 |
+
annotations = {}
|
| 46 |
+
for clas in type(self).mro():
|
| 47 |
+
if hasattr(clas, "__annotations__"):
|
| 48 |
+
annotations.update(clas.__annotations__)
|
| 49 |
+
for k, k_type in annotations.items():
|
| 50 |
+
assert hasattr(
|
| 51 |
+
self, k
|
| 52 |
+
), f"Key {k} with type {k_type} required for initialization."
|
| 53 |
+
attr = getattr(self, k)
|
| 54 |
+
if k in self.none_allowed:
|
| 55 |
+
k_type = (k_type, type(None))
|
| 56 |
+
assert isinstance(
|
| 57 |
+
attr, k_type
|
| 58 |
+
), f"Key {k} must be type {k_type}, provided={attr}."
|
| 59 |
+
|
| 60 |
+
self.global_step = 0
|
| 61 |
+
self.batch_per_epoch = 1
|
| 62 |
+
|
| 63 |
+
def build_optimizer(self, network: nn.Module) -> tuple[any, any]:
|
| 64 |
+
r"""require setting 'batch_per_epoch' before building optimizer & lr_scheduler"""
|
| 65 |
+
param_dict = {}
|
| 66 |
+
for name, param in network.named_parameters():
|
| 67 |
+
if param.requires_grad:
|
| 68 |
+
opt_config = [self.weight_decay, self.init_lr]
|
| 69 |
+
if self.no_wd_keys is not None and len(self.no_wd_keys) > 0:
|
| 70 |
+
if np.any([key in name for key in self.no_wd_keys]):
|
| 71 |
+
opt_config[0] = 0
|
| 72 |
+
opt_key = json.dumps(opt_config)
|
| 73 |
+
param_dict[opt_key] = param_dict.get(opt_key, []) + [param]
|
| 74 |
+
|
| 75 |
+
net_params = []
|
| 76 |
+
for opt_key, param_list in param_dict.items():
|
| 77 |
+
wd, lr = json.loads(opt_key)
|
| 78 |
+
net_params.append({"params": param_list, "weight_decay": wd, "lr": lr})
|
| 79 |
+
|
| 80 |
+
optimizer = build_optimizer(
|
| 81 |
+
net_params, self.optimizer_name, self.optimizer_params, self.init_lr
|
| 82 |
+
)
|
| 83 |
+
# build lr scheduler
|
| 84 |
+
if self.lr_schedule_name == "cosine":
|
| 85 |
+
decay_steps = []
|
| 86 |
+
for epoch in self.lr_schedule_param.get("step", []):
|
| 87 |
+
decay_steps.append(epoch * self.batch_per_epoch)
|
| 88 |
+
decay_steps.append(self.n_epochs * self.batch_per_epoch)
|
| 89 |
+
decay_steps.sort()
|
| 90 |
+
lr_scheduler = CosineLRwithWarmup(
|
| 91 |
+
optimizer,
|
| 92 |
+
self.warmup_epochs * self.batch_per_epoch,
|
| 93 |
+
self.warmup_lr,
|
| 94 |
+
decay_steps,
|
| 95 |
+
)
|
| 96 |
+
else:
|
| 97 |
+
raise NotImplementedError
|
| 98 |
+
return optimizer, lr_scheduler
|
| 99 |
+
|
| 100 |
+
def update_global_step(self, epoch, batch_id=0) -> None:
|
| 101 |
+
self.global_step = epoch * self.batch_per_epoch + batch_id
|
| 102 |
+
Scheduler.PROGRESS = self.progress
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def progress(self) -> float:
|
| 106 |
+
warmup_steps = self.warmup_epochs * self.batch_per_epoch
|
| 107 |
+
steps = max(0, self.global_step - warmup_steps)
|
| 108 |
+
return steps / (self.n_epochs * self.batch_per_epoch)
|
| 109 |
+
|
| 110 |
+
def step(self) -> None:
|
| 111 |
+
self.global_step += 1
|
| 112 |
+
Scheduler.PROGRESS = self.progress
|
| 113 |
+
|
| 114 |
+
def get_remaining_epoch(self, epoch, post=True) -> int:
|
| 115 |
+
return self.n_epochs + self.warmup_epochs - epoch - int(post)
|
| 116 |
+
|
| 117 |
+
def epoch_format(self, epoch: int) -> str:
|
| 118 |
+
epoch_format = f"%.{len(str(self.n_epochs))}d"
|
| 119 |
+
epoch_format = f"[{epoch_format}/{epoch_format}]"
|
| 120 |
+
epoch_format = epoch_format % (epoch + 1 - self.warmup_epochs, self.n_epochs)
|
| 121 |
+
return epoch_format
|
src/efficientvit/apps/utils/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
from .dist import *
|
| 6 |
+
from .ema import *
|
| 7 |
+
from .export import *
|
| 8 |
+
from .init import *
|
| 9 |
+
from .lr import *
|
| 10 |
+
from .metric import *
|
| 11 |
+
from .misc import *
|
| 12 |
+
from .opt import *
|
src/efficientvit/apps/utils/dist.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed
|
| 9 |
+
|
| 10 |
+
from src.efficientvit.models.utils.list import list_mean, list_sum
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"dist_init",
|
| 14 |
+
"get_dist_rank",
|
| 15 |
+
"get_dist_size",
|
| 16 |
+
"is_master",
|
| 17 |
+
"dist_barrier",
|
| 18 |
+
"get_dist_local_rank",
|
| 19 |
+
"sync_tensor",
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def dist_init() -> None:
|
| 24 |
+
try:
|
| 25 |
+
torch.distributed.init_process_group(backend="nccl")
|
| 26 |
+
assert torch.distributed.is_initialized()
|
| 27 |
+
except Exception:
|
| 28 |
+
# use torchpack
|
| 29 |
+
from torchpack import distributed as dist
|
| 30 |
+
|
| 31 |
+
dist.init()
|
| 32 |
+
os.environ["RANK"] = f"{dist.rank()}"
|
| 33 |
+
os.environ["WORLD_SIZE"] = f"{dist.size()}"
|
| 34 |
+
os.environ["LOCAL_RANK"] = f"{dist.local_rank()}"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_dist_rank() -> int:
|
| 38 |
+
return int(os.environ["RANK"])
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_dist_size() -> int:
|
| 42 |
+
return int(os.environ["WORLD_SIZE"])
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def is_master() -> bool:
|
| 46 |
+
return get_dist_rank() == 0
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def dist_barrier() -> None:
|
| 50 |
+
torch.distributed.barrier()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_dist_local_rank() -> int:
|
| 54 |
+
return int(os.environ["LOCAL_RANK"])
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def sync_tensor(
|
| 58 |
+
tensor: torch.Tensor or float, reduce="mean"
|
| 59 |
+
) -> torch.Tensor or list[torch.Tensor]:
|
| 60 |
+
if not isinstance(tensor, torch.Tensor):
|
| 61 |
+
tensor = torch.Tensor(1).fill_(tensor).cuda()
|
| 62 |
+
tensor_list = [torch.empty_like(tensor) for _ in range(get_dist_size())]
|
| 63 |
+
torch.distributed.all_gather(tensor_list, tensor.contiguous(), async_op=False)
|
| 64 |
+
if reduce == "mean":
|
| 65 |
+
return list_mean(tensor_list)
|
| 66 |
+
elif reduce == "sum":
|
| 67 |
+
return list_sum(tensor_list)
|
| 68 |
+
elif reduce == "cat":
|
| 69 |
+
return torch.cat(tensor_list, dim=0)
|
| 70 |
+
elif reduce == "root":
|
| 71 |
+
return tensor_list[0]
|
| 72 |
+
else:
|
| 73 |
+
return tensor_list
|
src/efficientvit/apps/utils/ema.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
import copy
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
from src.efficientvit.models.utils import is_parallel
|
| 12 |
+
|
| 13 |
+
__all__ = ["EMA"]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def update_ema(
|
| 17 |
+
ema: nn.Module, new_state_dict: dict[str, torch.Tensor], decay: float
|
| 18 |
+
) -> None:
|
| 19 |
+
for k, v in ema.state_dict().items():
|
| 20 |
+
if v.dtype.is_floating_point:
|
| 21 |
+
v -= (1.0 - decay) * (v - new_state_dict[k].detach())
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class EMA:
|
| 25 |
+
def __init__(self, model: nn.Module, decay: float, warmup_steps=2000):
|
| 26 |
+
self.shadows = copy.deepcopy(
|
| 27 |
+
model.module if is_parallel(model) else model
|
| 28 |
+
).eval()
|
| 29 |
+
self.decay = decay
|
| 30 |
+
self.warmup_steps = warmup_steps
|
| 31 |
+
|
| 32 |
+
for p in self.shadows.parameters():
|
| 33 |
+
p.requires_grad = False
|
| 34 |
+
|
| 35 |
+
def step(self, model: nn.Module, global_step: int) -> None:
|
| 36 |
+
with torch.no_grad():
|
| 37 |
+
msd = (model.module if is_parallel(model) else model).state_dict()
|
| 38 |
+
update_ema(
|
| 39 |
+
self.shadows,
|
| 40 |
+
msd,
|
| 41 |
+
self.decay * (1 - math.exp(-global_step / self.warmup_steps)),
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def state_dict(self) -> dict[float, dict[str, torch.Tensor]]:
|
| 45 |
+
return {self.decay: self.shadows.state_dict()}
|
| 46 |
+
|
| 47 |
+
def load_state_dict(self, state_dict: dict[float, dict[str, torch.Tensor]]) -> None:
|
| 48 |
+
for decay in state_dict:
|
| 49 |
+
if decay == self.decay:
|
| 50 |
+
self.shadows.load_state_dict(state_dict[decay])
|
src/efficientvit/apps/utils/export.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
import io
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
import onnx
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from onnxsim import simplify as simplify_func
|
| 12 |
+
|
| 13 |
+
__all__ = ["export_onnx"]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def export_onnx(
|
| 17 |
+
model: nn.Module, export_path: str, sample_inputs: any, simplify=True, opset=11
|
| 18 |
+
) -> None:
|
| 19 |
+
"""Export a model to a platform-specific onnx format.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
model: a torch.nn.Module object.
|
| 23 |
+
export_path: export location.
|
| 24 |
+
sample_inputs: Any.
|
| 25 |
+
simplify: a flag to turn on onnx-simplifier
|
| 26 |
+
opset: int
|
| 27 |
+
"""
|
| 28 |
+
model.eval()
|
| 29 |
+
|
| 30 |
+
buffer = io.BytesIO()
|
| 31 |
+
with torch.no_grad():
|
| 32 |
+
torch.onnx.export(model, sample_inputs, buffer, opset_version=opset)
|
| 33 |
+
buffer.seek(0, 0)
|
| 34 |
+
if simplify:
|
| 35 |
+
onnx_model = onnx.load_model(buffer)
|
| 36 |
+
onnx_model, success = simplify_func(onnx_model)
|
| 37 |
+
assert success
|
| 38 |
+
new_buffer = io.BytesIO()
|
| 39 |
+
onnx.save(onnx_model, new_buffer)
|
| 40 |
+
buffer = new_buffer
|
| 41 |
+
buffer.seek(0, 0)
|
| 42 |
+
|
| 43 |
+
if buffer.getbuffer().nbytes > 0:
|
| 44 |
+
save_dir = os.path.dirname(export_path)
|
| 45 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 46 |
+
with open(export_path, "wb") as f:
|
| 47 |
+
f.write(buffer.read())
|
src/efficientvit/apps/utils/init.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
| 8 |
+
|
| 9 |
+
__all__ = ["init_modules", "zero_last_gamma"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def init_modules(model: nn.Module or list[nn.Module], init_type="trunc_normal") -> None:
|
| 13 |
+
_DEFAULT_INIT_PARAM = {"trunc_normal": 0.02}
|
| 14 |
+
|
| 15 |
+
if isinstance(model, list):
|
| 16 |
+
for sub_module in model:
|
| 17 |
+
init_modules(sub_module, init_type)
|
| 18 |
+
else:
|
| 19 |
+
init_params = init_type.split("@")
|
| 20 |
+
init_params = float(init_params[1]) if len(init_params) > 1 else None
|
| 21 |
+
|
| 22 |
+
if init_type.startswith("trunc_normal"):
|
| 23 |
+
init_func = lambda param: nn.init.trunc_normal_(
|
| 24 |
+
param, std=(init_params or _DEFAULT_INIT_PARAM["trunc_normal"])
|
| 25 |
+
)
|
| 26 |
+
else:
|
| 27 |
+
raise NotImplementedError
|
| 28 |
+
|
| 29 |
+
for m in model.modules():
|
| 30 |
+
if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
|
| 31 |
+
init_func(m.weight)
|
| 32 |
+
if m.bias is not None:
|
| 33 |
+
m.bias.data.zero_()
|
| 34 |
+
elif isinstance(m, nn.Embedding):
|
| 35 |
+
init_func(m.weight)
|
| 36 |
+
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
|
| 37 |
+
m.weight.data.fill_(1)
|
| 38 |
+
m.bias.data.zero_()
|
| 39 |
+
else:
|
| 40 |
+
weight = getattr(m, "weight", None)
|
| 41 |
+
bias = getattr(m, "bias", None)
|
| 42 |
+
if isinstance(weight, torch.nn.Parameter):
|
| 43 |
+
init_func(weight)
|
| 44 |
+
if isinstance(bias, torch.nn.Parameter):
|
| 45 |
+
bias.data.zero_()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def zero_last_gamma(model: nn.Module, init_val=0) -> None:
|
| 49 |
+
import efficientvit.models.nn.ops as ops
|
| 50 |
+
|
| 51 |
+
for m in model.modules():
|
| 52 |
+
if isinstance(m, ops.ResidualBlock) and isinstance(
|
| 53 |
+
m.shortcut, ops.IdentityLayer
|
| 54 |
+
):
|
| 55 |
+
if isinstance(m.main, (ops.DSConv, ops.MBConv, ops.FusedMBConv)):
|
| 56 |
+
parent_module = m.main.point_conv
|
| 57 |
+
elif isinstance(m.main, ops.ResBlock):
|
| 58 |
+
parent_module = m.main.conv2
|
| 59 |
+
elif isinstance(m.main, ops.ConvLayer):
|
| 60 |
+
parent_module = m.main
|
| 61 |
+
elif isinstance(m.main, (ops.LiteMLA)):
|
| 62 |
+
parent_module = m.main.proj
|
| 63 |
+
else:
|
| 64 |
+
parent_module = None
|
| 65 |
+
if parent_module is not None:
|
| 66 |
+
norm = getattr(parent_module, "norm", None)
|
| 67 |
+
if norm is not None:
|
| 68 |
+
nn.init.constant_(norm.weight, init_val)
|
src/efficientvit/apps/utils/lr.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from src.efficientvit.models.utils.list import val2list
|
| 10 |
+
|
| 11 |
+
__all__ = ["CosineLRwithWarmup"]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class CosineLRwithWarmup(torch.optim.lr_scheduler._LRScheduler):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
optimizer: torch.optim.Optimizer,
|
| 18 |
+
warmup_steps: int,
|
| 19 |
+
warmup_lr: float,
|
| 20 |
+
decay_steps: int or list[int],
|
| 21 |
+
last_epoch: int = -1,
|
| 22 |
+
) -> None:
|
| 23 |
+
self.warmup_steps = warmup_steps
|
| 24 |
+
self.warmup_lr = warmup_lr
|
| 25 |
+
self.decay_steps = val2list(decay_steps)
|
| 26 |
+
super().__init__(optimizer, last_epoch)
|
| 27 |
+
|
| 28 |
+
def get_lr(self) -> list[float]:
|
| 29 |
+
if self.last_epoch < self.warmup_steps:
|
| 30 |
+
return [
|
| 31 |
+
(base_lr - self.warmup_lr) * (self.last_epoch + 1) / self.warmup_steps
|
| 32 |
+
+ self.warmup_lr
|
| 33 |
+
for base_lr in self.base_lrs
|
| 34 |
+
]
|
| 35 |
+
else:
|
| 36 |
+
current_steps = self.last_epoch - self.warmup_steps
|
| 37 |
+
decay_steps = [0] + self.decay_steps
|
| 38 |
+
idx = len(decay_steps) - 2
|
| 39 |
+
for i, decay_step in enumerate(decay_steps[:-1]):
|
| 40 |
+
if decay_step <= current_steps < decay_steps[i + 1]:
|
| 41 |
+
idx = i
|
| 42 |
+
break
|
| 43 |
+
current_steps -= decay_steps[idx]
|
| 44 |
+
decay_step = decay_steps[idx + 1] - decay_steps[idx]
|
| 45 |
+
return [
|
| 46 |
+
0.5 * base_lr * (1 + math.cos(math.pi * current_steps / decay_step))
|
| 47 |
+
for base_lr in self.base_lrs
|
| 48 |
+
]
|
src/efficientvit/apps/utils/metric.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from src.efficientvit.apps.utils.dist import sync_tensor
|
| 8 |
+
|
| 9 |
+
__all__ = ["AverageMeter"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class AverageMeter:
|
| 13 |
+
"""Computes and stores the average and current value."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, is_distributed=True):
|
| 16 |
+
self.is_distributed = is_distributed
|
| 17 |
+
self.sum = 0
|
| 18 |
+
self.count = 0
|
| 19 |
+
|
| 20 |
+
def _sync(self, val: torch.Tensor or int or float) -> torch.Tensor or int or float:
|
| 21 |
+
return sync_tensor(val, reduce="sum") if self.is_distributed else val
|
| 22 |
+
|
| 23 |
+
def update(self, val: torch.Tensor or int or float, delta_n=1):
|
| 24 |
+
self.count += self._sync(delta_n)
|
| 25 |
+
self.sum += self._sync(val * delta_n)
|
| 26 |
+
|
| 27 |
+
def get_count(self) -> torch.Tensor or int or float:
|
| 28 |
+
return (
|
| 29 |
+
self.count.item()
|
| 30 |
+
if isinstance(self.count, torch.Tensor) and self.count.numel() == 1
|
| 31 |
+
else self.count
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
@property
|
| 35 |
+
def avg(self):
|
| 36 |
+
avg = -1 if self.count == 0 else self.sum / self.count
|
| 37 |
+
return avg.item() if isinstance(avg, torch.Tensor) and avg.numel() == 1 else avg
|
src/efficientvit/apps/utils/misc.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
import yaml
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"parse_with_yaml",
|
| 11 |
+
"parse_unknown_args",
|
| 12 |
+
"partial_update_config",
|
| 13 |
+
"resolve_and_load_config",
|
| 14 |
+
"load_config",
|
| 15 |
+
"dump_config",
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def parse_with_yaml(config_str: str) -> str or dict:
|
| 20 |
+
try:
|
| 21 |
+
# add space manually for dict
|
| 22 |
+
if "{" in config_str and "}" in config_str and ":" in config_str:
|
| 23 |
+
out_str = config_str.replace(":", ": ")
|
| 24 |
+
else:
|
| 25 |
+
out_str = config_str
|
| 26 |
+
return yaml.safe_load(out_str)
|
| 27 |
+
except ValueError:
|
| 28 |
+
# return raw string if parsing fails
|
| 29 |
+
return config_str
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def parse_unknown_args(unknown: list) -> dict:
|
| 33 |
+
"""Parse unknown args."""
|
| 34 |
+
index = 0
|
| 35 |
+
parsed_dict = {}
|
| 36 |
+
while index < len(unknown):
|
| 37 |
+
key, val = unknown[index], unknown[index + 1]
|
| 38 |
+
index += 2
|
| 39 |
+
if not key.startswith("--"):
|
| 40 |
+
continue
|
| 41 |
+
key = key[2:]
|
| 42 |
+
|
| 43 |
+
# try parsing with either dot notation or full yaml notation
|
| 44 |
+
# Note that the vanilla case "--key value" will be parsed the same
|
| 45 |
+
if "." in key:
|
| 46 |
+
# key == a.b.c, val == val --> parsed_dict[a][b][c] = val
|
| 47 |
+
keys = key.split(".")
|
| 48 |
+
dict_to_update = parsed_dict
|
| 49 |
+
for key in keys[:-1]:
|
| 50 |
+
if not (
|
| 51 |
+
key in dict_to_update and isinstance(dict_to_update[key], dict)
|
| 52 |
+
):
|
| 53 |
+
dict_to_update[key] = {}
|
| 54 |
+
dict_to_update = dict_to_update[key]
|
| 55 |
+
dict_to_update[keys[-1]] = parse_with_yaml(
|
| 56 |
+
val
|
| 57 |
+
) # so we can parse lists, bools, etc...
|
| 58 |
+
else:
|
| 59 |
+
parsed_dict[key] = parse_with_yaml(val)
|
| 60 |
+
return parsed_dict
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def partial_update_config(config: dict, partial_config: dict) -> dict:
|
| 64 |
+
for key in partial_config:
|
| 65 |
+
if (
|
| 66 |
+
key in config
|
| 67 |
+
and isinstance(partial_config[key], dict)
|
| 68 |
+
and isinstance(config[key], dict)
|
| 69 |
+
):
|
| 70 |
+
partial_update_config(config[key], partial_config[key])
|
| 71 |
+
else:
|
| 72 |
+
config[key] = partial_config[key]
|
| 73 |
+
return config
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def resolve_and_load_config(path: str, config_name="config.yaml") -> dict:
|
| 77 |
+
path = os.path.realpath(os.path.expanduser(path))
|
| 78 |
+
if os.path.isdir(path):
|
| 79 |
+
config_path = os.path.join(path, config_name)
|
| 80 |
+
else:
|
| 81 |
+
config_path = path
|
| 82 |
+
if os.path.isfile(config_path):
|
| 83 |
+
pass
|
| 84 |
+
else:
|
| 85 |
+
raise Exception(f"Cannot find a valid config at {path}")
|
| 86 |
+
config = load_config(config_path)
|
| 87 |
+
return config
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class SafeLoaderWithTuple(yaml.SafeLoader):
|
| 91 |
+
"""A yaml safe loader with python tuple loading capabilities."""
|
| 92 |
+
|
| 93 |
+
def construct_python_tuple(self, node):
|
| 94 |
+
return tuple(self.construct_sequence(node))
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
SafeLoaderWithTuple.add_constructor(
|
| 98 |
+
"tag:yaml.org,2002:python/tuple", SafeLoaderWithTuple.construct_python_tuple
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def load_config(filename: str) -> dict:
|
| 103 |
+
"""Load a yaml file."""
|
| 104 |
+
filename = os.path.realpath(os.path.expanduser(filename))
|
| 105 |
+
return yaml.load(open(filename), Loader=SafeLoaderWithTuple)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def dump_config(config: dict, filename: str) -> None:
|
| 109 |
+
"""Dump a config file"""
|
| 110 |
+
filename = os.path.realpath(os.path.expanduser(filename))
|
| 111 |
+
yaml.dump(config, open(filename, "w"), sort_keys=False)
|
src/efficientvit/apps/utils/opt.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
__all__ = ["REGISTERED_OPTIMIZER_DICT", "build_optimizer"]
|
| 8 |
+
|
| 9 |
+
# register optimizer here
|
| 10 |
+
# name: optimizer, kwargs with default values
|
| 11 |
+
REGISTERED_OPTIMIZER_DICT: dict[str, tuple[type, dict[str, any]]] = {
|
| 12 |
+
"sgd": (torch.optim.SGD, {"momentum": 0.9, "nesterov": True}),
|
| 13 |
+
"adam": (torch.optim.Adam, {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False}),
|
| 14 |
+
"adamw": (
|
| 15 |
+
torch.optim.AdamW,
|
| 16 |
+
{"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False},
|
| 17 |
+
),
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def build_optimizer(
|
| 22 |
+
net_params, optimizer_name: str, optimizer_params: dict or None, init_lr: float
|
| 23 |
+
) -> torch.optim.Optimizer:
|
| 24 |
+
optimizer_class, default_params = REGISTERED_OPTIMIZER_DICT[optimizer_name]
|
| 25 |
+
optimizer_params = optimizer_params or {}
|
| 26 |
+
|
| 27 |
+
for key in default_params:
|
| 28 |
+
if key in optimizer_params:
|
| 29 |
+
default_params[key] = optimizer_params[key]
|
| 30 |
+
optimizer = optimizer_class(net_params, init_lr, **default_params)
|
| 31 |
+
return optimizer
|
src/efficientvit/models/__init__.py
ADDED
|
File without changes
|
src/efficientvit/models/efficientvit/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
from .backbone import *
|
| 6 |
+
from .cls import *
|
| 7 |
+
from .sam import *
|
| 8 |
+
from .seg import *
|
src/efficientvit/models/efficientvit/backbone.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from src.efficientvit.models.nn import (ConvLayer, DSConv, EfficientViTBlock,
|
| 9 |
+
FusedMBConv, IdentityLayer, MBConv,
|
| 10 |
+
OpSequential, ResBlock, ResidualBlock)
|
| 11 |
+
from src.efficientvit.models.utils import build_kwargs_from_config
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
"EfficientViTBackbone",
|
| 15 |
+
"efficientvit_backbone_b0",
|
| 16 |
+
"efficientvit_backbone_b1",
|
| 17 |
+
"efficientvit_backbone_b2",
|
| 18 |
+
"efficientvit_backbone_b3",
|
| 19 |
+
"EfficientViTLargeBackbone",
|
| 20 |
+
"efficientvit_backbone_l0",
|
| 21 |
+
"efficientvit_backbone_l1",
|
| 22 |
+
"efficientvit_backbone_l2",
|
| 23 |
+
"efficientvit_backbone_l3",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class EfficientViTBackbone(nn.Module):
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
width_list: list[int],
|
| 31 |
+
depth_list: list[int],
|
| 32 |
+
in_channels=3,
|
| 33 |
+
dim=32,
|
| 34 |
+
expand_ratio=4,
|
| 35 |
+
norm="bn2d",
|
| 36 |
+
act_func="hswish",
|
| 37 |
+
) -> None:
|
| 38 |
+
super().__init__()
|
| 39 |
+
|
| 40 |
+
self.width_list = []
|
| 41 |
+
# input stem
|
| 42 |
+
self.input_stem = [
|
| 43 |
+
ConvLayer(
|
| 44 |
+
in_channels=3,
|
| 45 |
+
out_channels=width_list[0],
|
| 46 |
+
stride=2,
|
| 47 |
+
norm=norm,
|
| 48 |
+
act_func=act_func,
|
| 49 |
+
)
|
| 50 |
+
]
|
| 51 |
+
for _ in range(depth_list[0]):
|
| 52 |
+
block = self.build_local_block(
|
| 53 |
+
in_channels=width_list[0],
|
| 54 |
+
out_channels=width_list[0],
|
| 55 |
+
stride=1,
|
| 56 |
+
expand_ratio=1,
|
| 57 |
+
norm=norm,
|
| 58 |
+
act_func=act_func,
|
| 59 |
+
)
|
| 60 |
+
self.input_stem.append(ResidualBlock(block, IdentityLayer()))
|
| 61 |
+
in_channels = width_list[0]
|
| 62 |
+
self.input_stem = OpSequential(self.input_stem)
|
| 63 |
+
self.width_list.append(in_channels)
|
| 64 |
+
|
| 65 |
+
# stages
|
| 66 |
+
self.stages = []
|
| 67 |
+
for w, d in zip(width_list[1:3], depth_list[1:3]):
|
| 68 |
+
stage = []
|
| 69 |
+
for i in range(d):
|
| 70 |
+
stride = 2 if i == 0 else 1
|
| 71 |
+
block = self.build_local_block(
|
| 72 |
+
in_channels=in_channels,
|
| 73 |
+
out_channels=w,
|
| 74 |
+
stride=stride,
|
| 75 |
+
expand_ratio=expand_ratio,
|
| 76 |
+
norm=norm,
|
| 77 |
+
act_func=act_func,
|
| 78 |
+
)
|
| 79 |
+
block = ResidualBlock(block, IdentityLayer() if stride == 1 else None)
|
| 80 |
+
stage.append(block)
|
| 81 |
+
in_channels = w
|
| 82 |
+
self.stages.append(OpSequential(stage))
|
| 83 |
+
self.width_list.append(in_channels)
|
| 84 |
+
|
| 85 |
+
for w, d in zip(width_list[3:], depth_list[3:]):
|
| 86 |
+
stage = []
|
| 87 |
+
block = self.build_local_block(
|
| 88 |
+
in_channels=in_channels,
|
| 89 |
+
out_channels=w,
|
| 90 |
+
stride=2,
|
| 91 |
+
expand_ratio=expand_ratio,
|
| 92 |
+
norm=norm,
|
| 93 |
+
act_func=act_func,
|
| 94 |
+
fewer_norm=True,
|
| 95 |
+
)
|
| 96 |
+
stage.append(ResidualBlock(block, None))
|
| 97 |
+
in_channels = w
|
| 98 |
+
|
| 99 |
+
for _ in range(d):
|
| 100 |
+
stage.append(
|
| 101 |
+
EfficientViTBlock(
|
| 102 |
+
in_channels=in_channels,
|
| 103 |
+
dim=dim,
|
| 104 |
+
expand_ratio=expand_ratio,
|
| 105 |
+
norm=norm,
|
| 106 |
+
act_func=act_func,
|
| 107 |
+
)
|
| 108 |
+
)
|
| 109 |
+
self.stages.append(OpSequential(stage))
|
| 110 |
+
self.width_list.append(in_channels)
|
| 111 |
+
self.stages = nn.ModuleList(self.stages)
|
| 112 |
+
|
| 113 |
+
@staticmethod
|
| 114 |
+
def build_local_block(
|
| 115 |
+
in_channels: int,
|
| 116 |
+
out_channels: int,
|
| 117 |
+
stride: int,
|
| 118 |
+
expand_ratio: float,
|
| 119 |
+
norm: str,
|
| 120 |
+
act_func: str,
|
| 121 |
+
fewer_norm: bool = False,
|
| 122 |
+
) -> nn.Module:
|
| 123 |
+
if expand_ratio == 1:
|
| 124 |
+
block = DSConv(
|
| 125 |
+
in_channels=in_channels,
|
| 126 |
+
out_channels=out_channels,
|
| 127 |
+
stride=stride,
|
| 128 |
+
use_bias=(True, False) if fewer_norm else False,
|
| 129 |
+
norm=(None, norm) if fewer_norm else norm,
|
| 130 |
+
act_func=(act_func, None),
|
| 131 |
+
)
|
| 132 |
+
else:
|
| 133 |
+
block = MBConv(
|
| 134 |
+
in_channels=in_channels,
|
| 135 |
+
out_channels=out_channels,
|
| 136 |
+
stride=stride,
|
| 137 |
+
expand_ratio=expand_ratio,
|
| 138 |
+
use_bias=(True, True, False) if fewer_norm else False,
|
| 139 |
+
norm=(None, None, norm) if fewer_norm else norm,
|
| 140 |
+
act_func=(act_func, act_func, None),
|
| 141 |
+
)
|
| 142 |
+
return block
|
| 143 |
+
|
| 144 |
+
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
| 145 |
+
output_dict = {"input": x}
|
| 146 |
+
output_dict["stage0"] = x = self.input_stem(x)
|
| 147 |
+
for stage_id, stage in enumerate(self.stages, 1):
|
| 148 |
+
output_dict["stage%d" % stage_id] = x = stage(x)
|
| 149 |
+
output_dict["stage_final"] = x
|
| 150 |
+
return output_dict
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def efficientvit_backbone_b0(**kwargs) -> EfficientViTBackbone:
|
| 154 |
+
backbone = EfficientViTBackbone(
|
| 155 |
+
width_list=[8, 16, 32, 64, 128],
|
| 156 |
+
depth_list=[1, 2, 2, 2, 2],
|
| 157 |
+
dim=16,
|
| 158 |
+
**build_kwargs_from_config(kwargs, EfficientViTBackbone),
|
| 159 |
+
)
|
| 160 |
+
return backbone
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def efficientvit_backbone_b1(**kwargs) -> EfficientViTBackbone:
|
| 164 |
+
backbone = EfficientViTBackbone(
|
| 165 |
+
width_list=[16, 32, 64, 128, 256],
|
| 166 |
+
depth_list=[1, 2, 3, 3, 4],
|
| 167 |
+
dim=16,
|
| 168 |
+
**build_kwargs_from_config(kwargs, EfficientViTBackbone),
|
| 169 |
+
)
|
| 170 |
+
return backbone
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def efficientvit_backbone_b2(**kwargs) -> EfficientViTBackbone:
|
| 174 |
+
backbone = EfficientViTBackbone(
|
| 175 |
+
width_list=[24, 48, 96, 192, 384],
|
| 176 |
+
depth_list=[1, 3, 4, 4, 6],
|
| 177 |
+
dim=32,
|
| 178 |
+
**build_kwargs_from_config(kwargs, EfficientViTBackbone),
|
| 179 |
+
)
|
| 180 |
+
return backbone
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def efficientvit_backbone_b3(**kwargs) -> EfficientViTBackbone:
|
| 184 |
+
backbone = EfficientViTBackbone(
|
| 185 |
+
width_list=[32, 64, 128, 256, 512],
|
| 186 |
+
depth_list=[1, 4, 6, 6, 9],
|
| 187 |
+
dim=32,
|
| 188 |
+
**build_kwargs_from_config(kwargs, EfficientViTBackbone),
|
| 189 |
+
)
|
| 190 |
+
return backbone
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class EfficientViTLargeBackbone(nn.Module):
|
| 194 |
+
def __init__(
|
| 195 |
+
self,
|
| 196 |
+
width_list: list[int],
|
| 197 |
+
depth_list: list[int],
|
| 198 |
+
block_list: list[str] or None = None,
|
| 199 |
+
expand_list: list[float] or None = None,
|
| 200 |
+
fewer_norm_list: list[bool] or None = None,
|
| 201 |
+
in_channels=3,
|
| 202 |
+
qkv_dim=32,
|
| 203 |
+
norm="bn2d",
|
| 204 |
+
act_func="gelu",
|
| 205 |
+
) -> None:
|
| 206 |
+
super().__init__()
|
| 207 |
+
block_list = block_list or ["res", "fmb", "fmb", "mb", "att"]
|
| 208 |
+
expand_list = expand_list or [1, 4, 4, 4, 6]
|
| 209 |
+
fewer_norm_list = fewer_norm_list or [False, False, False, True, True]
|
| 210 |
+
|
| 211 |
+
self.width_list = []
|
| 212 |
+
self.stages = []
|
| 213 |
+
# stage 0
|
| 214 |
+
stage0 = [
|
| 215 |
+
ConvLayer(
|
| 216 |
+
in_channels=3,
|
| 217 |
+
out_channels=width_list[0],
|
| 218 |
+
stride=2,
|
| 219 |
+
norm=norm,
|
| 220 |
+
act_func=act_func,
|
| 221 |
+
)
|
| 222 |
+
]
|
| 223 |
+
for _ in range(depth_list[0]):
|
| 224 |
+
block = self.build_local_block(
|
| 225 |
+
block=block_list[0],
|
| 226 |
+
in_channels=width_list[0],
|
| 227 |
+
out_channels=width_list[0],
|
| 228 |
+
stride=1,
|
| 229 |
+
expand_ratio=expand_list[0],
|
| 230 |
+
norm=norm,
|
| 231 |
+
act_func=act_func,
|
| 232 |
+
fewer_norm=fewer_norm_list[0],
|
| 233 |
+
)
|
| 234 |
+
stage0.append(ResidualBlock(block, IdentityLayer()))
|
| 235 |
+
in_channels = width_list[0]
|
| 236 |
+
self.stages.append(OpSequential(stage0))
|
| 237 |
+
self.width_list.append(in_channels)
|
| 238 |
+
|
| 239 |
+
for stage_id, (w, d) in enumerate(zip(width_list[1:], depth_list[1:]), start=1):
|
| 240 |
+
stage = []
|
| 241 |
+
block = self.build_local_block(
|
| 242 |
+
block=(
|
| 243 |
+
"mb"
|
| 244 |
+
if block_list[stage_id] not in ["mb", "fmb"]
|
| 245 |
+
else block_list[stage_id]
|
| 246 |
+
),
|
| 247 |
+
in_channels=in_channels,
|
| 248 |
+
out_channels=w,
|
| 249 |
+
stride=2,
|
| 250 |
+
expand_ratio=expand_list[stage_id] * 4,
|
| 251 |
+
norm=norm,
|
| 252 |
+
act_func=act_func,
|
| 253 |
+
fewer_norm=fewer_norm_list[stage_id],
|
| 254 |
+
)
|
| 255 |
+
stage.append(ResidualBlock(block, None))
|
| 256 |
+
in_channels = w
|
| 257 |
+
|
| 258 |
+
for _ in range(d):
|
| 259 |
+
if block_list[stage_id].startswith("att"):
|
| 260 |
+
stage.append(
|
| 261 |
+
EfficientViTBlock(
|
| 262 |
+
in_channels=in_channels,
|
| 263 |
+
dim=qkv_dim,
|
| 264 |
+
expand_ratio=expand_list[stage_id],
|
| 265 |
+
scales=(3,) if block_list[stage_id] == "att@3" else (5,),
|
| 266 |
+
norm=norm,
|
| 267 |
+
act_func=act_func,
|
| 268 |
+
)
|
| 269 |
+
)
|
| 270 |
+
else:
|
| 271 |
+
block = self.build_local_block(
|
| 272 |
+
block=block_list[stage_id],
|
| 273 |
+
in_channels=in_channels,
|
| 274 |
+
out_channels=in_channels,
|
| 275 |
+
stride=1,
|
| 276 |
+
expand_ratio=expand_list[stage_id],
|
| 277 |
+
norm=norm,
|
| 278 |
+
act_func=act_func,
|
| 279 |
+
fewer_norm=fewer_norm_list[stage_id],
|
| 280 |
+
)
|
| 281 |
+
block = ResidualBlock(block, IdentityLayer())
|
| 282 |
+
stage.append(block)
|
| 283 |
+
self.stages.append(OpSequential(stage))
|
| 284 |
+
self.width_list.append(in_channels)
|
| 285 |
+
self.stages = nn.ModuleList(self.stages)
|
| 286 |
+
|
| 287 |
+
@staticmethod
|
| 288 |
+
def build_local_block(
|
| 289 |
+
block: str,
|
| 290 |
+
in_channels: int,
|
| 291 |
+
out_channels: int,
|
| 292 |
+
stride: int,
|
| 293 |
+
expand_ratio: float,
|
| 294 |
+
norm: str,
|
| 295 |
+
act_func: str,
|
| 296 |
+
fewer_norm: bool = False,
|
| 297 |
+
) -> nn.Module:
|
| 298 |
+
if block == "res":
|
| 299 |
+
block = ResBlock(
|
| 300 |
+
in_channels=in_channels,
|
| 301 |
+
out_channels=out_channels,
|
| 302 |
+
stride=stride,
|
| 303 |
+
use_bias=(True, False) if fewer_norm else False,
|
| 304 |
+
norm=(None, norm) if fewer_norm else norm,
|
| 305 |
+
act_func=(act_func, None),
|
| 306 |
+
)
|
| 307 |
+
elif block == "fmb":
|
| 308 |
+
block = FusedMBConv(
|
| 309 |
+
in_channels=in_channels,
|
| 310 |
+
out_channels=out_channels,
|
| 311 |
+
stride=stride,
|
| 312 |
+
expand_ratio=expand_ratio,
|
| 313 |
+
use_bias=(True, False) if fewer_norm else False,
|
| 314 |
+
norm=(None, norm) if fewer_norm else norm,
|
| 315 |
+
act_func=(act_func, None),
|
| 316 |
+
)
|
| 317 |
+
elif block == "mb":
|
| 318 |
+
block = MBConv(
|
| 319 |
+
in_channels=in_channels,
|
| 320 |
+
out_channels=out_channels,
|
| 321 |
+
stride=stride,
|
| 322 |
+
expand_ratio=expand_ratio,
|
| 323 |
+
use_bias=(True, True, False) if fewer_norm else False,
|
| 324 |
+
norm=(None, None, norm) if fewer_norm else norm,
|
| 325 |
+
act_func=(act_func, act_func, None),
|
| 326 |
+
)
|
| 327 |
+
else:
|
| 328 |
+
raise ValueError(block)
|
| 329 |
+
return block
|
| 330 |
+
|
| 331 |
+
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
|
| 332 |
+
output_dict = {"input": x}
|
| 333 |
+
for stage_id, stage in enumerate(self.stages):
|
| 334 |
+
output_dict["stage%d" % stage_id] = x = stage(x)
|
| 335 |
+
output_dict["stage_final"] = x
|
| 336 |
+
return output_dict
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def efficientvit_backbone_l0(**kwargs) -> EfficientViTLargeBackbone:
|
| 340 |
+
backbone = EfficientViTLargeBackbone(
|
| 341 |
+
width_list=[32, 64, 128, 256, 512],
|
| 342 |
+
depth_list=[1, 1, 1, 4, 4],
|
| 343 |
+
**build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
|
| 344 |
+
)
|
| 345 |
+
return backbone
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def efficientvit_backbone_l1(**kwargs) -> EfficientViTLargeBackbone:
|
| 349 |
+
backbone = EfficientViTLargeBackbone(
|
| 350 |
+
width_list=[32, 64, 128, 256, 512],
|
| 351 |
+
depth_list=[1, 1, 1, 6, 6],
|
| 352 |
+
**build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
|
| 353 |
+
)
|
| 354 |
+
return backbone
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def efficientvit_backbone_l2(**kwargs) -> EfficientViTLargeBackbone:
|
| 358 |
+
backbone = EfficientViTLargeBackbone(
|
| 359 |
+
width_list=[32, 64, 128, 256, 512],
|
| 360 |
+
depth_list=[1, 2, 2, 8, 8],
|
| 361 |
+
**build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
|
| 362 |
+
)
|
| 363 |
+
return backbone
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def efficientvit_backbone_l3(**kwargs) -> EfficientViTLargeBackbone:
|
| 367 |
+
backbone = EfficientViTLargeBackbone(
|
| 368 |
+
width_list=[64, 128, 256, 512, 1024],
|
| 369 |
+
depth_list=[1, 2, 2, 8, 8],
|
| 370 |
+
**build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
|
| 371 |
+
)
|
| 372 |
+
return backbone
|
src/efficientvit/models/efficientvit/cls.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from src.efficientvit.models.efficientvit.backbone import (
|
| 9 |
+
EfficientViTBackbone, EfficientViTLargeBackbone)
|
| 10 |
+
from src.efficientvit.models.nn import ConvLayer, LinearLayer, OpSequential
|
| 11 |
+
from src.efficientvit.models.utils import build_kwargs_from_config
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
"EfficientViTCls",
|
| 15 |
+
######################
|
| 16 |
+
"efficientvit_cls_b0",
|
| 17 |
+
"efficientvit_cls_b1",
|
| 18 |
+
"efficientvit_cls_b2",
|
| 19 |
+
"efficientvit_cls_b3",
|
| 20 |
+
######################
|
| 21 |
+
"efficientvit_cls_l1",
|
| 22 |
+
"efficientvit_cls_l2",
|
| 23 |
+
"efficientvit_cls_l3",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class ClsHead(OpSequential):
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
in_channels: int,
|
| 31 |
+
width_list: list[int],
|
| 32 |
+
n_classes=1000,
|
| 33 |
+
dropout=0.0,
|
| 34 |
+
norm="bn2d",
|
| 35 |
+
act_func="hswish",
|
| 36 |
+
fid="stage_final",
|
| 37 |
+
):
|
| 38 |
+
ops = [
|
| 39 |
+
ConvLayer(in_channels, width_list[0], 1, norm=norm, act_func=act_func),
|
| 40 |
+
nn.AdaptiveAvgPool2d(output_size=1),
|
| 41 |
+
LinearLayer(
|
| 42 |
+
width_list[0], width_list[1], False, norm="ln", act_func=act_func
|
| 43 |
+
),
|
| 44 |
+
LinearLayer(width_list[1], n_classes, True, dropout, None, None),
|
| 45 |
+
]
|
| 46 |
+
super().__init__(ops)
|
| 47 |
+
|
| 48 |
+
self.fid = fid
|
| 49 |
+
|
| 50 |
+
def forward(self, feed_dict: dict[str, torch.Tensor]) -> torch.Tensor:
|
| 51 |
+
x = feed_dict[self.fid]
|
| 52 |
+
return OpSequential.forward(self, x)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class EfficientViTCls(nn.Module):
|
| 56 |
+
def __init__(
|
| 57 |
+
self, backbone: EfficientViTBackbone or EfficientViTLargeBackbone, head: ClsHead
|
| 58 |
+
) -> None:
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.backbone = backbone
|
| 61 |
+
self.head = head
|
| 62 |
+
|
| 63 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 64 |
+
feed_dict = self.backbone(x)
|
| 65 |
+
output = self.head(feed_dict)
|
| 66 |
+
return output
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def efficientvit_cls_b0(**kwargs) -> EfficientViTCls:
|
| 70 |
+
from efficientvit.models.efficientvit.backbone import \
|
| 71 |
+
efficientvit_backbone_b0
|
| 72 |
+
|
| 73 |
+
backbone = efficientvit_backbone_b0(**kwargs)
|
| 74 |
+
|
| 75 |
+
head = ClsHead(
|
| 76 |
+
in_channels=128,
|
| 77 |
+
width_list=[1024, 1280],
|
| 78 |
+
**build_kwargs_from_config(kwargs, ClsHead),
|
| 79 |
+
)
|
| 80 |
+
model = EfficientViTCls(backbone, head)
|
| 81 |
+
return model
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def efficientvit_cls_b1(**kwargs) -> EfficientViTCls:
|
| 85 |
+
from efficientvit.models.efficientvit.backbone import \
|
| 86 |
+
efficientvit_backbone_b1
|
| 87 |
+
|
| 88 |
+
backbone = efficientvit_backbone_b1(**kwargs)
|
| 89 |
+
|
| 90 |
+
head = ClsHead(
|
| 91 |
+
in_channels=256,
|
| 92 |
+
width_list=[1536, 1600],
|
| 93 |
+
**build_kwargs_from_config(kwargs, ClsHead),
|
| 94 |
+
)
|
| 95 |
+
model = EfficientViTCls(backbone, head)
|
| 96 |
+
return model
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def efficientvit_cls_b2(**kwargs) -> EfficientViTCls:
|
| 100 |
+
from efficientvit.models.efficientvit.backbone import \
|
| 101 |
+
efficientvit_backbone_b2
|
| 102 |
+
|
| 103 |
+
backbone = efficientvit_backbone_b2(**kwargs)
|
| 104 |
+
|
| 105 |
+
head = ClsHead(
|
| 106 |
+
in_channels=384,
|
| 107 |
+
width_list=[2304, 2560],
|
| 108 |
+
**build_kwargs_from_config(kwargs, ClsHead),
|
| 109 |
+
)
|
| 110 |
+
model = EfficientViTCls(backbone, head)
|
| 111 |
+
return model
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def efficientvit_cls_b3(**kwargs) -> EfficientViTCls:
|
| 115 |
+
from efficientvit.models.efficientvit.backbone import \
|
| 116 |
+
efficientvit_backbone_b3
|
| 117 |
+
|
| 118 |
+
backbone = efficientvit_backbone_b3(**kwargs)
|
| 119 |
+
|
| 120 |
+
head = ClsHead(
|
| 121 |
+
in_channels=512,
|
| 122 |
+
width_list=[2304, 2560],
|
| 123 |
+
**build_kwargs_from_config(kwargs, ClsHead),
|
| 124 |
+
)
|
| 125 |
+
model = EfficientViTCls(backbone, head)
|
| 126 |
+
return model
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def efficientvit_cls_l1(**kwargs) -> EfficientViTCls:
|
| 130 |
+
from efficientvit.models.efficientvit.backbone import \
|
| 131 |
+
efficientvit_backbone_l1
|
| 132 |
+
|
| 133 |
+
backbone = efficientvit_backbone_l1(**kwargs)
|
| 134 |
+
|
| 135 |
+
head = ClsHead(
|
| 136 |
+
in_channels=512,
|
| 137 |
+
width_list=[3072, 3200],
|
| 138 |
+
act_func="gelu",
|
| 139 |
+
**build_kwargs_from_config(kwargs, ClsHead),
|
| 140 |
+
)
|
| 141 |
+
model = EfficientViTCls(backbone, head)
|
| 142 |
+
return model
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def efficientvit_cls_l2(**kwargs) -> EfficientViTCls:
|
| 146 |
+
from efficientvit.models.efficientvit.backbone import \
|
| 147 |
+
efficientvit_backbone_l2
|
| 148 |
+
|
| 149 |
+
backbone = efficientvit_backbone_l2(**kwargs)
|
| 150 |
+
|
| 151 |
+
head = ClsHead(
|
| 152 |
+
in_channels=512,
|
| 153 |
+
width_list=[3072, 3200],
|
| 154 |
+
act_func="gelu",
|
| 155 |
+
**build_kwargs_from_config(kwargs, ClsHead),
|
| 156 |
+
)
|
| 157 |
+
model = EfficientViTCls(backbone, head)
|
| 158 |
+
return model
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def efficientvit_cls_l3(**kwargs) -> EfficientViTCls:
|
| 162 |
+
from efficientvit.models.efficientvit.backbone import \
|
| 163 |
+
efficientvit_backbone_l3
|
| 164 |
+
|
| 165 |
+
backbone = efficientvit_backbone_l3(**kwargs)
|
| 166 |
+
|
| 167 |
+
head = ClsHead(
|
| 168 |
+
in_channels=1024,
|
| 169 |
+
width_list=[6144, 6400],
|
| 170 |
+
act_func="gelu",
|
| 171 |
+
**build_kwargs_from_config(kwargs, ClsHead),
|
| 172 |
+
)
|
| 173 |
+
model = EfficientViTCls(backbone, head)
|
| 174 |
+
return model
|
src/efficientvit/models/efficientvit/sam.py
ADDED
|
@@ -0,0 +1,653 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
import copy
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torchvision.transforms as transforms
|
| 12 |
+
from segment_anything import SamAutomaticMaskGenerator
|
| 13 |
+
from segment_anything.modeling import (MaskDecoder, PromptEncoder,
|
| 14 |
+
TwoWayTransformer)
|
| 15 |
+
from segment_anything.modeling.mask_decoder import MaskDecoder
|
| 16 |
+
from segment_anything.modeling.prompt_encoder import PromptEncoder
|
| 17 |
+
from segment_anything.utils.amg import build_all_layer_point_grids
|
| 18 |
+
from segment_anything.utils.transforms import ResizeLongestSide
|
| 19 |
+
from torchvision.transforms.functional import resize, to_pil_image
|
| 20 |
+
|
| 21 |
+
from src.efficientvit.models.efficientvit.backbone import (
|
| 22 |
+
EfficientViTBackbone, EfficientViTLargeBackbone)
|
| 23 |
+
from src.efficientvit.models.nn import (ConvLayer, DAGBlock, FusedMBConv,
|
| 24 |
+
IdentityLayer, MBConv, OpSequential,
|
| 25 |
+
ResBlock, ResidualBlock, UpSampleLayer,
|
| 26 |
+
build_norm)
|
| 27 |
+
from src.efficientvit.models.utils import build_kwargs_from_config, get_device
|
| 28 |
+
|
| 29 |
+
__all__ = [
|
| 30 |
+
"SamPad",
|
| 31 |
+
"SamResize",
|
| 32 |
+
"SamNeck",
|
| 33 |
+
"EfficientViTSamImageEncoder",
|
| 34 |
+
"EfficientViTSam",
|
| 35 |
+
"EfficientViTSamPredictor",
|
| 36 |
+
"EfficientViTSamAutomaticMaskGenerator",
|
| 37 |
+
"efficientvit_sam_l0",
|
| 38 |
+
"efficientvit_sam_l1",
|
| 39 |
+
"efficientvit_sam_l2",
|
| 40 |
+
"efficientvit_sam_xl0",
|
| 41 |
+
"efficientvit_sam_xl1",
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class SamPad:
|
| 46 |
+
def __init__(self, size: int, fill: float = 0, pad_mode="corner") -> None:
|
| 47 |
+
self.size = size
|
| 48 |
+
self.fill = fill
|
| 49 |
+
self.pad_mode = pad_mode
|
| 50 |
+
|
| 51 |
+
def __call__(self, image: torch.Tensor) -> torch.Tensor:
|
| 52 |
+
h, w = image.shape[-2:]
|
| 53 |
+
th, tw = self.size, self.size
|
| 54 |
+
assert th >= h and tw >= w
|
| 55 |
+
if self.pad_mode == "corner":
|
| 56 |
+
image = F.pad(image, (0, tw - w, 0, th - h), value=self.fill)
|
| 57 |
+
else:
|
| 58 |
+
raise NotImplementedError
|
| 59 |
+
return image
|
| 60 |
+
|
| 61 |
+
def __repr__(self) -> str:
|
| 62 |
+
return f"{type(self).__name__}(size={self.size},mode={self.pad_mode},fill={self.fill})"
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class SamResize:
|
| 66 |
+
def __init__(self, size: int) -> None:
|
| 67 |
+
self.size = size
|
| 68 |
+
|
| 69 |
+
def __call__(self, image: np.ndarray) -> np.ndarray:
|
| 70 |
+
h, w, _ = image.shape
|
| 71 |
+
long_side = max(h, w)
|
| 72 |
+
if long_side != self.size:
|
| 73 |
+
return self.apply_image(image)
|
| 74 |
+
else:
|
| 75 |
+
return image
|
| 76 |
+
|
| 77 |
+
def apply_image(self, image: np.ndarray) -> np.ndarray:
|
| 78 |
+
"""
|
| 79 |
+
Expects a numpy array with shape HxWxC in uint8 format.
|
| 80 |
+
"""
|
| 81 |
+
target_size = self.get_preprocess_shape(
|
| 82 |
+
image.shape[0], image.shape[1], self.size
|
| 83 |
+
)
|
| 84 |
+
return np.array(resize(to_pil_image(image), target_size))
|
| 85 |
+
|
| 86 |
+
@staticmethod
|
| 87 |
+
def get_preprocess_shape(
|
| 88 |
+
oldh: int, oldw: int, long_side_length: int
|
| 89 |
+
) -> tuple[int, int]:
|
| 90 |
+
"""
|
| 91 |
+
Compute the output size given input size and target long side length.
|
| 92 |
+
"""
|
| 93 |
+
scale = long_side_length * 1.0 / max(oldh, oldw)
|
| 94 |
+
newh, neww = oldh * scale, oldw * scale
|
| 95 |
+
neww = int(neww + 0.5)
|
| 96 |
+
newh = int(newh + 0.5)
|
| 97 |
+
return (newh, neww)
|
| 98 |
+
|
| 99 |
+
def __repr__(self) -> str:
|
| 100 |
+
return f"{type(self).__name__}(size={self.size})"
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class SamNeck(DAGBlock):
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
fid_list: list[str],
|
| 107 |
+
in_channel_list: list[int],
|
| 108 |
+
head_width: int,
|
| 109 |
+
head_depth: int,
|
| 110 |
+
expand_ratio: float,
|
| 111 |
+
middle_op: str,
|
| 112 |
+
out_dim: int = 256,
|
| 113 |
+
norm="bn2d",
|
| 114 |
+
act_func="gelu",
|
| 115 |
+
):
|
| 116 |
+
inputs = {}
|
| 117 |
+
for fid, in_channel in zip(fid_list, in_channel_list):
|
| 118 |
+
inputs[fid] = OpSequential(
|
| 119 |
+
[
|
| 120 |
+
ConvLayer(in_channel, head_width, 1, norm=norm, act_func=None),
|
| 121 |
+
UpSampleLayer(size=(64, 64)),
|
| 122 |
+
]
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
middle = []
|
| 126 |
+
for _ in range(head_depth):
|
| 127 |
+
if middle_op == "mb":
|
| 128 |
+
block = MBConv(
|
| 129 |
+
head_width,
|
| 130 |
+
head_width,
|
| 131 |
+
expand_ratio=expand_ratio,
|
| 132 |
+
norm=norm,
|
| 133 |
+
act_func=(act_func, act_func, None),
|
| 134 |
+
)
|
| 135 |
+
elif middle_op == "fmb":
|
| 136 |
+
block = FusedMBConv(
|
| 137 |
+
head_width,
|
| 138 |
+
head_width,
|
| 139 |
+
expand_ratio=expand_ratio,
|
| 140 |
+
norm=norm,
|
| 141 |
+
act_func=(act_func, None),
|
| 142 |
+
)
|
| 143 |
+
elif middle_op == "res":
|
| 144 |
+
block = ResBlock(
|
| 145 |
+
head_width,
|
| 146 |
+
head_width,
|
| 147 |
+
expand_ratio=expand_ratio,
|
| 148 |
+
norm=norm,
|
| 149 |
+
act_func=(act_func, None),
|
| 150 |
+
)
|
| 151 |
+
else:
|
| 152 |
+
raise NotImplementedError
|
| 153 |
+
middle.append(ResidualBlock(block, IdentityLayer()))
|
| 154 |
+
middle = OpSequential(middle)
|
| 155 |
+
|
| 156 |
+
outputs = {
|
| 157 |
+
"sam_encoder": OpSequential(
|
| 158 |
+
[
|
| 159 |
+
ConvLayer(
|
| 160 |
+
head_width,
|
| 161 |
+
out_dim,
|
| 162 |
+
1,
|
| 163 |
+
use_bias=True,
|
| 164 |
+
norm=None,
|
| 165 |
+
act_func=None,
|
| 166 |
+
),
|
| 167 |
+
]
|
| 168 |
+
)
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
super(SamNeck, self).__init__(
|
| 172 |
+
inputs, "add", None, middle=middle, outputs=outputs
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class EfficientViTSamImageEncoder(nn.Module):
|
| 177 |
+
def __init__(
|
| 178 |
+
self, backbone: EfficientViTBackbone or EfficientViTLargeBackbone, neck: SamNeck
|
| 179 |
+
):
|
| 180 |
+
super().__init__()
|
| 181 |
+
self.backbone = backbone
|
| 182 |
+
self.neck = neck
|
| 183 |
+
|
| 184 |
+
self.norm = build_norm("ln2d", 256)
|
| 185 |
+
|
| 186 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 187 |
+
feed_dict = self.backbone(x)
|
| 188 |
+
feed_dict = self.neck(feed_dict)
|
| 189 |
+
|
| 190 |
+
output = feed_dict["sam_encoder"]
|
| 191 |
+
output = self.norm(output)
|
| 192 |
+
return output
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class EfficientViTSam(nn.Module):
|
| 196 |
+
mask_threshold: float = 0.0
|
| 197 |
+
image_format: str = "RGB"
|
| 198 |
+
|
| 199 |
+
def __init__(
|
| 200 |
+
self,
|
| 201 |
+
image_encoder: EfficientViTSamImageEncoder,
|
| 202 |
+
prompt_encoder: PromptEncoder,
|
| 203 |
+
mask_decoder: MaskDecoder,
|
| 204 |
+
image_size: tuple[int, int] = (1024, 512),
|
| 205 |
+
) -> None:
|
| 206 |
+
super().__init__()
|
| 207 |
+
self.image_encoder = image_encoder
|
| 208 |
+
self.prompt_encoder = prompt_encoder
|
| 209 |
+
self.mask_decoder = mask_decoder
|
| 210 |
+
|
| 211 |
+
self.image_size = image_size
|
| 212 |
+
|
| 213 |
+
self.transform = transforms.Compose(
|
| 214 |
+
[
|
| 215 |
+
SamResize(self.image_size[1]),
|
| 216 |
+
transforms.ToTensor(),
|
| 217 |
+
transforms.Normalize(
|
| 218 |
+
mean=[123.675 / 255, 116.28 / 255, 103.53 / 255],
|
| 219 |
+
std=[58.395 / 255, 57.12 / 255, 57.375 / 255],
|
| 220 |
+
),
|
| 221 |
+
SamPad(self.image_size[1]),
|
| 222 |
+
]
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
def postprocess_masks(
|
| 226 |
+
self,
|
| 227 |
+
masks: torch.Tensor,
|
| 228 |
+
input_size: tuple[int, ...],
|
| 229 |
+
original_size: tuple[int, ...],
|
| 230 |
+
) -> torch.Tensor:
|
| 231 |
+
masks = F.interpolate(
|
| 232 |
+
masks,
|
| 233 |
+
(self.image_size[0], self.image_size[0]),
|
| 234 |
+
mode="bilinear",
|
| 235 |
+
align_corners=False,
|
| 236 |
+
)
|
| 237 |
+
masks = masks[..., : input_size[0], : input_size[1]]
|
| 238 |
+
masks = F.interpolate(
|
| 239 |
+
masks, original_size, mode="bilinear", align_corners=False
|
| 240 |
+
)
|
| 241 |
+
return masks
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class EfficientViTSamPredictor:
|
| 245 |
+
def __init__(self, sam_model: EfficientViTSam) -> None:
|
| 246 |
+
self.model = sam_model
|
| 247 |
+
self.reset_image()
|
| 248 |
+
|
| 249 |
+
@property
|
| 250 |
+
def transform(self):
|
| 251 |
+
return self
|
| 252 |
+
|
| 253 |
+
@property
|
| 254 |
+
def device(self):
|
| 255 |
+
return get_device(self.model)
|
| 256 |
+
|
| 257 |
+
def reset_image(self) -> None:
|
| 258 |
+
self.is_image_set = False
|
| 259 |
+
self.features = None
|
| 260 |
+
self.original_size = None
|
| 261 |
+
self.input_size = None
|
| 262 |
+
|
| 263 |
+
def apply_coords(self, coords: np.ndarray, im_size=None) -> np.ndarray:
|
| 264 |
+
old_h, old_w = self.original_size
|
| 265 |
+
new_h, new_w = self.input_size
|
| 266 |
+
coords = copy.deepcopy(coords).astype(float)
|
| 267 |
+
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
| 268 |
+
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
| 269 |
+
return coords
|
| 270 |
+
|
| 271 |
+
def apply_boxes(self, boxes: np.ndarray, im_size=None) -> np.ndarray:
|
| 272 |
+
boxes = self.apply_coords(boxes.reshape(-1, 2, 2))
|
| 273 |
+
return boxes.reshape(-1, 4)
|
| 274 |
+
|
| 275 |
+
@torch.inference_mode()
|
| 276 |
+
def set_image(self, image: np.ndarray, image_format: str = "RGB") -> None:
|
| 277 |
+
assert image_format in [
|
| 278 |
+
"RGB",
|
| 279 |
+
"BGR",
|
| 280 |
+
], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
|
| 281 |
+
if image_format != self.model.image_format:
|
| 282 |
+
image = image[..., ::-1]
|
| 283 |
+
|
| 284 |
+
self.reset_image()
|
| 285 |
+
|
| 286 |
+
self.original_size = image.shape[:2]
|
| 287 |
+
self.input_size = ResizeLongestSide.get_preprocess_shape(
|
| 288 |
+
*self.original_size, long_side_length=self.model.image_size[0]
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
torch_data = (
|
| 292 |
+
self.model.transform(image).unsqueeze(dim=0).to(get_device(self.model))
|
| 293 |
+
)
|
| 294 |
+
self.features = self.model.image_encoder(torch_data)
|
| 295 |
+
self.is_image_set = True
|
| 296 |
+
|
| 297 |
+
def predict(
|
| 298 |
+
self,
|
| 299 |
+
point_coords: np.ndarray or None = None,
|
| 300 |
+
point_labels: np.ndarray or None = None,
|
| 301 |
+
box: np.ndarray or None = None,
|
| 302 |
+
mask_input: np.ndarray or None = None,
|
| 303 |
+
multimask_output: bool = True,
|
| 304 |
+
return_logits: bool = False,
|
| 305 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 306 |
+
"""
|
| 307 |
+
Predict masks for the given input prompts, using the currently set image.
|
| 308 |
+
|
| 309 |
+
Arguments:
|
| 310 |
+
point_coords (np.ndarray or None): A Nx2 array of point prompts to the
|
| 311 |
+
model. Each point is in (X,Y) in pixels.
|
| 312 |
+
point_labels (np.ndarray or None): A length N array of labels for the
|
| 313 |
+
point prompts. 1 indicates a foreground point and 0 indicates a
|
| 314 |
+
background point.
|
| 315 |
+
box (np.ndarray or None): A length 4 array given a box prompt to the
|
| 316 |
+
model, in XYXY format.
|
| 317 |
+
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
| 318 |
+
coming from a previous prediction iteration. Has form 1xHxW, where
|
| 319 |
+
for SAM, H=W=256.
|
| 320 |
+
multimask_output (bool): If true, the model will return three masks.
|
| 321 |
+
For ambiguous input prompts (such as a single click), this will often
|
| 322 |
+
produce better masks than a single prediction. If only a single
|
| 323 |
+
mask is needed, the model's predicted quality score can be used
|
| 324 |
+
to select the best mask. For non-ambiguous prompts, such as multiple
|
| 325 |
+
input prompts, multimask_output=False can give better results.
|
| 326 |
+
return_logits (bool): If true, returns un-thresholded masks logits
|
| 327 |
+
instead of a binary mask.
|
| 328 |
+
|
| 329 |
+
Returns:
|
| 330 |
+
(np.ndarray): The output masks in CxHxW format, where C is the
|
| 331 |
+
number of masks, and (H, W) is the original image size.
|
| 332 |
+
(np.ndarray): An array of length C containing the model's
|
| 333 |
+
predictions for the quality of each mask.
|
| 334 |
+
(np.ndarray): An array of shape CxHxW, where C is the number
|
| 335 |
+
of masks and H=W=256. These low resolution logits can be passed to
|
| 336 |
+
a subsequent iteration as mask input.
|
| 337 |
+
"""
|
| 338 |
+
if not self.is_image_set:
|
| 339 |
+
raise RuntimeError(
|
| 340 |
+
"An image must be set with .set_image(...) before mask prediction."
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
device = get_device(self.model)
|
| 344 |
+
# Transform input prompts
|
| 345 |
+
coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
|
| 346 |
+
if point_coords is not None:
|
| 347 |
+
assert (
|
| 348 |
+
point_labels is not None
|
| 349 |
+
), "point_labels must be supplied if point_coords is supplied."
|
| 350 |
+
point_coords = self.apply_coords(point_coords)
|
| 351 |
+
coords_torch = torch.as_tensor(
|
| 352 |
+
point_coords, dtype=torch.float, device=device
|
| 353 |
+
)
|
| 354 |
+
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device)
|
| 355 |
+
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
|
| 356 |
+
if box is not None:
|
| 357 |
+
box = self.apply_boxes(box)
|
| 358 |
+
box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
|
| 359 |
+
box_torch = box_torch[None, :]
|
| 360 |
+
if mask_input is not None:
|
| 361 |
+
mask_input_torch = torch.as_tensor(
|
| 362 |
+
mask_input, dtype=torch.float, device=device
|
| 363 |
+
)
|
| 364 |
+
mask_input_torch = mask_input_torch[None, :, :, :]
|
| 365 |
+
|
| 366 |
+
masks, iou_predictions, low_res_masks = self.predict_torch(
|
| 367 |
+
coords_torch,
|
| 368 |
+
labels_torch,
|
| 369 |
+
box_torch,
|
| 370 |
+
mask_input_torch,
|
| 371 |
+
multimask_output,
|
| 372 |
+
return_logits=return_logits,
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
masks = masks[0].detach().cpu().numpy()
|
| 376 |
+
iou_predictions = iou_predictions[0].detach().cpu().numpy()
|
| 377 |
+
low_res_masks = low_res_masks[0].detach().cpu().numpy()
|
| 378 |
+
return masks, iou_predictions, low_res_masks
|
| 379 |
+
|
| 380 |
+
@torch.inference_mode()
|
| 381 |
+
def predict_torch(
|
| 382 |
+
self,
|
| 383 |
+
point_coords: torch.Tensor or None = None,
|
| 384 |
+
point_labels: torch.Tensor or None = None,
|
| 385 |
+
boxes: torch.Tensor or None = None,
|
| 386 |
+
mask_input: torch.Tensor or None = None,
|
| 387 |
+
multimask_output: bool = True,
|
| 388 |
+
return_logits: bool = False,
|
| 389 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 390 |
+
"""
|
| 391 |
+
Predict masks for the given input prompts, using the currently set image.
|
| 392 |
+
Input prompts are batched torch tensors and are expected to already be
|
| 393 |
+
transformed to the input frame using ResizeLongestSide.
|
| 394 |
+
|
| 395 |
+
Arguments:
|
| 396 |
+
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
|
| 397 |
+
model. Each point is in (X,Y) in pixels.
|
| 398 |
+
point_labels (torch.Tensor or None): A BxN array of labels for the
|
| 399 |
+
point prompts. 1 indicates a foreground point and 0 indicates a
|
| 400 |
+
background point.
|
| 401 |
+
box (np.ndarray or None): A Bx4 array given a box prompt to the
|
| 402 |
+
model, in XYXY format.
|
| 403 |
+
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
| 404 |
+
coming from a previous prediction iteration. Has form Bx1xHxW, where
|
| 405 |
+
for SAM, H=W=256. Masks returned by a previous iteration of the
|
| 406 |
+
predict method do not need further transformation.
|
| 407 |
+
multimask_output (bool): If true, the model will return three masks.
|
| 408 |
+
For ambiguous input prompts (such as a single click), this will often
|
| 409 |
+
produce better masks than a single prediction. If only a single
|
| 410 |
+
mask is needed, the model's predicted quality score can be used
|
| 411 |
+
to select the best mask. For non-ambiguous prompts, such as multiple
|
| 412 |
+
input prompts, multimask_output=False can give better results.
|
| 413 |
+
return_logits (bool): If true, returns un-thresholded masks logits
|
| 414 |
+
instead of a binary mask.
|
| 415 |
+
|
| 416 |
+
Returns:
|
| 417 |
+
(torch.Tensor): The output masks in BxCxHxW format, where C is the
|
| 418 |
+
number of masks, and (H, W) is the original image size.
|
| 419 |
+
(torch.Tensor): An array of shape BxC containing the model's
|
| 420 |
+
predictions for the quality of each mask.
|
| 421 |
+
(torch.Tensor): An array of shape BxCxHxW, where C is the number
|
| 422 |
+
of masks and H=W=256. These low res logits can be passed to
|
| 423 |
+
a subsequent iteration as mask input.
|
| 424 |
+
"""
|
| 425 |
+
if not self.is_image_set:
|
| 426 |
+
raise RuntimeError(
|
| 427 |
+
"An image must be set with .set_image(...) before mask prediction."
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
if point_coords is not None:
|
| 431 |
+
points = (point_coords, point_labels)
|
| 432 |
+
else:
|
| 433 |
+
points = None
|
| 434 |
+
|
| 435 |
+
# Embed prompts
|
| 436 |
+
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
|
| 437 |
+
points=points,
|
| 438 |
+
boxes=boxes,
|
| 439 |
+
masks=mask_input,
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
# Predict masks
|
| 443 |
+
low_res_masks, iou_predictions = self.model.mask_decoder(
|
| 444 |
+
image_embeddings=self.features,
|
| 445 |
+
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
| 446 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
| 447 |
+
dense_prompt_embeddings=dense_embeddings,
|
| 448 |
+
multimask_output=multimask_output,
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
# Upscale the masks to the original image resolution
|
| 452 |
+
masks = self.model.postprocess_masks(
|
| 453 |
+
low_res_masks, self.input_size, self.original_size
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
if not return_logits:
|
| 457 |
+
masks = masks > self.model.mask_threshold
|
| 458 |
+
|
| 459 |
+
return masks, iou_predictions, low_res_masks
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
class EfficientViTSamAutomaticMaskGenerator(SamAutomaticMaskGenerator):
|
| 463 |
+
def __init__(
|
| 464 |
+
self,
|
| 465 |
+
model: EfficientViTSam,
|
| 466 |
+
points_per_side: int or None = 32,
|
| 467 |
+
points_per_batch: int = 64,
|
| 468 |
+
pred_iou_thresh: float = 0.88,
|
| 469 |
+
stability_score_thresh: float = 0.95,
|
| 470 |
+
stability_score_offset: float = 1.0,
|
| 471 |
+
box_nms_thresh: float = 0.7,
|
| 472 |
+
crop_n_layers: int = 0,
|
| 473 |
+
crop_nms_thresh: float = 0.7,
|
| 474 |
+
crop_overlap_ratio: float = 512 / 1500,
|
| 475 |
+
crop_n_points_downscale_factor: int = 1,
|
| 476 |
+
point_grids: list[np.ndarray] or None = None,
|
| 477 |
+
min_mask_region_area: int = 0,
|
| 478 |
+
output_mode: str = "binary_mask",
|
| 479 |
+
) -> None:
|
| 480 |
+
assert (points_per_side is None) != (
|
| 481 |
+
point_grids is None
|
| 482 |
+
), "Exactly one of points_per_side or point_grid must be provided."
|
| 483 |
+
if points_per_side is not None:
|
| 484 |
+
self.point_grids = build_all_layer_point_grids(
|
| 485 |
+
points_per_side,
|
| 486 |
+
crop_n_layers,
|
| 487 |
+
crop_n_points_downscale_factor,
|
| 488 |
+
)
|
| 489 |
+
elif point_grids is not None:
|
| 490 |
+
self.point_grids = point_grids
|
| 491 |
+
else:
|
| 492 |
+
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
| 493 |
+
|
| 494 |
+
assert output_mode in [
|
| 495 |
+
"binary_mask",
|
| 496 |
+
"uncompressed_rle",
|
| 497 |
+
"coco_rle",
|
| 498 |
+
], f"Unknown output_mode {output_mode}."
|
| 499 |
+
if output_mode == "coco_rle":
|
| 500 |
+
from pycocotools import \
|
| 501 |
+
mask as mask_utils # type: ignore # noqa: F401
|
| 502 |
+
|
| 503 |
+
if min_mask_region_area > 0:
|
| 504 |
+
import cv2 # type: ignore # noqa: F401
|
| 505 |
+
|
| 506 |
+
self.predictor = EfficientViTSamPredictor(model)
|
| 507 |
+
self.points_per_batch = points_per_batch
|
| 508 |
+
self.pred_iou_thresh = pred_iou_thresh
|
| 509 |
+
self.stability_score_thresh = stability_score_thresh
|
| 510 |
+
self.stability_score_offset = stability_score_offset
|
| 511 |
+
self.box_nms_thresh = box_nms_thresh
|
| 512 |
+
self.crop_n_layers = crop_n_layers
|
| 513 |
+
self.crop_nms_thresh = crop_nms_thresh
|
| 514 |
+
self.crop_overlap_ratio = crop_overlap_ratio
|
| 515 |
+
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
| 516 |
+
self.min_mask_region_area = min_mask_region_area
|
| 517 |
+
self.output_mode = output_mode
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def build_efficientvit_sam(
|
| 521 |
+
image_encoder: EfficientViTSamImageEncoder, image_size: int
|
| 522 |
+
) -> EfficientViTSam:
|
| 523 |
+
return EfficientViTSam(
|
| 524 |
+
image_encoder=image_encoder,
|
| 525 |
+
prompt_encoder=PromptEncoder(
|
| 526 |
+
embed_dim=256,
|
| 527 |
+
image_embedding_size=(64, 64),
|
| 528 |
+
input_image_size=(1024, 1024),
|
| 529 |
+
mask_in_chans=16,
|
| 530 |
+
),
|
| 531 |
+
mask_decoder=MaskDecoder(
|
| 532 |
+
num_multimask_outputs=3,
|
| 533 |
+
transformer=TwoWayTransformer(
|
| 534 |
+
depth=2,
|
| 535 |
+
embedding_dim=256,
|
| 536 |
+
mlp_dim=2048,
|
| 537 |
+
num_heads=8,
|
| 538 |
+
),
|
| 539 |
+
transformer_dim=256,
|
| 540 |
+
iou_head_depth=3,
|
| 541 |
+
iou_head_hidden_dim=256,
|
| 542 |
+
),
|
| 543 |
+
image_size=(1024, image_size),
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def efficientvit_sam_l0(image_size: int = 512, **kwargs) -> EfficientViTSam:
|
| 548 |
+
from efficientvit.models.efficientvit.backbone import \
|
| 549 |
+
efficientvit_backbone_l0
|
| 550 |
+
|
| 551 |
+
backbone = efficientvit_backbone_l0(**kwargs)
|
| 552 |
+
|
| 553 |
+
neck = SamNeck(
|
| 554 |
+
fid_list=["stage4", "stage3", "stage2"],
|
| 555 |
+
in_channel_list=[512, 256, 128],
|
| 556 |
+
head_width=256,
|
| 557 |
+
head_depth=4,
|
| 558 |
+
expand_ratio=1,
|
| 559 |
+
middle_op="fmb",
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
image_encoder = EfficientViTSamImageEncoder(backbone, neck)
|
| 563 |
+
return build_efficientvit_sam(image_encoder, image_size)
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def efficientvit_sam_l1(image_size: int = 512, **kwargs) -> EfficientViTSam:
|
| 567 |
+
from efficientvit.models.efficientvit.backbone import \
|
| 568 |
+
efficientvit_backbone_l1
|
| 569 |
+
|
| 570 |
+
backbone = efficientvit_backbone_l1(**kwargs)
|
| 571 |
+
|
| 572 |
+
neck = SamNeck(
|
| 573 |
+
fid_list=["stage4", "stage3", "stage2"],
|
| 574 |
+
in_channel_list=[512, 256, 128],
|
| 575 |
+
head_width=256,
|
| 576 |
+
head_depth=8,
|
| 577 |
+
expand_ratio=1,
|
| 578 |
+
middle_op="fmb",
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
image_encoder = EfficientViTSamImageEncoder(backbone, neck)
|
| 582 |
+
return build_efficientvit_sam(image_encoder, image_size)
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def efficientvit_sam_l2(image_size: int = 512, **kwargs) -> EfficientViTSam:
|
| 586 |
+
from efficientvit.models.efficientvit.backbone import \
|
| 587 |
+
efficientvit_backbone_l2
|
| 588 |
+
|
| 589 |
+
backbone = efficientvit_backbone_l2(**kwargs)
|
| 590 |
+
|
| 591 |
+
neck = SamNeck(
|
| 592 |
+
fid_list=["stage4", "stage3", "stage2"],
|
| 593 |
+
in_channel_list=[512, 256, 128],
|
| 594 |
+
head_width=256,
|
| 595 |
+
head_depth=12,
|
| 596 |
+
expand_ratio=1,
|
| 597 |
+
middle_op="fmb",
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
image_encoder = EfficientViTSamImageEncoder(backbone, neck)
|
| 601 |
+
return build_efficientvit_sam(image_encoder, image_size)
|
| 602 |
+
|
| 603 |
+
|
| 604 |
+
def efficientvit_sam_xl0(image_size: int = 1024, **kwargs) -> EfficientViTSam:
|
| 605 |
+
from efficientvit.models.efficientvit.backbone import \
|
| 606 |
+
EfficientViTLargeBackbone
|
| 607 |
+
|
| 608 |
+
backbone = EfficientViTLargeBackbone(
|
| 609 |
+
width_list=[32, 64, 128, 256, 512, 1024],
|
| 610 |
+
depth_list=[0, 1, 1, 2, 3, 3],
|
| 611 |
+
block_list=["res", "fmb", "fmb", "fmb", "att@3", "att@3"],
|
| 612 |
+
expand_list=[1, 4, 4, 4, 4, 6],
|
| 613 |
+
fewer_norm_list=[False, False, False, False, True, True],
|
| 614 |
+
**build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
neck = SamNeck(
|
| 618 |
+
fid_list=["stage5", "stage4", "stage3"],
|
| 619 |
+
in_channel_list=[1024, 512, 256],
|
| 620 |
+
head_width=256,
|
| 621 |
+
head_depth=6,
|
| 622 |
+
expand_ratio=4,
|
| 623 |
+
middle_op="fmb",
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
image_encoder = EfficientViTSamImageEncoder(backbone, neck)
|
| 627 |
+
return build_efficientvit_sam(image_encoder, image_size)
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
def efficientvit_sam_xl1(image_size: int = 1024, **kwargs) -> EfficientViTSam:
|
| 631 |
+
from src.efficientvit.models.efficientvit.backbone import \
|
| 632 |
+
EfficientViTLargeBackbone
|
| 633 |
+
|
| 634 |
+
backbone = EfficientViTLargeBackbone(
|
| 635 |
+
width_list=[32, 64, 128, 256, 512, 1024],
|
| 636 |
+
depth_list=[1, 2, 2, 4, 6, 6],
|
| 637 |
+
block_list=["res", "fmb", "fmb", "fmb", "att@3", "att@3"],
|
| 638 |
+
expand_list=[1, 4, 4, 4, 4, 6],
|
| 639 |
+
fewer_norm_list=[False, False, False, False, True, True],
|
| 640 |
+
**build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
neck = SamNeck(
|
| 644 |
+
fid_list=["stage5", "stage4", "stage3"],
|
| 645 |
+
in_channel_list=[1024, 512, 256],
|
| 646 |
+
head_width=256,
|
| 647 |
+
head_depth=12,
|
| 648 |
+
expand_ratio=4,
|
| 649 |
+
middle_op="fmb",
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
image_encoder = EfficientViTSamImageEncoder(backbone, neck)
|
| 653 |
+
return build_efficientvit_sam(image_encoder, image_size)
|
src/efficientvit/models/efficientvit/seg.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from src.efficientvit.models.efficientvit.backbone import (
|
| 9 |
+
EfficientViTBackbone, EfficientViTLargeBackbone)
|
| 10 |
+
from src.efficientvit.models.nn import (ConvLayer, DAGBlock, FusedMBConv,
|
| 11 |
+
IdentityLayer, MBConv, OpSequential,
|
| 12 |
+
ResidualBlock, UpSampleLayer)
|
| 13 |
+
from src.efficientvit.models.utils import build_kwargs_from_config
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"EfficientViTSeg",
|
| 17 |
+
"efficientvit_seg_b0",
|
| 18 |
+
"efficientvit_seg_b1",
|
| 19 |
+
"efficientvit_seg_b2",
|
| 20 |
+
"efficientvit_seg_b3",
|
| 21 |
+
"efficientvit_seg_l1",
|
| 22 |
+
"efficientvit_seg_l2",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class SegHead(DAGBlock):
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
fid_list: list[str],
|
| 30 |
+
in_channel_list: list[int],
|
| 31 |
+
stride_list: list[int],
|
| 32 |
+
head_stride: int,
|
| 33 |
+
head_width: int,
|
| 34 |
+
head_depth: int,
|
| 35 |
+
expand_ratio: float,
|
| 36 |
+
middle_op: str,
|
| 37 |
+
final_expand: float or None,
|
| 38 |
+
n_classes: int,
|
| 39 |
+
dropout=0,
|
| 40 |
+
norm="bn2d",
|
| 41 |
+
act_func="hswish",
|
| 42 |
+
):
|
| 43 |
+
inputs = {}
|
| 44 |
+
for fid, in_channel, stride in zip(fid_list, in_channel_list, stride_list):
|
| 45 |
+
factor = stride // head_stride
|
| 46 |
+
if factor == 1:
|
| 47 |
+
inputs[fid] = ConvLayer(
|
| 48 |
+
in_channel, head_width, 1, norm=norm, act_func=None
|
| 49 |
+
)
|
| 50 |
+
else:
|
| 51 |
+
inputs[fid] = OpSequential(
|
| 52 |
+
[
|
| 53 |
+
ConvLayer(in_channel, head_width, 1, norm=norm, act_func=None),
|
| 54 |
+
UpSampleLayer(factor=factor),
|
| 55 |
+
]
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
middle = []
|
| 59 |
+
for _ in range(head_depth):
|
| 60 |
+
if middle_op == "mbconv":
|
| 61 |
+
block = MBConv(
|
| 62 |
+
head_width,
|
| 63 |
+
head_width,
|
| 64 |
+
expand_ratio=expand_ratio,
|
| 65 |
+
norm=norm,
|
| 66 |
+
act_func=(act_func, act_func, None),
|
| 67 |
+
)
|
| 68 |
+
elif middle_op == "fmbconv":
|
| 69 |
+
block = FusedMBConv(
|
| 70 |
+
head_width,
|
| 71 |
+
head_width,
|
| 72 |
+
expand_ratio=expand_ratio,
|
| 73 |
+
norm=norm,
|
| 74 |
+
act_func=(act_func, None),
|
| 75 |
+
)
|
| 76 |
+
else:
|
| 77 |
+
raise NotImplementedError
|
| 78 |
+
middle.append(ResidualBlock(block, IdentityLayer()))
|
| 79 |
+
middle = OpSequential(middle)
|
| 80 |
+
|
| 81 |
+
outputs = {
|
| 82 |
+
"segout": OpSequential(
|
| 83 |
+
[
|
| 84 |
+
(
|
| 85 |
+
None
|
| 86 |
+
if final_expand is None
|
| 87 |
+
else ConvLayer(
|
| 88 |
+
head_width,
|
| 89 |
+
head_width * final_expand,
|
| 90 |
+
1,
|
| 91 |
+
norm=norm,
|
| 92 |
+
act_func=act_func,
|
| 93 |
+
)
|
| 94 |
+
),
|
| 95 |
+
ConvLayer(
|
| 96 |
+
head_width * (final_expand or 1),
|
| 97 |
+
n_classes,
|
| 98 |
+
1,
|
| 99 |
+
use_bias=True,
|
| 100 |
+
dropout=dropout,
|
| 101 |
+
norm=None,
|
| 102 |
+
act_func=None,
|
| 103 |
+
),
|
| 104 |
+
]
|
| 105 |
+
)
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
super(SegHead, self).__init__(
|
| 109 |
+
inputs, "add", None, middle=middle, outputs=outputs
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class EfficientViTSeg(nn.Module):
|
| 114 |
+
def __init__(
|
| 115 |
+
self, backbone: EfficientViTBackbone or EfficientViTLargeBackbone, head: SegHead
|
| 116 |
+
) -> None:
|
| 117 |
+
super().__init__()
|
| 118 |
+
self.backbone = backbone
|
| 119 |
+
self.head = head
|
| 120 |
+
|
| 121 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 122 |
+
feed_dict = self.backbone(x)
|
| 123 |
+
feed_dict = self.head(feed_dict)
|
| 124 |
+
|
| 125 |
+
return feed_dict["segout"]
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def efficientvit_seg_b0(dataset: str, **kwargs) -> EfficientViTSeg:
|
| 129 |
+
from efficientvit.models.efficientvit.backbone import \
|
| 130 |
+
efficientvit_backbone_b0
|
| 131 |
+
|
| 132 |
+
backbone = efficientvit_backbone_b0(**kwargs)
|
| 133 |
+
|
| 134 |
+
if dataset == "cityscapes":
|
| 135 |
+
head = SegHead(
|
| 136 |
+
fid_list=["stage4", "stage3", "stage2"],
|
| 137 |
+
in_channel_list=[128, 64, 32],
|
| 138 |
+
stride_list=[32, 16, 8],
|
| 139 |
+
head_stride=8,
|
| 140 |
+
head_width=32,
|
| 141 |
+
head_depth=1,
|
| 142 |
+
expand_ratio=4,
|
| 143 |
+
middle_op="mbconv",
|
| 144 |
+
final_expand=4,
|
| 145 |
+
n_classes=19,
|
| 146 |
+
**build_kwargs_from_config(kwargs, SegHead),
|
| 147 |
+
)
|
| 148 |
+
else:
|
| 149 |
+
raise NotImplementedError
|
| 150 |
+
model = EfficientViTSeg(backbone, head)
|
| 151 |
+
return model
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def efficientvit_seg_b1(dataset: str, **kwargs) -> EfficientViTSeg:
|
| 155 |
+
from efficientvit.models.efficientvit.backbone import \
|
| 156 |
+
efficientvit_backbone_b1
|
| 157 |
+
|
| 158 |
+
backbone = efficientvit_backbone_b1(**kwargs)
|
| 159 |
+
|
| 160 |
+
if dataset == "cityscapes":
|
| 161 |
+
head = SegHead(
|
| 162 |
+
fid_list=["stage4", "stage3", "stage2"],
|
| 163 |
+
in_channel_list=[256, 128, 64],
|
| 164 |
+
stride_list=[32, 16, 8],
|
| 165 |
+
head_stride=8,
|
| 166 |
+
head_width=64,
|
| 167 |
+
head_depth=3,
|
| 168 |
+
expand_ratio=4,
|
| 169 |
+
middle_op="mbconv",
|
| 170 |
+
final_expand=4,
|
| 171 |
+
n_classes=19,
|
| 172 |
+
**build_kwargs_from_config(kwargs, SegHead),
|
| 173 |
+
)
|
| 174 |
+
elif dataset == "ade20k":
|
| 175 |
+
head = SegHead(
|
| 176 |
+
fid_list=["stage4", "stage3", "stage2"],
|
| 177 |
+
in_channel_list=[256, 128, 64],
|
| 178 |
+
stride_list=[32, 16, 8],
|
| 179 |
+
head_stride=8,
|
| 180 |
+
head_width=64,
|
| 181 |
+
head_depth=3,
|
| 182 |
+
expand_ratio=4,
|
| 183 |
+
middle_op="mbconv",
|
| 184 |
+
final_expand=None,
|
| 185 |
+
n_classes=150,
|
| 186 |
+
**build_kwargs_from_config(kwargs, SegHead),
|
| 187 |
+
)
|
| 188 |
+
else:
|
| 189 |
+
raise NotImplementedError
|
| 190 |
+
model = EfficientViTSeg(backbone, head)
|
| 191 |
+
return model
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def efficientvit_seg_b2(dataset: str, **kwargs) -> EfficientViTSeg:
|
| 195 |
+
from efficientvit.models.efficientvit.backbone import \
|
| 196 |
+
efficientvit_backbone_b2
|
| 197 |
+
|
| 198 |
+
backbone = efficientvit_backbone_b2(**kwargs)
|
| 199 |
+
|
| 200 |
+
if dataset == "cityscapes":
|
| 201 |
+
head = SegHead(
|
| 202 |
+
fid_list=["stage4", "stage3", "stage2"],
|
| 203 |
+
in_channel_list=[384, 192, 96],
|
| 204 |
+
stride_list=[32, 16, 8],
|
| 205 |
+
head_stride=8,
|
| 206 |
+
head_width=96,
|
| 207 |
+
head_depth=3,
|
| 208 |
+
expand_ratio=4,
|
| 209 |
+
middle_op="mbconv",
|
| 210 |
+
final_expand=4,
|
| 211 |
+
n_classes=19,
|
| 212 |
+
**build_kwargs_from_config(kwargs, SegHead),
|
| 213 |
+
)
|
| 214 |
+
elif dataset == "ade20k":
|
| 215 |
+
head = SegHead(
|
| 216 |
+
fid_list=["stage4", "stage3", "stage2"],
|
| 217 |
+
in_channel_list=[384, 192, 96],
|
| 218 |
+
stride_list=[32, 16, 8],
|
| 219 |
+
head_stride=8,
|
| 220 |
+
head_width=96,
|
| 221 |
+
head_depth=3,
|
| 222 |
+
expand_ratio=4,
|
| 223 |
+
middle_op="mbconv",
|
| 224 |
+
final_expand=None,
|
| 225 |
+
n_classes=150,
|
| 226 |
+
**build_kwargs_from_config(kwargs, SegHead),
|
| 227 |
+
)
|
| 228 |
+
else:
|
| 229 |
+
raise NotImplementedError
|
| 230 |
+
model = EfficientViTSeg(backbone, head)
|
| 231 |
+
return model
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def efficientvit_seg_b3(dataset: str, **kwargs) -> EfficientViTSeg:
|
| 235 |
+
from efficientvit.models.efficientvit.backbone import \
|
| 236 |
+
efficientvit_backbone_b3
|
| 237 |
+
|
| 238 |
+
backbone = efficientvit_backbone_b3(**kwargs)
|
| 239 |
+
|
| 240 |
+
if dataset == "cityscapes":
|
| 241 |
+
head = SegHead(
|
| 242 |
+
fid_list=["stage4", "stage3", "stage2"],
|
| 243 |
+
in_channel_list=[512, 256, 128],
|
| 244 |
+
stride_list=[32, 16, 8],
|
| 245 |
+
head_stride=8,
|
| 246 |
+
head_width=128,
|
| 247 |
+
head_depth=3,
|
| 248 |
+
expand_ratio=4,
|
| 249 |
+
middle_op="mbconv",
|
| 250 |
+
final_expand=4,
|
| 251 |
+
n_classes=19,
|
| 252 |
+
**build_kwargs_from_config(kwargs, SegHead),
|
| 253 |
+
)
|
| 254 |
+
elif dataset == "ade20k":
|
| 255 |
+
head = SegHead(
|
| 256 |
+
fid_list=["stage4", "stage3", "stage2"],
|
| 257 |
+
in_channel_list=[512, 256, 128],
|
| 258 |
+
stride_list=[32, 16, 8],
|
| 259 |
+
head_stride=8,
|
| 260 |
+
head_width=128,
|
| 261 |
+
head_depth=3,
|
| 262 |
+
expand_ratio=4,
|
| 263 |
+
middle_op="mbconv",
|
| 264 |
+
final_expand=None,
|
| 265 |
+
n_classes=150,
|
| 266 |
+
**build_kwargs_from_config(kwargs, SegHead),
|
| 267 |
+
)
|
| 268 |
+
else:
|
| 269 |
+
raise NotImplementedError
|
| 270 |
+
model = EfficientViTSeg(backbone, head)
|
| 271 |
+
return model
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def efficientvit_seg_l1(dataset: str, **kwargs) -> EfficientViTSeg:
|
| 275 |
+
from efficientvit.models.efficientvit.backbone import \
|
| 276 |
+
efficientvit_backbone_l1
|
| 277 |
+
|
| 278 |
+
backbone = efficientvit_backbone_l1(**kwargs)
|
| 279 |
+
|
| 280 |
+
if dataset == "cityscapes":
|
| 281 |
+
head = SegHead(
|
| 282 |
+
fid_list=["stage4", "stage3", "stage2"],
|
| 283 |
+
in_channel_list=[512, 256, 128],
|
| 284 |
+
stride_list=[32, 16, 8],
|
| 285 |
+
head_stride=8,
|
| 286 |
+
head_width=256,
|
| 287 |
+
head_depth=3,
|
| 288 |
+
expand_ratio=1,
|
| 289 |
+
middle_op="fmbconv",
|
| 290 |
+
final_expand=None,
|
| 291 |
+
n_classes=19,
|
| 292 |
+
act_func="gelu",
|
| 293 |
+
**build_kwargs_from_config(kwargs, SegHead),
|
| 294 |
+
)
|
| 295 |
+
elif dataset == "ade20k":
|
| 296 |
+
head = SegHead(
|
| 297 |
+
fid_list=["stage4", "stage3", "stage2"],
|
| 298 |
+
in_channel_list=[512, 256, 128],
|
| 299 |
+
stride_list=[32, 16, 8],
|
| 300 |
+
head_stride=8,
|
| 301 |
+
head_width=128,
|
| 302 |
+
head_depth=3,
|
| 303 |
+
expand_ratio=4,
|
| 304 |
+
middle_op="fmbconv",
|
| 305 |
+
final_expand=8,
|
| 306 |
+
n_classes=150,
|
| 307 |
+
act_func="gelu",
|
| 308 |
+
**build_kwargs_from_config(kwargs, SegHead),
|
| 309 |
+
)
|
| 310 |
+
else:
|
| 311 |
+
raise NotImplementedError
|
| 312 |
+
model = EfficientViTSeg(backbone, head)
|
| 313 |
+
return model
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def efficientvit_seg_l2(dataset: str, **kwargs) -> EfficientViTSeg:
|
| 317 |
+
from efficientvit.models.efficientvit.backbone import \
|
| 318 |
+
efficientvit_backbone_l2
|
| 319 |
+
|
| 320 |
+
backbone = efficientvit_backbone_l2(**kwargs)
|
| 321 |
+
|
| 322 |
+
if dataset == "cityscapes":
|
| 323 |
+
head = SegHead(
|
| 324 |
+
fid_list=["stage4", "stage3", "stage2"],
|
| 325 |
+
in_channel_list=[512, 256, 128],
|
| 326 |
+
stride_list=[32, 16, 8],
|
| 327 |
+
head_stride=8,
|
| 328 |
+
head_width=256,
|
| 329 |
+
head_depth=5,
|
| 330 |
+
expand_ratio=1,
|
| 331 |
+
middle_op="fmbconv",
|
| 332 |
+
final_expand=None,
|
| 333 |
+
n_classes=19,
|
| 334 |
+
act_func="gelu",
|
| 335 |
+
**build_kwargs_from_config(kwargs, SegHead),
|
| 336 |
+
)
|
| 337 |
+
elif dataset == "ade20k":
|
| 338 |
+
head = SegHead(
|
| 339 |
+
fid_list=["stage4", "stage3", "stage2"],
|
| 340 |
+
in_channel_list=[512, 256, 128],
|
| 341 |
+
stride_list=[32, 16, 8],
|
| 342 |
+
head_stride=8,
|
| 343 |
+
head_width=128,
|
| 344 |
+
head_depth=3,
|
| 345 |
+
expand_ratio=4,
|
| 346 |
+
middle_op="fmbconv",
|
| 347 |
+
final_expand=8,
|
| 348 |
+
n_classes=150,
|
| 349 |
+
act_func="gelu",
|
| 350 |
+
**build_kwargs_from_config(kwargs, SegHead),
|
| 351 |
+
)
|
| 352 |
+
else:
|
| 353 |
+
raise NotImplementedError
|
| 354 |
+
model = EfficientViTSeg(backbone, head)
|
| 355 |
+
return model
|
src/efficientvit/models/nn/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
from .act import *
|
| 6 |
+
from .drop import *
|
| 7 |
+
from .norm import *
|
| 8 |
+
from .ops import *
|
src/efficientvit/models/nn/act.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
from functools import partial
|
| 6 |
+
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
from src.efficientvit.models.utils import build_kwargs_from_config
|
| 10 |
+
|
| 11 |
+
__all__ = ["build_act"]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# register activation function here
|
| 15 |
+
REGISTERED_ACT_DICT: dict[str, type] = {
|
| 16 |
+
"relu": nn.ReLU,
|
| 17 |
+
"relu6": nn.ReLU6,
|
| 18 |
+
"hswish": nn.Hardswish,
|
| 19 |
+
"silu": nn.SiLU,
|
| 20 |
+
"gelu": partial(nn.GELU, approximate="tanh"),
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def build_act(name: str, **kwargs) -> nn.Module or None:
|
| 25 |
+
if name in REGISTERED_ACT_DICT:
|
| 26 |
+
act_cls = REGISTERED_ACT_DICT[name]
|
| 27 |
+
args = build_kwargs_from_config(kwargs, act_cls)
|
| 28 |
+
return act_cls(**args)
|
| 29 |
+
else:
|
| 30 |
+
return None
|
src/efficientvit/models/nn/drop.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
from src.efficientvit.apps.trainer.run_config import Scheduler
|
| 10 |
+
from src.efficientvit.models.nn.ops import IdentityLayer, ResidualBlock
|
| 11 |
+
from src.efficientvit.models.utils import build_kwargs_from_config
|
| 12 |
+
|
| 13 |
+
__all__ = ["apply_drop_func"]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def apply_drop_func(network: nn.Module, drop_config: dict[str, any] or None) -> None:
|
| 17 |
+
if drop_config is None:
|
| 18 |
+
return
|
| 19 |
+
|
| 20 |
+
drop_lookup_table = {
|
| 21 |
+
"droppath": apply_droppath,
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
drop_func = drop_lookup_table[drop_config["name"]]
|
| 25 |
+
drop_kwargs = build_kwargs_from_config(drop_config, drop_func)
|
| 26 |
+
|
| 27 |
+
drop_func(network, **drop_kwargs)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def apply_droppath(
|
| 31 |
+
network: nn.Module,
|
| 32 |
+
drop_prob: float,
|
| 33 |
+
linear_decay=True,
|
| 34 |
+
scheduled=True,
|
| 35 |
+
skip=0,
|
| 36 |
+
) -> None:
|
| 37 |
+
all_valid_blocks = []
|
| 38 |
+
for m in network.modules():
|
| 39 |
+
for name, sub_module in m.named_children():
|
| 40 |
+
if isinstance(sub_module, ResidualBlock) and isinstance(
|
| 41 |
+
sub_module.shortcut, IdentityLayer
|
| 42 |
+
):
|
| 43 |
+
all_valid_blocks.append((m, name, sub_module))
|
| 44 |
+
all_valid_blocks = all_valid_blocks[skip:]
|
| 45 |
+
for i, (m, name, sub_module) in enumerate(all_valid_blocks):
|
| 46 |
+
prob = (
|
| 47 |
+
drop_prob * (i + 1) / len(all_valid_blocks) if linear_decay else drop_prob
|
| 48 |
+
)
|
| 49 |
+
new_module = DropPathResidualBlock(
|
| 50 |
+
sub_module.main,
|
| 51 |
+
sub_module.shortcut,
|
| 52 |
+
sub_module.post_act,
|
| 53 |
+
sub_module.pre_norm,
|
| 54 |
+
prob,
|
| 55 |
+
scheduled,
|
| 56 |
+
)
|
| 57 |
+
m._modules[name] = new_module
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class DropPathResidualBlock(ResidualBlock):
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
main: nn.Module,
|
| 64 |
+
shortcut: nn.Module or None,
|
| 65 |
+
post_act=None,
|
| 66 |
+
pre_norm: nn.Module or None = None,
|
| 67 |
+
######################################
|
| 68 |
+
drop_prob: float = 0,
|
| 69 |
+
scheduled=True,
|
| 70 |
+
):
|
| 71 |
+
super().__init__(main, shortcut, post_act, pre_norm)
|
| 72 |
+
|
| 73 |
+
self.drop_prob = drop_prob
|
| 74 |
+
self.scheduled = scheduled
|
| 75 |
+
|
| 76 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 77 |
+
if (
|
| 78 |
+
not self.training
|
| 79 |
+
or self.drop_prob == 0
|
| 80 |
+
or not isinstance(self.shortcut, IdentityLayer)
|
| 81 |
+
):
|
| 82 |
+
return ResidualBlock.forward(self, x)
|
| 83 |
+
else:
|
| 84 |
+
drop_prob = self.drop_prob
|
| 85 |
+
if self.scheduled:
|
| 86 |
+
drop_prob *= np.clip(Scheduler.PROGRESS, 0, 1)
|
| 87 |
+
keep_prob = 1 - drop_prob
|
| 88 |
+
|
| 89 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
| 90 |
+
random_tensor = keep_prob + torch.rand(
|
| 91 |
+
shape, dtype=x.dtype, device=x.device
|
| 92 |
+
)
|
| 93 |
+
random_tensor.floor_() # binarize
|
| 94 |
+
|
| 95 |
+
res = self.forward_main(x) / keep_prob * random_tensor + self.shortcut(x)
|
| 96 |
+
if self.post_act:
|
| 97 |
+
res = self.post_act(res)
|
| 98 |
+
return res
|
src/efficientvit/models/nn/norm.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
| 8 |
+
|
| 9 |
+
from src.efficientvit.models.utils import build_kwargs_from_config
|
| 10 |
+
|
| 11 |
+
__all__ = ["LayerNorm2d", "build_norm", "reset_bn", "set_norm_eps"]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LayerNorm2d(nn.LayerNorm):
|
| 15 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 16 |
+
out = x - torch.mean(x, dim=1, keepdim=True)
|
| 17 |
+
out = out / torch.sqrt(torch.square(out).mean(dim=1, keepdim=True) + self.eps)
|
| 18 |
+
if self.elementwise_affine:
|
| 19 |
+
out = out * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
|
| 20 |
+
return out
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# register normalization function here
|
| 24 |
+
REGISTERED_NORM_DICT: dict[str, type] = {
|
| 25 |
+
"bn2d": nn.BatchNorm2d,
|
| 26 |
+
"ln": nn.LayerNorm,
|
| 27 |
+
"ln2d": LayerNorm2d,
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def build_norm(name="bn2d", num_features=None, **kwargs) -> nn.Module or None:
|
| 32 |
+
if name in ["ln", "ln2d"]:
|
| 33 |
+
kwargs["normalized_shape"] = num_features
|
| 34 |
+
else:
|
| 35 |
+
kwargs["num_features"] = num_features
|
| 36 |
+
if name in REGISTERED_NORM_DICT:
|
| 37 |
+
norm_cls = REGISTERED_NORM_DICT[name]
|
| 38 |
+
args = build_kwargs_from_config(kwargs, norm_cls)
|
| 39 |
+
return norm_cls(**args)
|
| 40 |
+
else:
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def reset_bn(
|
| 45 |
+
model: nn.Module,
|
| 46 |
+
data_loader: list,
|
| 47 |
+
sync=True,
|
| 48 |
+
progress_bar=False,
|
| 49 |
+
) -> None:
|
| 50 |
+
import copy
|
| 51 |
+
|
| 52 |
+
import torch.nn.functional as F
|
| 53 |
+
from tqdm import tqdm
|
| 54 |
+
|
| 55 |
+
from efficientvit.apps.utils import AverageMeter, is_master, sync_tensor
|
| 56 |
+
from efficientvit.models.utils import get_device, list_join
|
| 57 |
+
|
| 58 |
+
bn_mean = {}
|
| 59 |
+
bn_var = {}
|
| 60 |
+
|
| 61 |
+
tmp_model = copy.deepcopy(model)
|
| 62 |
+
for name, m in tmp_model.named_modules():
|
| 63 |
+
if isinstance(m, _BatchNorm):
|
| 64 |
+
bn_mean[name] = AverageMeter(is_distributed=False)
|
| 65 |
+
bn_var[name] = AverageMeter(is_distributed=False)
|
| 66 |
+
|
| 67 |
+
def new_forward(bn, mean_est, var_est):
|
| 68 |
+
def lambda_forward(x):
|
| 69 |
+
x = x.contiguous()
|
| 70 |
+
if sync:
|
| 71 |
+
batch_mean = (
|
| 72 |
+
x.mean(0, keepdim=True)
|
| 73 |
+
.mean(2, keepdim=True)
|
| 74 |
+
.mean(3, keepdim=True)
|
| 75 |
+
) # 1, C, 1, 1
|
| 76 |
+
batch_mean = sync_tensor(batch_mean, reduce="cat")
|
| 77 |
+
batch_mean = torch.mean(batch_mean, dim=0, keepdim=True)
|
| 78 |
+
|
| 79 |
+
batch_var = (x - batch_mean) * (x - batch_mean)
|
| 80 |
+
batch_var = (
|
| 81 |
+
batch_var.mean(0, keepdim=True)
|
| 82 |
+
.mean(2, keepdim=True)
|
| 83 |
+
.mean(3, keepdim=True)
|
| 84 |
+
)
|
| 85 |
+
batch_var = sync_tensor(batch_var, reduce="cat")
|
| 86 |
+
batch_var = torch.mean(batch_var, dim=0, keepdim=True)
|
| 87 |
+
else:
|
| 88 |
+
batch_mean = (
|
| 89 |
+
x.mean(0, keepdim=True)
|
| 90 |
+
.mean(2, keepdim=True)
|
| 91 |
+
.mean(3, keepdim=True)
|
| 92 |
+
) # 1, C, 1, 1
|
| 93 |
+
batch_var = (x - batch_mean) * (x - batch_mean)
|
| 94 |
+
batch_var = (
|
| 95 |
+
batch_var.mean(0, keepdim=True)
|
| 96 |
+
.mean(2, keepdim=True)
|
| 97 |
+
.mean(3, keepdim=True)
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
batch_mean = torch.squeeze(batch_mean)
|
| 101 |
+
batch_var = torch.squeeze(batch_var)
|
| 102 |
+
|
| 103 |
+
mean_est.update(batch_mean.data, x.size(0))
|
| 104 |
+
var_est.update(batch_var.data, x.size(0))
|
| 105 |
+
|
| 106 |
+
# bn forward using calculated mean & var
|
| 107 |
+
_feature_dim = batch_mean.shape[0]
|
| 108 |
+
return F.batch_norm(
|
| 109 |
+
x,
|
| 110 |
+
batch_mean,
|
| 111 |
+
batch_var,
|
| 112 |
+
bn.weight[:_feature_dim],
|
| 113 |
+
bn.bias[:_feature_dim],
|
| 114 |
+
False,
|
| 115 |
+
0.0,
|
| 116 |
+
bn.eps,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
return lambda_forward
|
| 120 |
+
|
| 121 |
+
m.forward = new_forward(m, bn_mean[name], bn_var[name])
|
| 122 |
+
|
| 123 |
+
# skip if there is no batch normalization layers in the network
|
| 124 |
+
if len(bn_mean) == 0:
|
| 125 |
+
return
|
| 126 |
+
|
| 127 |
+
tmp_model.eval()
|
| 128 |
+
with torch.no_grad():
|
| 129 |
+
with tqdm(
|
| 130 |
+
total=len(data_loader),
|
| 131 |
+
desc="reset bn",
|
| 132 |
+
disable=not progress_bar or not is_master(),
|
| 133 |
+
) as t:
|
| 134 |
+
for images in data_loader:
|
| 135 |
+
images = images.to(get_device(tmp_model))
|
| 136 |
+
tmp_model(images)
|
| 137 |
+
t.set_postfix(
|
| 138 |
+
{
|
| 139 |
+
"bs": images.size(0),
|
| 140 |
+
"res": list_join(images.shape[-2:], "x"),
|
| 141 |
+
}
|
| 142 |
+
)
|
| 143 |
+
t.update()
|
| 144 |
+
|
| 145 |
+
for name, m in model.named_modules():
|
| 146 |
+
if name in bn_mean and bn_mean[name].count > 0:
|
| 147 |
+
feature_dim = bn_mean[name].avg.size(0)
|
| 148 |
+
assert isinstance(m, _BatchNorm)
|
| 149 |
+
m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg)
|
| 150 |
+
m.running_var.data[:feature_dim].copy_(bn_var[name].avg)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def set_norm_eps(model: nn.Module, eps: float or None = None) -> None:
|
| 154 |
+
for m in model.modules():
|
| 155 |
+
if isinstance(m, (nn.GroupNorm, nn.LayerNorm, _BatchNorm)):
|
| 156 |
+
if eps is not None:
|
| 157 |
+
m.eps = eps
|
src/efficientvit/models/nn/ops.py
ADDED
|
@@ -0,0 +1,585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch.cuda.amp import autocast
|
| 9 |
+
|
| 10 |
+
from src.efficientvit.models.nn.act import build_act
|
| 11 |
+
from src.efficientvit.models.nn.norm import build_norm
|
| 12 |
+
from src.efficientvit.models.utils import (get_same_padding, list_sum, resize,
|
| 13 |
+
val2list, val2tuple)
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"ConvLayer",
|
| 17 |
+
"UpSampleLayer",
|
| 18 |
+
"LinearLayer",
|
| 19 |
+
"IdentityLayer",
|
| 20 |
+
"DSConv",
|
| 21 |
+
"MBConv",
|
| 22 |
+
"FusedMBConv",
|
| 23 |
+
"ResBlock",
|
| 24 |
+
"LiteMLA",
|
| 25 |
+
"EfficientViTBlock",
|
| 26 |
+
"ResidualBlock",
|
| 27 |
+
"DAGBlock",
|
| 28 |
+
"OpSequential",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
#################################################################################
|
| 33 |
+
# Basic Layers #
|
| 34 |
+
#################################################################################
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ConvLayer(nn.Module):
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
in_channels: int,
|
| 41 |
+
out_channels: int,
|
| 42 |
+
kernel_size=3,
|
| 43 |
+
stride=1,
|
| 44 |
+
dilation=1,
|
| 45 |
+
groups=1,
|
| 46 |
+
use_bias=False,
|
| 47 |
+
dropout=0,
|
| 48 |
+
norm="bn2d",
|
| 49 |
+
act_func="relu",
|
| 50 |
+
):
|
| 51 |
+
super(ConvLayer, self).__init__()
|
| 52 |
+
|
| 53 |
+
padding = get_same_padding(kernel_size)
|
| 54 |
+
padding *= dilation
|
| 55 |
+
|
| 56 |
+
self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None
|
| 57 |
+
self.conv = nn.Conv2d(
|
| 58 |
+
in_channels,
|
| 59 |
+
out_channels,
|
| 60 |
+
kernel_size=(kernel_size, kernel_size),
|
| 61 |
+
stride=(stride, stride),
|
| 62 |
+
padding=padding,
|
| 63 |
+
dilation=(dilation, dilation),
|
| 64 |
+
groups=groups,
|
| 65 |
+
bias=use_bias,
|
| 66 |
+
)
|
| 67 |
+
self.norm = build_norm(norm, num_features=out_channels)
|
| 68 |
+
self.act = build_act(act_func)
|
| 69 |
+
|
| 70 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 71 |
+
if self.dropout is not None:
|
| 72 |
+
x = self.dropout(x)
|
| 73 |
+
x = self.conv(x)
|
| 74 |
+
if self.norm:
|
| 75 |
+
x = self.norm(x)
|
| 76 |
+
if self.act:
|
| 77 |
+
x = self.act(x)
|
| 78 |
+
return x
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class UpSampleLayer(nn.Module):
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
mode="bicubic",
|
| 85 |
+
size: int or tuple[int, int] or list[int] or None = None,
|
| 86 |
+
factor=2,
|
| 87 |
+
align_corners=False,
|
| 88 |
+
):
|
| 89 |
+
super(UpSampleLayer, self).__init__()
|
| 90 |
+
self.mode = mode
|
| 91 |
+
self.size = val2list(size, 2) if size is not None else None
|
| 92 |
+
self.factor = None if self.size is not None else factor
|
| 93 |
+
self.align_corners = align_corners
|
| 94 |
+
|
| 95 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 96 |
+
if (
|
| 97 |
+
self.size is not None and tuple(x.shape[-2:]) == self.size
|
| 98 |
+
) or self.factor == 1:
|
| 99 |
+
return x
|
| 100 |
+
return resize(x, self.size, self.factor, self.mode, self.align_corners)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class LinearLayer(nn.Module):
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
in_features: int,
|
| 107 |
+
out_features: int,
|
| 108 |
+
use_bias=True,
|
| 109 |
+
dropout=0,
|
| 110 |
+
norm=None,
|
| 111 |
+
act_func=None,
|
| 112 |
+
):
|
| 113 |
+
super(LinearLayer, self).__init__()
|
| 114 |
+
|
| 115 |
+
self.dropout = nn.Dropout(dropout, inplace=False) if dropout > 0 else None
|
| 116 |
+
self.linear = nn.Linear(in_features, out_features, use_bias)
|
| 117 |
+
self.norm = build_norm(norm, num_features=out_features)
|
| 118 |
+
self.act = build_act(act_func)
|
| 119 |
+
|
| 120 |
+
def _try_squeeze(self, x: torch.Tensor) -> torch.Tensor:
|
| 121 |
+
if x.dim() > 2:
|
| 122 |
+
x = torch.flatten(x, start_dim=1)
|
| 123 |
+
return x
|
| 124 |
+
|
| 125 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 126 |
+
x = self._try_squeeze(x)
|
| 127 |
+
if self.dropout:
|
| 128 |
+
x = self.dropout(x)
|
| 129 |
+
x = self.linear(x)
|
| 130 |
+
if self.norm:
|
| 131 |
+
x = self.norm(x)
|
| 132 |
+
if self.act:
|
| 133 |
+
x = self.act(x)
|
| 134 |
+
return x
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class IdentityLayer(nn.Module):
|
| 138 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 139 |
+
return x
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
#################################################################################
|
| 143 |
+
# Basic Blocks #
|
| 144 |
+
#################################################################################
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class DSConv(nn.Module):
|
| 148 |
+
def __init__(
|
| 149 |
+
self,
|
| 150 |
+
in_channels: int,
|
| 151 |
+
out_channels: int,
|
| 152 |
+
kernel_size=3,
|
| 153 |
+
stride=1,
|
| 154 |
+
use_bias=False,
|
| 155 |
+
norm=("bn2d", "bn2d"),
|
| 156 |
+
act_func=("relu6", None),
|
| 157 |
+
):
|
| 158 |
+
super(DSConv, self).__init__()
|
| 159 |
+
|
| 160 |
+
use_bias = val2tuple(use_bias, 2)
|
| 161 |
+
norm = val2tuple(norm, 2)
|
| 162 |
+
act_func = val2tuple(act_func, 2)
|
| 163 |
+
|
| 164 |
+
self.depth_conv = ConvLayer(
|
| 165 |
+
in_channels,
|
| 166 |
+
in_channels,
|
| 167 |
+
kernel_size,
|
| 168 |
+
stride,
|
| 169 |
+
groups=in_channels,
|
| 170 |
+
norm=norm[0],
|
| 171 |
+
act_func=act_func[0],
|
| 172 |
+
use_bias=use_bias[0],
|
| 173 |
+
)
|
| 174 |
+
self.point_conv = ConvLayer(
|
| 175 |
+
in_channels,
|
| 176 |
+
out_channels,
|
| 177 |
+
1,
|
| 178 |
+
norm=norm[1],
|
| 179 |
+
act_func=act_func[1],
|
| 180 |
+
use_bias=use_bias[1],
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 184 |
+
x = self.depth_conv(x)
|
| 185 |
+
x = self.point_conv(x)
|
| 186 |
+
return x
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class MBConv(nn.Module):
|
| 190 |
+
def __init__(
|
| 191 |
+
self,
|
| 192 |
+
in_channels: int,
|
| 193 |
+
out_channels: int,
|
| 194 |
+
kernel_size=3,
|
| 195 |
+
stride=1,
|
| 196 |
+
mid_channels=None,
|
| 197 |
+
expand_ratio=6,
|
| 198 |
+
use_bias=False,
|
| 199 |
+
norm=("bn2d", "bn2d", "bn2d"),
|
| 200 |
+
act_func=("relu6", "relu6", None),
|
| 201 |
+
):
|
| 202 |
+
super(MBConv, self).__init__()
|
| 203 |
+
|
| 204 |
+
use_bias = val2tuple(use_bias, 3)
|
| 205 |
+
norm = val2tuple(norm, 3)
|
| 206 |
+
act_func = val2tuple(act_func, 3)
|
| 207 |
+
mid_channels = mid_channels or round(in_channels * expand_ratio)
|
| 208 |
+
|
| 209 |
+
self.inverted_conv = ConvLayer(
|
| 210 |
+
in_channels,
|
| 211 |
+
mid_channels,
|
| 212 |
+
1,
|
| 213 |
+
stride=1,
|
| 214 |
+
norm=norm[0],
|
| 215 |
+
act_func=act_func[0],
|
| 216 |
+
use_bias=use_bias[0],
|
| 217 |
+
)
|
| 218 |
+
self.depth_conv = ConvLayer(
|
| 219 |
+
mid_channels,
|
| 220 |
+
mid_channels,
|
| 221 |
+
kernel_size,
|
| 222 |
+
stride=stride,
|
| 223 |
+
groups=mid_channels,
|
| 224 |
+
norm=norm[1],
|
| 225 |
+
act_func=act_func[1],
|
| 226 |
+
use_bias=use_bias[1],
|
| 227 |
+
)
|
| 228 |
+
self.point_conv = ConvLayer(
|
| 229 |
+
mid_channels,
|
| 230 |
+
out_channels,
|
| 231 |
+
1,
|
| 232 |
+
norm=norm[2],
|
| 233 |
+
act_func=act_func[2],
|
| 234 |
+
use_bias=use_bias[2],
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 238 |
+
x = self.inverted_conv(x)
|
| 239 |
+
x = self.depth_conv(x)
|
| 240 |
+
x = self.point_conv(x)
|
| 241 |
+
return x
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class FusedMBConv(nn.Module):
|
| 245 |
+
def __init__(
|
| 246 |
+
self,
|
| 247 |
+
in_channels: int,
|
| 248 |
+
out_channels: int,
|
| 249 |
+
kernel_size=3,
|
| 250 |
+
stride=1,
|
| 251 |
+
mid_channels=None,
|
| 252 |
+
expand_ratio=6,
|
| 253 |
+
groups=1,
|
| 254 |
+
use_bias=False,
|
| 255 |
+
norm=("bn2d", "bn2d"),
|
| 256 |
+
act_func=("relu6", None),
|
| 257 |
+
):
|
| 258 |
+
super().__init__()
|
| 259 |
+
use_bias = val2tuple(use_bias, 2)
|
| 260 |
+
norm = val2tuple(norm, 2)
|
| 261 |
+
act_func = val2tuple(act_func, 2)
|
| 262 |
+
|
| 263 |
+
mid_channels = mid_channels or round(in_channels * expand_ratio)
|
| 264 |
+
|
| 265 |
+
self.spatial_conv = ConvLayer(
|
| 266 |
+
in_channels,
|
| 267 |
+
mid_channels,
|
| 268 |
+
kernel_size,
|
| 269 |
+
stride,
|
| 270 |
+
groups=groups,
|
| 271 |
+
use_bias=use_bias[0],
|
| 272 |
+
norm=norm[0],
|
| 273 |
+
act_func=act_func[0],
|
| 274 |
+
)
|
| 275 |
+
self.point_conv = ConvLayer(
|
| 276 |
+
mid_channels,
|
| 277 |
+
out_channels,
|
| 278 |
+
1,
|
| 279 |
+
use_bias=use_bias[1],
|
| 280 |
+
norm=norm[1],
|
| 281 |
+
act_func=act_func[1],
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 285 |
+
x = self.spatial_conv(x)
|
| 286 |
+
x = self.point_conv(x)
|
| 287 |
+
return x
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class ResBlock(nn.Module):
|
| 291 |
+
def __init__(
|
| 292 |
+
self,
|
| 293 |
+
in_channels: int,
|
| 294 |
+
out_channels: int,
|
| 295 |
+
kernel_size=3,
|
| 296 |
+
stride=1,
|
| 297 |
+
mid_channels=None,
|
| 298 |
+
expand_ratio=1,
|
| 299 |
+
use_bias=False,
|
| 300 |
+
norm=("bn2d", "bn2d"),
|
| 301 |
+
act_func=("relu6", None),
|
| 302 |
+
):
|
| 303 |
+
super().__init__()
|
| 304 |
+
use_bias = val2tuple(use_bias, 2)
|
| 305 |
+
norm = val2tuple(norm, 2)
|
| 306 |
+
act_func = val2tuple(act_func, 2)
|
| 307 |
+
|
| 308 |
+
mid_channels = mid_channels or round(in_channels * expand_ratio)
|
| 309 |
+
|
| 310 |
+
self.conv1 = ConvLayer(
|
| 311 |
+
in_channels,
|
| 312 |
+
mid_channels,
|
| 313 |
+
kernel_size,
|
| 314 |
+
stride,
|
| 315 |
+
use_bias=use_bias[0],
|
| 316 |
+
norm=norm[0],
|
| 317 |
+
act_func=act_func[0],
|
| 318 |
+
)
|
| 319 |
+
self.conv2 = ConvLayer(
|
| 320 |
+
mid_channels,
|
| 321 |
+
out_channels,
|
| 322 |
+
kernel_size,
|
| 323 |
+
1,
|
| 324 |
+
use_bias=use_bias[1],
|
| 325 |
+
norm=norm[1],
|
| 326 |
+
act_func=act_func[1],
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 330 |
+
x = self.conv1(x)
|
| 331 |
+
x = self.conv2(x)
|
| 332 |
+
return x
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
class LiteMLA(nn.Module):
|
| 336 |
+
r"""Lightweight multi-scale linear attention"""
|
| 337 |
+
|
| 338 |
+
def __init__(
|
| 339 |
+
self,
|
| 340 |
+
in_channels: int,
|
| 341 |
+
out_channels: int,
|
| 342 |
+
heads: int or None = None,
|
| 343 |
+
heads_ratio: float = 1.0,
|
| 344 |
+
dim=8,
|
| 345 |
+
use_bias=False,
|
| 346 |
+
norm=(None, "bn2d"),
|
| 347 |
+
act_func=(None, None),
|
| 348 |
+
kernel_func="relu",
|
| 349 |
+
scales: tuple[int, ...] = (5,),
|
| 350 |
+
eps=1.0e-15,
|
| 351 |
+
):
|
| 352 |
+
super(LiteMLA, self).__init__()
|
| 353 |
+
self.eps = eps
|
| 354 |
+
heads = heads or int(in_channels // dim * heads_ratio)
|
| 355 |
+
|
| 356 |
+
total_dim = heads * dim
|
| 357 |
+
|
| 358 |
+
use_bias = val2tuple(use_bias, 2)
|
| 359 |
+
norm = val2tuple(norm, 2)
|
| 360 |
+
act_func = val2tuple(act_func, 2)
|
| 361 |
+
|
| 362 |
+
self.dim = dim
|
| 363 |
+
self.qkv = ConvLayer(
|
| 364 |
+
in_channels,
|
| 365 |
+
3 * total_dim,
|
| 366 |
+
1,
|
| 367 |
+
use_bias=use_bias[0],
|
| 368 |
+
norm=norm[0],
|
| 369 |
+
act_func=act_func[0],
|
| 370 |
+
)
|
| 371 |
+
self.aggreg = nn.ModuleList(
|
| 372 |
+
[
|
| 373 |
+
nn.Sequential(
|
| 374 |
+
nn.Conv2d(
|
| 375 |
+
3 * total_dim,
|
| 376 |
+
3 * total_dim,
|
| 377 |
+
scale,
|
| 378 |
+
padding=get_same_padding(scale),
|
| 379 |
+
groups=3 * total_dim,
|
| 380 |
+
bias=use_bias[0],
|
| 381 |
+
),
|
| 382 |
+
nn.Conv2d(
|
| 383 |
+
3 * total_dim,
|
| 384 |
+
3 * total_dim,
|
| 385 |
+
1,
|
| 386 |
+
groups=3 * heads,
|
| 387 |
+
bias=use_bias[0],
|
| 388 |
+
),
|
| 389 |
+
)
|
| 390 |
+
for scale in scales
|
| 391 |
+
]
|
| 392 |
+
)
|
| 393 |
+
self.kernel_func = build_act(kernel_func, inplace=False)
|
| 394 |
+
|
| 395 |
+
self.proj = ConvLayer(
|
| 396 |
+
total_dim * (1 + len(scales)),
|
| 397 |
+
out_channels,
|
| 398 |
+
1,
|
| 399 |
+
use_bias=use_bias[1],
|
| 400 |
+
norm=norm[1],
|
| 401 |
+
act_func=act_func[1],
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
@autocast(enabled=False)
|
| 405 |
+
def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor:
|
| 406 |
+
B, _, H, W = list(qkv.size())
|
| 407 |
+
|
| 408 |
+
if qkv.dtype == torch.float16:
|
| 409 |
+
qkv = qkv.float()
|
| 410 |
+
|
| 411 |
+
qkv = torch.reshape(
|
| 412 |
+
qkv,
|
| 413 |
+
(
|
| 414 |
+
B,
|
| 415 |
+
-1,
|
| 416 |
+
3 * self.dim,
|
| 417 |
+
H * W,
|
| 418 |
+
),
|
| 419 |
+
)
|
| 420 |
+
qkv = torch.transpose(qkv, -1, -2)
|
| 421 |
+
q, k, v = (
|
| 422 |
+
qkv[..., 0 : self.dim],
|
| 423 |
+
qkv[..., self.dim : 2 * self.dim],
|
| 424 |
+
qkv[..., 2 * self.dim :],
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
# lightweight linear attention
|
| 428 |
+
q = self.kernel_func(q)
|
| 429 |
+
k = self.kernel_func(k)
|
| 430 |
+
|
| 431 |
+
# linear matmul
|
| 432 |
+
trans_k = k.transpose(-1, -2)
|
| 433 |
+
|
| 434 |
+
v = F.pad(v, (0, 1), mode="constant", value=1)
|
| 435 |
+
kv = torch.matmul(trans_k, v)
|
| 436 |
+
out = torch.matmul(q, kv)
|
| 437 |
+
out = out[..., :-1] / (out[..., -1:] + self.eps)
|
| 438 |
+
|
| 439 |
+
out = torch.transpose(out, -1, -2)
|
| 440 |
+
out = torch.reshape(out, (B, -1, H, W))
|
| 441 |
+
return out
|
| 442 |
+
|
| 443 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 444 |
+
# generate multi-scale q, k, v
|
| 445 |
+
qkv = self.qkv(x)
|
| 446 |
+
multi_scale_qkv = [qkv]
|
| 447 |
+
for op in self.aggreg:
|
| 448 |
+
multi_scale_qkv.append(op(qkv))
|
| 449 |
+
multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1)
|
| 450 |
+
|
| 451 |
+
out = self.relu_linear_att(multi_scale_qkv)
|
| 452 |
+
out = self.proj(out)
|
| 453 |
+
|
| 454 |
+
return out
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
class EfficientViTBlock(nn.Module):
|
| 458 |
+
def __init__(
|
| 459 |
+
self,
|
| 460 |
+
in_channels: int,
|
| 461 |
+
heads_ratio: float = 1.0,
|
| 462 |
+
dim=32,
|
| 463 |
+
expand_ratio: float = 4,
|
| 464 |
+
scales=(5,),
|
| 465 |
+
norm="bn2d",
|
| 466 |
+
act_func="hswish",
|
| 467 |
+
):
|
| 468 |
+
super(EfficientViTBlock, self).__init__()
|
| 469 |
+
self.context_module = ResidualBlock(
|
| 470 |
+
LiteMLA(
|
| 471 |
+
in_channels=in_channels,
|
| 472 |
+
out_channels=in_channels,
|
| 473 |
+
heads_ratio=heads_ratio,
|
| 474 |
+
dim=dim,
|
| 475 |
+
norm=(None, norm),
|
| 476 |
+
scales=scales,
|
| 477 |
+
),
|
| 478 |
+
IdentityLayer(),
|
| 479 |
+
)
|
| 480 |
+
local_module = MBConv(
|
| 481 |
+
in_channels=in_channels,
|
| 482 |
+
out_channels=in_channels,
|
| 483 |
+
expand_ratio=expand_ratio,
|
| 484 |
+
use_bias=(True, True, False),
|
| 485 |
+
norm=(None, None, norm),
|
| 486 |
+
act_func=(act_func, act_func, None),
|
| 487 |
+
)
|
| 488 |
+
self.local_module = ResidualBlock(local_module, IdentityLayer())
|
| 489 |
+
|
| 490 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 491 |
+
x = self.context_module(x)
|
| 492 |
+
x = self.local_module(x)
|
| 493 |
+
return x
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
#################################################################################
|
| 497 |
+
# Functional Blocks #
|
| 498 |
+
#################################################################################
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
class ResidualBlock(nn.Module):
|
| 502 |
+
def __init__(
|
| 503 |
+
self,
|
| 504 |
+
main: nn.Module or None,
|
| 505 |
+
shortcut: nn.Module or None,
|
| 506 |
+
post_act=None,
|
| 507 |
+
pre_norm: nn.Module or None = None,
|
| 508 |
+
):
|
| 509 |
+
super(ResidualBlock, self).__init__()
|
| 510 |
+
|
| 511 |
+
self.pre_norm = pre_norm
|
| 512 |
+
self.main = main
|
| 513 |
+
self.shortcut = shortcut
|
| 514 |
+
self.post_act = build_act(post_act)
|
| 515 |
+
|
| 516 |
+
def forward_main(self, x: torch.Tensor) -> torch.Tensor:
|
| 517 |
+
if self.pre_norm is None:
|
| 518 |
+
return self.main(x)
|
| 519 |
+
else:
|
| 520 |
+
return self.main(self.pre_norm(x))
|
| 521 |
+
|
| 522 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 523 |
+
if self.main is None:
|
| 524 |
+
res = x
|
| 525 |
+
elif self.shortcut is None:
|
| 526 |
+
res = self.forward_main(x)
|
| 527 |
+
else:
|
| 528 |
+
res = self.forward_main(x) + self.shortcut(x)
|
| 529 |
+
if self.post_act:
|
| 530 |
+
res = self.post_act(res)
|
| 531 |
+
return res
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
class DAGBlock(nn.Module):
|
| 535 |
+
def __init__(
|
| 536 |
+
self,
|
| 537 |
+
inputs: dict[str, nn.Module],
|
| 538 |
+
merge: str,
|
| 539 |
+
post_input: nn.Module or None,
|
| 540 |
+
middle: nn.Module,
|
| 541 |
+
outputs: dict[str, nn.Module],
|
| 542 |
+
):
|
| 543 |
+
super(DAGBlock, self).__init__()
|
| 544 |
+
|
| 545 |
+
self.input_keys = list(inputs.keys())
|
| 546 |
+
self.input_ops = nn.ModuleList(list(inputs.values()))
|
| 547 |
+
self.merge = merge
|
| 548 |
+
self.post_input = post_input
|
| 549 |
+
|
| 550 |
+
self.middle = middle
|
| 551 |
+
|
| 552 |
+
self.output_keys = list(outputs.keys())
|
| 553 |
+
self.output_ops = nn.ModuleList(list(outputs.values()))
|
| 554 |
+
|
| 555 |
+
def forward(self, feature_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
| 556 |
+
feat = [
|
| 557 |
+
op(feature_dict[key]) for key, op in zip(self.input_keys, self.input_ops)
|
| 558 |
+
]
|
| 559 |
+
if self.merge == "add":
|
| 560 |
+
feat = list_sum(feat)
|
| 561 |
+
elif self.merge == "cat":
|
| 562 |
+
feat = torch.concat(feat, dim=1)
|
| 563 |
+
else:
|
| 564 |
+
raise NotImplementedError
|
| 565 |
+
if self.post_input is not None:
|
| 566 |
+
feat = self.post_input(feat)
|
| 567 |
+
feat = self.middle(feat)
|
| 568 |
+
for key, op in zip(self.output_keys, self.output_ops):
|
| 569 |
+
feature_dict[key] = op(feat)
|
| 570 |
+
return feature_dict
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
class OpSequential(nn.Module):
|
| 574 |
+
def __init__(self, op_list: list[nn.Module or None]):
|
| 575 |
+
super(OpSequential, self).__init__()
|
| 576 |
+
valid_op_list = []
|
| 577 |
+
for op in op_list:
|
| 578 |
+
if op is not None:
|
| 579 |
+
valid_op_list.append(op)
|
| 580 |
+
self.op_list = nn.ModuleList(valid_op_list)
|
| 581 |
+
|
| 582 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 583 |
+
for op in self.op_list:
|
| 584 |
+
x = op(x)
|
| 585 |
+
return x
|
src/efficientvit/models/utils/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
from .list import *
|
| 6 |
+
from .network import *
|
| 7 |
+
from .random import *
|
src/efficientvit/models/utils/list.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"list_sum",
|
| 7 |
+
"list_mean",
|
| 8 |
+
"weighted_list_sum",
|
| 9 |
+
"list_join",
|
| 10 |
+
"val2list",
|
| 11 |
+
"val2tuple",
|
| 12 |
+
"squeeze_list",
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def list_sum(x: list) -> any:
|
| 17 |
+
return x[0] if len(x) == 1 else x[0] + list_sum(x[1:])
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def list_mean(x: list) -> any:
|
| 21 |
+
return list_sum(x) / len(x)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def weighted_list_sum(x: list, weights: list) -> any:
|
| 25 |
+
assert len(x) == len(weights)
|
| 26 |
+
return (
|
| 27 |
+
x[0] * weights[0]
|
| 28 |
+
if len(x) == 1
|
| 29 |
+
else x[0] * weights[0] + weighted_list_sum(x[1:], weights[1:])
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def list_join(x: list, sep="\t", format_str="%s") -> str:
|
| 34 |
+
return sep.join([format_str % val for val in x])
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def val2list(x: list or tuple or any, repeat_time=1) -> list:
|
| 38 |
+
if isinstance(x, (list, tuple)):
|
| 39 |
+
return list(x)
|
| 40 |
+
return [x for _ in range(repeat_time)]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple:
|
| 44 |
+
x = val2list(x)
|
| 45 |
+
|
| 46 |
+
# repeat elements if necessary
|
| 47 |
+
if len(x) > 0:
|
| 48 |
+
x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
|
| 49 |
+
|
| 50 |
+
return tuple(x)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def squeeze_list(x: list or None) -> list or any:
|
| 54 |
+
if x is not None and len(x) == 1:
|
| 55 |
+
return x[0]
|
| 56 |
+
else:
|
| 57 |
+
return x
|
src/efficientvit/models/utils/network.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from inspect import signature
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"is_parallel",
|
| 14 |
+
"get_device",
|
| 15 |
+
"get_same_padding",
|
| 16 |
+
"resize",
|
| 17 |
+
"build_kwargs_from_config",
|
| 18 |
+
"load_state_dict_from_file",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def is_parallel(model: nn.Module) -> bool:
|
| 23 |
+
return isinstance(
|
| 24 |
+
model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_device(model: nn.Module) -> torch.device:
|
| 29 |
+
return model.parameters().__next__().device
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]:
|
| 33 |
+
if isinstance(kernel_size, tuple):
|
| 34 |
+
return tuple([get_same_padding(ks) for ks in kernel_size])
|
| 35 |
+
else:
|
| 36 |
+
assert kernel_size % 2 > 0, "kernel size should be odd number"
|
| 37 |
+
return kernel_size // 2
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def resize(
|
| 41 |
+
x: torch.Tensor,
|
| 42 |
+
size: any or None = None,
|
| 43 |
+
scale_factor: list[float] or None = None,
|
| 44 |
+
mode: str = "bicubic",
|
| 45 |
+
align_corners: bool or None = False,
|
| 46 |
+
) -> torch.Tensor:
|
| 47 |
+
if mode in {"bilinear", "bicubic"}:
|
| 48 |
+
return F.interpolate(
|
| 49 |
+
x,
|
| 50 |
+
size=size,
|
| 51 |
+
scale_factor=scale_factor,
|
| 52 |
+
mode=mode,
|
| 53 |
+
align_corners=align_corners,
|
| 54 |
+
)
|
| 55 |
+
elif mode in {"nearest", "area"}:
|
| 56 |
+
return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode)
|
| 57 |
+
else:
|
| 58 |
+
raise NotImplementedError(f"resize(mode={mode}) not implemented.")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def build_kwargs_from_config(config: dict, target_func: callable) -> dict[str, any]:
|
| 62 |
+
valid_keys = list(signature(target_func).parameters)
|
| 63 |
+
kwargs = {}
|
| 64 |
+
for key in config:
|
| 65 |
+
if key in valid_keys:
|
| 66 |
+
kwargs[key] = config[key]
|
| 67 |
+
return kwargs
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def load_state_dict_from_file(
|
| 71 |
+
file: str, only_state_dict=True
|
| 72 |
+
) -> dict[str, torch.Tensor]:
|
| 73 |
+
file = os.path.realpath(os.path.expanduser(file))
|
| 74 |
+
checkpoint = torch.load(file, map_location="cpu")
|
| 75 |
+
if only_state_dict and "state_dict" in checkpoint:
|
| 76 |
+
checkpoint = checkpoint["state_dict"]
|
| 77 |
+
return checkpoint
|
src/efficientvit/models/utils/random.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"torch_randint",
|
| 10 |
+
"torch_random",
|
| 11 |
+
"torch_shuffle",
|
| 12 |
+
"torch_uniform",
|
| 13 |
+
"torch_random_choices",
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def torch_randint(
|
| 18 |
+
low: int, high: int, generator: torch.Generator or None = None
|
| 19 |
+
) -> int:
|
| 20 |
+
"""uniform: [low, high)"""
|
| 21 |
+
if low == high:
|
| 22 |
+
return low
|
| 23 |
+
else:
|
| 24 |
+
assert low < high
|
| 25 |
+
return int(torch.randint(low=low, high=high, generator=generator, size=(1,)))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def torch_random(generator: torch.Generator or None = None) -> float:
|
| 29 |
+
"""uniform distribution on the interval [0, 1)"""
|
| 30 |
+
return float(torch.rand(1, generator=generator))
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def torch_shuffle(
|
| 34 |
+
src_list: list[any], generator: torch.Generator or None = None
|
| 35 |
+
) -> list[any]:
|
| 36 |
+
rand_indexes = torch.randperm(len(src_list), generator=generator).tolist()
|
| 37 |
+
return [src_list[i] for i in rand_indexes]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def torch_uniform(
|
| 41 |
+
low: float, high: float, generator: torch.Generator or None = None
|
| 42 |
+
) -> float:
|
| 43 |
+
"""uniform distribution on the interval [low, high)"""
|
| 44 |
+
rand_val = torch_random(generator)
|
| 45 |
+
return (high - low) * rand_val + low
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def torch_random_choices(
|
| 49 |
+
src_list: list[any],
|
| 50 |
+
generator: torch.Generator or None = None,
|
| 51 |
+
k=1,
|
| 52 |
+
weight_list: list[float] or None = None,
|
| 53 |
+
) -> any or list:
|
| 54 |
+
if weight_list is None:
|
| 55 |
+
rand_idx = torch.randint(
|
| 56 |
+
low=0, high=len(src_list), generator=generator, size=(k,)
|
| 57 |
+
)
|
| 58 |
+
out_list = [src_list[i] for i in rand_idx]
|
| 59 |
+
else:
|
| 60 |
+
assert len(weight_list) == len(src_list)
|
| 61 |
+
accumulate_weight_list = np.cumsum(weight_list)
|
| 62 |
+
|
| 63 |
+
out_list = []
|
| 64 |
+
for _ in range(k):
|
| 65 |
+
val = torch_uniform(0, accumulate_weight_list[-1], generator)
|
| 66 |
+
active_id = 0
|
| 67 |
+
for i, weight_val in enumerate(accumulate_weight_list):
|
| 68 |
+
active_id = i
|
| 69 |
+
if weight_val > val:
|
| 70 |
+
break
|
| 71 |
+
out_list.append(src_list[active_id])
|
| 72 |
+
|
| 73 |
+
return out_list[0] if k == 1 else out_list
|
src/efficientvit/sam_model_zoo.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
|
| 2 |
+
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
|
| 3 |
+
# International Conference on Computer Vision (ICCV), 2023
|
| 4 |
+
|
| 5 |
+
from src.efficientvit.models.efficientvit import (EfficientViTSam,
|
| 6 |
+
efficientvit_sam_l0,
|
| 7 |
+
efficientvit_sam_l1,
|
| 8 |
+
efficientvit_sam_l2,
|
| 9 |
+
efficientvit_sam_xl0,
|
| 10 |
+
efficientvit_sam_xl1)
|
| 11 |
+
from src.efficientvit.models.nn.norm import set_norm_eps
|
| 12 |
+
from src.efficientvit.models.utils import load_state_dict_from_file
|
| 13 |
+
|
| 14 |
+
__all__ = ["create_sam_model"]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
REGISTERED_SAM_MODEL: dict[str, str] = {
|
| 18 |
+
"l0": "assets/checkpoints/sam/l0.pt",
|
| 19 |
+
"l1": "assets/checkpoints/sam/l1.pt",
|
| 20 |
+
"l2": "assets/checkpoints/sam/l2.pt",
|
| 21 |
+
"xl0": "assets/checkpoints/sam/xl0.pt",
|
| 22 |
+
"xl1": "assets/checkpoints/sam/xl1.pt",
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def create_sam_model(
|
| 27 |
+
name: str, pretrained=True, weight_url: str or None = None, **kwargs
|
| 28 |
+
) -> EfficientViTSam:
|
| 29 |
+
model_dict = {
|
| 30 |
+
"l0": efficientvit_sam_l0,
|
| 31 |
+
"l1": efficientvit_sam_l1,
|
| 32 |
+
"l2": efficientvit_sam_l2,
|
| 33 |
+
"xl0": efficientvit_sam_xl0,
|
| 34 |
+
"xl1": efficientvit_sam_xl1,
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
model_id = name.split("-")[0]
|
| 38 |
+
if model_id not in model_dict:
|
| 39 |
+
raise ValueError(
|
| 40 |
+
f"Do not find {name} in the model zoo. List of models: {list(model_dict.keys())}"
|
| 41 |
+
)
|
| 42 |
+
else:
|
| 43 |
+
model = model_dict[model_id](**kwargs)
|
| 44 |
+
set_norm_eps(model, 1e-6)
|
| 45 |
+
|
| 46 |
+
if pretrained:
|
| 47 |
+
weight_url = weight_url or REGISTERED_SAM_MODEL.get(name, None)
|
| 48 |
+
if weight_url is None:
|
| 49 |
+
raise ValueError(f"Do not find the pretrained weight of {name}.")
|
| 50 |
+
else:
|
| 51 |
+
weight = load_state_dict_from_file(weight_url)
|
| 52 |
+
model.load_state_dict(weight)
|
| 53 |
+
return model
|
src/ip_adapter/attention_processor.py
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
import xformers
|
| 8 |
+
import xformers.ops
|
| 9 |
+
|
| 10 |
+
xformers_available = True
|
| 11 |
+
except Exception as e:
|
| 12 |
+
xformers_available = False
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class AttnProcessor(nn.Module):
|
| 16 |
+
r"""
|
| 17 |
+
Default processor for performing attention-related computations.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
hidden_size=None,
|
| 23 |
+
cross_attention_dim=None,
|
| 24 |
+
):
|
| 25 |
+
super().__init__()
|
| 26 |
+
|
| 27 |
+
def __call__(
|
| 28 |
+
self,
|
| 29 |
+
attn,
|
| 30 |
+
hidden_states,
|
| 31 |
+
encoder_hidden_states=None,
|
| 32 |
+
attention_mask=None,
|
| 33 |
+
temb=None,
|
| 34 |
+
):
|
| 35 |
+
residual = hidden_states
|
| 36 |
+
|
| 37 |
+
if attn.spatial_norm is not None:
|
| 38 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
| 39 |
+
|
| 40 |
+
input_ndim = hidden_states.ndim
|
| 41 |
+
|
| 42 |
+
if input_ndim == 4:
|
| 43 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 44 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 45 |
+
|
| 46 |
+
batch_size, sequence_length, _ = (
|
| 47 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 48 |
+
)
|
| 49 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 50 |
+
|
| 51 |
+
if attn.group_norm is not None:
|
| 52 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 53 |
+
|
| 54 |
+
query = attn.to_q(hidden_states)
|
| 55 |
+
|
| 56 |
+
if encoder_hidden_states is None:
|
| 57 |
+
encoder_hidden_states = hidden_states
|
| 58 |
+
elif attn.norm_cross:
|
| 59 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 60 |
+
|
| 61 |
+
key = attn.to_k(encoder_hidden_states)
|
| 62 |
+
value = attn.to_v(encoder_hidden_states)
|
| 63 |
+
|
| 64 |
+
query = attn.head_to_batch_dim(query)
|
| 65 |
+
key = attn.head_to_batch_dim(key)
|
| 66 |
+
value = attn.head_to_batch_dim(value)
|
| 67 |
+
|
| 68 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
| 69 |
+
hidden_states = torch.bmm(attention_probs, value)
|
| 70 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 71 |
+
|
| 72 |
+
# linear proj
|
| 73 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 74 |
+
# dropout
|
| 75 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 76 |
+
|
| 77 |
+
if input_ndim == 4:
|
| 78 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 79 |
+
|
| 80 |
+
if attn.residual_connection:
|
| 81 |
+
hidden_states = hidden_states + residual
|
| 82 |
+
|
| 83 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 84 |
+
|
| 85 |
+
return hidden_states
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class IPAttnProcessor(nn.Module):
|
| 89 |
+
r"""
|
| 90 |
+
Attention processor for IP-Adapater.
|
| 91 |
+
Args:
|
| 92 |
+
hidden_size (`int`):
|
| 93 |
+
The hidden size of the attention layer.
|
| 94 |
+
cross_attention_dim (`int`):
|
| 95 |
+
The number of channels in the `encoder_hidden_states`.
|
| 96 |
+
scale (`float`, defaults to 1.0):
|
| 97 |
+
the weight scale of image prompt.
|
| 98 |
+
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
| 99 |
+
The context length of the image features.
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
|
| 103 |
+
super().__init__()
|
| 104 |
+
|
| 105 |
+
self.hidden_size = hidden_size
|
| 106 |
+
self.cross_attention_dim = cross_attention_dim
|
| 107 |
+
self.scale = scale
|
| 108 |
+
self.num_tokens = num_tokens
|
| 109 |
+
|
| 110 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
| 111 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
| 112 |
+
|
| 113 |
+
def __call__(
|
| 114 |
+
self,
|
| 115 |
+
attn,
|
| 116 |
+
hidden_states,
|
| 117 |
+
encoder_hidden_states=None,
|
| 118 |
+
attention_mask=None,
|
| 119 |
+
temb=None,
|
| 120 |
+
):
|
| 121 |
+
residual = hidden_states
|
| 122 |
+
|
| 123 |
+
if attn.spatial_norm is not None:
|
| 124 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
| 125 |
+
|
| 126 |
+
input_ndim = hidden_states.ndim
|
| 127 |
+
|
| 128 |
+
if input_ndim == 4:
|
| 129 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 130 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 131 |
+
|
| 132 |
+
batch_size, sequence_length, _ = (
|
| 133 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 134 |
+
)
|
| 135 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 136 |
+
|
| 137 |
+
if attn.group_norm is not None:
|
| 138 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 139 |
+
|
| 140 |
+
query = attn.to_q(hidden_states)
|
| 141 |
+
|
| 142 |
+
if encoder_hidden_states is None:
|
| 143 |
+
encoder_hidden_states = hidden_states
|
| 144 |
+
else:
|
| 145 |
+
# get encoder_hidden_states, ip_hidden_states
|
| 146 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
| 147 |
+
encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:,
|
| 148 |
+
end_pos:, :]
|
| 149 |
+
if attn.norm_cross:
|
| 150 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 151 |
+
|
| 152 |
+
key = attn.to_k(encoder_hidden_states)
|
| 153 |
+
value = attn.to_v(encoder_hidden_states)
|
| 154 |
+
|
| 155 |
+
query = attn.head_to_batch_dim(query)
|
| 156 |
+
key = attn.head_to_batch_dim(key)
|
| 157 |
+
value = attn.head_to_batch_dim(value)
|
| 158 |
+
|
| 159 |
+
if xformers_available:
|
| 160 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
| 161 |
+
else:
|
| 162 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
| 163 |
+
hidden_states = torch.bmm(attention_probs, value)
|
| 164 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 165 |
+
|
| 166 |
+
# for ip-adapter
|
| 167 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
| 168 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
| 169 |
+
|
| 170 |
+
ip_key = attn.head_to_batch_dim(ip_key)
|
| 171 |
+
ip_value = attn.head_to_batch_dim(ip_value)
|
| 172 |
+
|
| 173 |
+
if xformers_available:
|
| 174 |
+
ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
|
| 175 |
+
else:
|
| 176 |
+
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
| 177 |
+
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
| 178 |
+
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
| 179 |
+
|
| 180 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
| 181 |
+
|
| 182 |
+
# linear proj
|
| 183 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 184 |
+
# dropout
|
| 185 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 186 |
+
|
| 187 |
+
if input_ndim == 4:
|
| 188 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 189 |
+
|
| 190 |
+
if attn.residual_connection:
|
| 191 |
+
hidden_states = hidden_states + residual
|
| 192 |
+
|
| 193 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 194 |
+
|
| 195 |
+
return hidden_states
|
| 196 |
+
|
| 197 |
+
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
|
| 198 |
+
# TODO attention_mask
|
| 199 |
+
query = query.contiguous()
|
| 200 |
+
key = key.contiguous()
|
| 201 |
+
value = value.contiguous()
|
| 202 |
+
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
|
| 203 |
+
# hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
| 204 |
+
return hidden_states
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class AttnProcessor2_0(torch.nn.Module):
|
| 208 |
+
r"""
|
| 209 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
def __init__(
|
| 213 |
+
self,
|
| 214 |
+
hidden_size=None,
|
| 215 |
+
cross_attention_dim=None,
|
| 216 |
+
):
|
| 217 |
+
super().__init__()
|
| 218 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 219 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 220 |
+
|
| 221 |
+
def __call__(
|
| 222 |
+
self,
|
| 223 |
+
attn,
|
| 224 |
+
hidden_states,
|
| 225 |
+
encoder_hidden_states=None,
|
| 226 |
+
attention_mask=None,
|
| 227 |
+
temb=None,
|
| 228 |
+
):
|
| 229 |
+
residual = hidden_states
|
| 230 |
+
|
| 231 |
+
if attn.spatial_norm is not None:
|
| 232 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
| 233 |
+
|
| 234 |
+
input_ndim = hidden_states.ndim
|
| 235 |
+
|
| 236 |
+
if input_ndim == 4:
|
| 237 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 238 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 239 |
+
|
| 240 |
+
batch_size, sequence_length, _ = (
|
| 241 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
if attention_mask is not None:
|
| 245 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 246 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
| 247 |
+
# (batch, heads, source_length, target_length)
|
| 248 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
| 249 |
+
|
| 250 |
+
if attn.group_norm is not None:
|
| 251 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 252 |
+
|
| 253 |
+
query = attn.to_q(hidden_states)
|
| 254 |
+
|
| 255 |
+
if encoder_hidden_states is None:
|
| 256 |
+
encoder_hidden_states = hidden_states
|
| 257 |
+
elif attn.norm_cross:
|
| 258 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 259 |
+
|
| 260 |
+
key = attn.to_k(encoder_hidden_states)
|
| 261 |
+
value = attn.to_v(encoder_hidden_states)
|
| 262 |
+
|
| 263 |
+
inner_dim = key.shape[-1]
|
| 264 |
+
head_dim = inner_dim // attn.heads
|
| 265 |
+
|
| 266 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 267 |
+
|
| 268 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 269 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 270 |
+
|
| 271 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 272 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
| 273 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 274 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 278 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 279 |
+
|
| 280 |
+
# linear proj
|
| 281 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 282 |
+
# dropout
|
| 283 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 284 |
+
|
| 285 |
+
if input_ndim == 4:
|
| 286 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 287 |
+
|
| 288 |
+
if attn.residual_connection:
|
| 289 |
+
hidden_states = hidden_states + residual
|
| 290 |
+
|
| 291 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 292 |
+
|
| 293 |
+
return hidden_states
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class IPAttnProcessor2_0(torch.nn.Module):
|
| 297 |
+
r"""
|
| 298 |
+
Attention processor for IP-Adapater for PyTorch 2.0.
|
| 299 |
+
Args:
|
| 300 |
+
hidden_size (`int`):
|
| 301 |
+
The hidden size of the attention layer.
|
| 302 |
+
cross_attention_dim (`int`):
|
| 303 |
+
The number of channels in the `encoder_hidden_states`.
|
| 304 |
+
scale (`float`, defaults to 1.0):
|
| 305 |
+
the weight scale of image prompt.
|
| 306 |
+
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
| 307 |
+
The context length of the image features.
|
| 308 |
+
"""
|
| 309 |
+
|
| 310 |
+
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
|
| 311 |
+
super().__init__()
|
| 312 |
+
|
| 313 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 314 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 315 |
+
|
| 316 |
+
self.hidden_size = hidden_size
|
| 317 |
+
self.cross_attention_dim = cross_attention_dim
|
| 318 |
+
self.scale = scale
|
| 319 |
+
self.num_tokens = num_tokens
|
| 320 |
+
|
| 321 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
| 322 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
| 323 |
+
|
| 324 |
+
def __call__(
|
| 325 |
+
self,
|
| 326 |
+
attn,
|
| 327 |
+
hidden_states,
|
| 328 |
+
encoder_hidden_states=None,
|
| 329 |
+
attention_mask=None,
|
| 330 |
+
temb=None,
|
| 331 |
+
):
|
| 332 |
+
residual = hidden_states
|
| 333 |
+
|
| 334 |
+
if attn.spatial_norm is not None:
|
| 335 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
| 336 |
+
|
| 337 |
+
input_ndim = hidden_states.ndim
|
| 338 |
+
|
| 339 |
+
if input_ndim == 4:
|
| 340 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 341 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 342 |
+
|
| 343 |
+
batch_size, sequence_length, _ = (
|
| 344 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
if attention_mask is not None:
|
| 348 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 349 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
| 350 |
+
# (batch, heads, source_length, target_length)
|
| 351 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
| 352 |
+
|
| 353 |
+
if attn.group_norm is not None:
|
| 354 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 355 |
+
|
| 356 |
+
query = attn.to_q(hidden_states)
|
| 357 |
+
|
| 358 |
+
if encoder_hidden_states is None:
|
| 359 |
+
encoder_hidden_states = hidden_states
|
| 360 |
+
else:
|
| 361 |
+
# get encoder_hidden_states, ip_hidden_states
|
| 362 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
| 363 |
+
encoder_hidden_states, ip_hidden_states = (
|
| 364 |
+
encoder_hidden_states[:, :end_pos, :],
|
| 365 |
+
encoder_hidden_states[:, end_pos:, :],
|
| 366 |
+
)
|
| 367 |
+
if attn.norm_cross:
|
| 368 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 369 |
+
|
| 370 |
+
key = attn.to_k(encoder_hidden_states)
|
| 371 |
+
value = attn.to_v(encoder_hidden_states)
|
| 372 |
+
|
| 373 |
+
inner_dim = key.shape[-1]
|
| 374 |
+
head_dim = inner_dim // attn.heads
|
| 375 |
+
|
| 376 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 377 |
+
|
| 378 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 379 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 380 |
+
|
| 381 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 382 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
| 383 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 384 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 388 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 389 |
+
|
| 390 |
+
# for ip-adapter
|
| 391 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
| 392 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
| 393 |
+
|
| 394 |
+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 395 |
+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 396 |
+
|
| 397 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 398 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
| 399 |
+
ip_hidden_states = F.scaled_dot_product_attention(
|
| 400 |
+
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
| 401 |
+
)
|
| 402 |
+
with torch.no_grad():
|
| 403 |
+
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
|
| 404 |
+
# print(self.attn_map.shape)
|
| 405 |
+
|
| 406 |
+
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 407 |
+
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
| 408 |
+
|
| 409 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
| 410 |
+
|
| 411 |
+
# linear proj
|
| 412 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 413 |
+
# dropout
|
| 414 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 415 |
+
|
| 416 |
+
if input_ndim == 4:
|
| 417 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 418 |
+
|
| 419 |
+
if attn.residual_connection:
|
| 420 |
+
hidden_states = hidden_states + residual
|
| 421 |
+
|
| 422 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 423 |
+
|
| 424 |
+
return hidden_states
|
src/ip_adapter/resampler.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# FFN
|
| 9 |
+
def FeedForward(dim, mult=4):
|
| 10 |
+
inner_dim = int(dim * mult)
|
| 11 |
+
return nn.Sequential(
|
| 12 |
+
nn.LayerNorm(dim),
|
| 13 |
+
nn.Linear(dim, inner_dim, bias=False),
|
| 14 |
+
nn.GELU(),
|
| 15 |
+
nn.Linear(inner_dim, dim, bias=False),
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def reshape_tensor(x, heads):
|
| 20 |
+
bs, length, width = x.shape
|
| 21 |
+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
| 22 |
+
x = x.view(bs, length, heads, -1)
|
| 23 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
| 24 |
+
x = x.transpose(1, 2)
|
| 25 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
| 26 |
+
x = x.reshape(bs, heads, length, -1)
|
| 27 |
+
return x
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class PerceiverAttention(nn.Module):
|
| 31 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.scale = dim_head ** -0.5
|
| 34 |
+
self.dim_head = dim_head
|
| 35 |
+
self.heads = heads
|
| 36 |
+
inner_dim = dim_head * heads
|
| 37 |
+
|
| 38 |
+
self.norm1 = nn.LayerNorm(dim)
|
| 39 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 40 |
+
|
| 41 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
| 42 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
| 43 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
| 44 |
+
|
| 45 |
+
def forward(self, x, latents):
|
| 46 |
+
"""
|
| 47 |
+
Args:
|
| 48 |
+
x (torch.Tensor): image features
|
| 49 |
+
shape (b, n1, D)
|
| 50 |
+
latent (torch.Tensor): latent features
|
| 51 |
+
shape (b, n2, D)
|
| 52 |
+
"""
|
| 53 |
+
x = self.norm1(x)
|
| 54 |
+
latents = self.norm2(latents)
|
| 55 |
+
|
| 56 |
+
b, l, _ = latents.shape
|
| 57 |
+
|
| 58 |
+
q = self.to_q(latents)
|
| 59 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
| 60 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
| 61 |
+
|
| 62 |
+
q = reshape_tensor(q, self.heads)
|
| 63 |
+
k = reshape_tensor(k, self.heads)
|
| 64 |
+
v = reshape_tensor(v, self.heads)
|
| 65 |
+
|
| 66 |
+
# attention
|
| 67 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
| 68 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
| 69 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
| 70 |
+
out = weight @ v
|
| 71 |
+
|
| 72 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
| 73 |
+
|
| 74 |
+
return self.to_out(out)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class Resampler(nn.Module):
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
dim=1024,
|
| 81 |
+
depth=8,
|
| 82 |
+
dim_head=64,
|
| 83 |
+
heads=16,
|
| 84 |
+
num_queries=8,
|
| 85 |
+
embedding_dim=768,
|
| 86 |
+
output_dim=1024,
|
| 87 |
+
ff_mult=4,
|
| 88 |
+
):
|
| 89 |
+
super().__init__()
|
| 90 |
+
|
| 91 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
|
| 92 |
+
|
| 93 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
| 94 |
+
|
| 95 |
+
self.proj_out = nn.Linear(dim, output_dim)
|
| 96 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
| 97 |
+
|
| 98 |
+
self.layers = nn.ModuleList([])
|
| 99 |
+
for _ in range(depth):
|
| 100 |
+
self.layers.append(
|
| 101 |
+
nn.ModuleList(
|
| 102 |
+
[
|
| 103 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
| 104 |
+
FeedForward(dim=dim, mult=ff_mult),
|
| 105 |
+
]
|
| 106 |
+
)
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
def forward(self, x):
|
| 110 |
+
|
| 111 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
| 112 |
+
|
| 113 |
+
x = self.proj_in(x)
|
| 114 |
+
|
| 115 |
+
for attn, ff in self.layers:
|
| 116 |
+
latents = attn(x, latents) + latents
|
| 117 |
+
latents = ff(latents) + latents
|
| 118 |
+
|
| 119 |
+
latents = self.proj_out(latents)
|
| 120 |
+
return self.norm_out(latents)
|
src/ip_adapter/utils.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn.functional as F
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def is_torch2_available():
|
| 5 |
+
return hasattr(F, "scaled_dot_product_attention")
|
src/pipelines/instantid_pipeline.py
ADDED
|
@@ -0,0 +1,720 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import PIL.Image
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from transformers import (
|
| 9 |
+
CLIPImageProcessor,
|
| 10 |
+
CLIPTextModel,
|
| 11 |
+
CLIPTextModelWithProjection,
|
| 12 |
+
CLIPTokenizer,
|
| 13 |
+
CLIPVisionModelWithProjection,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
from diffusers.utils.import_utils import is_invisible_watermark_available
|
| 17 |
+
|
| 18 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 19 |
+
from diffusers.loaders import (
|
| 20 |
+
FromSingleFileMixin,
|
| 21 |
+
IPAdapterMixin,
|
| 22 |
+
StableDiffusionXLLoraLoaderMixin,
|
| 23 |
+
TextualInversionLoaderMixin,
|
| 24 |
+
)
|
| 25 |
+
from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
|
| 26 |
+
from diffusers.models.attention_processor import (
|
| 27 |
+
AttnProcessor2_0,
|
| 28 |
+
LoRAAttnProcessor2_0,
|
| 29 |
+
LoRAXFormersAttnProcessor,
|
| 30 |
+
XFormersAttnProcessor,
|
| 31 |
+
)
|
| 32 |
+
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
| 33 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
| 34 |
+
from diffusers.utils import (
|
| 35 |
+
USE_PEFT_BACKEND,
|
| 36 |
+
deprecate,
|
| 37 |
+
logging,
|
| 38 |
+
replace_example_docstring,
|
| 39 |
+
scale_lora_layers,
|
| 40 |
+
unscale_lora_layers,
|
| 41 |
+
)
|
| 42 |
+
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
| 43 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 44 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
if is_invisible_watermark_available():
|
| 48 |
+
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
| 49 |
+
|
| 50 |
+
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
| 51 |
+
from diffusers import StableDiffusionXLControlNetPipeline
|
| 52 |
+
from PIL import Image
|
| 53 |
+
from torchvision.transforms.functional import to_tensor
|
| 54 |
+
from einops import rearrange
|
| 55 |
+
from torch import einsum
|
| 56 |
+
import math
|
| 57 |
+
from torchvision.utils import save_image
|
| 58 |
+
from diffusers.utils import load_image
|
| 59 |
+
import cv2
|
| 60 |
+
|
| 61 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 62 |
+
|
| 63 |
+
class RegionControlNet_AttnProcessor:
|
| 64 |
+
def __init__(self, attention_op=None, controller=None, place_in_unet=None):
|
| 65 |
+
self.attention_op = attention_op
|
| 66 |
+
self.controller = controller
|
| 67 |
+
self.place_in_unet = place_in_unet
|
| 68 |
+
|
| 69 |
+
def __call__(
|
| 70 |
+
self,
|
| 71 |
+
attn,
|
| 72 |
+
hidden_states: torch.FloatTensor,
|
| 73 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 74 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 75 |
+
temb: Optional[torch.FloatTensor] = None,
|
| 76 |
+
scale: float = 1.0,
|
| 77 |
+
**cross_attention_kwargs
|
| 78 |
+
) -> torch.Tensor:
|
| 79 |
+
residual = hidden_states
|
| 80 |
+
|
| 81 |
+
args = () if USE_PEFT_BACKEND else (scale,)
|
| 82 |
+
|
| 83 |
+
if attn.spatial_norm is not None:
|
| 84 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
| 85 |
+
|
| 86 |
+
input_ndim = hidden_states.ndim
|
| 87 |
+
|
| 88 |
+
if input_ndim == 4:
|
| 89 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 90 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 91 |
+
|
| 92 |
+
batch_size, sequence_length, _ = (
|
| 93 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 94 |
+
)
|
| 95 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 96 |
+
|
| 97 |
+
if attn.group_norm is not None:
|
| 98 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 99 |
+
|
| 100 |
+
query = attn.to_q(hidden_states, *args)
|
| 101 |
+
|
| 102 |
+
is_cross = True
|
| 103 |
+
if encoder_hidden_states is None:
|
| 104 |
+
is_cross = False
|
| 105 |
+
encoder_hidden_states = hidden_states
|
| 106 |
+
elif attn.norm_cross:
|
| 107 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 108 |
+
|
| 109 |
+
key = attn.to_k(encoder_hidden_states, *args)
|
| 110 |
+
value = attn.to_v(encoder_hidden_states, *args)
|
| 111 |
+
|
| 112 |
+
query = attn.head_to_batch_dim(query)
|
| 113 |
+
key = attn.head_to_batch_dim(key)
|
| 114 |
+
value = attn.head_to_batch_dim(value)
|
| 115 |
+
|
| 116 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
| 117 |
+
attention_probs = self.controller(attention_probs, is_cross, self.place_in_unet)
|
| 118 |
+
hidden_states = torch.bmm(attention_probs, value)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 122 |
+
|
| 123 |
+
# linear proj
|
| 124 |
+
hidden_states = attn.to_out[0](hidden_states, *args)
|
| 125 |
+
# dropout
|
| 126 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 127 |
+
|
| 128 |
+
if input_ndim == 4:
|
| 129 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 130 |
+
|
| 131 |
+
if attn.residual_connection:
|
| 132 |
+
hidden_states = hidden_states + residual
|
| 133 |
+
|
| 134 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 135 |
+
|
| 136 |
+
return hidden_states
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def revise_regionally_controlnet_forward(unet, controller):
|
| 140 |
+
def change_forward(unet, count, place_in_unet):
|
| 141 |
+
for name, layer in unet.named_children():
|
| 142 |
+
if layer.__class__.__name__ == 'Attention':
|
| 143 |
+
layer.set_processor(RegionControlNet_AttnProcessor(controller=controller, place_in_unet=place_in_unet))
|
| 144 |
+
if 'attn2' in name:
|
| 145 |
+
count += 1
|
| 146 |
+
else:
|
| 147 |
+
count = change_forward(layer, count, place_in_unet)
|
| 148 |
+
return count
|
| 149 |
+
|
| 150 |
+
# use this to ensure the order
|
| 151 |
+
cross_attention_idx = change_forward(unet.down_blocks, 0, "down")
|
| 152 |
+
cross_attention_idx = change_forward(unet.mid_block, cross_attention_idx, "up")
|
| 153 |
+
cross_attention_idx = change_forward(unet.up_blocks, cross_attention_idx, "mid")
|
| 154 |
+
print(f'Number of attention layer registered {cross_attention_idx}')
|
| 155 |
+
controller.num_att_layers = cross_attention_idx*2
|
| 156 |
+
|
| 157 |
+
class InstantidMultiConceptPipeline(StableDiffusionXLControlNetPipeline):
|
| 158 |
+
# leave controlnet out on purpose because it iterates with unet
|
| 159 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
|
| 160 |
+
_optional_components = [
|
| 161 |
+
"tokenizer",
|
| 162 |
+
"tokenizer_2",
|
| 163 |
+
"text_encoder",
|
| 164 |
+
"text_encoder_2",
|
| 165 |
+
"feature_extractor",
|
| 166 |
+
"image_encoder",
|
| 167 |
+
]
|
| 168 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
| 169 |
+
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
vae: AutoencoderKL,
|
| 173 |
+
text_encoder: CLIPTextModel,
|
| 174 |
+
text_encoder_2: CLIPTextModelWithProjection,
|
| 175 |
+
tokenizer: CLIPTokenizer,
|
| 176 |
+
tokenizer_2: CLIPTokenizer,
|
| 177 |
+
unet: UNet2DConditionModel,
|
| 178 |
+
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
|
| 179 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 180 |
+
force_zeros_for_empty_prompt: bool = True,
|
| 181 |
+
add_watermarker: Optional[bool] = None,
|
| 182 |
+
feature_extractor: CLIPImageProcessor = None,
|
| 183 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 184 |
+
):
|
| 185 |
+
if isinstance(controlnet, (list, tuple)):
|
| 186 |
+
controlnet = MultiControlNetModel(controlnet)
|
| 187 |
+
|
| 188 |
+
self.register_modules(
|
| 189 |
+
vae=vae,
|
| 190 |
+
text_encoder=text_encoder,
|
| 191 |
+
text_encoder_2=text_encoder_2,
|
| 192 |
+
tokenizer=tokenizer,
|
| 193 |
+
tokenizer_2=tokenizer_2,
|
| 194 |
+
unet=unet,
|
| 195 |
+
controlnet=controlnet,
|
| 196 |
+
scheduler=scheduler,
|
| 197 |
+
feature_extractor=feature_extractor,
|
| 198 |
+
image_encoder=image_encoder,
|
| 199 |
+
)
|
| 200 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 201 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
|
| 202 |
+
self.control_image_processor = VaeImageProcessor(
|
| 203 |
+
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
| 204 |
+
)
|
| 205 |
+
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
| 206 |
+
|
| 207 |
+
if add_watermarker:
|
| 208 |
+
self.watermark = StableDiffusionXLWatermarker()
|
| 209 |
+
else:
|
| 210 |
+
self.watermark = None
|
| 211 |
+
|
| 212 |
+
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
| 213 |
+
|
| 214 |
+
@torch.no_grad()
|
| 215 |
+
def __call__(
|
| 216 |
+
self,
|
| 217 |
+
prompt: Union[str, List[str]] = None,
|
| 218 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 219 |
+
image: PipelineImageInput = None,
|
| 220 |
+
height: Optional[int] = None,
|
| 221 |
+
width: Optional[int] = None,
|
| 222 |
+
num_inference_steps: int = 50,
|
| 223 |
+
guidance_scale: float = 5.0,
|
| 224 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 225 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 226 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 227 |
+
eta: float = 0.0,
|
| 228 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 229 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 230 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 231 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 232 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 233 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 234 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 235 |
+
output_type: Optional[str] = "pil",
|
| 236 |
+
return_dict: bool = True,
|
| 237 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 238 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
| 239 |
+
guess_mode: bool = False,
|
| 240 |
+
control_guidance_start: Union[float, List[float]] = 0.0,
|
| 241 |
+
control_guidance_end: Union[float, List[float]] = 1.0,
|
| 242 |
+
original_size: Tuple[int, int] = None,
|
| 243 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 244 |
+
target_size: Tuple[int, int] = None,
|
| 245 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
| 246 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 247 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
| 248 |
+
clip_skip: Optional[int] = None,
|
| 249 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 250 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 251 |
+
controller=None,
|
| 252 |
+
concept_models=None,
|
| 253 |
+
indices_to_alter=None,
|
| 254 |
+
face_app=None,
|
| 255 |
+
stage=None,
|
| 256 |
+
region_masks=None,
|
| 257 |
+
**kwargs,
|
| 258 |
+
):
|
| 259 |
+
# revise_regionally_controlnet_forward(self.unet, controller)
|
| 260 |
+
callback = kwargs.pop("callback", None)
|
| 261 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
| 262 |
+
|
| 263 |
+
if callback is not None:
|
| 264 |
+
deprecate(
|
| 265 |
+
"callback",
|
| 266 |
+
"1.0.0",
|
| 267 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
| 268 |
+
)
|
| 269 |
+
if callback_steps is not None:
|
| 270 |
+
deprecate(
|
| 271 |
+
"callback_steps",
|
| 272 |
+
"1.0.0",
|
| 273 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
| 277 |
+
|
| 278 |
+
# align format for control guidance
|
| 279 |
+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
| 280 |
+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
| 281 |
+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
| 282 |
+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
| 283 |
+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
| 284 |
+
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
| 285 |
+
control_guidance_start, control_guidance_end = (
|
| 286 |
+
mult * [control_guidance_start],
|
| 287 |
+
mult * [control_guidance_end],
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# 1. Check inputs. Raise error if not correct
|
| 291 |
+
self.check_inputs(
|
| 292 |
+
prompt,
|
| 293 |
+
prompt_2,
|
| 294 |
+
image,
|
| 295 |
+
callback_steps,
|
| 296 |
+
negative_prompt,
|
| 297 |
+
negative_prompt_2,
|
| 298 |
+
prompt_embeds,
|
| 299 |
+
negative_prompt_embeds,
|
| 300 |
+
pooled_prompt_embeds,
|
| 301 |
+
negative_pooled_prompt_embeds,
|
| 302 |
+
controlnet_conditioning_scale,
|
| 303 |
+
control_guidance_start,
|
| 304 |
+
control_guidance_end,
|
| 305 |
+
callback_on_step_end_tensor_inputs,
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
self._guidance_scale = guidance_scale
|
| 309 |
+
self._clip_skip = clip_skip
|
| 310 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 311 |
+
|
| 312 |
+
# 2. Define call parameters
|
| 313 |
+
batch_size = 2
|
| 314 |
+
|
| 315 |
+
device = self._execution_device
|
| 316 |
+
|
| 317 |
+
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
| 318 |
+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
| 319 |
+
|
| 320 |
+
global_pool_conditions = (
|
| 321 |
+
controlnet.config.global_pool_conditions
|
| 322 |
+
if isinstance(controlnet, ControlNetModel)
|
| 323 |
+
else controlnet.nets[0].config.global_pool_conditions
|
| 324 |
+
)
|
| 325 |
+
guess_mode = guess_mode or global_pool_conditions
|
| 326 |
+
|
| 327 |
+
# 3.1 Encode input prompt
|
| 328 |
+
text_encoder_lora_scale = (
|
| 329 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
global_prompt = prompt[0]
|
| 333 |
+
global_negative_prompt = negative_prompt
|
| 334 |
+
region_prompts = [pt[0] for pt in prompt[1]]
|
| 335 |
+
region_negative_prompts = [pt[1] for pt in prompt[1]]
|
| 336 |
+
ref_images = [pt[2] for pt in prompt[1]]
|
| 337 |
+
|
| 338 |
+
concat_prompts = global_prompt + region_prompts
|
| 339 |
+
concat_negative_prompts = global_negative_prompt + region_negative_prompts
|
| 340 |
+
|
| 341 |
+
(
|
| 342 |
+
concat_prompt_embeds,
|
| 343 |
+
concat_negative_prompt_embeds,
|
| 344 |
+
concat_pooled_prompt_embeds,
|
| 345 |
+
concat_negative_pooled_prompt_embeds,
|
| 346 |
+
) = self.encode_prompt(
|
| 347 |
+
concat_prompts,
|
| 348 |
+
prompt_2,
|
| 349 |
+
device,
|
| 350 |
+
num_images_per_prompt,
|
| 351 |
+
self.do_classifier_free_guidance,
|
| 352 |
+
concat_negative_prompts,
|
| 353 |
+
negative_prompt_2,
|
| 354 |
+
prompt_embeds=prompt_embeds,
|
| 355 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 356 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 357 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 358 |
+
lora_scale=text_encoder_lora_scale,
|
| 359 |
+
clip_skip=self.clip_skip,
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
prompt_embeds = concat_prompt_embeds[:2]
|
| 363 |
+
negative_prompt_embeds = concat_negative_prompt_embeds[:2]
|
| 364 |
+
pooled_prompt_embeds = concat_pooled_prompt_embeds[:2]
|
| 365 |
+
negative_pooled_prompt_embeds = concat_negative_pooled_prompt_embeds[:2]
|
| 366 |
+
|
| 367 |
+
region_prompt_embeds_list = []
|
| 368 |
+
region_add_text_embeds_list = []
|
| 369 |
+
for region_prompt_embeds, region_negative_prompt_embeds, region_pooled_prompt_embeds, region_negative_pooled_prompt_embeds in zip(concat_prompt_embeds[2:], concat_negative_prompt_embeds[2:], concat_pooled_prompt_embeds[2:], concat_negative_pooled_prompt_embeds[2:]):
|
| 370 |
+
region_prompt_embeds_list.append(
|
| 371 |
+
torch.concat([region_negative_prompt_embeds.unsqueeze(0), region_prompt_embeds.unsqueeze(0)], dim=0).to(concept_models._execution_device))
|
| 372 |
+
region_add_text_embeds_list.append(
|
| 373 |
+
torch.concat([region_negative_pooled_prompt_embeds.unsqueeze(0), region_pooled_prompt_embeds.unsqueeze(0)], dim=0).to(concept_models._execution_device))
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
if stage==2:
|
| 377 |
+
mask_list = [mask.float().to(dtype=prompt_embeds.dtype, device=device) if mask is not None else None for mask in region_masks]
|
| 378 |
+
image_embedding_list = get_face_embedding(face_app, ref_images)
|
| 379 |
+
image_prompt_image_emb_list = []
|
| 380 |
+
for image_embeds in image_embedding_list:
|
| 381 |
+
prompt_image_emb = concept_models._encode_prompt_image_emb(image_embeds,
|
| 382 |
+
concept_models._execution_device,
|
| 383 |
+
num_images_per_prompt,
|
| 384 |
+
concept_models.unet.dtype,
|
| 385 |
+
True)
|
| 386 |
+
image_prompt_image_emb_list.append(prompt_image_emb)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
# 4. Prepare image
|
| 391 |
+
if isinstance(controlnet, ControlNetModel) and image is not None:
|
| 392 |
+
image = self.prepare_image(
|
| 393 |
+
image=image,
|
| 394 |
+
width=width,
|
| 395 |
+
height=height,
|
| 396 |
+
batch_size=1 * num_images_per_prompt,
|
| 397 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 398 |
+
device=device,
|
| 399 |
+
dtype=controlnet.dtype,
|
| 400 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 401 |
+
guess_mode=guess_mode,
|
| 402 |
+
)
|
| 403 |
+
height, width = image.shape[-2:]
|
| 404 |
+
elif isinstance(controlnet, MultiControlNetModel) and image is not None:
|
| 405 |
+
images = []
|
| 406 |
+
|
| 407 |
+
for image_ in image:
|
| 408 |
+
image_ = self.prepare_image(
|
| 409 |
+
image=image_,
|
| 410 |
+
width=width,
|
| 411 |
+
height=height,
|
| 412 |
+
batch_size=batch_size * num_images_per_prompt,
|
| 413 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 414 |
+
device=device,
|
| 415 |
+
dtype=controlnet.dtype,
|
| 416 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 417 |
+
guess_mode=guess_mode,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
images.append(image_)
|
| 421 |
+
|
| 422 |
+
image = images
|
| 423 |
+
height, width = image[0].shape[-2:]
|
| 424 |
+
else:
|
| 425 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 426 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 427 |
+
|
| 428 |
+
# 5. Prepare timesteps
|
| 429 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 430 |
+
timesteps = self.scheduler.timesteps
|
| 431 |
+
self._num_timesteps = len(timesteps)
|
| 432 |
+
|
| 433 |
+
# 6. Prepare latent variables
|
| 434 |
+
num_channels_latents = self.unet.config.in_channels
|
| 435 |
+
latents = self.prepare_latents(
|
| 436 |
+
batch_size//2 * num_images_per_prompt,
|
| 437 |
+
num_channels_latents,
|
| 438 |
+
height,
|
| 439 |
+
width,
|
| 440 |
+
prompt_embeds.dtype,
|
| 441 |
+
device,
|
| 442 |
+
generator,
|
| 443 |
+
latents,
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
# 6.1 repeat latent
|
| 447 |
+
latents = torch.cat([latents, latents.clone()])
|
| 448 |
+
|
| 449 |
+
# 6.5 Optionally get Guidance Scale Embedding
|
| 450 |
+
timestep_cond = None
|
| 451 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
| 452 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
| 453 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
| 454 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
| 455 |
+
).to(device=device, dtype=latents.dtype)
|
| 456 |
+
|
| 457 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 458 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 459 |
+
|
| 460 |
+
# 7.1 Create tensor stating which controlnets to keep
|
| 461 |
+
controlnet_keep = []
|
| 462 |
+
for i in range(len(timesteps)):
|
| 463 |
+
keeps = [
|
| 464 |
+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
| 465 |
+
for s, e in zip(control_guidance_start, control_guidance_end)
|
| 466 |
+
]
|
| 467 |
+
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
| 468 |
+
|
| 469 |
+
# 7.2 Prepare added time ids & embeddings
|
| 470 |
+
if isinstance(image, list):
|
| 471 |
+
original_size = original_size or image[0].shape[-2:]
|
| 472 |
+
else:
|
| 473 |
+
original_size = original_size or (height, width)
|
| 474 |
+
target_size = target_size or (height, width)
|
| 475 |
+
|
| 476 |
+
add_text_embeds = pooled_prompt_embeds
|
| 477 |
+
if self.text_encoder_2 is None:
|
| 478 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
| 479 |
+
else:
|
| 480 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
| 481 |
+
|
| 482 |
+
add_time_ids = self._get_add_time_ids(
|
| 483 |
+
original_size,
|
| 484 |
+
crops_coords_top_left,
|
| 485 |
+
target_size,
|
| 486 |
+
dtype=prompt_embeds.dtype,
|
| 487 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
add_time_ids_list = []
|
| 491 |
+
region_add_time_ids = concept_models._get_add_time_ids(original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim)
|
| 492 |
+
for _ in range(len(prompt[1])):
|
| 493 |
+
add_time_ids_list.append(torch.concat([region_add_time_ids, region_add_time_ids], dim=0).to(concept_models._execution_device))
|
| 494 |
+
|
| 495 |
+
if negative_original_size is not None and negative_target_size is not None:
|
| 496 |
+
negative_add_time_ids = self._get_add_time_ids(
|
| 497 |
+
negative_original_size,
|
| 498 |
+
negative_crops_coords_top_left,
|
| 499 |
+
negative_target_size,
|
| 500 |
+
dtype=prompt_embeds.dtype,
|
| 501 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
| 502 |
+
)
|
| 503 |
+
else:
|
| 504 |
+
negative_add_time_ids = add_time_ids
|
| 505 |
+
|
| 506 |
+
if self.do_classifier_free_guidance:
|
| 507 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 508 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
| 509 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
| 510 |
+
|
| 511 |
+
prompt_embeds = prompt_embeds.to(device)
|
| 512 |
+
add_text_embeds = add_text_embeds.to(device)
|
| 513 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
| 514 |
+
|
| 515 |
+
# 8. Denoising loop
|
| 516 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 517 |
+
is_unet_compiled = is_compiled_module(self.unet)
|
| 518 |
+
is_controlnet_compiled = is_compiled_module(self.controlnet)
|
| 519 |
+
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
|
| 520 |
+
# hyper-parameters
|
| 521 |
+
scale_range = np.linspace(1, 0.5, len(self.scheduler.timesteps))
|
| 522 |
+
|
| 523 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 524 |
+
for i, t in enumerate(timesteps):
|
| 525 |
+
# Relevant thread:
|
| 526 |
+
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
|
| 527 |
+
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
|
| 528 |
+
torch._inductor.cudagraph_mark_step_begin()
|
| 529 |
+
# expand the latents if we are doing classifier free guidance
|
| 530 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 531 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 532 |
+
|
| 533 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
| 534 |
+
|
| 535 |
+
# controlnet(s) inference
|
| 536 |
+
if guess_mode and self.do_classifier_free_guidance:
|
| 537 |
+
# Infer ControlNet only for the conditional batch.
|
| 538 |
+
control_model_input = latents
|
| 539 |
+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
| 540 |
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
| 541 |
+
controlnet_added_cond_kwargs = {
|
| 542 |
+
"text_embeds": add_text_embeds.chunk(2)[1],
|
| 543 |
+
"time_ids": add_time_ids.chunk(2)[1],
|
| 544 |
+
}
|
| 545 |
+
else:
|
| 546 |
+
control_model_input = latent_model_input
|
| 547 |
+
controlnet_prompt_embeds = prompt_embeds
|
| 548 |
+
controlnet_added_cond_kwargs = added_cond_kwargs
|
| 549 |
+
|
| 550 |
+
if isinstance(controlnet_keep[i], list):
|
| 551 |
+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
| 552 |
+
else:
|
| 553 |
+
controlnet_cond_scale = controlnet_conditioning_scale
|
| 554 |
+
if isinstance(controlnet_cond_scale, list):
|
| 555 |
+
controlnet_cond_scale = controlnet_cond_scale[0]
|
| 556 |
+
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
# predict the noise residual
|
| 560 |
+
noise_pred = self.unet(
|
| 561 |
+
latent_model_input,
|
| 562 |
+
t,
|
| 563 |
+
encoder_hidden_states=prompt_embeds,
|
| 564 |
+
timestep_cond=timestep_cond,
|
| 565 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
| 566 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 567 |
+
return_dict=False,
|
| 568 |
+
)[0]
|
| 569 |
+
|
| 570 |
+
if i > 15 and stage == 2:
|
| 571 |
+
region_mask = self.get_region_mask(mask_list, noise_pred.shape[2], noise_pred.shape[3])
|
| 572 |
+
edit_noise = torch.concat([noise_pred[1:2], noise_pred[3:4]], dim=0)
|
| 573 |
+
new_noise_pred = torch.zeros_like(edit_noise)
|
| 574 |
+
new_noise_pred[:, :, region_mask == 0] = edit_noise[:, :, region_mask == 0]
|
| 575 |
+
replace_ratio = 1.0
|
| 576 |
+
new_noise_pred[:, :, region_mask != 0] = (1 - replace_ratio) * edit_noise[:, :, region_mask != 0]
|
| 577 |
+
|
| 578 |
+
for region_prompt_embeds, region_add_text_embeds, region_add_time_ids, concept_mask, region_prompt, region_prompt_image_emb in zip(region_prompt_embeds_list, region_add_text_embeds_list, add_time_ids_list, mask_list, region_prompts, image_prompt_image_emb_list):
|
| 579 |
+
if concept_mask is not None:
|
| 580 |
+
concept_mask = F.interpolate(concept_mask.unsqueeze(0).unsqueeze(0),
|
| 581 |
+
size=(noise_pred.shape[2], noise_pred.shape[3]),
|
| 582 |
+
mode='nearest').squeeze().to(dtype=noise_pred.dtype, device=concept_models._execution_device)
|
| 583 |
+
|
| 584 |
+
region_latent_model_input = latent_model_input[3:4].clone().to(concept_models._execution_device)
|
| 585 |
+
|
| 586 |
+
region_latent_model_input = torch.cat([region_latent_model_input] * 2)
|
| 587 |
+
region_added_cond_kwargs = {"text_embeds": region_add_text_embeds,
|
| 588 |
+
"time_ids": region_add_time_ids}
|
| 589 |
+
|
| 590 |
+
if image is not None:
|
| 591 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
| 592 |
+
region_latent_model_input,
|
| 593 |
+
t,
|
| 594 |
+
encoder_hidden_states=region_prompt_image_emb,
|
| 595 |
+
controlnet_cond=image,
|
| 596 |
+
conditioning_scale=cond_scale,
|
| 597 |
+
guess_mode=guess_mode,
|
| 598 |
+
added_cond_kwargs=region_added_cond_kwargs,
|
| 599 |
+
return_dict=False,
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
if guess_mode and self.do_classifier_free_guidance:
|
| 603 |
+
# Infered ControlNet only for the conditional batch.
|
| 604 |
+
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
| 605 |
+
# add 0 to the unconditional batch to keep it unchanged.
|
| 606 |
+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in
|
| 607 |
+
down_block_res_samples]
|
| 608 |
+
mid_block_res_sample = torch.cat(
|
| 609 |
+
[torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
| 610 |
+
|
| 611 |
+
else:
|
| 612 |
+
down_block_res_samples = None
|
| 613 |
+
mid_block_res_sample = None
|
| 614 |
+
|
| 615 |
+
region_encoder_hidden_states = torch.cat([region_prompt_embeds, region_prompt_image_emb], dim=1)
|
| 616 |
+
|
| 617 |
+
region_noise_pred = concept_models.unet(
|
| 618 |
+
region_latent_model_input,
|
| 619 |
+
t,
|
| 620 |
+
encoder_hidden_states=region_encoder_hidden_states,
|
| 621 |
+
cross_attention_kwargs=None,
|
| 622 |
+
down_block_additional_residuals=down_block_res_samples,
|
| 623 |
+
mid_block_additional_residual=mid_block_res_sample,
|
| 624 |
+
added_cond_kwargs=region_added_cond_kwargs,
|
| 625 |
+
return_dict=False,
|
| 626 |
+
)[0]
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
new_noise_pred = new_noise_pred.to(concept_models._execution_device)
|
| 630 |
+
new_noise_pred[:, :, concept_mask==1] += replace_ratio * (region_noise_pred[:, :, concept_mask==1] / (concept_mask.reshape(1, 1, *concept_mask.shape)[:, :, concept_mask==1].to(region_noise_pred.device)))
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
new_noise_pred = new_noise_pred.to(noise_pred.device)
|
| 634 |
+
noise_pred[1, :, :, :] = new_noise_pred[0]
|
| 635 |
+
noise_pred[3, :, :, :] = new_noise_pred[1]
|
| 636 |
+
|
| 637 |
+
if self.do_classifier_free_guidance:
|
| 638 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 639 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 640 |
+
|
| 641 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 642 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 643 |
+
|
| 644 |
+
if callback_on_step_end is not None:
|
| 645 |
+
callback_kwargs = {}
|
| 646 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 647 |
+
callback_kwargs[k] = locals()[k]
|
| 648 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 649 |
+
|
| 650 |
+
latents = callback_outputs.pop("latents", latents)
|
| 651 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 652 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 653 |
+
|
| 654 |
+
# call the callback, if provided
|
| 655 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 656 |
+
progress_bar.update()
|
| 657 |
+
if callback is not None and i % callback_steps == 0:
|
| 658 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 659 |
+
callback(step_idx, t, latents)
|
| 660 |
+
|
| 661 |
+
# manually for max memory savings
|
| 662 |
+
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
| 663 |
+
self.upcast_vae()
|
| 664 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
| 665 |
+
|
| 666 |
+
if not output_type == "latent":
|
| 667 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
| 668 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
| 669 |
+
|
| 670 |
+
if needs_upcasting:
|
| 671 |
+
self.upcast_vae()
|
| 672 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
| 673 |
+
|
| 674 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
| 675 |
+
|
| 676 |
+
# cast back to fp16 if needed
|
| 677 |
+
if needs_upcasting:
|
| 678 |
+
self.vae.to(dtype=torch.float16)
|
| 679 |
+
else:
|
| 680 |
+
image = latents
|
| 681 |
+
|
| 682 |
+
if not output_type == "latent":
|
| 683 |
+
# apply watermark if available
|
| 684 |
+
if self.watermark is not None:
|
| 685 |
+
image = self.watermark.apply_watermark(image)
|
| 686 |
+
|
| 687 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 688 |
+
|
| 689 |
+
# Offload all models
|
| 690 |
+
self.maybe_free_model_hooks()
|
| 691 |
+
|
| 692 |
+
if not return_dict:
|
| 693 |
+
return (image,)
|
| 694 |
+
|
| 695 |
+
return StableDiffusionXLPipelineOutput(images=image)
|
| 696 |
+
|
| 697 |
+
def check_image(self, image, prompt, prompt_embeds):
|
| 698 |
+
pass
|
| 699 |
+
|
| 700 |
+
def get_region_mask(self, mask_list, feat_height, feat_width):
|
| 701 |
+
exclusive_mask = torch.zeros((feat_height, feat_width))
|
| 702 |
+
for mask in mask_list:
|
| 703 |
+
if mask is not None:
|
| 704 |
+
mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(feat_height, feat_width),
|
| 705 |
+
mode='nearest').squeeze().to(dtype=exclusive_mask.dtype, device=exclusive_mask.device)
|
| 706 |
+
exclusive_mask = ((mask == 1) | (exclusive_mask == 1)).to(dtype=mask.dtype)
|
| 707 |
+
return exclusive_mask
|
| 708 |
+
|
| 709 |
+
def get_face_embedding(face_app, ref_images):
|
| 710 |
+
emb_list = []
|
| 711 |
+
for img_path in ref_images:
|
| 712 |
+
face_image = load_image(img_path)
|
| 713 |
+
|
| 714 |
+
# prepare face emb
|
| 715 |
+
face_info = face_app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
|
| 716 |
+
face_info = sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * x['bbox'][3] - x['bbox'][1])[0] # only use the maximum face
|
| 717 |
+
face_emb = face_info['embedding']
|
| 718 |
+
emb_list.append(face_emb)
|
| 719 |
+
# face_kps = draw_kps(face_image, face_info['kps'])
|
| 720 |
+
return emb_list
|
src/pipelines/instantid_single_pieline.py
ADDED
|
@@ -0,0 +1,772 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The InstantX Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import cv2
|
| 19 |
+
import math
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import PIL.Image
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
|
| 26 |
+
from diffusers.image_processor import PipelineImageInput
|
| 27 |
+
|
| 28 |
+
from diffusers.models import ControlNetModel
|
| 29 |
+
|
| 30 |
+
from diffusers.utils import (
|
| 31 |
+
deprecate,
|
| 32 |
+
logging,
|
| 33 |
+
replace_example_docstring,
|
| 34 |
+
)
|
| 35 |
+
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
|
| 36 |
+
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
| 37 |
+
|
| 38 |
+
from diffusers import StableDiffusionXLControlNetPipeline
|
| 39 |
+
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
| 40 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 41 |
+
|
| 42 |
+
from src.ip_adapter.resampler import Resampler
|
| 43 |
+
from src.ip_adapter.utils import is_torch2_available
|
| 44 |
+
|
| 45 |
+
if is_torch2_available():
|
| 46 |
+
from src.ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
|
| 47 |
+
else:
|
| 48 |
+
from src.ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor
|
| 49 |
+
|
| 50 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 51 |
+
|
| 52 |
+
EXAMPLE_DOC_STRING = """
|
| 53 |
+
Examples:
|
| 54 |
+
```py
|
| 55 |
+
>>> # !pip install opencv-python transformers accelerate insightface
|
| 56 |
+
>>> import diffusers
|
| 57 |
+
>>> from diffusers.utils import load_image
|
| 58 |
+
>>> from diffusers.models import ControlNetModel
|
| 59 |
+
|
| 60 |
+
>>> import cv2
|
| 61 |
+
>>> import torch
|
| 62 |
+
>>> import numpy as np
|
| 63 |
+
>>> from PIL import Image
|
| 64 |
+
|
| 65 |
+
>>> from insightface.app import FaceAnalysis
|
| 66 |
+
>>> from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
|
| 67 |
+
|
| 68 |
+
>>> # download 'antelopev2' under ./models
|
| 69 |
+
>>> app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
| 70 |
+
>>> app.prepare(ctx_id=0, det_size=(640, 640))
|
| 71 |
+
|
| 72 |
+
>>> # download models under ./checkpoints
|
| 73 |
+
>>> face_adapter = f'./checkpoints/ip-adapter.bin'
|
| 74 |
+
>>> controlnet_path = f'./checkpoints/ControlNetModel'
|
| 75 |
+
|
| 76 |
+
>>> # load IdentityNet
|
| 77 |
+
>>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
|
| 78 |
+
|
| 79 |
+
>>> pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
|
| 80 |
+
... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
|
| 81 |
+
... )
|
| 82 |
+
>>> pipe.cuda()
|
| 83 |
+
|
| 84 |
+
>>> # load adapter
|
| 85 |
+
>>> pipe.load_ip_adapter_instantid(face_adapter)
|
| 86 |
+
|
| 87 |
+
>>> prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality"
|
| 88 |
+
>>> negative_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured"
|
| 89 |
+
|
| 90 |
+
>>> # load an image
|
| 91 |
+
>>> image = load_image("your-example.jpg")
|
| 92 |
+
|
| 93 |
+
>>> face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))[-1]
|
| 94 |
+
>>> face_emb = face_info['embedding']
|
| 95 |
+
>>> face_kps = draw_kps(face_image, face_info['kps'])
|
| 96 |
+
|
| 97 |
+
>>> pipe.set_ip_adapter_scale(0.8)
|
| 98 |
+
|
| 99 |
+
>>> # generate image
|
| 100 |
+
>>> image = pipe(
|
| 101 |
+
... prompt, image_embeds=face_emb, image=face_kps, controlnet_conditioning_scale=0.8
|
| 102 |
+
... ).images[0]
|
| 103 |
+
```
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
|
| 108 |
+
stickwidth = 4
|
| 109 |
+
limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
|
| 110 |
+
kps = np.array(kps)
|
| 111 |
+
|
| 112 |
+
w, h = image_pil.size
|
| 113 |
+
out_img = np.zeros([h, w, 3])
|
| 114 |
+
|
| 115 |
+
for i in range(len(limbSeq)):
|
| 116 |
+
index = limbSeq[i]
|
| 117 |
+
color = color_list[index[0]]
|
| 118 |
+
|
| 119 |
+
x = kps[index][:, 0]
|
| 120 |
+
y = kps[index][:, 1]
|
| 121 |
+
length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
|
| 122 |
+
angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
|
| 123 |
+
polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0,
|
| 124 |
+
360, 1)
|
| 125 |
+
out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
|
| 126 |
+
out_img = (out_img * 0.6).astype(np.uint8)
|
| 127 |
+
|
| 128 |
+
for idx_kp, kp in enumerate(kps):
|
| 129 |
+
color = color_list[idx_kp]
|
| 130 |
+
x, y = kp
|
| 131 |
+
out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
|
| 132 |
+
|
| 133 |
+
out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
|
| 134 |
+
return out_img_pil
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class InstantidSingleConceptPipeline(StableDiffusionXLControlNetPipeline):
|
| 138 |
+
|
| 139 |
+
def cuda(self, dtype=torch.float16, use_xformers=False):
|
| 140 |
+
self.to('cuda', dtype)
|
| 141 |
+
|
| 142 |
+
if hasattr(self, 'image_proj_model'):
|
| 143 |
+
self.image_proj_model.to(self.unet.device).to(self.unet.dtype)
|
| 144 |
+
|
| 145 |
+
if use_xformers:
|
| 146 |
+
if is_xformers_available():
|
| 147 |
+
import xformers
|
| 148 |
+
from packaging import version
|
| 149 |
+
|
| 150 |
+
xformers_version = version.parse(xformers.__version__)
|
| 151 |
+
if xformers_version == version.parse("0.0.16"):
|
| 152 |
+
logger.warn(
|
| 153 |
+
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
| 154 |
+
)
|
| 155 |
+
self.enable_xformers_memory_efficient_attention()
|
| 156 |
+
else:
|
| 157 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
| 158 |
+
|
| 159 |
+
def load_ip_adapter_instantid(self, model_ckpt, image_emb_dim=512, num_tokens=16, scale=0.5):
|
| 160 |
+
self.set_image_proj_model(model_ckpt, image_emb_dim, num_tokens)
|
| 161 |
+
self.set_ip_adapter(model_ckpt, num_tokens, scale)
|
| 162 |
+
|
| 163 |
+
def set_image_proj_model(self, model_ckpt, image_emb_dim=512, num_tokens=16):
|
| 164 |
+
|
| 165 |
+
image_proj_model = Resampler(
|
| 166 |
+
dim=1280,
|
| 167 |
+
depth=4,
|
| 168 |
+
dim_head=64,
|
| 169 |
+
heads=20,
|
| 170 |
+
num_queries=num_tokens,
|
| 171 |
+
embedding_dim=image_emb_dim,
|
| 172 |
+
output_dim=self.unet.config.cross_attention_dim,
|
| 173 |
+
ff_mult=4,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
image_proj_model.eval()
|
| 177 |
+
|
| 178 |
+
self.image_proj_model = image_proj_model.to(self.device, dtype=self.dtype)
|
| 179 |
+
state_dict = torch.load(model_ckpt, map_location="cpu")
|
| 180 |
+
if 'image_proj' in state_dict:
|
| 181 |
+
state_dict = state_dict["image_proj"]
|
| 182 |
+
self.image_proj_model.load_state_dict(state_dict)
|
| 183 |
+
|
| 184 |
+
self.image_proj_model_in_features = image_emb_dim
|
| 185 |
+
|
| 186 |
+
def set_ip_adapter(self, model_ckpt, num_tokens, scale):
|
| 187 |
+
|
| 188 |
+
unet = self.unet
|
| 189 |
+
attn_procs = {}
|
| 190 |
+
for name in unet.attn_processors.keys():
|
| 191 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
| 192 |
+
if name.startswith("mid_block"):
|
| 193 |
+
hidden_size = unet.config.block_out_channels[-1]
|
| 194 |
+
elif name.startswith("up_blocks"):
|
| 195 |
+
block_id = int(name[len("up_blocks.")])
|
| 196 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
| 197 |
+
elif name.startswith("down_blocks"):
|
| 198 |
+
block_id = int(name[len("down_blocks.")])
|
| 199 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
| 200 |
+
if cross_attention_dim is None:
|
| 201 |
+
attn_procs[name] = AttnProcessor().to(unet.device, dtype=unet.dtype)
|
| 202 |
+
else:
|
| 203 |
+
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size,
|
| 204 |
+
cross_attention_dim=cross_attention_dim,
|
| 205 |
+
scale=scale,
|
| 206 |
+
num_tokens=num_tokens).to(unet.device, dtype=unet.dtype)
|
| 207 |
+
unet.set_attn_processor(attn_procs)
|
| 208 |
+
|
| 209 |
+
state_dict = torch.load(model_ckpt, map_location="cpu")
|
| 210 |
+
ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
|
| 211 |
+
if 'ip_adapter' in state_dict:
|
| 212 |
+
state_dict = state_dict['ip_adapter']
|
| 213 |
+
ip_layers.load_state_dict(state_dict)
|
| 214 |
+
|
| 215 |
+
def set_ip_adapter_scale(self, scale):
|
| 216 |
+
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
| 217 |
+
for attn_processor in unet.attn_processors.values():
|
| 218 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
| 219 |
+
attn_processor.scale = scale
|
| 220 |
+
|
| 221 |
+
def _encode_prompt_image_emb(self, prompt_image_emb, device, num_images_per_prompt, dtype,
|
| 222 |
+
do_classifier_free_guidance):
|
| 223 |
+
|
| 224 |
+
if isinstance(prompt_image_emb, torch.Tensor):
|
| 225 |
+
prompt_image_emb = prompt_image_emb.clone().detach()
|
| 226 |
+
else:
|
| 227 |
+
prompt_image_emb = torch.tensor(prompt_image_emb)
|
| 228 |
+
|
| 229 |
+
prompt_image_emb = prompt_image_emb.to(device=device, dtype=dtype)
|
| 230 |
+
prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features])
|
| 231 |
+
|
| 232 |
+
if do_classifier_free_guidance:
|
| 233 |
+
prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0)
|
| 234 |
+
else:
|
| 235 |
+
prompt_image_emb = torch.cat([prompt_image_emb], dim=0)
|
| 236 |
+
|
| 237 |
+
prompt_image_emb = self.image_proj_model(prompt_image_emb)
|
| 238 |
+
|
| 239 |
+
bs_embed, seq_len, _ = prompt_image_emb.shape
|
| 240 |
+
prompt_image_emb = prompt_image_emb.repeat(1, num_images_per_prompt, 1)
|
| 241 |
+
prompt_image_emb = prompt_image_emb.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
| 242 |
+
|
| 243 |
+
return prompt_image_emb
|
| 244 |
+
|
| 245 |
+
@torch.no_grad()
|
| 246 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 247 |
+
def __call__(
|
| 248 |
+
self,
|
| 249 |
+
prompt: Union[str, List[str]] = None,
|
| 250 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 251 |
+
image: PipelineImageInput = None,
|
| 252 |
+
height: Optional[int] = None,
|
| 253 |
+
width: Optional[int] = None,
|
| 254 |
+
num_inference_steps: int = 50,
|
| 255 |
+
guidance_scale: float = 5.0,
|
| 256 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 257 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 258 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 259 |
+
eta: float = 0.0,
|
| 260 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 261 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 262 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 263 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 264 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 265 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 266 |
+
image_embeds: Optional[torch.FloatTensor] = None,
|
| 267 |
+
output_type: Optional[str] = "pil",
|
| 268 |
+
return_dict: bool = True,
|
| 269 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 270 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
| 271 |
+
guess_mode: bool = False,
|
| 272 |
+
control_guidance_start: Union[float, List[float]] = 0.0,
|
| 273 |
+
control_guidance_end: Union[float, List[float]] = 1.0,
|
| 274 |
+
original_size: Tuple[int, int] = None,
|
| 275 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 276 |
+
target_size: Tuple[int, int] = None,
|
| 277 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
| 278 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 279 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
| 280 |
+
clip_skip: Optional[int] = None,
|
| 281 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 282 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 283 |
+
|
| 284 |
+
# IP adapter
|
| 285 |
+
ip_adapter_scale=None,
|
| 286 |
+
|
| 287 |
+
**kwargs,
|
| 288 |
+
):
|
| 289 |
+
r"""
|
| 290 |
+
The call function to the pipeline for generation.
|
| 291 |
+
|
| 292 |
+
Args:
|
| 293 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 294 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
| 295 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
| 296 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
| 297 |
+
used in both text-encoders.
|
| 298 |
+
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
|
| 299 |
+
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
| 300 |
+
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
| 301 |
+
specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
|
| 302 |
+
accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
|
| 303 |
+
and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
|
| 304 |
+
`init`, images must be passed as a list such that each element of the list can be correctly batched for
|
| 305 |
+
input to a single ControlNet.
|
| 306 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 307 |
+
The height in pixels of the generated image. Anything below 512 pixels won't work well for
|
| 308 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
| 309 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
| 310 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 311 |
+
The width in pixels of the generated image. Anything below 512 pixels won't work well for
|
| 312 |
+
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
| 313 |
+
and checkpoints that are not specifically fine-tuned on low resolutions.
|
| 314 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 315 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 316 |
+
expense of slower inference.
|
| 317 |
+
guidance_scale (`float`, *optional*, defaults to 5.0):
|
| 318 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
| 319 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 320 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 321 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
| 322 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| 323 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
| 324 |
+
The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
|
| 325 |
+
and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
|
| 326 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 327 |
+
The number of images to generate per prompt.
|
| 328 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 329 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
| 330 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 331 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 332 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 333 |
+
generation deterministic.
|
| 334 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 335 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 336 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 337 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 338 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 339 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 340 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 341 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 342 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
| 343 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
| 344 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 345 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
| 346 |
+
not provided, pooled text embeddings are generated from `prompt` input argument.
|
| 347 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 348 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
|
| 349 |
+
weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
|
| 350 |
+
argument.
|
| 351 |
+
image_embeds (`torch.FloatTensor`, *optional*):
|
| 352 |
+
Pre-generated image embeddings.
|
| 353 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 354 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 355 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 356 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 357 |
+
plain tuple.
|
| 358 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 359 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
| 360 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 361 |
+
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
| 362 |
+
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
| 363 |
+
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
|
| 364 |
+
the corresponding scale as a list.
|
| 365 |
+
guess_mode (`bool`, *optional*, defaults to `False`):
|
| 366 |
+
The ControlNet encoder tries to recognize the content of the input image even if you remove all
|
| 367 |
+
prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
|
| 368 |
+
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
| 369 |
+
The percentage of total steps at which the ControlNet starts applying.
|
| 370 |
+
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
| 371 |
+
The percentage of total steps at which the ControlNet stops applying.
|
| 372 |
+
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
| 373 |
+
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
| 374 |
+
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
| 375 |
+
explained in section 2.2 of
|
| 376 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
| 377 |
+
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
| 378 |
+
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
| 379 |
+
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
| 380 |
+
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
| 381 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
| 382 |
+
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
| 383 |
+
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
| 384 |
+
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
| 385 |
+
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
| 386 |
+
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
| 387 |
+
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
| 388 |
+
micro-conditioning as explained in section 2.2 of
|
| 389 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
| 390 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
| 391 |
+
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
| 392 |
+
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
| 393 |
+
micro-conditioning as explained in section 2.2 of
|
| 394 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
| 395 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
| 396 |
+
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
| 397 |
+
To negatively condition the generation process based on a target image resolution. It should be as same
|
| 398 |
+
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
| 399 |
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
| 400 |
+
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
| 401 |
+
clip_skip (`int`, *optional*):
|
| 402 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 403 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 404 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 405 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 406 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 407 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 408 |
+
`callback_on_step_end_tensor_inputs`.
|
| 409 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 410 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 411 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 412 |
+
`._callback_tensor_inputs` attribute of your pipeine class.
|
| 413 |
+
|
| 414 |
+
Examples:
|
| 415 |
+
|
| 416 |
+
Returns:
|
| 417 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 418 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
| 419 |
+
otherwise a `tuple` is returned containing the output images.
|
| 420 |
+
"""
|
| 421 |
+
|
| 422 |
+
callback = kwargs.pop("callback", None)
|
| 423 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
| 424 |
+
|
| 425 |
+
if callback is not None:
|
| 426 |
+
deprecate(
|
| 427 |
+
"callback",
|
| 428 |
+
"1.0.0",
|
| 429 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
| 430 |
+
)
|
| 431 |
+
if callback_steps is not None:
|
| 432 |
+
deprecate(
|
| 433 |
+
"callback_steps",
|
| 434 |
+
"1.0.0",
|
| 435 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
| 439 |
+
|
| 440 |
+
# align format for control guidance
|
| 441 |
+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
| 442 |
+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
| 443 |
+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
| 444 |
+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
| 445 |
+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
| 446 |
+
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
| 447 |
+
control_guidance_start, control_guidance_end = (
|
| 448 |
+
mult * [control_guidance_start],
|
| 449 |
+
mult * [control_guidance_end],
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
# 0. set ip_adapter_scale
|
| 453 |
+
if ip_adapter_scale is not None:
|
| 454 |
+
self.set_ip_adapter_scale(ip_adapter_scale)
|
| 455 |
+
|
| 456 |
+
# 1. Check inputs. Raise error if not correct
|
| 457 |
+
self.check_inputs(
|
| 458 |
+
prompt,
|
| 459 |
+
prompt_2,
|
| 460 |
+
image,
|
| 461 |
+
callback_steps,
|
| 462 |
+
negative_prompt,
|
| 463 |
+
negative_prompt_2,
|
| 464 |
+
prompt_embeds,
|
| 465 |
+
negative_prompt_embeds,
|
| 466 |
+
pooled_prompt_embeds,
|
| 467 |
+
negative_pooled_prompt_embeds,
|
| 468 |
+
controlnet_conditioning_scale,
|
| 469 |
+
control_guidance_start,
|
| 470 |
+
control_guidance_end,
|
| 471 |
+
callback_on_step_end_tensor_inputs,
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
self._guidance_scale = guidance_scale
|
| 475 |
+
self._clip_skip = clip_skip
|
| 476 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 477 |
+
|
| 478 |
+
# 2. Define call parameters
|
| 479 |
+
if prompt is not None and isinstance(prompt, str):
|
| 480 |
+
batch_size = 1
|
| 481 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 482 |
+
batch_size = len(prompt)
|
| 483 |
+
else:
|
| 484 |
+
batch_size = prompt_embeds.shape[0]
|
| 485 |
+
|
| 486 |
+
device = self._execution_device
|
| 487 |
+
|
| 488 |
+
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
| 489 |
+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
| 490 |
+
|
| 491 |
+
global_pool_conditions = (
|
| 492 |
+
controlnet.config.global_pool_conditions
|
| 493 |
+
if isinstance(controlnet, ControlNetModel)
|
| 494 |
+
else controlnet.nets[0].config.global_pool_conditions
|
| 495 |
+
)
|
| 496 |
+
guess_mode = guess_mode or global_pool_conditions
|
| 497 |
+
|
| 498 |
+
# 3.1 Encode input prompt
|
| 499 |
+
text_encoder_lora_scale = (
|
| 500 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
| 501 |
+
)
|
| 502 |
+
(
|
| 503 |
+
prompt_embeds,
|
| 504 |
+
negative_prompt_embeds,
|
| 505 |
+
pooled_prompt_embeds,
|
| 506 |
+
negative_pooled_prompt_embeds,
|
| 507 |
+
) = self.encode_prompt(
|
| 508 |
+
prompt,
|
| 509 |
+
prompt_2,
|
| 510 |
+
device,
|
| 511 |
+
num_images_per_prompt,
|
| 512 |
+
self.do_classifier_free_guidance,
|
| 513 |
+
negative_prompt,
|
| 514 |
+
negative_prompt_2,
|
| 515 |
+
prompt_embeds=prompt_embeds,
|
| 516 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 517 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 518 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 519 |
+
lora_scale=text_encoder_lora_scale,
|
| 520 |
+
clip_skip=self.clip_skip,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
# 3.2 Encode image prompt
|
| 524 |
+
prompt_image_emb = self._encode_prompt_image_emb(image_embeds,
|
| 525 |
+
device,
|
| 526 |
+
num_images_per_prompt,
|
| 527 |
+
self.unet.dtype,
|
| 528 |
+
self.do_classifier_free_guidance)
|
| 529 |
+
|
| 530 |
+
# 4. Prepare image
|
| 531 |
+
if isinstance(controlnet, ControlNetModel):
|
| 532 |
+
image = self.prepare_image(
|
| 533 |
+
image=image,
|
| 534 |
+
width=width,
|
| 535 |
+
height=height,
|
| 536 |
+
batch_size=batch_size * num_images_per_prompt,
|
| 537 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 538 |
+
device=device,
|
| 539 |
+
dtype=controlnet.dtype,
|
| 540 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 541 |
+
guess_mode=guess_mode,
|
| 542 |
+
)
|
| 543 |
+
height, width = image.shape[-2:]
|
| 544 |
+
elif isinstance(controlnet, MultiControlNetModel):
|
| 545 |
+
images = []
|
| 546 |
+
|
| 547 |
+
for image_ in image:
|
| 548 |
+
image_ = self.prepare_image(
|
| 549 |
+
image=image_,
|
| 550 |
+
width=width,
|
| 551 |
+
height=height,
|
| 552 |
+
batch_size=batch_size * num_images_per_prompt,
|
| 553 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 554 |
+
device=device,
|
| 555 |
+
dtype=controlnet.dtype,
|
| 556 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 557 |
+
guess_mode=guess_mode,
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
images.append(image_)
|
| 561 |
+
|
| 562 |
+
image = images
|
| 563 |
+
height, width = image[0].shape[-2:]
|
| 564 |
+
else:
|
| 565 |
+
assert False
|
| 566 |
+
|
| 567 |
+
# 5. Prepare timesteps
|
| 568 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 569 |
+
timesteps = self.scheduler.timesteps
|
| 570 |
+
self._num_timesteps = len(timesteps)
|
| 571 |
+
|
| 572 |
+
# 6. Prepare latent variables
|
| 573 |
+
num_channels_latents = self.unet.config.in_channels
|
| 574 |
+
latents = self.prepare_latents(
|
| 575 |
+
batch_size * num_images_per_prompt,
|
| 576 |
+
num_channels_latents,
|
| 577 |
+
height,
|
| 578 |
+
width,
|
| 579 |
+
prompt_embeds.dtype,
|
| 580 |
+
device,
|
| 581 |
+
generator,
|
| 582 |
+
latents,
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
# 6.5 Optionally get Guidance Scale Embedding
|
| 586 |
+
timestep_cond = None
|
| 587 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
| 588 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
| 589 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
| 590 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
| 591 |
+
).to(device=device, dtype=latents.dtype)
|
| 592 |
+
|
| 593 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 594 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 595 |
+
|
| 596 |
+
# 7.1 Create tensor stating which controlnets to keep
|
| 597 |
+
controlnet_keep = []
|
| 598 |
+
for i in range(len(timesteps)):
|
| 599 |
+
keeps = [
|
| 600 |
+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
| 601 |
+
for s, e in zip(control_guidance_start, control_guidance_end)
|
| 602 |
+
]
|
| 603 |
+
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
| 604 |
+
|
| 605 |
+
# 7.2 Prepare added time ids & embeddings
|
| 606 |
+
if isinstance(image, list):
|
| 607 |
+
original_size = original_size or image[0].shape[-2:]
|
| 608 |
+
else:
|
| 609 |
+
original_size = original_size or image.shape[-2:]
|
| 610 |
+
target_size = target_size or (height, width)
|
| 611 |
+
|
| 612 |
+
add_text_embeds = pooled_prompt_embeds
|
| 613 |
+
if self.text_encoder_2 is None:
|
| 614 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
| 615 |
+
else:
|
| 616 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
| 617 |
+
|
| 618 |
+
add_time_ids = self._get_add_time_ids(
|
| 619 |
+
original_size,
|
| 620 |
+
crops_coords_top_left,
|
| 621 |
+
target_size,
|
| 622 |
+
dtype=prompt_embeds.dtype,
|
| 623 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
if negative_original_size is not None and negative_target_size is not None:
|
| 627 |
+
negative_add_time_ids = self._get_add_time_ids(
|
| 628 |
+
negative_original_size,
|
| 629 |
+
negative_crops_coords_top_left,
|
| 630 |
+
negative_target_size,
|
| 631 |
+
dtype=prompt_embeds.dtype,
|
| 632 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
| 633 |
+
)
|
| 634 |
+
else:
|
| 635 |
+
negative_add_time_ids = add_time_ids
|
| 636 |
+
|
| 637 |
+
if self.do_classifier_free_guidance:
|
| 638 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 639 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
| 640 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
| 641 |
+
|
| 642 |
+
prompt_embeds = prompt_embeds.to(device)
|
| 643 |
+
add_text_embeds = add_text_embeds.to(device)
|
| 644 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
| 645 |
+
encoder_hidden_states = torch.cat([prompt_embeds, prompt_image_emb], dim=1)
|
| 646 |
+
|
| 647 |
+
# 8. Denoising loop
|
| 648 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 649 |
+
is_unet_compiled = is_compiled_module(self.unet)
|
| 650 |
+
is_controlnet_compiled = is_compiled_module(self.controlnet)
|
| 651 |
+
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
|
| 652 |
+
|
| 653 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 654 |
+
for i, t in enumerate(timesteps):
|
| 655 |
+
# Relevant thread:
|
| 656 |
+
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
|
| 657 |
+
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
|
| 658 |
+
torch._inductor.cudagraph_mark_step_begin()
|
| 659 |
+
# expand the latents if we are doing classifier free guidance
|
| 660 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 661 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 662 |
+
|
| 663 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
| 664 |
+
|
| 665 |
+
# controlnet(s) inference
|
| 666 |
+
if guess_mode and self.do_classifier_free_guidance:
|
| 667 |
+
# Infer ControlNet only for the conditional batch.
|
| 668 |
+
control_model_input = latents
|
| 669 |
+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
| 670 |
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
| 671 |
+
controlnet_added_cond_kwargs = {
|
| 672 |
+
"text_embeds": add_text_embeds.chunk(2)[1],
|
| 673 |
+
"time_ids": add_time_ids.chunk(2)[1],
|
| 674 |
+
}
|
| 675 |
+
else:
|
| 676 |
+
control_model_input = latent_model_input
|
| 677 |
+
controlnet_prompt_embeds = prompt_embeds
|
| 678 |
+
controlnet_added_cond_kwargs = added_cond_kwargs
|
| 679 |
+
|
| 680 |
+
if isinstance(controlnet_keep[i], list):
|
| 681 |
+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
| 682 |
+
else:
|
| 683 |
+
controlnet_cond_scale = controlnet_conditioning_scale
|
| 684 |
+
if isinstance(controlnet_cond_scale, list):
|
| 685 |
+
controlnet_cond_scale = controlnet_cond_scale[0]
|
| 686 |
+
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
| 687 |
+
|
| 688 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
| 689 |
+
control_model_input,
|
| 690 |
+
t,
|
| 691 |
+
encoder_hidden_states=prompt_image_emb,
|
| 692 |
+
controlnet_cond=image,
|
| 693 |
+
conditioning_scale=cond_scale,
|
| 694 |
+
guess_mode=guess_mode,
|
| 695 |
+
added_cond_kwargs=controlnet_added_cond_kwargs,
|
| 696 |
+
return_dict=False,
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
if guess_mode and self.do_classifier_free_guidance:
|
| 700 |
+
# Infered ControlNet only for the conditional batch.
|
| 701 |
+
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
| 702 |
+
# add 0 to the unconditional batch to keep it unchanged.
|
| 703 |
+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
| 704 |
+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
| 705 |
+
|
| 706 |
+
# predict the noise residual
|
| 707 |
+
noise_pred = self.unet(
|
| 708 |
+
latent_model_input,
|
| 709 |
+
t,
|
| 710 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 711 |
+
timestep_cond=timestep_cond,
|
| 712 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
| 713 |
+
down_block_additional_residuals=down_block_res_samples,
|
| 714 |
+
mid_block_additional_residual=mid_block_res_sample,
|
| 715 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 716 |
+
return_dict=False,
|
| 717 |
+
)[0]
|
| 718 |
+
|
| 719 |
+
# perform guidance
|
| 720 |
+
if self.do_classifier_free_guidance:
|
| 721 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 722 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 723 |
+
|
| 724 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 725 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 726 |
+
|
| 727 |
+
if callback_on_step_end is not None:
|
| 728 |
+
callback_kwargs = {}
|
| 729 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 730 |
+
callback_kwargs[k] = locals()[k]
|
| 731 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 732 |
+
|
| 733 |
+
latents = callback_outputs.pop("latents", latents)
|
| 734 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 735 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 736 |
+
|
| 737 |
+
# call the callback, if provided
|
| 738 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 739 |
+
progress_bar.update()
|
| 740 |
+
if callback is not None and i % callback_steps == 0:
|
| 741 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 742 |
+
callback(step_idx, t, latents)
|
| 743 |
+
|
| 744 |
+
if not output_type == "latent":
|
| 745 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
| 746 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
| 747 |
+
if needs_upcasting:
|
| 748 |
+
self.upcast_vae()
|
| 749 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
| 750 |
+
|
| 751 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
| 752 |
+
|
| 753 |
+
# cast back to fp16 if needed
|
| 754 |
+
if needs_upcasting:
|
| 755 |
+
self.vae.to(dtype=torch.float16)
|
| 756 |
+
else:
|
| 757 |
+
image = latents
|
| 758 |
+
|
| 759 |
+
if not output_type == "latent":
|
| 760 |
+
# apply watermark if available
|
| 761 |
+
if self.watermark is not None:
|
| 762 |
+
image = self.watermark.apply_watermark(image)
|
| 763 |
+
|
| 764 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 765 |
+
|
| 766 |
+
# Offload all models
|
| 767 |
+
self.maybe_free_model_hooks()
|
| 768 |
+
|
| 769 |
+
if not return_dict:
|
| 770 |
+
return (image,)
|
| 771 |
+
|
| 772 |
+
return StableDiffusionXLPipelineOutput(images=image)
|
src/pipelines/lora_pipeline.py
ADDED
|
@@ -0,0 +1,681 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import PIL.Image
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from transformers import (
|
| 9 |
+
CLIPImageProcessor,
|
| 10 |
+
CLIPTextModel,
|
| 11 |
+
CLIPTextModelWithProjection,
|
| 12 |
+
CLIPTokenizer,
|
| 13 |
+
CLIPVisionModelWithProjection,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
from diffusers.utils.import_utils import is_invisible_watermark_available
|
| 17 |
+
|
| 18 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 19 |
+
from diffusers.loaders import (
|
| 20 |
+
FromSingleFileMixin,
|
| 21 |
+
IPAdapterMixin,
|
| 22 |
+
StableDiffusionXLLoraLoaderMixin,
|
| 23 |
+
TextualInversionLoaderMixin,
|
| 24 |
+
)
|
| 25 |
+
from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
|
| 26 |
+
from diffusers.models.attention_processor import (
|
| 27 |
+
AttnProcessor2_0,
|
| 28 |
+
LoRAAttnProcessor2_0,
|
| 29 |
+
LoRAXFormersAttnProcessor,
|
| 30 |
+
XFormersAttnProcessor,
|
| 31 |
+
)
|
| 32 |
+
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
| 33 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
| 34 |
+
from diffusers.utils import (
|
| 35 |
+
USE_PEFT_BACKEND,
|
| 36 |
+
deprecate,
|
| 37 |
+
logging,
|
| 38 |
+
replace_example_docstring,
|
| 39 |
+
scale_lora_layers,
|
| 40 |
+
unscale_lora_layers,
|
| 41 |
+
)
|
| 42 |
+
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
| 43 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 44 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
if is_invisible_watermark_available():
|
| 48 |
+
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
| 49 |
+
|
| 50 |
+
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
| 51 |
+
from diffusers import StableDiffusionXLControlNetPipeline
|
| 52 |
+
from PIL import Image
|
| 53 |
+
from torchvision.transforms.functional import to_tensor
|
| 54 |
+
from einops import rearrange
|
| 55 |
+
from torch import einsum
|
| 56 |
+
import math
|
| 57 |
+
from torchvision.utils import save_image
|
| 58 |
+
|
| 59 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 60 |
+
|
| 61 |
+
class RegionControlNet_AttnProcessor:
|
| 62 |
+
def __init__(self, attention_op=None, controller=None, place_in_unet=None):
|
| 63 |
+
self.attention_op = attention_op
|
| 64 |
+
self.controller = controller
|
| 65 |
+
self.place_in_unet = place_in_unet
|
| 66 |
+
|
| 67 |
+
def __call__(
|
| 68 |
+
self,
|
| 69 |
+
attn,
|
| 70 |
+
hidden_states: torch.FloatTensor,
|
| 71 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 72 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 73 |
+
temb: Optional[torch.FloatTensor] = None,
|
| 74 |
+
scale: float = 1.0,
|
| 75 |
+
**cross_attention_kwargs
|
| 76 |
+
) -> torch.Tensor:
|
| 77 |
+
residual = hidden_states
|
| 78 |
+
|
| 79 |
+
args = () if USE_PEFT_BACKEND else (scale,)
|
| 80 |
+
|
| 81 |
+
if attn.spatial_norm is not None:
|
| 82 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
| 83 |
+
|
| 84 |
+
input_ndim = hidden_states.ndim
|
| 85 |
+
|
| 86 |
+
if input_ndim == 4:
|
| 87 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 88 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 89 |
+
|
| 90 |
+
batch_size, sequence_length, _ = (
|
| 91 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 92 |
+
)
|
| 93 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 94 |
+
|
| 95 |
+
if attn.group_norm is not None:
|
| 96 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 97 |
+
|
| 98 |
+
query = attn.to_q(hidden_states, *args)
|
| 99 |
+
|
| 100 |
+
is_cross = True
|
| 101 |
+
if encoder_hidden_states is None:
|
| 102 |
+
is_cross = False
|
| 103 |
+
encoder_hidden_states = hidden_states
|
| 104 |
+
elif attn.norm_cross:
|
| 105 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 106 |
+
|
| 107 |
+
key = attn.to_k(encoder_hidden_states, *args)
|
| 108 |
+
value = attn.to_v(encoder_hidden_states, *args)
|
| 109 |
+
|
| 110 |
+
query = attn.head_to_batch_dim(query)
|
| 111 |
+
key = attn.head_to_batch_dim(key)
|
| 112 |
+
value = attn.head_to_batch_dim(value)
|
| 113 |
+
|
| 114 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
| 115 |
+
attention_probs = self.controller(attention_probs, is_cross, self.place_in_unet)
|
| 116 |
+
hidden_states = torch.bmm(attention_probs, value)
|
| 117 |
+
|
| 118 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 119 |
+
|
| 120 |
+
# linear proj
|
| 121 |
+
hidden_states = attn.to_out[0](hidden_states, *args)
|
| 122 |
+
# dropout
|
| 123 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 124 |
+
|
| 125 |
+
if input_ndim == 4:
|
| 126 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 127 |
+
|
| 128 |
+
if attn.residual_connection:
|
| 129 |
+
hidden_states = hidden_states + residual
|
| 130 |
+
|
| 131 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 132 |
+
|
| 133 |
+
return hidden_states
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def revise_regionally_controlnet_forward(unet, controller):
|
| 137 |
+
def change_forward(unet, count, place_in_unet):
|
| 138 |
+
for name, layer in unet.named_children():
|
| 139 |
+
if layer.__class__.__name__ == 'Attention':
|
| 140 |
+
layer.set_processor(RegionControlNet_AttnProcessor(controller=controller, place_in_unet=place_in_unet))
|
| 141 |
+
if 'attn2' in name:
|
| 142 |
+
count += 1
|
| 143 |
+
else:
|
| 144 |
+
count = change_forward(layer, count, place_in_unet)
|
| 145 |
+
return count
|
| 146 |
+
|
| 147 |
+
# use this to ensure the order
|
| 148 |
+
cross_attention_idx = change_forward(unet.down_blocks, 0, "down")
|
| 149 |
+
cross_attention_idx = change_forward(unet.mid_block, cross_attention_idx, "up")
|
| 150 |
+
cross_attention_idx = change_forward(unet.up_blocks, cross_attention_idx, "mid")
|
| 151 |
+
print(f'Number of attention layer registered {cross_attention_idx}')
|
| 152 |
+
controller.num_att_layers = cross_attention_idx*2
|
| 153 |
+
|
| 154 |
+
class LoraMultiConceptPipeline(StableDiffusionXLControlNetPipeline):
|
| 155 |
+
# leave controlnet out on purpose because it iterates with unet
|
| 156 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
|
| 157 |
+
_optional_components = [
|
| 158 |
+
"tokenizer",
|
| 159 |
+
"tokenizer_2",
|
| 160 |
+
"text_encoder",
|
| 161 |
+
"text_encoder_2",
|
| 162 |
+
"feature_extractor",
|
| 163 |
+
"image_encoder",
|
| 164 |
+
]
|
| 165 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
| 166 |
+
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
vae: AutoencoderKL,
|
| 170 |
+
text_encoder: CLIPTextModel,
|
| 171 |
+
text_encoder_2: CLIPTextModelWithProjection,
|
| 172 |
+
tokenizer: CLIPTokenizer,
|
| 173 |
+
tokenizer_2: CLIPTokenizer,
|
| 174 |
+
unet: UNet2DConditionModel,
|
| 175 |
+
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
|
| 176 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 177 |
+
force_zeros_for_empty_prompt: bool = True,
|
| 178 |
+
add_watermarker: Optional[bool] = None,
|
| 179 |
+
feature_extractor: CLIPImageProcessor = None,
|
| 180 |
+
image_encoder: CLIPVisionModelWithProjection = None
|
| 181 |
+
):
|
| 182 |
+
if isinstance(controlnet, (list, tuple)):
|
| 183 |
+
controlnet = MultiControlNetModel(controlnet)
|
| 184 |
+
|
| 185 |
+
self.register_modules(
|
| 186 |
+
vae=vae,
|
| 187 |
+
text_encoder=text_encoder,
|
| 188 |
+
text_encoder_2=text_encoder_2,
|
| 189 |
+
tokenizer=tokenizer,
|
| 190 |
+
tokenizer_2=tokenizer_2,
|
| 191 |
+
unet=unet,
|
| 192 |
+
controlnet=controlnet,
|
| 193 |
+
scheduler=scheduler,
|
| 194 |
+
feature_extractor=feature_extractor,
|
| 195 |
+
image_encoder=image_encoder,
|
| 196 |
+
)
|
| 197 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 198 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
|
| 199 |
+
self.control_image_processor = VaeImageProcessor(
|
| 200 |
+
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
| 201 |
+
)
|
| 202 |
+
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
|
| 203 |
+
|
| 204 |
+
if add_watermarker:
|
| 205 |
+
self.watermark = StableDiffusionXLWatermarker()
|
| 206 |
+
else:
|
| 207 |
+
self.watermark = None
|
| 208 |
+
|
| 209 |
+
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
| 210 |
+
|
| 211 |
+
@torch.no_grad()
|
| 212 |
+
def __call__(
|
| 213 |
+
self,
|
| 214 |
+
prompt: Union[str, List[str]] = None,
|
| 215 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 216 |
+
image: PipelineImageInput = None,
|
| 217 |
+
height: Optional[int] = None,
|
| 218 |
+
width: Optional[int] = None,
|
| 219 |
+
num_inference_steps: int = 50,
|
| 220 |
+
guidance_scale: float = 5.0,
|
| 221 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 222 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 223 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 224 |
+
eta: float = 0.0,
|
| 225 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 226 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 227 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 228 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 229 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 230 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 231 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
| 232 |
+
output_type: Optional[str] = "pil",
|
| 233 |
+
return_dict: bool = True,
|
| 234 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 235 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
| 236 |
+
guess_mode: bool = False,
|
| 237 |
+
control_guidance_start: Union[float, List[float]] = 0.0,
|
| 238 |
+
control_guidance_end: Union[float, List[float]] = 1.0,
|
| 239 |
+
original_size: Tuple[int, int] = None,
|
| 240 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 241 |
+
target_size: Tuple[int, int] = None,
|
| 242 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
| 243 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
| 244 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
| 245 |
+
clip_skip: Optional[int] = None,
|
| 246 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 247 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 248 |
+
controller=None,
|
| 249 |
+
concept_models=None,
|
| 250 |
+
stage=None,
|
| 251 |
+
region_masks=None,
|
| 252 |
+
lora_list=None,
|
| 253 |
+
styleL=None,
|
| 254 |
+
**kwargs,
|
| 255 |
+
):
|
| 256 |
+
callback = kwargs.pop("callback", None)
|
| 257 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
| 258 |
+
|
| 259 |
+
if callback is not None:
|
| 260 |
+
deprecate(
|
| 261 |
+
"callback",
|
| 262 |
+
"1.0.0",
|
| 263 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
| 264 |
+
)
|
| 265 |
+
if callback_steps is not None:
|
| 266 |
+
deprecate(
|
| 267 |
+
"callback_steps",
|
| 268 |
+
"1.0.0",
|
| 269 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
| 273 |
+
|
| 274 |
+
# align format for control guidance
|
| 275 |
+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
| 276 |
+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
| 277 |
+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
| 278 |
+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
| 279 |
+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
| 280 |
+
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
| 281 |
+
control_guidance_start, control_guidance_end = (
|
| 282 |
+
mult * [control_guidance_start],
|
| 283 |
+
mult * [control_guidance_end],
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
self._guidance_scale = guidance_scale
|
| 287 |
+
self._clip_skip = clip_skip
|
| 288 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 289 |
+
|
| 290 |
+
# 2. Define call parameters
|
| 291 |
+
batch_size = 2
|
| 292 |
+
|
| 293 |
+
device = self._execution_device
|
| 294 |
+
|
| 295 |
+
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
| 296 |
+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
| 297 |
+
|
| 298 |
+
global_pool_conditions = (
|
| 299 |
+
controlnet.config.global_pool_conditions
|
| 300 |
+
if isinstance(controlnet, ControlNetModel)
|
| 301 |
+
else controlnet.nets[0].config.global_pool_conditions
|
| 302 |
+
)
|
| 303 |
+
guess_mode = guess_mode or global_pool_conditions
|
| 304 |
+
|
| 305 |
+
# 3.1 Encode input prompt
|
| 306 |
+
text_encoder_lora_scale = (
|
| 307 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
global_prompt = prompt[0]
|
| 311 |
+
global_negative_prompt = negative_prompt
|
| 312 |
+
region_prompts = [pt[0] for pt in prompt[1]]
|
| 313 |
+
region_negative_prompts = [pt[1] for pt in prompt[1]]
|
| 314 |
+
|
| 315 |
+
(
|
| 316 |
+
prompt_embeds,
|
| 317 |
+
negative_prompt_embeds,
|
| 318 |
+
pooled_prompt_embeds,
|
| 319 |
+
negative_pooled_prompt_embeds,
|
| 320 |
+
) = self.encode_prompt(
|
| 321 |
+
global_prompt,
|
| 322 |
+
prompt_2,
|
| 323 |
+
device,
|
| 324 |
+
num_images_per_prompt,
|
| 325 |
+
self.do_classifier_free_guidance,
|
| 326 |
+
global_negative_prompt,
|
| 327 |
+
negative_prompt_2,
|
| 328 |
+
prompt_embeds=prompt_embeds,
|
| 329 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 330 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 331 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
| 332 |
+
lora_scale=text_encoder_lora_scale,
|
| 333 |
+
clip_skip=self.clip_skip,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
region_prompt_embeds_list = []
|
| 337 |
+
region_add_text_embeds_list = []
|
| 338 |
+
for lora_param, region_prompt, region_negative_prompt in zip(lora_list, region_prompts, region_negative_prompts):
|
| 339 |
+
if styleL:
|
| 340 |
+
concept_models.set_adapters([lora_param, "style"], adapter_weights=[0.7, 0.5])
|
| 341 |
+
else:
|
| 342 |
+
concept_models.set_adapters(lora_param)
|
| 343 |
+
region_prompt_embeds, region_negative_prompt_embeds, region_pooled_prompt_embeds, region_negative_pooled_prompt_embeds = concept_models.encode_prompt(
|
| 344 |
+
prompt=region_prompt, device=concept_models._execution_device, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=region_negative_prompt, lora_scale=text_encoder_lora_scale
|
| 345 |
+
)
|
| 346 |
+
region_prompt_embeds_list.append(torch.concat([region_negative_prompt_embeds, region_prompt_embeds], dim=0).to(concept_models._execution_device))
|
| 347 |
+
region_add_text_embeds_list.append(torch.concat([region_negative_pooled_prompt_embeds, region_pooled_prompt_embeds], dim=0).to(concept_models._execution_device))
|
| 348 |
+
|
| 349 |
+
if stage==2:
|
| 350 |
+
mask_list = [mask.float().to(dtype=prompt_embeds.dtype, device=device) if mask is not None else None for mask in region_masks]
|
| 351 |
+
|
| 352 |
+
# 4. Prepare image
|
| 353 |
+
if isinstance(controlnet, ControlNetModel) and image is not None:
|
| 354 |
+
image = self.prepare_image(
|
| 355 |
+
image=image,
|
| 356 |
+
width=width,
|
| 357 |
+
height=height,
|
| 358 |
+
batch_size=batch_size * num_images_per_prompt,
|
| 359 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 360 |
+
device=device,
|
| 361 |
+
dtype=controlnet.dtype,
|
| 362 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 363 |
+
guess_mode=guess_mode,
|
| 364 |
+
)
|
| 365 |
+
height, width = image.shape[-2:]
|
| 366 |
+
elif isinstance(controlnet, MultiControlNetModel) and image is not None:
|
| 367 |
+
images = []
|
| 368 |
+
|
| 369 |
+
for image_ in image:
|
| 370 |
+
image_ = self.prepare_image(
|
| 371 |
+
image=image_,
|
| 372 |
+
width=width,
|
| 373 |
+
height=height,
|
| 374 |
+
batch_size=batch_size * num_images_per_prompt,
|
| 375 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 376 |
+
device=device,
|
| 377 |
+
dtype=controlnet.dtype,
|
| 378 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 379 |
+
guess_mode=guess_mode,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
images.append(image_)
|
| 383 |
+
|
| 384 |
+
image = images
|
| 385 |
+
height, width = image[0].shape[-2:]
|
| 386 |
+
else:
|
| 387 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 388 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 389 |
+
|
| 390 |
+
# 5. Prepare timesteps
|
| 391 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 392 |
+
timesteps = self.scheduler.timesteps
|
| 393 |
+
self._num_timesteps = len(timesteps)
|
| 394 |
+
|
| 395 |
+
# 6. Prepare latent variables
|
| 396 |
+
num_channels_latents = self.unet.config.in_channels
|
| 397 |
+
latents = self.prepare_latents(
|
| 398 |
+
batch_size//2 * num_images_per_prompt,
|
| 399 |
+
num_channels_latents,
|
| 400 |
+
height,
|
| 401 |
+
width,
|
| 402 |
+
prompt_embeds.dtype,
|
| 403 |
+
device,
|
| 404 |
+
generator,
|
| 405 |
+
latents,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# 6.1 repeat latent
|
| 409 |
+
latents = torch.cat([latents, latents.clone()])
|
| 410 |
+
|
| 411 |
+
timestep_cond = None
|
| 412 |
+
if self.unet.config.time_cond_proj_dim is not None:
|
| 413 |
+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
| 414 |
+
timestep_cond = self.get_guidance_scale_embedding(
|
| 415 |
+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
| 416 |
+
).to(device=device, dtype=latents.dtype)
|
| 417 |
+
|
| 418 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 419 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 420 |
+
|
| 421 |
+
# 7.1 Create tensor stating which controlnets to keep
|
| 422 |
+
controlnet_keep = []
|
| 423 |
+
for i in range(len(timesteps)):
|
| 424 |
+
keeps = [
|
| 425 |
+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
| 426 |
+
for s, e in zip(control_guidance_start, control_guidance_end)
|
| 427 |
+
]
|
| 428 |
+
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
| 429 |
+
|
| 430 |
+
# 7.2 Prepare added time ids & embeddings
|
| 431 |
+
if isinstance(image, list):
|
| 432 |
+
original_size = original_size or image[0].shape[-2:]
|
| 433 |
+
else:
|
| 434 |
+
original_size = original_size or (height, width)
|
| 435 |
+
target_size = target_size or (height, width)
|
| 436 |
+
|
| 437 |
+
add_text_embeds = pooled_prompt_embeds
|
| 438 |
+
if self.text_encoder_2 is None:
|
| 439 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
| 440 |
+
else:
|
| 441 |
+
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
|
| 442 |
+
|
| 443 |
+
add_time_ids = self._get_add_time_ids(
|
| 444 |
+
original_size,
|
| 445 |
+
crops_coords_top_left,
|
| 446 |
+
target_size,
|
| 447 |
+
dtype=prompt_embeds.dtype,
|
| 448 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
add_time_ids_list = []
|
| 452 |
+
for _ in lora_list:
|
| 453 |
+
region_add_time_ids = concept_models._get_add_time_ids(original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim)
|
| 454 |
+
add_time_ids_list.append(torch.concat([region_add_time_ids, region_add_time_ids], dim=0).to(concept_models._execution_device))
|
| 455 |
+
|
| 456 |
+
if negative_original_size is not None and negative_target_size is not None:
|
| 457 |
+
negative_add_time_ids = self._get_add_time_ids(
|
| 458 |
+
negative_original_size,
|
| 459 |
+
negative_crops_coords_top_left,
|
| 460 |
+
negative_target_size,
|
| 461 |
+
dtype=prompt_embeds.dtype,
|
| 462 |
+
text_encoder_projection_dim=text_encoder_projection_dim,
|
| 463 |
+
)
|
| 464 |
+
else:
|
| 465 |
+
negative_add_time_ids = add_time_ids
|
| 466 |
+
|
| 467 |
+
if self.do_classifier_free_guidance:
|
| 468 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 469 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
| 470 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
| 471 |
+
|
| 472 |
+
prompt_embeds = prompt_embeds.to(device)
|
| 473 |
+
add_text_embeds = add_text_embeds.to(device)
|
| 474 |
+
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
| 475 |
+
|
| 476 |
+
# 8. Denoising loop
|
| 477 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 478 |
+
is_unet_compiled = is_compiled_module(self.unet)
|
| 479 |
+
is_controlnet_compiled = is_compiled_module(self.controlnet)
|
| 480 |
+
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
|
| 481 |
+
# hyper-parameters
|
| 482 |
+
scale_range = np.linspace(1, 0.5, len(self.scheduler.timesteps))
|
| 483 |
+
|
| 484 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 485 |
+
for i, t in enumerate(timesteps):
|
| 486 |
+
# Relevant thread:
|
| 487 |
+
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
|
| 488 |
+
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
|
| 489 |
+
torch._inductor.cudagraph_mark_step_begin()
|
| 490 |
+
# expand the latents if we are doing classifier free guidance
|
| 491 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 492 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 493 |
+
|
| 494 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
| 495 |
+
|
| 496 |
+
# controlnet(s) inference
|
| 497 |
+
if guess_mode and self.do_classifier_free_guidance:
|
| 498 |
+
# Infer ControlNet only for the conditional batch.
|
| 499 |
+
control_model_input = latents
|
| 500 |
+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
| 501 |
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
| 502 |
+
controlnet_added_cond_kwargs = {
|
| 503 |
+
"text_embeds": add_text_embeds.chunk(2)[1],
|
| 504 |
+
"time_ids": add_time_ids.chunk(2)[1],
|
| 505 |
+
}
|
| 506 |
+
else:
|
| 507 |
+
control_model_input = latent_model_input
|
| 508 |
+
controlnet_prompt_embeds = prompt_embeds
|
| 509 |
+
controlnet_added_cond_kwargs = added_cond_kwargs
|
| 510 |
+
|
| 511 |
+
if isinstance(controlnet_keep[i], list):
|
| 512 |
+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
| 513 |
+
else:
|
| 514 |
+
controlnet_cond_scale = controlnet_conditioning_scale
|
| 515 |
+
if isinstance(controlnet_cond_scale, list):
|
| 516 |
+
controlnet_cond_scale = controlnet_cond_scale[0]
|
| 517 |
+
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
| 518 |
+
|
| 519 |
+
if image is not None:
|
| 520 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
| 521 |
+
control_model_input,
|
| 522 |
+
t,
|
| 523 |
+
encoder_hidden_states=controlnet_prompt_embeds,
|
| 524 |
+
controlnet_cond=image,
|
| 525 |
+
conditioning_scale=cond_scale,
|
| 526 |
+
guess_mode=guess_mode,
|
| 527 |
+
added_cond_kwargs=controlnet_added_cond_kwargs,
|
| 528 |
+
return_dict=False,
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
if guess_mode and self.do_classifier_free_guidance:
|
| 532 |
+
# Infered ControlNet only for the conditional batch.
|
| 533 |
+
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
| 534 |
+
# add 0 to the unconditional batch to keep it unchanged.
|
| 535 |
+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
| 536 |
+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
| 537 |
+
|
| 538 |
+
else:
|
| 539 |
+
down_block_res_samples = None
|
| 540 |
+
mid_block_res_sample = None
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
# predict the noise residual
|
| 545 |
+
if image is not None:
|
| 546 |
+
noise_pred = self.unet(
|
| 547 |
+
latent_model_input,
|
| 548 |
+
t,
|
| 549 |
+
encoder_hidden_states=prompt_embeds,
|
| 550 |
+
timestep_cond=timestep_cond,
|
| 551 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
| 552 |
+
down_block_additional_residuals=down_block_res_samples,
|
| 553 |
+
mid_block_additional_residual=mid_block_res_sample,
|
| 554 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 555 |
+
return_dict=False,
|
| 556 |
+
)[0]
|
| 557 |
+
else:
|
| 558 |
+
noise_pred = self.unet(
|
| 559 |
+
latent_model_input,
|
| 560 |
+
t,
|
| 561 |
+
encoder_hidden_states=prompt_embeds,
|
| 562 |
+
timestep_cond=timestep_cond,
|
| 563 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
| 564 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 565 |
+
return_dict=False,
|
| 566 |
+
)[0]
|
| 567 |
+
|
| 568 |
+
if i > 15 and stage == 2:
|
| 569 |
+
region_mask = self.get_region_mask(mask_list, noise_pred.shape[2], noise_pred.shape[3])
|
| 570 |
+
edit_noise = torch.concat([noise_pred[1:2], noise_pred[3:4]], dim=0)
|
| 571 |
+
new_noise_pred = torch.zeros_like(edit_noise)
|
| 572 |
+
new_noise_pred[:, :, region_mask == 0] = edit_noise[:, :, region_mask == 0]
|
| 573 |
+
replace_ratio = 1.0
|
| 574 |
+
new_noise_pred[:, :, region_mask != 0] = (1 - replace_ratio) * edit_noise[:, :, region_mask != 0]
|
| 575 |
+
|
| 576 |
+
for region_prompt_embeds, region_add_text_embeds, region_add_time_ids, concept_mask, region_prompt, lora_param in zip(region_prompt_embeds_list, region_add_text_embeds_list, add_time_ids_list, mask_list, region_prompts, lora_list):
|
| 577 |
+
if concept_mask is not None:
|
| 578 |
+
concept_mask = F.interpolate(concept_mask.unsqueeze(0).unsqueeze(0),
|
| 579 |
+
size=(noise_pred.shape[2], noise_pred.shape[3]),
|
| 580 |
+
mode='nearest').squeeze().to(dtype=noise_pred.dtype, device=concept_models._execution_device)
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
region_latent_model_input = latent_model_input[3:4].clone().to(concept_models._execution_device)
|
| 584 |
+
|
| 585 |
+
region_latent_model_input = torch.cat([region_latent_model_input] * 2)
|
| 586 |
+
region_added_cond_kwargs = {"text_embeds": region_add_text_embeds,
|
| 587 |
+
"time_ids": region_add_time_ids}
|
| 588 |
+
if styleL:
|
| 589 |
+
concept_models.set_adapters([lora_param, "style"], adapter_weights=[0.7, 0.5])
|
| 590 |
+
else:
|
| 591 |
+
concept_models.set_adapters(lora_param)
|
| 592 |
+
region_noise_pred = concept_models.unet(
|
| 593 |
+
region_latent_model_input,
|
| 594 |
+
t,
|
| 595 |
+
encoder_hidden_states=region_prompt_embeds,
|
| 596 |
+
cross_attention_kwargs={'scale': 0.8},
|
| 597 |
+
added_cond_kwargs=region_added_cond_kwargs,
|
| 598 |
+
return_dict=False,
|
| 599 |
+
)[0]
|
| 600 |
+
|
| 601 |
+
new_noise_pred = new_noise_pred.to(concept_models._execution_device)
|
| 602 |
+
new_noise_pred[:, :, concept_mask==1] += replace_ratio * (region_noise_pred[:, :, concept_mask==1] / (concept_mask.reshape(1, 1, *concept_mask.shape)[:, :, concept_mask==1].to(region_noise_pred.device)))
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
new_noise_pred = new_noise_pred.to(noise_pred.device)
|
| 606 |
+
noise_pred[1, :, :, :] = new_noise_pred[0]
|
| 607 |
+
noise_pred[3, :, :, :] = new_noise_pred[1]
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
if self.do_classifier_free_guidance:
|
| 611 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 612 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 613 |
+
|
| 614 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 615 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 616 |
+
|
| 617 |
+
if callback_on_step_end is not None:
|
| 618 |
+
callback_kwargs = {}
|
| 619 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 620 |
+
callback_kwargs[k] = locals()[k]
|
| 621 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 622 |
+
|
| 623 |
+
latents = callback_outputs.pop("latents", latents)
|
| 624 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 625 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 626 |
+
|
| 627 |
+
# call the callback, if provided
|
| 628 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 629 |
+
progress_bar.update()
|
| 630 |
+
if callback is not None and i % callback_steps == 0:
|
| 631 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 632 |
+
callback(step_idx, t, latents)
|
| 633 |
+
|
| 634 |
+
# manually for max memory savings
|
| 635 |
+
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
| 636 |
+
self.upcast_vae()
|
| 637 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
| 638 |
+
if stage==2:
|
| 639 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
| 640 |
+
if not output_type == "latent":
|
| 641 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
| 642 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
| 643 |
+
|
| 644 |
+
if needs_upcasting:
|
| 645 |
+
self.upcast_vae()
|
| 646 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
| 647 |
+
|
| 648 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
| 649 |
+
|
| 650 |
+
# cast back to fp16 if needed
|
| 651 |
+
if needs_upcasting:
|
| 652 |
+
self.vae.to(dtype=torch.float16)
|
| 653 |
+
else:
|
| 654 |
+
image = latents
|
| 655 |
+
|
| 656 |
+
if not output_type == "latent":
|
| 657 |
+
# apply watermark if available
|
| 658 |
+
if self.watermark is not None:
|
| 659 |
+
image = self.watermark.apply_watermark(image)
|
| 660 |
+
|
| 661 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 662 |
+
|
| 663 |
+
# Offload all models
|
| 664 |
+
self.maybe_free_model_hooks()
|
| 665 |
+
|
| 666 |
+
if not return_dict:
|
| 667 |
+
return (image,)
|
| 668 |
+
|
| 669 |
+
return StableDiffusionXLPipelineOutput(images=image)
|
| 670 |
+
|
| 671 |
+
def check_image(self, image, prompt, prompt_embeds):
|
| 672 |
+
pass
|
| 673 |
+
|
| 674 |
+
def get_region_mask(self, mask_list, feat_height, feat_width):
|
| 675 |
+
exclusive_mask = torch.zeros((feat_height, feat_width))
|
| 676 |
+
for mask in mask_list:
|
| 677 |
+
if mask is not None:
|
| 678 |
+
mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(feat_height, feat_width),
|
| 679 |
+
mode='nearest').squeeze().to(dtype=exclusive_mask.dtype, device=exclusive_mask.device)
|
| 680 |
+
exclusive_mask = ((mask == 1) | (exclusive_mask == 1)).to(dtype=mask.dtype)
|
| 681 |
+
return exclusive_mask
|
src/prompt_attention/p2p_attention.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Union, Tuple, List, Callable, Dict
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as nnf
|
| 4 |
+
import numpy as np
|
| 5 |
+
import abc
|
| 6 |
+
import src.prompt_attention.p2p_utils as p2p_utils
|
| 7 |
+
import src.prompt_attention.seq_aligner as seq_aligner
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class AttentionControl(abc.ABC):
|
| 12 |
+
|
| 13 |
+
def step_callback(self, x_t):
|
| 14 |
+
return x_t
|
| 15 |
+
|
| 16 |
+
def between_steps(self):
|
| 17 |
+
return
|
| 18 |
+
|
| 19 |
+
@property
|
| 20 |
+
def num_uncond_att_layers(self):
|
| 21 |
+
# return self.num_att_layers if self.low_resource else 0
|
| 22 |
+
return 0
|
| 23 |
+
|
| 24 |
+
@abc.abstractmethod
|
| 25 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
| 26 |
+
raise NotImplementedError
|
| 27 |
+
|
| 28 |
+
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
| 29 |
+
if self.cur_att_layer >= self.num_uncond_att_layers:
|
| 30 |
+
if self.low_resource:
|
| 31 |
+
attn = self.forward(attn, is_cross, place_in_unet)
|
| 32 |
+
else:
|
| 33 |
+
h = attn.shape[0]
|
| 34 |
+
attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
|
| 35 |
+
self.cur_att_layer += 1
|
| 36 |
+
if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
|
| 37 |
+
self.cur_att_layer = 0
|
| 38 |
+
self.cur_step += 1
|
| 39 |
+
self.between_steps()
|
| 40 |
+
return attn
|
| 41 |
+
|
| 42 |
+
def reset(self):
|
| 43 |
+
self.cur_step = 0
|
| 44 |
+
self.cur_att_layer = 0
|
| 45 |
+
|
| 46 |
+
def __init__(self, low_resource=False, width=None, height=None):
|
| 47 |
+
self.cur_step = 0
|
| 48 |
+
self.num_att_layers = -1
|
| 49 |
+
self.cur_att_layer = 0
|
| 50 |
+
self.low_resource = low_resource
|
| 51 |
+
self.width = width
|
| 52 |
+
self.height = height
|
| 53 |
+
|
| 54 |
+
class AttentionStore(AttentionControl):
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
def get_empty_store():
|
| 58 |
+
return {"down_cross": [], "mid_cross": [], "up_cross": [],
|
| 59 |
+
"down_self": [], "mid_self": [], "up_self": []}
|
| 60 |
+
|
| 61 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
| 62 |
+
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
|
| 63 |
+
# if attn.shape[1] <= att_size * 64:
|
| 64 |
+
return attn
|
| 65 |
+
|
| 66 |
+
def between_steps(self):
|
| 67 |
+
if self.save_global_store:
|
| 68 |
+
if len(self.attention_store) == 0:
|
| 69 |
+
self.attention_store = self.step_store
|
| 70 |
+
else:
|
| 71 |
+
for key in self.attention_store:
|
| 72 |
+
for i in range(len(self.attention_store[key])):
|
| 73 |
+
self.attention_store[key][i] += self.step_store[key][i]
|
| 74 |
+
else:
|
| 75 |
+
self.attention_store = self.step_store
|
| 76 |
+
self.step_store = self.get_empty_store()
|
| 77 |
+
|
| 78 |
+
def get_average_attention(self):
|
| 79 |
+
average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in
|
| 80 |
+
self.attention_store}
|
| 81 |
+
return average_attention
|
| 82 |
+
|
| 83 |
+
def reset(self):
|
| 84 |
+
super(AttentionStore, self).reset()
|
| 85 |
+
self.step_store = self.get_empty_store()
|
| 86 |
+
self.attention_store = {}
|
| 87 |
+
|
| 88 |
+
def __init__(self, width, height, low_resolution=False, save_global_store=False):
|
| 89 |
+
super(AttentionStore, self).__init__(low_resolution, width, height)
|
| 90 |
+
self.step_store = self.get_empty_store()
|
| 91 |
+
self.attention_store = {}
|
| 92 |
+
self.save_global_store = save_global_store
|
| 93 |
+
|
| 94 |
+
class AttentionControlEdit(AttentionStore, abc.ABC):
|
| 95 |
+
def __init__(self, prompts, num_steps: int,
|
| 96 |
+
cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
|
| 97 |
+
self_replace_steps: Union[float, Tuple[float, float]],
|
| 98 |
+
local_blend=None, width=None, height=None, tokenizer=None, device=None):
|
| 99 |
+
super(AttentionControlEdit, self).__init__(width, height)
|
| 100 |
+
self.batch_size = len(prompts)
|
| 101 |
+
self.cross_replace_alpha = p2p_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps,
|
| 102 |
+
tokenizer).to(device)
|
| 103 |
+
if type(self_replace_steps) is float:
|
| 104 |
+
self_replace_steps = 0, self_replace_steps
|
| 105 |
+
self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
|
| 106 |
+
self.local_blend = local_blend
|
| 107 |
+
|
| 108 |
+
def step_callback(self, x_t):
|
| 109 |
+
print("step_callback")
|
| 110 |
+
if self.local_blend is not None:
|
| 111 |
+
x_t = self.local_blend(x_t, self.attention_store)
|
| 112 |
+
return x_t
|
| 113 |
+
|
| 114 |
+
def replace_self_attention(self, attn_base, att_replace):
|
| 115 |
+
if att_replace.shape[2] <= self.width * self.height:
|
| 116 |
+
return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
|
| 117 |
+
else:
|
| 118 |
+
return att_replace
|
| 119 |
+
|
| 120 |
+
@abc.abstractmethod
|
| 121 |
+
def replace_cross_attention(self, attn_base, att_replace):
|
| 122 |
+
raise NotImplementedError
|
| 123 |
+
|
| 124 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
| 125 |
+
super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
|
| 126 |
+
if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
|
| 127 |
+
h = attn.shape[0] // (self.batch_size)
|
| 128 |
+
attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
|
| 129 |
+
attn_base, attn_repalce = attn[0], attn[1:]
|
| 130 |
+
if is_cross:
|
| 131 |
+
alpha_words = self.cross_replace_alpha[self.cur_step]
|
| 132 |
+
attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (
|
| 133 |
+
1 - alpha_words) * attn_repalce
|
| 134 |
+
attn[1:] = attn_repalce_new
|
| 135 |
+
else:
|
| 136 |
+
attn[1:] = self.replace_self_attention(attn_base, attn_repalce)
|
| 137 |
+
attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
|
| 138 |
+
return attn
|
| 139 |
+
|
| 140 |
+
class AttentionReplace(AttentionControlEdit):
|
| 141 |
+
def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, width, height,
|
| 142 |
+
local_blend = None, tokenizer=None, device=None, dtype=None):
|
| 143 |
+
super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, width, height, tokenizer=tokenizer, device=device)
|
| 144 |
+
self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(dtype=dtype, device=device)
|
| 145 |
+
|
| 146 |
+
def replace_cross_attention(self, attn_base, att_replace):
|
| 147 |
+
return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper)
|
| 148 |
+
|