File size: 1,319 Bytes
9da6974
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
---
license: apache-2.0
---

Use like:

```py
CACHE_ROOT        = pathlib.Path("qwen-image-int8-quanto ")   # where we store INT8 modules
TRANSFORMER_DIR   = CACHE_ROOT / "qwen_image_transformer_int8"
TEXT_ENCODER_DIR  = CACHE_ROOT / "qwen_text_encoder_int8"

def load_quantized_modules(transformer_dir: pathlib.Path, text_encoder_dir: pathlib.Path):
    """
    Load quantized modules (we saved them with the exact filenames the loaders expect).
    """
    tr = torch.load(str(transformer_dir / 'pytorch_model.bin'), weights_only=False)
    te = torch.load(str(text_encoder_dir / 'pytorch_model.bin'), weights_only=False)
    return tr, te

def build_pipe(cls, transformer_dir: pathlib.Path, text_encoder_dir: pathlib.Path):
    """
    Build a pipeline of class `cls` by loading the quantized modules from disk.
    Fresh module instances each time avoids offload-hook/state reuse hangs.
    """
    transformer, text_encoder = load_quantized_modules(transformer_dir, text_encoder_dir)
    pipe = cls.from_pretrained(
        BASE_MODEL_ID,
        transformer=transformer,
        text_encoder=text_encoder,
        torch_dtype=torch.bfloat16,
        use_safetensors=True,
        low_cpu_mem_usage=True,
    )
    pipe.enable_model_cpu_offload()
    pipe.set_progress_bar_config(disable=False)
    return pipe
```