Spaces:
Runtime error
Runtime error
Commit
·
4bfb360
1
Parent(s):
4d20c2f
vllm
Browse files- app.py +30 -51
- app_naive.py +160 -0
- requirements.txt +2 -1
- serve/README.md +63 -0
- serve/gpt_model.py +369 -0
- serve/gpu_executor.py +201 -0
- serve/llm.py +267 -0
- serve/llm_engine.py +671 -0
- serve/model_runner.py +1223 -0
- serve/sample_c2i.py +97 -0
- serve/sampler.py +868 -0
- serve/worker.py +349 -0
app.py
CHANGED
|
@@ -8,12 +8,12 @@ torch.backends.cudnn.allow_tf32 = True
|
|
| 8 |
torch.set_float32_matmul_precision('high')
|
| 9 |
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
|
| 10 |
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)
|
| 11 |
-
|
| 12 |
import time
|
| 13 |
import argparse
|
| 14 |
from tokenizer_image.vq_model import VQ_models
|
| 15 |
-
from models.
|
| 16 |
-
from
|
| 17 |
|
| 18 |
device = "cuda"
|
| 19 |
|
|
@@ -38,46 +38,16 @@ def load_model(args):
|
|
| 38 |
del checkpoint
|
| 39 |
print(f"image tokenizer is loaded")
|
| 40 |
|
| 41 |
-
#
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
model_type=args.gpt_type,
|
| 50 |
-
).to(device=device, dtype=precision)
|
| 51 |
-
|
| 52 |
-
checkpoint = torch.load(f"{ckpt_folder}{gpt_ckpt}", map_location="cpu")
|
| 53 |
-
if args.from_fsdp: # fspd
|
| 54 |
-
model_weight = checkpoint
|
| 55 |
-
elif "model" in checkpoint: # ddp
|
| 56 |
-
model_weight = checkpoint["model"]
|
| 57 |
-
elif "module" in checkpoint: # deepspeed
|
| 58 |
-
model_weight = checkpoint["module"]
|
| 59 |
-
elif "state_dict" in checkpoint:
|
| 60 |
-
model_weight = checkpoint["state_dict"]
|
| 61 |
-
else:
|
| 62 |
-
raise Exception("please check model weight")
|
| 63 |
-
# if 'freqs_cis' in model_weight:
|
| 64 |
-
# model_weight.pop('freqs_cis')
|
| 65 |
-
gpt_model.load_state_dict(model_weight, strict=False)
|
| 66 |
-
gpt_model.eval()
|
| 67 |
-
del checkpoint
|
| 68 |
print(f"gpt model is loaded")
|
| 69 |
-
|
| 70 |
-
if args.compile:
|
| 71 |
-
print(f"compiling the model...")
|
| 72 |
-
gpt_model = torch.compile(
|
| 73 |
-
gpt_model,
|
| 74 |
-
mode="reduce-overhead",
|
| 75 |
-
fullgraph=True
|
| 76 |
-
) # requires PyTorch 2.0 (optional)
|
| 77 |
-
else:
|
| 78 |
-
print(f"no need to compile model in demo")
|
| 79 |
-
|
| 80 |
-
return vq_model, gpt_model, image_size
|
| 81 |
|
| 82 |
|
| 83 |
def infer(cfg_scale, top_k, top_p, temperature, class_label, seed):
|
|
@@ -85,20 +55,29 @@ def infer(cfg_scale, top_k, top_p, temperature, class_label, seed):
|
|
| 85 |
latent_size = image_size // args.downsample_size
|
| 86 |
# Labels to condition the model with (feel free to change):
|
| 87 |
class_labels = [class_label for _ in range(n)]
|
| 88 |
-
c_indices = torch.tensor(class_labels, device=device)
|
| 89 |
qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size]
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
t1 = time.time()
|
| 92 |
torch.manual_seed(seed)
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
top_p=top_p, sample_logits=True,
|
| 98 |
-
)
|
| 99 |
sampling_time = time.time() - t1
|
| 100 |
print(f"gpt sampling takes about {sampling_time:.2f} seconds.")
|
| 101 |
|
|
|
|
|
|
|
|
|
|
| 102 |
t2 = time.time()
|
| 103 |
samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
|
| 104 |
decoder_time = time.time() - t2
|
|
@@ -110,7 +89,7 @@ def infer(cfg_scale, top_k, top_p, temperature, class_label, seed):
|
|
| 110 |
|
| 111 |
|
| 112 |
parser = argparse.ArgumentParser()
|
| 113 |
-
parser.add_argument("--gpt-model", type=str,
|
| 114 |
parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
|
| 115 |
parser.add_argument("--from-fsdp", action='store_true')
|
| 116 |
parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
|
|
@@ -129,7 +108,7 @@ parser.add_argument("--temperature", type=float, default=1.0, help="temperature
|
|
| 129 |
parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
|
| 130 |
args = parser.parse_args()
|
| 131 |
|
| 132 |
-
vq_model,
|
| 133 |
|
| 134 |
with gr.Blocks() as demo:
|
| 135 |
gr.Markdown("<h1 style='text-align: center'>Autoregressive Model Beats Diffusion: Llama for Scalable Image Generation</h1>")
|
|
|
|
| 8 |
torch.set_float32_matmul_precision('high')
|
| 9 |
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
|
| 10 |
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)
|
| 11 |
+
from vllm import SamplingParams
|
| 12 |
import time
|
| 13 |
import argparse
|
| 14 |
from tokenizer_image.vq_model import VQ_models
|
| 15 |
+
# from models.generate import generate
|
| 16 |
+
from serve.llm import LLM
|
| 17 |
|
| 18 |
device = "cuda"
|
| 19 |
|
|
|
|
| 38 |
del checkpoint
|
| 39 |
print(f"image tokenizer is loaded")
|
| 40 |
|
| 41 |
+
# Create an LLM.
|
| 42 |
+
args.image_size = image_size
|
| 43 |
+
args.gpt_ckpt = f"{ckpt_folder}{gpt_ckpt}"
|
| 44 |
+
llm = LLM(
|
| 45 |
+
args=args,
|
| 46 |
+
model='serve/fake_json/{}.json'.format(args.gpt_model),
|
| 47 |
+
gpu_memory_utilization=0.6,
|
| 48 |
+
skip_tokenizer_init=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
print(f"gpt model is loaded")
|
| 50 |
+
return vq_model, llm, image_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
|
| 53 |
def infer(cfg_scale, top_k, top_p, temperature, class_label, seed):
|
|
|
|
| 55 |
latent_size = image_size // args.downsample_size
|
| 56 |
# Labels to condition the model with (feel free to change):
|
| 57 |
class_labels = [class_label for _ in range(n)]
|
|
|
|
| 58 |
qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size]
|
| 59 |
|
| 60 |
+
prompt_token_ids = [[cind] for cind in class_labels]
|
| 61 |
+
if cfg_scale > 1.0:
|
| 62 |
+
prompt_token_ids.extend([[args.num_classes] for _ in range(len(prompt_token_ids))])
|
| 63 |
+
|
| 64 |
+
# Create a sampling params object.
|
| 65 |
+
sampling_params = SamplingParams(
|
| 66 |
+
temperature=temperature, top_p=top_p, top_k=top_k,
|
| 67 |
+
max_tokens=latent_size ** 2)
|
| 68 |
+
|
| 69 |
t1 = time.time()
|
| 70 |
torch.manual_seed(seed)
|
| 71 |
+
outputs = llm.generate(
|
| 72 |
+
prompt_token_ids=prompt_token_ids,
|
| 73 |
+
sampling_params=sampling_params,
|
| 74 |
+
use_tqdm=False)
|
|
|
|
|
|
|
| 75 |
sampling_time = time.time() - t1
|
| 76 |
print(f"gpt sampling takes about {sampling_time:.2f} seconds.")
|
| 77 |
|
| 78 |
+
index_sample = torch.tensor([output.outputs[0].token_ids for output in outputs], device=device)
|
| 79 |
+
if args.cfg_scale > 1.0:
|
| 80 |
+
index_sample = index_sample[:len(class_labels)]
|
| 81 |
t2 = time.time()
|
| 82 |
samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
|
| 83 |
decoder_time = time.time() - t2
|
|
|
|
| 89 |
|
| 90 |
|
| 91 |
parser = argparse.ArgumentParser()
|
| 92 |
+
parser.add_argument("--gpt-model", type=str, default="GPT-XL")
|
| 93 |
parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
|
| 94 |
parser.add_argument("--from-fsdp", action='store_true')
|
| 95 |
parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
|
|
|
|
| 108 |
parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
|
| 109 |
args = parser.parse_args()
|
| 110 |
|
| 111 |
+
vq_model, llm, image_size = load_model(args)
|
| 112 |
|
| 113 |
with gr.Blocks() as demo:
|
| 114 |
gr.Markdown("<h1 style='text-align: center'>Autoregressive Model Beats Diffusion: Llama for Scalable Image Generation</h1>")
|
app_naive.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from imagenet_en_cn import IMAGENET_1K_CLASSES
|
| 4 |
+
from huggingface_hub import hf_hub_download
|
| 5 |
+
import torch
|
| 6 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 7 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 8 |
+
torch.set_float32_matmul_precision('high')
|
| 9 |
+
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
|
| 10 |
+
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)
|
| 11 |
+
|
| 12 |
+
import time
|
| 13 |
+
import argparse
|
| 14 |
+
from tokenizer_image.vq_model import VQ_models
|
| 15 |
+
from models.gpt import GPT_models
|
| 16 |
+
from models.generate import generate
|
| 17 |
+
|
| 18 |
+
device = "cuda"
|
| 19 |
+
|
| 20 |
+
model2ckpt = {
|
| 21 |
+
"GPT-XL": ("vq_ds16_c2i.pt", "c2i_XL_384.pt", 384),
|
| 22 |
+
"GPT-B": ("vq_ds16_c2i.pt", "c2i_B_256.pt", 256),
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
def load_model(args):
|
| 26 |
+
ckpt_folder = "./"
|
| 27 |
+
vq_ckpt, gpt_ckpt, image_size = model2ckpt[args.gpt_model]
|
| 28 |
+
hf_hub_download(repo_id="FoundationVision/LlamaGen", filename=vq_ckpt, local_dir=ckpt_folder)
|
| 29 |
+
hf_hub_download(repo_id="FoundationVision/LlamaGen", filename=gpt_ckpt, local_dir=ckpt_folder)
|
| 30 |
+
# create and load model
|
| 31 |
+
vq_model = VQ_models[args.vq_model](
|
| 32 |
+
codebook_size=args.codebook_size,
|
| 33 |
+
codebook_embed_dim=args.codebook_embed_dim)
|
| 34 |
+
vq_model.to(device)
|
| 35 |
+
vq_model.eval()
|
| 36 |
+
checkpoint = torch.load(f"{ckpt_folder}{vq_ckpt}", map_location="cpu")
|
| 37 |
+
vq_model.load_state_dict(checkpoint["model"])
|
| 38 |
+
del checkpoint
|
| 39 |
+
print(f"image tokenizer is loaded")
|
| 40 |
+
|
| 41 |
+
# create and load gpt model
|
| 42 |
+
precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
|
| 43 |
+
latent_size = image_size // args.downsample_size
|
| 44 |
+
gpt_model = GPT_models[args.gpt_model](
|
| 45 |
+
vocab_size=args.codebook_size,
|
| 46 |
+
block_size=latent_size ** 2,
|
| 47 |
+
num_classes=args.num_classes,
|
| 48 |
+
cls_token_num=args.cls_token_num,
|
| 49 |
+
model_type=args.gpt_type,
|
| 50 |
+
).to(device=device, dtype=precision)
|
| 51 |
+
|
| 52 |
+
checkpoint = torch.load(f"{ckpt_folder}{gpt_ckpt}", map_location="cpu")
|
| 53 |
+
if args.from_fsdp: # fspd
|
| 54 |
+
model_weight = checkpoint
|
| 55 |
+
elif "model" in checkpoint: # ddp
|
| 56 |
+
model_weight = checkpoint["model"]
|
| 57 |
+
elif "module" in checkpoint: # deepspeed
|
| 58 |
+
model_weight = checkpoint["module"]
|
| 59 |
+
elif "state_dict" in checkpoint:
|
| 60 |
+
model_weight = checkpoint["state_dict"]
|
| 61 |
+
else:
|
| 62 |
+
raise Exception("please check model weight")
|
| 63 |
+
# if 'freqs_cis' in model_weight:
|
| 64 |
+
# model_weight.pop('freqs_cis')
|
| 65 |
+
gpt_model.load_state_dict(model_weight, strict=False)
|
| 66 |
+
gpt_model.eval()
|
| 67 |
+
del checkpoint
|
| 68 |
+
print(f"gpt model is loaded")
|
| 69 |
+
|
| 70 |
+
if args.compile:
|
| 71 |
+
print(f"compiling the model...")
|
| 72 |
+
gpt_model = torch.compile(
|
| 73 |
+
gpt_model,
|
| 74 |
+
mode="reduce-overhead",
|
| 75 |
+
fullgraph=True
|
| 76 |
+
) # requires PyTorch 2.0 (optional)
|
| 77 |
+
else:
|
| 78 |
+
print(f"no need to compile model in demo")
|
| 79 |
+
|
| 80 |
+
return vq_model, gpt_model, image_size
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def infer(cfg_scale, top_k, top_p, temperature, class_label, seed):
|
| 84 |
+
n = 4
|
| 85 |
+
latent_size = image_size // args.downsample_size
|
| 86 |
+
# Labels to condition the model with (feel free to change):
|
| 87 |
+
class_labels = [class_label for _ in range(n)]
|
| 88 |
+
c_indices = torch.tensor(class_labels, device=device)
|
| 89 |
+
qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size]
|
| 90 |
+
|
| 91 |
+
t1 = time.time()
|
| 92 |
+
torch.manual_seed(seed)
|
| 93 |
+
index_sample = generate(
|
| 94 |
+
gpt_model, c_indices, latent_size ** 2,
|
| 95 |
+
cfg_scale=cfg_scale, cfg_interval=args.cfg_interval,
|
| 96 |
+
temperature=temperature, top_k=top_k,
|
| 97 |
+
top_p=top_p, sample_logits=True,
|
| 98 |
+
)
|
| 99 |
+
sampling_time = time.time() - t1
|
| 100 |
+
print(f"gpt sampling takes about {sampling_time:.2f} seconds.")
|
| 101 |
+
|
| 102 |
+
t2 = time.time()
|
| 103 |
+
samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
|
| 104 |
+
decoder_time = time.time() - t2
|
| 105 |
+
print(f"decoder takes about {decoder_time:.2f} seconds.")
|
| 106 |
+
# Convert to PIL.Image format:
|
| 107 |
+
samples = samples.mul(127.5).add_(128.0).clamp_(0, 255).permute(0, 2, 3, 1).to("cpu", torch.uint8).numpy()
|
| 108 |
+
samples = [Image.fromarray(sample) for sample in samples]
|
| 109 |
+
return samples
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
parser = argparse.ArgumentParser()
|
| 113 |
+
parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-XL")
|
| 114 |
+
parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
|
| 115 |
+
parser.add_argument("--from-fsdp", action='store_true')
|
| 116 |
+
parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
|
| 117 |
+
parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
|
| 118 |
+
parser.add_argument("--compile", action='store_true', default=False)
|
| 119 |
+
parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
|
| 120 |
+
parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
|
| 121 |
+
parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
|
| 122 |
+
parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
|
| 123 |
+
parser.add_argument("--num-classes", type=int, default=1000)
|
| 124 |
+
parser.add_argument("--cfg-scale", type=float, default=4.0)
|
| 125 |
+
parser.add_argument("--cfg-interval", type=float, default=-1)
|
| 126 |
+
parser.add_argument("--seed", type=int, default=0)
|
| 127 |
+
parser.add_argument("--top-k", type=int, default=2000,help="top-k value to sample with")
|
| 128 |
+
parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
|
| 129 |
+
parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
|
| 130 |
+
args = parser.parse_args()
|
| 131 |
+
|
| 132 |
+
vq_model, gpt_model, image_size = load_model(args)
|
| 133 |
+
|
| 134 |
+
with gr.Blocks() as demo:
|
| 135 |
+
gr.Markdown("<h1 style='text-align: center'>Autoregressive Model Beats Diffusion: Llama for Scalable Image Generation</h1>")
|
| 136 |
+
|
| 137 |
+
with gr.Tabs():
|
| 138 |
+
with gr.TabItem('Generate'):
|
| 139 |
+
with gr.Row():
|
| 140 |
+
with gr.Column():
|
| 141 |
+
# with gr.Row():
|
| 142 |
+
# image_size = gr.Radio(choices=[384], value=384, label='Peize Model Resolution')
|
| 143 |
+
with gr.Row():
|
| 144 |
+
i1k_class = gr.Dropdown(
|
| 145 |
+
list(IMAGENET_1K_CLASSES.values()),
|
| 146 |
+
value='Eskimo dog, husky [爱斯基摩犬,哈士奇]',
|
| 147 |
+
type="index", label='ImageNet-1K Class'
|
| 148 |
+
)
|
| 149 |
+
cfg_scale = gr.Slider(minimum=1, maximum=25, step=0.1, value=4.0, label='Classifier-free Guidance Scale')
|
| 150 |
+
top_k = gr.Slider(minimum=1, maximum=16384, step=1, value=4000, label='Top-K')
|
| 151 |
+
top_p = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label="Top-P")
|
| 152 |
+
temperature = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label='Temperature')
|
| 153 |
+
seed = gr.Slider(minimum=0, maximum=1000, step=1, value=42, label='Seed')
|
| 154 |
+
# seed = gr.Number(value=0, label='Seed')
|
| 155 |
+
button = gr.Button("Generate", variant="primary")
|
| 156 |
+
with gr.Column():
|
| 157 |
+
output = gr.Gallery(label='Generated Images', height=700)
|
| 158 |
+
button.click(infer, inputs=[cfg_scale, top_k, top_p, temperature, i1k_class, seed], outputs=[output])
|
| 159 |
+
demo.queue()
|
| 160 |
+
demo.launch(debug=True)
|
requirements.txt
CHANGED
|
@@ -1 +1,2 @@
|
|
| 1 |
-
|
|
|
|
|
|
| 1 |
+
vllm==0.4.1
|
| 2 |
+
torchvision==0.17.1
|
serve/README.md
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## serving by vLLM
|
| 2 |
+
|
| 3 |
+
### Install
|
| 4 |
+
```
|
| 5 |
+
pip install vllm==0.4.1
|
| 6 |
+
pip install torchvision==0.17.1
|
| 7 |
+
```
|
| 8 |
+
|
| 9 |
+
### Demo
|
| 10 |
+
```
|
| 11 |
+
cd ${THIS_REPO_ROOT}
|
| 12 |
+
python3 autoregressive/serve/sample_c2i.py --vq-ckpt /path/to/vq_ds16size16384dim8.pt --gpt-ckpt /path/to/GPT-B/checkpoints/1500000.pt --gpt-model GPT-B
|
| 13 |
+
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
### Comparison (A100)
|
| 18 |
+
|
| 19 |
+
Method | params | baseline(s) | vllm(s) | speed-up ratio
|
| 20 |
+
--- |:---:|:---:|:---:|:---:
|
| 21 |
+
GPT-B | 100M | 7.80 | 2.39 | 326 %
|
| 22 |
+
GPT-L | 300M | 13.72 | 3.48 | 380 %
|
| 23 |
+
GPT-XL | 700M | 19.76 | 4.84 | 408 %
|
| 24 |
+
GPT-XXL | 1.4B | 26.38 | 6.36 | 414 %
|
| 25 |
+
GPT-3B | 3.1B | - | - | -
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
```
|
| 29 |
+
### GPT-B
|
| 30 |
+
# 7.80 seconds
|
| 31 |
+
python3 autoregressive/sample/sample_c2i.py --vq-ckpt /mnt/bn/foundation-lq/peize.sun/models/vq_ds16size16384dim8.pt --gpt-ckpt /mnt/bn/foundation-lq/peize.sun/vqgan_arnold/2024-04-24-20-56-19/002-GPT-B/checkpoints/1500000.pt
|
| 32 |
+
|
| 33 |
+
# 2.39 seconds
|
| 34 |
+
python3 autoregressive/serve/sample_c2i.py --vq-ckpt /mnt/bn/foundation-lq/peize.sun/models/vq_ds16size16384dim8.pt --gpt-ckpt /mnt/bn/foundation-lq/peize.sun/vqgan_arnold/2024-04-24-20-56-19/002-GPT-B/checkpoints/1500000.pt
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
### GPT-L
|
| 38 |
+
# 13.72 seconds
|
| 39 |
+
python3 autoregressive/sample/sample_c2i.py --vq-ckpt /mnt/bn/foundation-lq/peize.sun/models/vq_ds16size16384dim8.pt --gpt-ckpt /mnt/bn/foundation-lq/peize.sun/vqgan_arnold/2024-04-27-14-27-57/011-GPT-L/checkpoints/1500000.pt --gpt-model GPT-L
|
| 40 |
+
|
| 41 |
+
# 3.48 seconds
|
| 42 |
+
python3 autoregressive/serve/sample_c2i.py --vq-ckpt /mnt/bn/foundation-lq/peize.sun/models/vq_ds16size16384dim8.pt --gpt-ckpt /mnt/bn/foundation-lq/peize.sun/vqgan_arnold/2024-04-27-14-27-57/011-GPT-L/checkpoints/1500000.pt --gpt-model GPT-L
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
### GPT-XL
|
| 46 |
+
# 19.76 seconds
|
| 47 |
+
python3 autoregressive/sample/sample_c2i.py --vq-ckpt /mnt/bn/foundation-lq/peize.sun/models/vq_ds16size16384dim8.pt --gpt-ckpt /mnt/bn/foundation-lq/peize.sun/vqgan_arnold/2024-05-05-13-15-40/000-GPT-XL/checkpoints/1500000.pt --gpt-model GPT-XL
|
| 48 |
+
|
| 49 |
+
# 4.84 seconds
|
| 50 |
+
python3 autoregressive/serve/sample_c2i.py --vq-ckpt /mnt/bn/foundation-lq/peize.sun/models/vq_ds16size16384dim8.pt --gpt-ckpt /mnt/bn/foundation-lq/peize.sun/vqgan_arnold/2024-05-05-13-15-40/000-GPT-XL/checkpoints/1500000.pt --gpt-model GPT-XL
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
### GPT-XXL
|
| 54 |
+
# 26.38 seconds
|
| 55 |
+
python3 autoregressive/sample/sample_c2i.py --vq-ckpt /mnt/bn/foundation-lq/peize.sun/models/vq_ds16size16384dim8.pt --gpt-ckpt /mnt/bn/foundation-lq/peize.sun/vqgan_arnold/20240506150815-GPT-XXXL/0125000/consolidated.pth --from-fsdp --gpt-model GPT-XXXL
|
| 56 |
+
|
| 57 |
+
# 6.36 seconds
|
| 58 |
+
python3 autoregressive/serve/sample_c2i.py --vq-ckpt /mnt/bn/foundation-lq/peize.sun/models/vq_ds16size16384dim8.pt --gpt-ckpt /mnt/bn/foundation-lq/peize.sun/vqgan_arnold/20240506150815-GPT-XXXL/0125000/consolidated.pth --from-fsdp --gpt-model GPT-XXXL
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
In 3B model, head size 100 is not supported by PagedAttention, supported head sizes are: [64, 80, 96, 112, 128, 256]
|
serve/gpt_model.py
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Optional, List
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
| 8 |
+
from vllm.model_executor.layers.activation import SiluAndMul
|
| 9 |
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
| 10 |
+
from vllm.sequence import SamplerOutput
|
| 11 |
+
|
| 12 |
+
from vllm.attention import AttentionMetadata
|
| 13 |
+
from vllm.attention import Attention as pagedAttention
|
| 14 |
+
|
| 15 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 16 |
+
from serve.sampler import Sampler
|
| 17 |
+
|
| 18 |
+
def find_multiple(n: int, k: int):
|
| 19 |
+
if n % k == 0:
|
| 20 |
+
return n
|
| 21 |
+
return n + k - (n % k)
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class ModelArgs:
|
| 25 |
+
dim: int = 4096
|
| 26 |
+
n_layer: int = 32
|
| 27 |
+
n_head: int = 32
|
| 28 |
+
n_kv_head: Optional[int] = None
|
| 29 |
+
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
| 30 |
+
ffn_dim_multiplier: Optional[float] = None
|
| 31 |
+
rope_base: float = 10000
|
| 32 |
+
norm_eps: float = 1e-5
|
| 33 |
+
initializer_range: float = 0.02
|
| 34 |
+
|
| 35 |
+
num_classes: int = 1000
|
| 36 |
+
class_dropout_prob: float = 0.1
|
| 37 |
+
model_type: str = 'c2i'
|
| 38 |
+
cfg_scale: float = 4.0
|
| 39 |
+
|
| 40 |
+
vocab_size: int = 16384
|
| 41 |
+
cls_token_num: int = 1
|
| 42 |
+
block_size: int = 256
|
| 43 |
+
max_batch_size: int = 32
|
| 44 |
+
max_seq_len: int = 2048
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
#################################################################################
|
| 48 |
+
# Embedding Layers for Class Labels #
|
| 49 |
+
#################################################################################
|
| 50 |
+
class LabelEmbedder(nn.Module):
|
| 51 |
+
"""
|
| 52 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
| 53 |
+
"""
|
| 54 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
| 55 |
+
super().__init__()
|
| 56 |
+
use_cfg_embedding = dropout_prob > 0
|
| 57 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
| 58 |
+
self.num_classes = num_classes
|
| 59 |
+
self.dropout_prob = dropout_prob
|
| 60 |
+
|
| 61 |
+
# def token_drop(self, labels, force_drop_ids=None):
|
| 62 |
+
# """
|
| 63 |
+
# Drops labels to enable classifier-free guidance.
|
| 64 |
+
# """
|
| 65 |
+
# if force_drop_ids is None:
|
| 66 |
+
# drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
| 67 |
+
# else:
|
| 68 |
+
# drop_ids = force_drop_ids == 1
|
| 69 |
+
# labels = torch.where(drop_ids, self.num_classes, labels)
|
| 70 |
+
# return labels
|
| 71 |
+
|
| 72 |
+
# def forward(self, labels, train, force_drop_ids=None):
|
| 73 |
+
def forward(self, labels):
|
| 74 |
+
# use_dropout = self.dropout_prob > 0
|
| 75 |
+
# if (train and use_dropout) or (force_drop_ids is not None):
|
| 76 |
+
# labels = self.token_drop(labels, force_drop_ids)
|
| 77 |
+
embeddings = self.embedding_table(labels)
|
| 78 |
+
return embeddings
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
#################################################################################
|
| 82 |
+
# GPT Model #
|
| 83 |
+
#################################################################################
|
| 84 |
+
# class RMSNorm(torch.nn.Module):
|
| 85 |
+
# def __init__(self, dim: int, eps: float = 1e-5):
|
| 86 |
+
# super().__init__()
|
| 87 |
+
# self.eps = eps
|
| 88 |
+
# self.weight = nn.Parameter(torch.ones(dim))
|
| 89 |
+
|
| 90 |
+
# def _norm(self, x):
|
| 91 |
+
# return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
|
| 92 |
+
|
| 93 |
+
# def forward(self, x):
|
| 94 |
+
# output = self._norm(x.float()).type_as(x)
|
| 95 |
+
# return output * self.weight
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class FeedForward(nn.Module):
|
| 99 |
+
def __init__(self, config: ModelArgs):
|
| 100 |
+
super().__init__()
|
| 101 |
+
hidden_dim = 4 * config.dim
|
| 102 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
| 103 |
+
# custom dim factor multiplier
|
| 104 |
+
if config.ffn_dim_multiplier is not None:
|
| 105 |
+
hidden_dim = int(config.ffn_dim_multiplier * hidden_dim)
|
| 106 |
+
hidden_dim = find_multiple(hidden_dim, config.multiple_of)
|
| 107 |
+
|
| 108 |
+
# self.w1 = nn.Linear(config.dim, hidden_dim, bias=False)
|
| 109 |
+
# self.w3 = nn.Linear(config.dim, hidden_dim, bias=False)
|
| 110 |
+
self.w_merged = nn.Linear(config.dim, hidden_dim * 2, bias=False)
|
| 111 |
+
self.act_fn = SiluAndMul()
|
| 112 |
+
|
| 113 |
+
self.w2 = nn.Linear(hidden_dim, config.dim, bias=False)
|
| 114 |
+
# self.ffn_dropout = nn.Dropout(config.ffn_dropout_p)
|
| 115 |
+
|
| 116 |
+
# def forward(self, x):
|
| 117 |
+
# return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
| 118 |
+
|
| 119 |
+
def forward(self, x):
|
| 120 |
+
x = self.w_merged(x)
|
| 121 |
+
x = self.act_fn(x)
|
| 122 |
+
x = self.w2(x)
|
| 123 |
+
# return self.ffn_dropout(x)
|
| 124 |
+
return x
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class Attention(nn.Module):
|
| 128 |
+
def __init__(self, config: ModelArgs):
|
| 129 |
+
super().__init__()
|
| 130 |
+
assert config.dim % config.n_head == 0
|
| 131 |
+
self.dim = config.dim
|
| 132 |
+
self.head_dim = config.dim // config.n_head
|
| 133 |
+
self.n_head = config.n_head
|
| 134 |
+
self.n_kv_head = config.n_kv_head if config.n_kv_head is not None else config.n_head
|
| 135 |
+
total_kv_dim = (self.n_head + 2 * self.n_kv_head) * self.head_dim
|
| 136 |
+
|
| 137 |
+
# key, query, value projections for all heads, but in a batch
|
| 138 |
+
self.wqkv = nn.Linear(config.dim, total_kv_dim, bias=False)
|
| 139 |
+
self.wo = nn.Linear(config.dim, config.dim, bias=False)
|
| 140 |
+
|
| 141 |
+
# pagedAttention
|
| 142 |
+
self.attn = pagedAttention(self.n_head,
|
| 143 |
+
self.head_dim,
|
| 144 |
+
self.head_dim**-0.5,
|
| 145 |
+
num_kv_heads=self.n_kv_head,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# 2d rotary pos embedding
|
| 149 |
+
grid_size = int(config.block_size ** 0.5)
|
| 150 |
+
assert grid_size * grid_size == config.block_size
|
| 151 |
+
freqs_cis = precompute_freqs_cis_2d(grid_size, config.dim // config.n_head, config.rope_base, config.cls_token_num)
|
| 152 |
+
self.register_buffer('freqs_cis', freqs_cis)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def forward(
|
| 156 |
+
self,
|
| 157 |
+
x: torch.Tensor,
|
| 158 |
+
positions: torch.Tensor,
|
| 159 |
+
kv_cache: torch.Tensor,
|
| 160 |
+
attn_metadata: AttentionMetadata,
|
| 161 |
+
):
|
| 162 |
+
kv_size = self.n_kv_head * self.head_dim
|
| 163 |
+
xq, xk, xv = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
|
| 164 |
+
|
| 165 |
+
xq = xq.view(*xq.shape[:-1], 1, self.n_head, self.head_dim)
|
| 166 |
+
xk = xk.view(*xk.shape[:-1], 1, self.n_kv_head, self.head_dim)
|
| 167 |
+
freqs_cis = self.freqs_cis[positions].unsqueeze(1)
|
| 168 |
+
xq = apply_rotary_emb_bs(xq, freqs_cis)
|
| 169 |
+
xk = apply_rotary_emb_bs(xk, freqs_cis)
|
| 170 |
+
xq = xq.flatten(1)
|
| 171 |
+
xk = xk.flatten(1)
|
| 172 |
+
|
| 173 |
+
output = self.attn(xq, xk, xv, kv_cache, attn_metadata)
|
| 174 |
+
output = self.wo(output)
|
| 175 |
+
|
| 176 |
+
return output
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class TransformerBlock(nn.Module):
|
| 181 |
+
def __init__(self, config: ModelArgs):
|
| 182 |
+
super().__init__()
|
| 183 |
+
self.attention = Attention(config)
|
| 184 |
+
self.feed_forward = FeedForward(config)
|
| 185 |
+
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
| 186 |
+
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
| 187 |
+
|
| 188 |
+
def forward(self, x: torch.Tensor, positions: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata):
|
| 189 |
+
h = x + self.attention(self.attention_norm(x), positions, kv_cache, attn_metadata)
|
| 190 |
+
out = h + self.feed_forward(self.ffn_norm(h))
|
| 191 |
+
return out
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class Transformer(nn.Module):
|
| 195 |
+
def __init__(self, config: ModelArgs):
|
| 196 |
+
super().__init__()
|
| 197 |
+
self.config = config
|
| 198 |
+
self.vocab_size = config.vocab_size
|
| 199 |
+
self.n_layer = config.n_layer
|
| 200 |
+
self.block_size = config.block_size
|
| 201 |
+
self.num_classes = config.num_classes
|
| 202 |
+
self.model_type = config.model_type
|
| 203 |
+
self.cls_token_num = config.cls_token_num
|
| 204 |
+
self.cfg_scale = config.cfg_scale
|
| 205 |
+
if self.model_type == 'c2i':
|
| 206 |
+
self.cls_embedding = LabelEmbedder(config.num_classes, config.dim, config.class_dropout_prob)
|
| 207 |
+
else:
|
| 208 |
+
raise Exception("vllm only supports c2i now, please check model type")
|
| 209 |
+
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
|
| 210 |
+
|
| 211 |
+
self.layers = torch.nn.ModuleList()
|
| 212 |
+
for layer_id in range(config.n_layer):
|
| 213 |
+
self.layers.append(TransformerBlock(config))
|
| 214 |
+
|
| 215 |
+
# output layer
|
| 216 |
+
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
|
| 217 |
+
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
|
| 218 |
+
|
| 219 |
+
self.logits_processor = LogitsProcessor(config.vocab_size)
|
| 220 |
+
|
| 221 |
+
self.sampler = Sampler(config.cfg_scale)
|
| 222 |
+
|
| 223 |
+
def forward(
|
| 224 |
+
self,
|
| 225 |
+
input_ids: torch.Tensor=None,
|
| 226 |
+
positions: torch.Tensor=None,
|
| 227 |
+
kv_caches: List[torch.Tensor]=None,
|
| 228 |
+
attn_metadata: AttentionMetadata=None,
|
| 229 |
+
):
|
| 230 |
+
# if positions.max() == 0: # prefill in inference
|
| 231 |
+
# token_embeddings = self.cls_embedding(input_ids)
|
| 232 |
+
# else: # decode_n_tokens(kv cache) in inference
|
| 233 |
+
# token_embeddings = self.tok_embeddings(input_ids)
|
| 234 |
+
cond_ids = torch.clamp(input_ids, max=self.num_classes)
|
| 235 |
+
token_embeddings = self.cls_embedding(cond_ids) * (positions.max() == 0) + \
|
| 236 |
+
self.tok_embeddings(input_ids) * (positions.max() != 0)
|
| 237 |
+
|
| 238 |
+
hh = token_embeddings
|
| 239 |
+
# transformer blocks
|
| 240 |
+
for layer_id, layer in enumerate(self.layers):
|
| 241 |
+
hh = layer(hh, positions, kv_caches[layer_id], attn_metadata)
|
| 242 |
+
|
| 243 |
+
# output layers
|
| 244 |
+
hh = self.norm(hh)
|
| 245 |
+
return hh
|
| 246 |
+
|
| 247 |
+
def compute_logits(self, hidden_states: torch.Tensor,
|
| 248 |
+
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
| 249 |
+
logits = self.logits_processor(self.output.weight, hidden_states, sampling_metadata)
|
| 250 |
+
return logits
|
| 251 |
+
|
| 252 |
+
def sample(
|
| 253 |
+
self,
|
| 254 |
+
logits: torch.Tensor,
|
| 255 |
+
sampling_metadata: SamplingMetadata,
|
| 256 |
+
) -> Optional[SamplerOutput]:
|
| 257 |
+
next_tokens = self.sampler(logits, sampling_metadata)
|
| 258 |
+
return next_tokens
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def custom_load_state_dict(self, model_weights):
|
| 262 |
+
model_weights = model_weights.copy()
|
| 263 |
+
for layer_id in range(len(self.layers)):
|
| 264 |
+
branch1 = f'layers.{layer_id}.feed_forward.w1.weight'
|
| 265 |
+
branch3 = f'layers.{layer_id}.feed_forward.w3.weight'
|
| 266 |
+
branch_merged = f'layers.{layer_id}.feed_forward.w_merged.weight'
|
| 267 |
+
model_weights[branch_merged] = torch.cat(
|
| 268 |
+
[model_weights[branch1], model_weights[branch3]], dim=0
|
| 269 |
+
)
|
| 270 |
+
model_weights.pop(branch1)
|
| 271 |
+
model_weights.pop(branch3)
|
| 272 |
+
|
| 273 |
+
if 'freqs_cis' in model_weights:
|
| 274 |
+
model_weights.pop('freqs_cis')
|
| 275 |
+
|
| 276 |
+
self.load_state_dict(model_weights, strict=False)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
#################################################################################
|
| 281 |
+
# Rotary Positional Embedding Functions #
|
| 282 |
+
#################################################################################
|
| 283 |
+
# https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
|
| 284 |
+
def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000, cls_token_num=120):
|
| 285 |
+
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
|
| 286 |
+
t = torch.arange(seq_len, device=freqs.device)
|
| 287 |
+
freqs = torch.outer(t, freqs) # (seq_len, head_dim // 2)
|
| 288 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
| 289 |
+
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) # (cls_token_num+seq_len, head_dim // 2, 2)
|
| 290 |
+
cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+seq_len, head_dim // 2, 2)
|
| 291 |
+
return cond_cache
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def precompute_freqs_cis_2d(grid_size: int, n_elem: int, base: int = 10000, cls_token_num=120):
|
| 295 |
+
# split the dimension into half, one for x and one for y
|
| 296 |
+
half_dim = n_elem // 2
|
| 297 |
+
freqs = 1.0 / (base ** (torch.arange(0, half_dim, 2)[: (half_dim // 2)].float() / half_dim))
|
| 298 |
+
t = torch.arange(grid_size, device=freqs.device)
|
| 299 |
+
freqs = torch.outer(t, freqs) # (grid_size, head_dim // 2)
|
| 300 |
+
freqs_grid = torch.concat([
|
| 301 |
+
freqs[:, None, :].expand(-1, grid_size, -1),
|
| 302 |
+
freqs[None, :, :].expand(grid_size, -1, -1),
|
| 303 |
+
], dim=-1) # (grid_size, grid_size, head_dim // 2)
|
| 304 |
+
cache_grid = torch.stack([torch.cos(freqs_grid), torch.sin(freqs_grid)], dim=-1) # (grid_size, grid_size, head_dim // 2, 2)
|
| 305 |
+
cache = cache_grid.flatten(0, 1)
|
| 306 |
+
cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+grid_size**2, head_dim // 2, 2)
|
| 307 |
+
return cond_cache
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor):
|
| 311 |
+
# x: (bs, seq_len, n_head, head_dim)
|
| 312 |
+
# freqs_cis (seq_len, head_dim // 2, 2)
|
| 313 |
+
xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # (bs, seq_len, n_head, head_dim//2, 2)
|
| 314 |
+
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) # (1, seq_len, 1, head_dim//2, 2)
|
| 315 |
+
x_out2 = torch.stack([
|
| 316 |
+
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
|
| 317 |
+
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
|
| 318 |
+
], dim=-1)
|
| 319 |
+
x_out2 = x_out2.flatten(3)
|
| 320 |
+
return x_out2.type_as(x)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def apply_rotary_emb_bs(x: torch.Tensor, freqs_cis: torch.Tensor):
|
| 324 |
+
# x: (bs, seq_len, n_head, head_dim)
|
| 325 |
+
# freqs_cis (seq_len, head_dim // 2, 2)
|
| 326 |
+
xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # (bs, seq_len, n_head, head_dim//2, 2)
|
| 327 |
+
freqs_cis = freqs_cis.view(xshaped.size(0), xshaped.size(1), 1, xshaped.size(3), 2) # (bs, seq_len, 1, head_dim//2, 2)
|
| 328 |
+
x_out2 = torch.stack([
|
| 329 |
+
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
|
| 330 |
+
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
|
| 331 |
+
], dim=-1)
|
| 332 |
+
x_out2 = x_out2.flatten(3)
|
| 333 |
+
return x_out2.type_as(x)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
#################################################################################
|
| 337 |
+
# GPT Configs #
|
| 338 |
+
#################################################################################
|
| 339 |
+
### text-conditional
|
| 340 |
+
def GPT_7B(**kwargs):
|
| 341 |
+
return Transformer(ModelArgs(n_layer=32, n_head=32, dim=4096, **kwargs)) # 6.6B
|
| 342 |
+
|
| 343 |
+
def GPT_3B(**kwargs):
|
| 344 |
+
return Transformer(ModelArgs(n_layer=24, n_head=32, dim=3200, **kwargs)) # 3.1B
|
| 345 |
+
|
| 346 |
+
def GPT_1B(**kwargs):
|
| 347 |
+
return Transformer(ModelArgs(n_layer=22, n_head=32, dim=2048, **kwargs)) # 1.2B
|
| 348 |
+
|
| 349 |
+
### class-conditional
|
| 350 |
+
def GPT_XXXL(**kwargs):
|
| 351 |
+
return Transformer(ModelArgs(n_layer=48, n_head=40, dim=2560, **kwargs)) # 3.9B
|
| 352 |
+
|
| 353 |
+
def GPT_XXL(**kwargs):
|
| 354 |
+
return Transformer(ModelArgs(n_layer=48, n_head=24, dim=1536, **kwargs)) # 1.4B
|
| 355 |
+
|
| 356 |
+
def GPT_XL(**kwargs):
|
| 357 |
+
return Transformer(ModelArgs(n_layer=36, n_head=20, dim=1280, **kwargs)) # 775M
|
| 358 |
+
|
| 359 |
+
def GPT_L(**kwargs):
|
| 360 |
+
return Transformer(ModelArgs(n_layer=24, n_head=16, dim=1024, **kwargs)) # 343M
|
| 361 |
+
|
| 362 |
+
def GPT_B(**kwargs):
|
| 363 |
+
return Transformer(ModelArgs(n_layer=12, n_head=12, dim=768, **kwargs)) # 111M
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
GPT_models = {
|
| 367 |
+
'GPT-B': GPT_B, 'GPT-L': GPT_L, 'GPT-XL': GPT_XL, 'GPT-XXL': GPT_XXL, 'GPT-XXXL': GPT_XXXL,
|
| 368 |
+
'GPT-1B': GPT_1B, 'GPT-3B': GPT_3B, 'GPT-7B': GPT_7B,
|
| 369 |
+
}
|
serve/gpu_executor.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List, Set, Tuple, Optional, Set
|
| 2 |
+
import argparse
|
| 3 |
+
|
| 4 |
+
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
| 5 |
+
ModelConfig, ParallelConfig, SchedulerConfig,
|
| 6 |
+
SpeculativeConfig, VisionLanguageConfig)
|
| 7 |
+
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
| 8 |
+
from vllm.logger import init_logger
|
| 9 |
+
from vllm.lora.request import LoRARequest
|
| 10 |
+
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
| 11 |
+
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
| 12 |
+
make_async)
|
| 13 |
+
|
| 14 |
+
logger = init_logger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class GPUExecutor(ExecutorBase):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
args: argparse.ArgumentParser,
|
| 21 |
+
model_config: ModelConfig,
|
| 22 |
+
cache_config: CacheConfig,
|
| 23 |
+
parallel_config: ParallelConfig,
|
| 24 |
+
scheduler_config: SchedulerConfig,
|
| 25 |
+
device_config: DeviceConfig,
|
| 26 |
+
load_config: LoadConfig,
|
| 27 |
+
lora_config: Optional[LoRAConfig],
|
| 28 |
+
vision_language_config: Optional[VisionLanguageConfig],
|
| 29 |
+
speculative_config: Optional[SpeculativeConfig],
|
| 30 |
+
) -> None:
|
| 31 |
+
self.args = args
|
| 32 |
+
self.model_config = model_config
|
| 33 |
+
self.cache_config = cache_config
|
| 34 |
+
self.lora_config = lora_config
|
| 35 |
+
self.load_config = load_config
|
| 36 |
+
self.parallel_config = parallel_config
|
| 37 |
+
self.scheduler_config = scheduler_config
|
| 38 |
+
self.device_config = device_config
|
| 39 |
+
self.vision_language_config = vision_language_config
|
| 40 |
+
self.speculative_config = speculative_config
|
| 41 |
+
|
| 42 |
+
self._init_executor()
|
| 43 |
+
|
| 44 |
+
def _init_executor(self) -> None:
|
| 45 |
+
"""Initialize the worker and load the model.
|
| 46 |
+
|
| 47 |
+
If speculative decoding is enabled, we instead create the speculative
|
| 48 |
+
worker.
|
| 49 |
+
"""
|
| 50 |
+
if self.speculative_config is None:
|
| 51 |
+
self._init_non_spec_worker()
|
| 52 |
+
else:
|
| 53 |
+
self._init_spec_worker()
|
| 54 |
+
|
| 55 |
+
def _init_non_spec_worker(self):
|
| 56 |
+
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
| 57 |
+
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
| 58 |
+
# from vllm.worker.worker import Worker
|
| 59 |
+
from serve.worker import Worker
|
| 60 |
+
|
| 61 |
+
assert self.parallel_config.world_size == 1, (
|
| 62 |
+
"GPUExecutor only supports single GPU.")
|
| 63 |
+
|
| 64 |
+
distributed_init_method = get_distributed_init_method(
|
| 65 |
+
get_ip(), get_open_port())
|
| 66 |
+
self.driver_worker = Worker(
|
| 67 |
+
model_config=self.model_config,
|
| 68 |
+
parallel_config=self.parallel_config,
|
| 69 |
+
scheduler_config=self.scheduler_config,
|
| 70 |
+
device_config=self.device_config,
|
| 71 |
+
cache_config=self.cache_config,
|
| 72 |
+
load_config=self.load_config,
|
| 73 |
+
local_rank=0,
|
| 74 |
+
rank=0,
|
| 75 |
+
distributed_init_method=distributed_init_method,
|
| 76 |
+
lora_config=self.lora_config,
|
| 77 |
+
vision_language_config=self.vision_language_config,
|
| 78 |
+
is_driver_worker=True,
|
| 79 |
+
)
|
| 80 |
+
self.driver_worker.init_device()
|
| 81 |
+
self.driver_worker.load_model(self.args)
|
| 82 |
+
|
| 83 |
+
def _init_spec_worker(self):
|
| 84 |
+
"""Initialize a SpecDecodeWorker, using a draft model for proposals.
|
| 85 |
+
"""
|
| 86 |
+
assert self.speculative_config is not None
|
| 87 |
+
|
| 88 |
+
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
| 89 |
+
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
|
| 90 |
+
from vllm.worker.worker import Worker
|
| 91 |
+
|
| 92 |
+
distributed_init_method = get_distributed_init_method(
|
| 93 |
+
get_ip(), get_open_port())
|
| 94 |
+
|
| 95 |
+
target_worker = Worker(
|
| 96 |
+
model_config=self.model_config,
|
| 97 |
+
parallel_config=self.parallel_config,
|
| 98 |
+
scheduler_config=self.scheduler_config,
|
| 99 |
+
device_config=self.device_config,
|
| 100 |
+
cache_config=self.cache_config,
|
| 101 |
+
load_config=self.load_config,
|
| 102 |
+
local_rank=0,
|
| 103 |
+
rank=0,
|
| 104 |
+
distributed_init_method=distributed_init_method,
|
| 105 |
+
lora_config=self.lora_config,
|
| 106 |
+
vision_language_config=self.vision_language_config,
|
| 107 |
+
is_driver_worker=True,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
draft_worker = MultiStepWorker(
|
| 111 |
+
model_config=self.speculative_config.draft_model_config,
|
| 112 |
+
parallel_config=self.speculative_config.draft_parallel_config,
|
| 113 |
+
scheduler_config=self.scheduler_config,
|
| 114 |
+
device_config=self.device_config,
|
| 115 |
+
cache_config=self.cache_config,
|
| 116 |
+
load_config=self.load_config,
|
| 117 |
+
local_rank=0,
|
| 118 |
+
rank=0,
|
| 119 |
+
distributed_init_method=distributed_init_method,
|
| 120 |
+
lora_config=self.lora_config,
|
| 121 |
+
vision_language_config=self.vision_language_config,
|
| 122 |
+
is_driver_worker=True,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
spec_decode_worker = SpecDecodeWorker.from_workers(
|
| 126 |
+
proposer_worker=draft_worker, scorer_worker=target_worker)
|
| 127 |
+
|
| 128 |
+
assert self.parallel_config.world_size == 1, (
|
| 129 |
+
"GPUExecutor only supports single GPU.")
|
| 130 |
+
|
| 131 |
+
self.driver_worker = spec_decode_worker
|
| 132 |
+
|
| 133 |
+
# Load model handled in spec decode worker.
|
| 134 |
+
self.driver_worker.init_device()
|
| 135 |
+
|
| 136 |
+
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
| 137 |
+
"""Determine the number of available KV blocks by invoking the
|
| 138 |
+
underlying worker.
|
| 139 |
+
"""
|
| 140 |
+
return self.driver_worker.determine_num_available_blocks()
|
| 141 |
+
|
| 142 |
+
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
|
| 143 |
+
"""Initialize the KV cache by invoking the underlying worker.
|
| 144 |
+
"""
|
| 145 |
+
# NOTE: This is logged in the executor because there can be >1 worker
|
| 146 |
+
# with other executors. We could log in the engine level, but work
|
| 147 |
+
# remains to abstract away the device for non-GPU configurations.
|
| 148 |
+
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
|
| 149 |
+
f"# CPU blocks: {num_cpu_blocks}")
|
| 150 |
+
|
| 151 |
+
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
| 152 |
+
|
| 153 |
+
def execute_model(
|
| 154 |
+
self,
|
| 155 |
+
seq_group_metadata_list: List[SequenceGroupMetadata],
|
| 156 |
+
blocks_to_swap_in: Dict[int, int],
|
| 157 |
+
blocks_to_swap_out: Dict[int, int],
|
| 158 |
+
blocks_to_copy: Dict[int, List[int]],
|
| 159 |
+
num_lookahead_slots: int,
|
| 160 |
+
) -> List[SamplerOutput]:
|
| 161 |
+
output = self.driver_worker.execute_model(
|
| 162 |
+
seq_group_metadata_list=seq_group_metadata_list,
|
| 163 |
+
blocks_to_swap_in=blocks_to_swap_in,
|
| 164 |
+
blocks_to_swap_out=blocks_to_swap_out,
|
| 165 |
+
blocks_to_copy=blocks_to_copy,
|
| 166 |
+
num_lookahead_slots=num_lookahead_slots,
|
| 167 |
+
)
|
| 168 |
+
return output
|
| 169 |
+
|
| 170 |
+
def add_lora(self, lora_request: LoRARequest) -> bool:
|
| 171 |
+
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
|
| 172 |
+
return self.driver_worker.add_lora(lora_request)
|
| 173 |
+
|
| 174 |
+
def remove_lora(self, lora_id: int) -> bool:
|
| 175 |
+
assert lora_id > 0, "lora_id must be greater than 0."
|
| 176 |
+
return self.driver_worker.remove_lora(lora_id)
|
| 177 |
+
|
| 178 |
+
def list_loras(self) -> Set[int]:
|
| 179 |
+
return self.driver_worker.list_loras()
|
| 180 |
+
|
| 181 |
+
def check_health(self) -> None:
|
| 182 |
+
# GPUExecutor will always be healthy as long as
|
| 183 |
+
# it's running.
|
| 184 |
+
return
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
|
| 188 |
+
|
| 189 |
+
async def execute_model_async(
|
| 190 |
+
self,
|
| 191 |
+
seq_group_metadata_list: List[SequenceGroupMetadata],
|
| 192 |
+
blocks_to_swap_in: Dict[int, int],
|
| 193 |
+
blocks_to_swap_out: Dict[int, int],
|
| 194 |
+
blocks_to_copy: Dict[int, List[int]],
|
| 195 |
+
) -> SamplerOutput:
|
| 196 |
+
output = await make_async(self.driver_worker.execute_model)(
|
| 197 |
+
seq_group_metadata_list=seq_group_metadata_list,
|
| 198 |
+
blocks_to_swap_in=blocks_to_swap_in,
|
| 199 |
+
blocks_to_swap_out=blocks_to_swap_out,
|
| 200 |
+
blocks_to_copy=blocks_to_copy)
|
| 201 |
+
return output
|
serve/llm.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from:
|
| 2 |
+
# vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py
|
| 3 |
+
from typing import List, Optional, Union
|
| 4 |
+
import argparse
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
| 9 |
+
|
| 10 |
+
from vllm.engine.arg_utils import EngineArgs
|
| 11 |
+
# from vllm.engine.llm_engine import LLMEngine
|
| 12 |
+
from vllm.lora.request import LoRARequest
|
| 13 |
+
from vllm.outputs import RequestOutput
|
| 14 |
+
from vllm.sampling_params import SamplingParams
|
| 15 |
+
from vllm.sequence import MultiModalData
|
| 16 |
+
from vllm.usage.usage_lib import UsageContext
|
| 17 |
+
from vllm.utils import Counter
|
| 18 |
+
|
| 19 |
+
from serve.llm_engine import LLMEngine
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class LLM:
|
| 23 |
+
"""An LLM for generating texts from given prompts and sampling parameters.
|
| 24 |
+
|
| 25 |
+
This class includes a tokenizer, a language model (possibly distributed
|
| 26 |
+
across multiple GPUs), and GPU memory space allocated for intermediate
|
| 27 |
+
states (aka KV cache). Given a batch of prompts and sampling parameters,
|
| 28 |
+
this class generates texts from the model, using an intelligent batching
|
| 29 |
+
mechanism and efficient memory management.
|
| 30 |
+
|
| 31 |
+
NOTE: This class is intended to be used for offline inference. For online
|
| 32 |
+
serving, use the `AsyncLLMEngine` class instead.
|
| 33 |
+
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
model: The name or path of a HuggingFace Transformers model.
|
| 37 |
+
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
|
| 38 |
+
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
|
| 39 |
+
if available, and "slow" will always use the slow tokenizer.
|
| 40 |
+
skip_tokenizer_init: If true, skip initialization of tokenizer and
|
| 41 |
+
detokenizer. Expect valid prompt_token_ids and None for prompt
|
| 42 |
+
from the input.
|
| 43 |
+
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
| 44 |
+
downloading the model and tokenizer.
|
| 45 |
+
tensor_parallel_size: The number of GPUs to use for distributed
|
| 46 |
+
execution with tensor parallelism.
|
| 47 |
+
dtype: The data type for the model weights and activations. Currently,
|
| 48 |
+
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
|
| 49 |
+
the `torch_dtype` attribute specified in the model config file.
|
| 50 |
+
However, if the `torch_dtype` in the config is `float32`, we will
|
| 51 |
+
use `float16` instead.
|
| 52 |
+
quantization: The method used to quantize the model weights. Currently,
|
| 53 |
+
we support "awq", "gptq", "squeezellm", and "fp8" (experimental).
|
| 54 |
+
If None, we first check the `quantization_config` attribute in the
|
| 55 |
+
model config file. If that is None, we assume the model weights are
|
| 56 |
+
not quantized and use `dtype` to determine the data type of
|
| 57 |
+
the weights.
|
| 58 |
+
revision: The specific model version to use. It can be a branch name,
|
| 59 |
+
a tag name, or a commit id.
|
| 60 |
+
tokenizer_revision: The specific tokenizer version to use. It can be a
|
| 61 |
+
branch name, a tag name, or a commit id.
|
| 62 |
+
seed: The seed to initialize the random number generator for sampling.
|
| 63 |
+
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
|
| 64 |
+
reserve for the model weights, activations, and KV cache. Higher
|
| 65 |
+
values will increase the KV cache size and thus improve the model's
|
| 66 |
+
throughput. However, if the value is too high, it may cause out-of-
|
| 67 |
+
memory (OOM) errors.
|
| 68 |
+
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
|
| 69 |
+
This can be used for temporarily storing the states of the requests
|
| 70 |
+
when their `best_of` sampling parameters are larger than 1. If all
|
| 71 |
+
requests will have `best_of=1`, you can safely set this to 0.
|
| 72 |
+
Otherwise, too small values may cause out-of-memory (OOM) errors.
|
| 73 |
+
enforce_eager: Whether to enforce eager execution. If True, we will
|
| 74 |
+
disable CUDA graph and always execute the model in eager mode.
|
| 75 |
+
If False, we will use CUDA graph and eager execution in hybrid.
|
| 76 |
+
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
| 77 |
+
When a sequence has context length larger than this, we fall back
|
| 78 |
+
to eager mode.
|
| 79 |
+
disable_custom_all_reduce: See ParallelConfig
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
args: argparse.ArgumentParser,
|
| 85 |
+
model: str,
|
| 86 |
+
tokenizer: Optional[str] = None,
|
| 87 |
+
tokenizer_mode: str = "auto",
|
| 88 |
+
skip_tokenizer_init: bool = False,
|
| 89 |
+
trust_remote_code: bool = False,
|
| 90 |
+
tensor_parallel_size: int = 1,
|
| 91 |
+
dtype: str = "auto",
|
| 92 |
+
quantization: Optional[str] = None,
|
| 93 |
+
revision: Optional[str] = None,
|
| 94 |
+
tokenizer_revision: Optional[str] = None,
|
| 95 |
+
seed: int = 0,
|
| 96 |
+
gpu_memory_utilization: float = 0.9,
|
| 97 |
+
swap_space: int = 4,
|
| 98 |
+
enforce_eager: bool = False,
|
| 99 |
+
max_context_len_to_capture: int = 8192,
|
| 100 |
+
disable_custom_all_reduce: bool = False,
|
| 101 |
+
**kwargs,
|
| 102 |
+
) -> None:
|
| 103 |
+
if "disable_log_stats" not in kwargs:
|
| 104 |
+
kwargs["disable_log_stats"] = True
|
| 105 |
+
engine_args = EngineArgs(
|
| 106 |
+
model=model,
|
| 107 |
+
tokenizer=tokenizer,
|
| 108 |
+
tokenizer_mode=tokenizer_mode,
|
| 109 |
+
skip_tokenizer_init=skip_tokenizer_init,
|
| 110 |
+
trust_remote_code=trust_remote_code,
|
| 111 |
+
tensor_parallel_size=tensor_parallel_size,
|
| 112 |
+
dtype=dtype,
|
| 113 |
+
quantization=quantization,
|
| 114 |
+
revision=revision,
|
| 115 |
+
tokenizer_revision=tokenizer_revision,
|
| 116 |
+
seed=seed,
|
| 117 |
+
gpu_memory_utilization=gpu_memory_utilization,
|
| 118 |
+
swap_space=swap_space,
|
| 119 |
+
enforce_eager=enforce_eager,
|
| 120 |
+
max_context_len_to_capture=max_context_len_to_capture,
|
| 121 |
+
disable_custom_all_reduce=disable_custom_all_reduce,
|
| 122 |
+
**kwargs,
|
| 123 |
+
)
|
| 124 |
+
self.llm_engine = LLMEngine.from_engine_args(
|
| 125 |
+
engine_args, usage_context=UsageContext.LLM_CLASS, args=args)
|
| 126 |
+
self.request_counter = Counter()
|
| 127 |
+
|
| 128 |
+
def get_tokenizer(
|
| 129 |
+
self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
| 130 |
+
return self.llm_engine.tokenizer.tokenizer
|
| 131 |
+
|
| 132 |
+
def set_tokenizer(
|
| 133 |
+
self,
|
| 134 |
+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
| 135 |
+
) -> None:
|
| 136 |
+
self.llm_engine.tokenizer.tokenizer = tokenizer
|
| 137 |
+
|
| 138 |
+
def generate(
|
| 139 |
+
self,
|
| 140 |
+
prompts: Optional[Union[str, List[str]]] = None,
|
| 141 |
+
sampling_params: Optional[Union[SamplingParams,
|
| 142 |
+
List[SamplingParams]]] = None,
|
| 143 |
+
prompt_token_ids: Optional[List[List[int]]] = None,
|
| 144 |
+
use_tqdm: bool = True,
|
| 145 |
+
lora_request: Optional[LoRARequest] = None,
|
| 146 |
+
multi_modal_data: Optional[MultiModalData] = None,
|
| 147 |
+
) -> List[RequestOutput]:
|
| 148 |
+
"""Generates the completions for the input prompts.
|
| 149 |
+
|
| 150 |
+
NOTE: This class automatically batches the given prompts, considering
|
| 151 |
+
the memory constraint. For the best performance, put all of your prompts
|
| 152 |
+
into a single list and pass it to this method.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
prompts: A list of prompts to generate completions for.
|
| 156 |
+
sampling_params: The sampling parameters for text generation. If
|
| 157 |
+
None, we use the default sampling parameters.
|
| 158 |
+
When it is a single value, it is applied to every prompt.
|
| 159 |
+
When it is a list, the list must have the same length as the
|
| 160 |
+
prompts and it is paired one by one with the prompt.
|
| 161 |
+
prompt_token_ids: A list of token IDs for the prompts. If None, we
|
| 162 |
+
use the tokenizer to convert the prompts to token IDs.
|
| 163 |
+
use_tqdm: Whether to use tqdm to display the progress bar.
|
| 164 |
+
lora_request: LoRA request to use for generation, if any.
|
| 165 |
+
multi_modal_data: Multi modal data.
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
A list of `RequestOutput` objects containing the generated
|
| 169 |
+
completions in the same order as the input prompts.
|
| 170 |
+
"""
|
| 171 |
+
if prompts is None and prompt_token_ids is None:
|
| 172 |
+
raise ValueError("Either prompts or prompt_token_ids must be "
|
| 173 |
+
"provided.")
|
| 174 |
+
if self.llm_engine.model_config.skip_tokenizer_init \
|
| 175 |
+
and prompts is not None:
|
| 176 |
+
raise ValueError("prompts must be None if skip_tokenizer_init "
|
| 177 |
+
"is True")
|
| 178 |
+
if isinstance(prompts, str):
|
| 179 |
+
# Convert a single prompt to a list.
|
| 180 |
+
prompts = [prompts]
|
| 181 |
+
if (prompts is not None and prompt_token_ids is not None
|
| 182 |
+
and len(prompts) != len(prompt_token_ids)):
|
| 183 |
+
raise ValueError("The lengths of prompts and prompt_token_ids "
|
| 184 |
+
"must be the same.")
|
| 185 |
+
|
| 186 |
+
if prompts is not None:
|
| 187 |
+
num_requests = len(prompts)
|
| 188 |
+
else:
|
| 189 |
+
assert prompt_token_ids is not None
|
| 190 |
+
num_requests = len(prompt_token_ids)
|
| 191 |
+
|
| 192 |
+
if sampling_params is None:
|
| 193 |
+
# Use default sampling params.
|
| 194 |
+
sampling_params = SamplingParams()
|
| 195 |
+
|
| 196 |
+
elif isinstance(sampling_params,
|
| 197 |
+
list) and len(sampling_params) != num_requests:
|
| 198 |
+
raise ValueError("The lengths of prompts and sampling_params "
|
| 199 |
+
"must be the same.")
|
| 200 |
+
if multi_modal_data:
|
| 201 |
+
multi_modal_data.data = multi_modal_data.data.to(torch.float16)
|
| 202 |
+
|
| 203 |
+
# Add requests to the engine.
|
| 204 |
+
for i in range(num_requests):
|
| 205 |
+
prompt = prompts[i] if prompts is not None else None
|
| 206 |
+
token_ids = None if prompt_token_ids is None else prompt_token_ids[i]
|
| 207 |
+
self._add_request(
|
| 208 |
+
prompt,
|
| 209 |
+
sampling_params[i]
|
| 210 |
+
if isinstance(sampling_params, list) else sampling_params,
|
| 211 |
+
token_ids,
|
| 212 |
+
lora_request=lora_request,
|
| 213 |
+
# Get ith image while maintaining the batch dim.
|
| 214 |
+
multi_modal_data=MultiModalData(
|
| 215 |
+
type=multi_modal_data.type,
|
| 216 |
+
data=multi_modal_data.data[i].unsqueeze(0))
|
| 217 |
+
if multi_modal_data else None,
|
| 218 |
+
)
|
| 219 |
+
return self._run_engine(use_tqdm)
|
| 220 |
+
|
| 221 |
+
def _add_request(
|
| 222 |
+
self,
|
| 223 |
+
prompt: Optional[str],
|
| 224 |
+
sampling_params: SamplingParams,
|
| 225 |
+
prompt_token_ids: Optional[List[int]],
|
| 226 |
+
lora_request: Optional[LoRARequest] = None,
|
| 227 |
+
multi_modal_data: Optional[MultiModalData] = None,
|
| 228 |
+
) -> None:
|
| 229 |
+
request_id = str(next(self.request_counter))
|
| 230 |
+
self.llm_engine.add_request(request_id,
|
| 231 |
+
prompt,
|
| 232 |
+
sampling_params,
|
| 233 |
+
prompt_token_ids,
|
| 234 |
+
lora_request=lora_request,
|
| 235 |
+
multi_modal_data=multi_modal_data)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
|
| 239 |
+
# Initialize tqdm.
|
| 240 |
+
if use_tqdm:
|
| 241 |
+
num_requests = self.llm_engine.get_num_unfinished_requests()
|
| 242 |
+
pbar = tqdm(
|
| 243 |
+
total=num_requests,
|
| 244 |
+
desc="Processed prompts",
|
| 245 |
+
dynamic_ncols=True,
|
| 246 |
+
postfix=f"Generation Speed: {0:.2f} toks/s",
|
| 247 |
+
)
|
| 248 |
+
# Run the engine.
|
| 249 |
+
outputs: List[RequestOutput] = []
|
| 250 |
+
while self.llm_engine.has_unfinished_requests():
|
| 251 |
+
step_outputs = self.llm_engine.step()
|
| 252 |
+
for output in step_outputs:
|
| 253 |
+
if output.finished:
|
| 254 |
+
outputs.append(output)
|
| 255 |
+
if use_tqdm:
|
| 256 |
+
total_toks += (sum(
|
| 257 |
+
len(stp.token_ids) for stp in output.outputs))
|
| 258 |
+
spd = total_toks / pbar.format_dict["elapsed"]
|
| 259 |
+
pbar.postfix = f"Generation Speed: {spd:.2f} toks/s"
|
| 260 |
+
pbar.update(1)
|
| 261 |
+
if use_tqdm:
|
| 262 |
+
pbar.close()
|
| 263 |
+
# Sort the outputs by request ID.
|
| 264 |
+
# This is necessary because some requests may be finished earlier than
|
| 265 |
+
# its previous requests.
|
| 266 |
+
outputs = sorted(outputs, key=lambda x: int(x.request_id))
|
| 267 |
+
return outputs
|
serve/llm_engine.py
ADDED
|
@@ -0,0 +1,671 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from:
|
| 2 |
+
# vLLM: https://github.com/vllm-project/vllm/blob/main/vllm/engine/llm_engine.py
|
| 3 |
+
import time
|
| 4 |
+
from typing import Iterable, List, Optional, Type, Union
|
| 5 |
+
import argparse
|
| 6 |
+
|
| 7 |
+
from transformers import GenerationConfig, PreTrainedTokenizer
|
| 8 |
+
|
| 9 |
+
import vllm
|
| 10 |
+
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
|
| 11 |
+
LoRAConfig, ModelConfig, ParallelConfig,
|
| 12 |
+
SchedulerConfig, SpeculativeConfig,
|
| 13 |
+
VisionLanguageConfig)
|
| 14 |
+
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
| 15 |
+
from vllm.engine.arg_utils import EngineArgs
|
| 16 |
+
from vllm.engine.metrics import StatLogger, Stats
|
| 17 |
+
from vllm.engine.output_processor.interfaces import (
|
| 18 |
+
SequenceGroupOutputProcessor)
|
| 19 |
+
from vllm.engine.output_processor.stop_checker import StopChecker
|
| 20 |
+
from vllm.engine.output_processor.util import create_output_by_sequence_group
|
| 21 |
+
from vllm.engine.ray_utils import initialize_ray_cluster
|
| 22 |
+
from vllm.executor.executor_base import ExecutorBase
|
| 23 |
+
from vllm.logger import init_logger
|
| 24 |
+
from vllm.lora.request import LoRARequest
|
| 25 |
+
from vllm.outputs import RequestOutput
|
| 26 |
+
from vllm.sampling_params import SamplingParams
|
| 27 |
+
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
|
| 28 |
+
SequenceGroup)
|
| 29 |
+
from vllm.transformers_utils.detokenizer import Detokenizer
|
| 30 |
+
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
| 31 |
+
get_tokenizer_group)
|
| 32 |
+
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
| 33 |
+
usage_message)
|
| 34 |
+
from vllm.utils import Counter
|
| 35 |
+
|
| 36 |
+
logger = init_logger(__name__)
|
| 37 |
+
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _load_generation_config_dict(model_config: ModelConfig):
|
| 41 |
+
try:
|
| 42 |
+
return GenerationConfig.from_pretrained(
|
| 43 |
+
model_config.model,
|
| 44 |
+
revision=model_config.revision,
|
| 45 |
+
).to_diff_dict()
|
| 46 |
+
except OSError:
|
| 47 |
+
# Not found.
|
| 48 |
+
return {}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class LLMEngine:
|
| 52 |
+
"""An LLM engine that receives requests and generates texts.
|
| 53 |
+
|
| 54 |
+
This is the main class for the vLLM engine. It receives requests
|
| 55 |
+
from clients and generates texts from the LLM. It includes a tokenizer, a
|
| 56 |
+
language model (possibly distributed across multiple GPUs), and GPU memory
|
| 57 |
+
space allocated for intermediate states (aka KV cache). This class utilizes
|
| 58 |
+
iteration-level scheduling and efficient memory management to maximize the
|
| 59 |
+
serving throughput.
|
| 60 |
+
|
| 61 |
+
The `LLM` class wraps this class for offline batched inference and the
|
| 62 |
+
`AsyncLLMEngine` class wraps this class for online serving.
|
| 63 |
+
|
| 64 |
+
NOTE: The config arguments are derived from the `EngineArgs` class. For the
|
| 65 |
+
comprehensive list of arguments, see `EngineArgs`.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
model_config: The configuration related to the LLM model.
|
| 69 |
+
cache_config: The configuration related to the KV cache memory
|
| 70 |
+
management.
|
| 71 |
+
parallel_config: The configuration related to distributed execution.
|
| 72 |
+
scheduler_config: The configuration related to the request scheduler.
|
| 73 |
+
device_config: The configuration related to the device.
|
| 74 |
+
lora_config (Optional): The configuration related to serving multi-LoRA.
|
| 75 |
+
vision_language_config (Optional): The configuration related to vision
|
| 76 |
+
language models.
|
| 77 |
+
speculative_config (Optional): The configuration related to speculative
|
| 78 |
+
decoding.
|
| 79 |
+
executor_class: The model executor class for managing distributed
|
| 80 |
+
execution.
|
| 81 |
+
log_stats: Whether to log statistics.
|
| 82 |
+
usage_context: Specified entry point, used for usage info collection
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
args: argparse.ArgumentParser,
|
| 88 |
+
model_config: ModelConfig,
|
| 89 |
+
cache_config: CacheConfig,
|
| 90 |
+
parallel_config: ParallelConfig,
|
| 91 |
+
scheduler_config: SchedulerConfig,
|
| 92 |
+
device_config: DeviceConfig,
|
| 93 |
+
load_config: LoadConfig,
|
| 94 |
+
lora_config: Optional[LoRAConfig],
|
| 95 |
+
vision_language_config: Optional[VisionLanguageConfig],
|
| 96 |
+
speculative_config: Optional[SpeculativeConfig],
|
| 97 |
+
decoding_config: Optional[DecodingConfig],
|
| 98 |
+
executor_class: Type[ExecutorBase],
|
| 99 |
+
log_stats: bool,
|
| 100 |
+
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
| 101 |
+
) -> None:
|
| 102 |
+
logger.info(
|
| 103 |
+
f"Initializing an LLM engine (v{vllm.__version__}) with config: "
|
| 104 |
+
f"model={model_config.model!r}, "
|
| 105 |
+
f"speculative_config={speculative_config!r}, "
|
| 106 |
+
f"tokenizer={model_config.tokenizer!r}, "
|
| 107 |
+
f"skip_tokenizer_init={model_config.skip_tokenizer_init}, "
|
| 108 |
+
f"tokenizer_mode={model_config.tokenizer_mode}, "
|
| 109 |
+
f"revision={model_config.revision}, "
|
| 110 |
+
f"tokenizer_revision={model_config.tokenizer_revision}, "
|
| 111 |
+
f"trust_remote_code={model_config.trust_remote_code}, "
|
| 112 |
+
f"dtype={model_config.dtype}, "
|
| 113 |
+
f"max_seq_len={model_config.max_model_len}, "
|
| 114 |
+
f"download_dir={load_config.download_dir!r}, "
|
| 115 |
+
f"load_format={load_config.load_format}, "
|
| 116 |
+
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
| 117 |
+
f"disable_custom_all_reduce="
|
| 118 |
+
f"{parallel_config.disable_custom_all_reduce}, "
|
| 119 |
+
f"quantization={model_config.quantization}, "
|
| 120 |
+
f"enforce_eager={model_config.enforce_eager}, "
|
| 121 |
+
f"kv_cache_dtype={cache_config.cache_dtype}, "
|
| 122 |
+
f"quantization_param_path={model_config.quantization_param_path}, "
|
| 123 |
+
f"device_config={device_config.device}, "
|
| 124 |
+
f"decoding_config={decoding_config!r}, "
|
| 125 |
+
f"seed={model_config.seed})")
|
| 126 |
+
# TODO(woosuk): Print more configs in debug mode.
|
| 127 |
+
|
| 128 |
+
self.model_config = model_config
|
| 129 |
+
self.cache_config = cache_config
|
| 130 |
+
self.lora_config = lora_config
|
| 131 |
+
self.vision_language_config = vision_language_config
|
| 132 |
+
self.parallel_config = parallel_config
|
| 133 |
+
self.scheduler_config = scheduler_config
|
| 134 |
+
self.device_config = device_config
|
| 135 |
+
self.speculative_config = speculative_config
|
| 136 |
+
self.load_config = load_config
|
| 137 |
+
self.decoding_config = decoding_config or DecodingConfig()
|
| 138 |
+
self.log_stats = log_stats
|
| 139 |
+
|
| 140 |
+
if not self.model_config.skip_tokenizer_init:
|
| 141 |
+
self.tokenizer: BaseTokenizerGroup
|
| 142 |
+
self._init_tokenizer()
|
| 143 |
+
self.detokenizer = Detokenizer(self.tokenizer)
|
| 144 |
+
else:
|
| 145 |
+
self.detokenizer = None
|
| 146 |
+
self.tokenizer = None
|
| 147 |
+
|
| 148 |
+
self.seq_counter = Counter()
|
| 149 |
+
self.generation_config_fields = _load_generation_config_dict(
|
| 150 |
+
model_config)
|
| 151 |
+
|
| 152 |
+
self.model_executor = executor_class(
|
| 153 |
+
args=args,
|
| 154 |
+
model_config=model_config,
|
| 155 |
+
cache_config=cache_config,
|
| 156 |
+
parallel_config=parallel_config,
|
| 157 |
+
scheduler_config=scheduler_config,
|
| 158 |
+
device_config=device_config,
|
| 159 |
+
lora_config=lora_config,
|
| 160 |
+
vision_language_config=vision_language_config,
|
| 161 |
+
speculative_config=speculative_config,
|
| 162 |
+
load_config=load_config,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
self._initialize_kv_caches()
|
| 166 |
+
|
| 167 |
+
# If usage stat is enabled, collect relevant info.
|
| 168 |
+
if is_usage_stats_enabled():
|
| 169 |
+
from vllm.model_executor.model_loader import (
|
| 170 |
+
get_architecture_class_name)
|
| 171 |
+
usage_message.report_usage(
|
| 172 |
+
get_architecture_class_name(model_config),
|
| 173 |
+
usage_context,
|
| 174 |
+
extra_kvs={
|
| 175 |
+
# Common configuration
|
| 176 |
+
"dtype":
|
| 177 |
+
str(model_config.dtype),
|
| 178 |
+
"tensor_parallel_size":
|
| 179 |
+
parallel_config.tensor_parallel_size,
|
| 180 |
+
"block_size":
|
| 181 |
+
cache_config.block_size,
|
| 182 |
+
"gpu_memory_utilization":
|
| 183 |
+
cache_config.gpu_memory_utilization,
|
| 184 |
+
|
| 185 |
+
# Quantization
|
| 186 |
+
"quantization":
|
| 187 |
+
model_config.quantization,
|
| 188 |
+
"kv_cache_dtype":
|
| 189 |
+
cache_config.cache_dtype,
|
| 190 |
+
|
| 191 |
+
# Feature flags
|
| 192 |
+
"enable_lora":
|
| 193 |
+
bool(lora_config),
|
| 194 |
+
"enable_prefix_caching":
|
| 195 |
+
cache_config.enable_prefix_caching,
|
| 196 |
+
"enforce_eager":
|
| 197 |
+
model_config.enforce_eager,
|
| 198 |
+
"disable_custom_all_reduce":
|
| 199 |
+
parallel_config.disable_custom_all_reduce,
|
| 200 |
+
})
|
| 201 |
+
|
| 202 |
+
if self.tokenizer:
|
| 203 |
+
# Ping the tokenizer to ensure liveness if it runs in a
|
| 204 |
+
# different process.
|
| 205 |
+
self.tokenizer.ping()
|
| 206 |
+
|
| 207 |
+
# Create the scheduler.
|
| 208 |
+
# NOTE: the cache_config here have been updated with the numbers of
|
| 209 |
+
# GPU and CPU blocks, which are profiled in the distributed executor.
|
| 210 |
+
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
|
| 211 |
+
|
| 212 |
+
# Metric Logging.
|
| 213 |
+
if self.log_stats:
|
| 214 |
+
self.stat_logger = StatLogger(
|
| 215 |
+
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
|
| 216 |
+
labels=dict(model_name=model_config.model))
|
| 217 |
+
self.stat_logger.info("cache_config", self.cache_config)
|
| 218 |
+
|
| 219 |
+
# Create sequence output processor, e.g. for beam search or
|
| 220 |
+
# speculative decoding.
|
| 221 |
+
self.output_processor = (
|
| 222 |
+
SequenceGroupOutputProcessor.create_output_processor(
|
| 223 |
+
self.scheduler_config,
|
| 224 |
+
self.detokenizer,
|
| 225 |
+
self.scheduler,
|
| 226 |
+
self.seq_counter,
|
| 227 |
+
self.get_tokenizer_for_seq,
|
| 228 |
+
stop_checker=StopChecker(
|
| 229 |
+
self.scheduler_config.max_model_len,
|
| 230 |
+
self.get_tokenizer_for_seq,
|
| 231 |
+
),
|
| 232 |
+
))
|
| 233 |
+
|
| 234 |
+
def _initialize_kv_caches(self) -> None:
|
| 235 |
+
"""Initialize the KV cache in the worker(s).
|
| 236 |
+
|
| 237 |
+
The workers will determine the number of blocks in both the GPU cache
|
| 238 |
+
and the swap CPU cache.
|
| 239 |
+
"""
|
| 240 |
+
num_gpu_blocks, num_cpu_blocks = (
|
| 241 |
+
self.model_executor.determine_num_available_blocks())
|
| 242 |
+
|
| 243 |
+
if self.cache_config.num_gpu_blocks_override is not None:
|
| 244 |
+
num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
|
| 245 |
+
logger.info(f"Overriding {num_gpu_blocks=} with "
|
| 246 |
+
f"{num_gpu_blocks_override=}")
|
| 247 |
+
num_gpu_blocks = num_gpu_blocks_override
|
| 248 |
+
|
| 249 |
+
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
| 250 |
+
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
| 251 |
+
|
| 252 |
+
self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
| 253 |
+
|
| 254 |
+
@classmethod
|
| 255 |
+
def from_engine_args(
|
| 256 |
+
cls,
|
| 257 |
+
engine_args: EngineArgs,
|
| 258 |
+
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
| 259 |
+
args: argparse.ArgumentParser = None,
|
| 260 |
+
) -> "LLMEngine":
|
| 261 |
+
"""Creates an LLM engine from the engine arguments."""
|
| 262 |
+
# Create the engine configs.
|
| 263 |
+
engine_config = engine_args.create_engine_config()
|
| 264 |
+
|
| 265 |
+
# Initialize the cluster and specify the executor class.
|
| 266 |
+
if engine_config.device_config.device_type == "neuron":
|
| 267 |
+
from vllm.executor.neuron_executor import NeuronExecutor
|
| 268 |
+
executor_class = NeuronExecutor
|
| 269 |
+
elif engine_config.device_config.device_type == "cpu":
|
| 270 |
+
from vllm.executor.cpu_executor import CPUExecutor
|
| 271 |
+
executor_class = CPUExecutor
|
| 272 |
+
elif engine_config.parallel_config.worker_use_ray:
|
| 273 |
+
initialize_ray_cluster(engine_config.parallel_config)
|
| 274 |
+
from vllm.executor.ray_gpu_executor import RayGPUExecutor
|
| 275 |
+
executor_class = RayGPUExecutor
|
| 276 |
+
else:
|
| 277 |
+
assert engine_config.parallel_config.world_size == 1, (
|
| 278 |
+
"Ray is required if parallel_config.world_size > 1.")
|
| 279 |
+
# from vllm.executor.gpu_executor import GPUExecutor
|
| 280 |
+
from serve.gpu_executor import GPUExecutor
|
| 281 |
+
executor_class = GPUExecutor
|
| 282 |
+
|
| 283 |
+
# Create the LLM engine.
|
| 284 |
+
engine = cls(
|
| 285 |
+
**engine_config.to_dict(),
|
| 286 |
+
executor_class=executor_class,
|
| 287 |
+
log_stats=not engine_args.disable_log_stats,
|
| 288 |
+
usage_context=usage_context,
|
| 289 |
+
args=args,
|
| 290 |
+
)
|
| 291 |
+
return engine
|
| 292 |
+
|
| 293 |
+
def __reduce__(self):
|
| 294 |
+
# This is to ensure that the LLMEngine is not referenced in
|
| 295 |
+
# the closure used to initialize Ray worker actors
|
| 296 |
+
raise RuntimeError("LLMEngine should not be pickled!")
|
| 297 |
+
|
| 298 |
+
def get_tokenizer(self) -> "PreTrainedTokenizer":
|
| 299 |
+
return self.tokenizer.get_lora_tokenizer(None)
|
| 300 |
+
|
| 301 |
+
def get_tokenizer_for_seq(self,
|
| 302 |
+
sequence: Sequence) -> "PreTrainedTokenizer":
|
| 303 |
+
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
|
| 304 |
+
|
| 305 |
+
def _init_tokenizer(self, **tokenizer_init_kwargs):
|
| 306 |
+
init_kwargs = dict(
|
| 307 |
+
tokenizer_id=self.model_config.tokenizer,
|
| 308 |
+
enable_lora=bool(self.lora_config),
|
| 309 |
+
max_num_seqs=self.scheduler_config.max_num_seqs,
|
| 310 |
+
max_input_length=None,
|
| 311 |
+
tokenizer_mode=self.model_config.tokenizer_mode,
|
| 312 |
+
trust_remote_code=self.model_config.trust_remote_code,
|
| 313 |
+
revision=self.model_config.tokenizer_revision)
|
| 314 |
+
init_kwargs.update(tokenizer_init_kwargs)
|
| 315 |
+
self.tokenizer = get_tokenizer_group(
|
| 316 |
+
self.parallel_config.tokenizer_pool_config, **init_kwargs)
|
| 317 |
+
|
| 318 |
+
def _verify_args(self) -> None:
|
| 319 |
+
self.model_config.verify_with_parallel_config(self.parallel_config)
|
| 320 |
+
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
| 321 |
+
if self.lora_config:
|
| 322 |
+
self.lora_config.verify_with_model_config(self.model_config)
|
| 323 |
+
self.lora_config.verify_with_scheduler_config(
|
| 324 |
+
self.scheduler_config)
|
| 325 |
+
|
| 326 |
+
def encode_request(
|
| 327 |
+
self,
|
| 328 |
+
request_id: str, # pylint: disable=unused-argument
|
| 329 |
+
prompt: Optional[str],
|
| 330 |
+
prompt_token_ids: Optional[List[int]] = None,
|
| 331 |
+
lora_request: Optional[LoRARequest] = None,
|
| 332 |
+
):
|
| 333 |
+
if prompt_token_ids is None:
|
| 334 |
+
assert prompt is not None
|
| 335 |
+
prompt_token_ids = self.tokenizer.encode(request_id=request_id,
|
| 336 |
+
prompt=prompt,
|
| 337 |
+
lora_request=lora_request)
|
| 338 |
+
return prompt_token_ids
|
| 339 |
+
|
| 340 |
+
def add_request(
|
| 341 |
+
self,
|
| 342 |
+
request_id: str,
|
| 343 |
+
prompt: Optional[str],
|
| 344 |
+
sampling_params: SamplingParams,
|
| 345 |
+
prompt_token_ids: Optional[List[int]] = None,
|
| 346 |
+
arrival_time: Optional[float] = None,
|
| 347 |
+
lora_request: Optional[LoRARequest] = None,
|
| 348 |
+
multi_modal_data: Optional[MultiModalData] = None,
|
| 349 |
+
) -> None:
|
| 350 |
+
"""Add a request to the engine's request pool.
|
| 351 |
+
|
| 352 |
+
The request is added to the request pool and will be processed by the
|
| 353 |
+
scheduler as `engine.step()` is called. The exact scheduling policy is
|
| 354 |
+
determined by the scheduler.
|
| 355 |
+
|
| 356 |
+
Args:
|
| 357 |
+
request_id: The unique ID of the request.
|
| 358 |
+
prompt: The prompt string. Can be None if prompt_token_ids is
|
| 359 |
+
provided.
|
| 360 |
+
sampling_params: The sampling parameters for text generation.
|
| 361 |
+
prompt_token_ids: The token IDs of the prompt. If None, we
|
| 362 |
+
use the tokenizer to convert the prompts to token IDs.
|
| 363 |
+
arrival_time: The arrival time of the request. If None, we use
|
| 364 |
+
the current monotonic time.
|
| 365 |
+
multi_modal_data: Multi modal data per request.
|
| 366 |
+
|
| 367 |
+
Details:
|
| 368 |
+
- Set arrival_time to the current time if it is None.
|
| 369 |
+
- Set prompt_token_ids to the encoded prompt if it is None.
|
| 370 |
+
- Create `best_of` number of :class:`~vllm.Sequence` objects.
|
| 371 |
+
- Create a :class:`~vllm.SequenceGroup` object
|
| 372 |
+
from the list of :class:`~vllm.Sequence`.
|
| 373 |
+
- Add the :class:`~vllm.SequenceGroup` object to the scheduler.
|
| 374 |
+
|
| 375 |
+
Example:
|
| 376 |
+
>>> # initialize engine
|
| 377 |
+
>>> engine = LLMEngine.from_engine_args(engine_args)
|
| 378 |
+
>>> # set request arguments
|
| 379 |
+
>>> example_prompt = "Who is the president of the United States?"
|
| 380 |
+
>>> sampling_params = SamplingParams(temperature=0.0)
|
| 381 |
+
>>> request_id = 0
|
| 382 |
+
>>>
|
| 383 |
+
>>> # add the request to the engine
|
| 384 |
+
>>> engine.add_request(
|
| 385 |
+
>>> str(request_id),
|
| 386 |
+
>>> example_prompt,
|
| 387 |
+
>>> SamplingParams(temperature=0.0))
|
| 388 |
+
>>> # continue the request processing
|
| 389 |
+
>>> ...
|
| 390 |
+
"""
|
| 391 |
+
if lora_request is not None and not self.lora_config:
|
| 392 |
+
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
| 393 |
+
"not enabled!")
|
| 394 |
+
max_logprobs = self.get_model_config().max_logprobs
|
| 395 |
+
if (sampling_params.logprobs
|
| 396 |
+
and sampling_params.logprobs > max_logprobs) or (
|
| 397 |
+
sampling_params.prompt_logprobs
|
| 398 |
+
and sampling_params.prompt_logprobs > max_logprobs):
|
| 399 |
+
raise ValueError(f"Cannot request more than "
|
| 400 |
+
f"{max_logprobs} logprobs.")
|
| 401 |
+
if arrival_time is None:
|
| 402 |
+
arrival_time = time.time()
|
| 403 |
+
prompt_token_ids = self.encode_request(
|
| 404 |
+
request_id=request_id,
|
| 405 |
+
prompt=prompt,
|
| 406 |
+
prompt_token_ids=prompt_token_ids,
|
| 407 |
+
lora_request=lora_request)
|
| 408 |
+
|
| 409 |
+
# Create the sequences.
|
| 410 |
+
block_size = self.cache_config.block_size
|
| 411 |
+
seq_id = next(self.seq_counter)
|
| 412 |
+
eos_token_id = None
|
| 413 |
+
if self.tokenizer:
|
| 414 |
+
eos_token_id = self.tokenizer.get_lora_tokenizer(
|
| 415 |
+
lora_request).eos_token_id
|
| 416 |
+
else:
|
| 417 |
+
logger.warning("Use None for EOS token id because tokenizer is "
|
| 418 |
+
"not initialized")
|
| 419 |
+
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
|
| 420 |
+
eos_token_id, lora_request)
|
| 421 |
+
|
| 422 |
+
# Defensive copy of SamplingParams, which are used by the sampler,
|
| 423 |
+
# this doesn't deep-copy LogitsProcessor objects
|
| 424 |
+
sampling_params = sampling_params.clone()
|
| 425 |
+
# Add the eos token id into the sampling_params to support min_tokens
|
| 426 |
+
# processing
|
| 427 |
+
if seq.eos_token_id is not None:
|
| 428 |
+
sampling_params.all_stop_token_ids.add(seq.eos_token_id)
|
| 429 |
+
sampling_params.update_from_generation_config(
|
| 430 |
+
self.generation_config_fields)
|
| 431 |
+
|
| 432 |
+
# Create the sequence group.
|
| 433 |
+
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
| 434 |
+
arrival_time, lora_request, multi_modal_data)
|
| 435 |
+
|
| 436 |
+
# Add the sequence group to the scheduler.
|
| 437 |
+
self.scheduler.add_seq_group(seq_group)
|
| 438 |
+
|
| 439 |
+
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
|
| 440 |
+
"""Aborts a request(s) with the given ID.
|
| 441 |
+
|
| 442 |
+
Args:
|
| 443 |
+
request_id: The ID(s) of the request to abort.
|
| 444 |
+
|
| 445 |
+
Details:
|
| 446 |
+
- Refer to the
|
| 447 |
+
:meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
|
| 448 |
+
from class :class:`~vllm.core.scheduler.Scheduler`.
|
| 449 |
+
|
| 450 |
+
Example:
|
| 451 |
+
>>> # initialize engine and add a request with request_id
|
| 452 |
+
>>> request_id = str(0)
|
| 453 |
+
>>> # abort the request
|
| 454 |
+
>>> engine.abort_request(request_id)
|
| 455 |
+
"""
|
| 456 |
+
self.scheduler.abort_seq_group(request_id)
|
| 457 |
+
|
| 458 |
+
def get_model_config(self) -> ModelConfig:
|
| 459 |
+
"""Gets the model configuration."""
|
| 460 |
+
return self.model_config
|
| 461 |
+
|
| 462 |
+
def get_num_unfinished_requests(self) -> int:
|
| 463 |
+
"""Gets the number of unfinished requests."""
|
| 464 |
+
return self.scheduler.get_num_unfinished_seq_groups()
|
| 465 |
+
|
| 466 |
+
def has_unfinished_requests(self) -> bool:
|
| 467 |
+
"""Returns True if there are unfinished requests."""
|
| 468 |
+
return self.scheduler.has_unfinished_seqs()
|
| 469 |
+
|
| 470 |
+
def _process_model_outputs(
|
| 471 |
+
self, output: List[SamplerOutput],
|
| 472 |
+
scheduled_seq_groups: List[SequenceGroup],
|
| 473 |
+
ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
|
| 474 |
+
"""Apply the model output to the sequences in the scheduled seq groups.
|
| 475 |
+
|
| 476 |
+
Returns RequestOutputs that can be returned to the client.
|
| 477 |
+
"""
|
| 478 |
+
now = time.time()
|
| 479 |
+
|
| 480 |
+
# Organize outputs by [sequence group][step] instead of
|
| 481 |
+
# [step][sequence group].
|
| 482 |
+
output_by_sequence_group = create_output_by_sequence_group(
|
| 483 |
+
sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
|
| 484 |
+
|
| 485 |
+
# Update the scheduled sequence groups with the model outputs.
|
| 486 |
+
for scheduled_seq_group, outputs in zip(scheduled_seq_groups,
|
| 487 |
+
output_by_sequence_group):
|
| 488 |
+
seq_group = scheduled_seq_group.seq_group
|
| 489 |
+
seq_group.update_num_computed_tokens(
|
| 490 |
+
scheduled_seq_group.token_chunk_size)
|
| 491 |
+
# If uncomputed tokens > 0, it means prefill is chunked.
|
| 492 |
+
# We don't need to process outputs in that case.
|
| 493 |
+
if seq_group.get_num_uncomputed_tokens() == 0:
|
| 494 |
+
self.output_processor.process_outputs(seq_group, outputs)
|
| 495 |
+
|
| 496 |
+
# Free the finished sequence groups.
|
| 497 |
+
self.scheduler.free_finished_seq_groups()
|
| 498 |
+
|
| 499 |
+
# Create the outputs.
|
| 500 |
+
request_outputs: List[RequestOutput] = []
|
| 501 |
+
for scheduled_seq_group in scheduled_seq_groups:
|
| 502 |
+
seq_group = scheduled_seq_group.seq_group
|
| 503 |
+
seq_group.maybe_set_first_token_time(now)
|
| 504 |
+
request_output = RequestOutput.from_seq_group(seq_group)
|
| 505 |
+
request_outputs.append(request_output)
|
| 506 |
+
for seq_group in ignored_seq_groups:
|
| 507 |
+
request_output = RequestOutput.from_seq_group(seq_group)
|
| 508 |
+
request_outputs.append(request_output)
|
| 509 |
+
return request_outputs
|
| 510 |
+
|
| 511 |
+
def step(self) -> List[RequestOutput]:
|
| 512 |
+
"""Performs one decoding iteration and returns newly generated results.
|
| 513 |
+
|
| 514 |
+
.. figure:: https://i.imgur.com/sv2HssD.png
|
| 515 |
+
:alt: Overview of the step function
|
| 516 |
+
:align: center
|
| 517 |
+
|
| 518 |
+
Overview of the step function.
|
| 519 |
+
|
| 520 |
+
Details:
|
| 521 |
+
- Step 1: Schedules the sequences to be executed in the next
|
| 522 |
+
iteration and the token blocks to be swapped in/out/copy.
|
| 523 |
+
|
| 524 |
+
- Depending on the scheduling policy,
|
| 525 |
+
sequences may be `preempted/reordered`.
|
| 526 |
+
- A Sequence Group (SG) refer to a group of sequences
|
| 527 |
+
that are generated from the same prompt.
|
| 528 |
+
|
| 529 |
+
- Step 2: Calls the distributed executor to execute the model.
|
| 530 |
+
- Step 3: Processes the model output. This mainly includes:
|
| 531 |
+
|
| 532 |
+
- Decodes the relevant outputs.
|
| 533 |
+
- Updates the scheduled sequence groups with model outputs
|
| 534 |
+
based on its `sampling parameters` (`use_beam_search` or not).
|
| 535 |
+
- Frees the finished sequence groups.
|
| 536 |
+
|
| 537 |
+
- Finally, it creates and returns the newly generated results.
|
| 538 |
+
|
| 539 |
+
Example:
|
| 540 |
+
>>> # Please see the example/ folder for more detailed examples.
|
| 541 |
+
>>>
|
| 542 |
+
>>> # initialize engine and request arguments
|
| 543 |
+
>>> engine = LLMEngine.from_engine_args(engine_args)
|
| 544 |
+
>>> example_inputs = [(0, "What is LLM?",
|
| 545 |
+
>>> SamplingParams(temperature=0.0))]
|
| 546 |
+
>>>
|
| 547 |
+
>>> # Start the engine with an event loop
|
| 548 |
+
>>> while True:
|
| 549 |
+
>>> if example_inputs:
|
| 550 |
+
>>> req_id, prompt, sampling_params = example_inputs.pop(0)
|
| 551 |
+
>>> engine.add_request(str(req_id), prompt, sampling_params)
|
| 552 |
+
>>>
|
| 553 |
+
>>> # continue the request processing
|
| 554 |
+
>>> request_outputs = engine.step()
|
| 555 |
+
>>> for request_output in request_outputs:
|
| 556 |
+
>>> if request_output.finished:
|
| 557 |
+
>>> # return or show the request output
|
| 558 |
+
>>>
|
| 559 |
+
>>> if not (engine.has_unfinished_requests() or example_inputs):
|
| 560 |
+
>>> break
|
| 561 |
+
"""
|
| 562 |
+
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
| 563 |
+
if not scheduler_outputs.is_empty():
|
| 564 |
+
output = self.model_executor.execute_model(
|
| 565 |
+
seq_group_metadata_list=seq_group_metadata_list,
|
| 566 |
+
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
| 567 |
+
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
| 568 |
+
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
| 569 |
+
num_lookahead_slots=scheduler_outputs.num_lookahead_slots)
|
| 570 |
+
else:
|
| 571 |
+
output = []
|
| 572 |
+
|
| 573 |
+
request_outputs = self._process_model_outputs(
|
| 574 |
+
output, scheduler_outputs.scheduled_seq_groups,
|
| 575 |
+
scheduler_outputs.ignored_seq_groups)
|
| 576 |
+
|
| 577 |
+
# Log stats.
|
| 578 |
+
if self.log_stats:
|
| 579 |
+
self.stat_logger.log(self._get_stats(scheduler_outputs))
|
| 580 |
+
|
| 581 |
+
return request_outputs
|
| 582 |
+
|
| 583 |
+
def do_log_stats(self) -> None:
|
| 584 |
+
"""Forced log when no requests active."""
|
| 585 |
+
if self.log_stats:
|
| 586 |
+
self.stat_logger.log(self._get_stats(scheduler_outputs=None))
|
| 587 |
+
|
| 588 |
+
def _get_stats(self,
|
| 589 |
+
scheduler_outputs: Optional[SchedulerOutputs]) -> Stats:
|
| 590 |
+
"""Get Stats to be Logged to Prometheus."""
|
| 591 |
+
now = time.time()
|
| 592 |
+
|
| 593 |
+
# KV Cache Usage in %.
|
| 594 |
+
num_total_gpu = self.cache_config.num_gpu_blocks
|
| 595 |
+
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
|
| 596 |
+
gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu)
|
| 597 |
+
|
| 598 |
+
num_total_cpu = self.cache_config.num_cpu_blocks
|
| 599 |
+
cpu_cache_usage = 0.
|
| 600 |
+
if num_total_cpu > 0:
|
| 601 |
+
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
|
| 602 |
+
)
|
| 603 |
+
cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu)
|
| 604 |
+
|
| 605 |
+
# Scheduler State
|
| 606 |
+
num_running = len(self.scheduler.running)
|
| 607 |
+
num_swapped = len(self.scheduler.swapped)
|
| 608 |
+
num_waiting = len(self.scheduler.waiting)
|
| 609 |
+
|
| 610 |
+
# Iteration stats if we have scheduler output.
|
| 611 |
+
num_prompt_tokens = 0
|
| 612 |
+
num_generation_tokens = 0
|
| 613 |
+
time_to_first_tokens = []
|
| 614 |
+
time_per_output_tokens = []
|
| 615 |
+
time_e2e_requests = []
|
| 616 |
+
if scheduler_outputs is not None:
|
| 617 |
+
prompt_run = scheduler_outputs.num_prefill_groups > 0
|
| 618 |
+
|
| 619 |
+
# Number of Tokens.
|
| 620 |
+
if prompt_run:
|
| 621 |
+
num_prompt_tokens = sum(
|
| 622 |
+
len(scheduled_seq_group.seq_group.prompt_token_ids)
|
| 623 |
+
for scheduled_seq_group in
|
| 624 |
+
scheduler_outputs.scheduled_seq_groups)
|
| 625 |
+
num_generation_tokens = sum(
|
| 626 |
+
scheduled_seq_group.seq_group.num_seqs()
|
| 627 |
+
for scheduled_seq_group in
|
| 628 |
+
scheduler_outputs.scheduled_seq_groups)
|
| 629 |
+
else:
|
| 630 |
+
num_generation_tokens = scheduler_outputs.num_batched_tokens
|
| 631 |
+
|
| 632 |
+
# Latency Timings.
|
| 633 |
+
time_last_iters = []
|
| 634 |
+
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
|
| 635 |
+
seq_group = scheduled_seq_group.seq_group
|
| 636 |
+
# Time since last token.
|
| 637 |
+
# (n.b. updates seq_group.metrics.last_token_time)
|
| 638 |
+
time_last_iters.append(seq_group.get_last_latency(now))
|
| 639 |
+
# Time since arrival for all finished requests.
|
| 640 |
+
if seq_group.is_finished():
|
| 641 |
+
time_e2e_requests.append(now -
|
| 642 |
+
seq_group.metrics.arrival_time)
|
| 643 |
+
|
| 644 |
+
time_to_first_tokens = time_last_iters if prompt_run else []
|
| 645 |
+
time_per_output_tokens = [] if prompt_run else time_last_iters
|
| 646 |
+
|
| 647 |
+
return Stats(
|
| 648 |
+
now=now,
|
| 649 |
+
num_running=num_running,
|
| 650 |
+
num_swapped=num_swapped,
|
| 651 |
+
num_waiting=num_waiting,
|
| 652 |
+
gpu_cache_usage=gpu_cache_usage,
|
| 653 |
+
cpu_cache_usage=cpu_cache_usage,
|
| 654 |
+
num_prompt_tokens=num_prompt_tokens,
|
| 655 |
+
num_generation_tokens=num_generation_tokens,
|
| 656 |
+
time_to_first_tokens=time_to_first_tokens,
|
| 657 |
+
time_per_output_tokens=time_per_output_tokens,
|
| 658 |
+
time_e2e_requests=time_e2e_requests,
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
def add_lora(self, lora_request: LoRARequest) -> bool:
|
| 662 |
+
return self.model_executor.add_lora(lora_request)
|
| 663 |
+
|
| 664 |
+
def remove_lora(self, lora_id: int) -> bool:
|
| 665 |
+
return self.model_executor.remove_lora(lora_id)
|
| 666 |
+
|
| 667 |
+
def list_loras(self) -> List[int]:
|
| 668 |
+
return self.model_executor.list_loras()
|
| 669 |
+
|
| 670 |
+
def check_health(self) -> None:
|
| 671 |
+
self.model_executor.check_health()
|
serve/model_runner.py
ADDED
|
@@ -0,0 +1,1223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import time
|
| 3 |
+
from enum import IntEnum
|
| 4 |
+
from typing import Dict, List, NamedTuple, Optional, Set, Tuple
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
|
| 11 |
+
get_attn_backend)
|
| 12 |
+
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
|
| 13 |
+
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
| 14 |
+
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
|
| 15 |
+
from vllm.distributed.device_communicators import (custom_all_reduce,
|
| 16 |
+
pynccl_utils)
|
| 17 |
+
from vllm.logger import init_logger
|
| 18 |
+
from vllm.lora.layers import LoRAMapping
|
| 19 |
+
from vllm.lora.request import LoRARequest
|
| 20 |
+
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
| 21 |
+
from vllm.model_executor import SamplingMetadata
|
| 22 |
+
from vllm.model_executor.model_loader import get_model
|
| 23 |
+
from vllm.sampling_params import SamplingParams, SamplingType
|
| 24 |
+
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
|
| 25 |
+
SequenceGroupMetadata)
|
| 26 |
+
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, is_hip,
|
| 27 |
+
is_pin_memory_available, make_tensor_with_pad,
|
| 28 |
+
maybe_expand_dim)
|
| 29 |
+
from serve.gpt_model import GPT_models
|
| 30 |
+
|
| 31 |
+
logger = init_logger(__name__)
|
| 32 |
+
|
| 33 |
+
_PAD_SLOT_ID = -1
|
| 34 |
+
LORA_WARMUP_RANK = 8
|
| 35 |
+
_BATCH_SIZE_ALIGNMENT = 8
|
| 36 |
+
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
|
| 37 |
+
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
|
| 38 |
+
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
|
| 39 |
+
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class PreparePromptMetadata(NamedTuple):
|
| 44 |
+
input_tokens: List[int]
|
| 45 |
+
input_positions: List[int]
|
| 46 |
+
attn_metadata: Optional[AttentionMetadataPerStage]
|
| 47 |
+
prompt_lens: List[int]
|
| 48 |
+
subquery_lens: List[int]
|
| 49 |
+
lora_index_mapping: List[int]
|
| 50 |
+
lora_prompt_mapping: List[int]
|
| 51 |
+
lora_requests: Set[LoRARequest]
|
| 52 |
+
multi_modal_input: Optional[torch.Tensor]
|
| 53 |
+
slot_mapping: List[int]
|
| 54 |
+
|
| 55 |
+
@classmethod
|
| 56 |
+
def empty(cls):
|
| 57 |
+
return PreparePromptMetadata(
|
| 58 |
+
input_tokens=[],
|
| 59 |
+
input_positions=[],
|
| 60 |
+
attn_metadata=None,
|
| 61 |
+
prompt_lens=[],
|
| 62 |
+
subquery_lens=[],
|
| 63 |
+
lora_index_mapping=[],
|
| 64 |
+
lora_prompt_mapping=[],
|
| 65 |
+
lora_requests=set(),
|
| 66 |
+
multi_modal_input=None,
|
| 67 |
+
slot_mapping=[],
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class PrepareDecodeMetadata(NamedTuple):
|
| 72 |
+
input_tokens: List[int]
|
| 73 |
+
input_positions: List[int]
|
| 74 |
+
attn_metadata: Optional[AttentionMetadata]
|
| 75 |
+
lora_index_mapping: List[int]
|
| 76 |
+
lora_prompt_mapping: List[int]
|
| 77 |
+
lora_requests: Set[LoRARequest]
|
| 78 |
+
slot_mapping: List[int]
|
| 79 |
+
|
| 80 |
+
@classmethod
|
| 81 |
+
def empty(cls):
|
| 82 |
+
return PrepareDecodeMetadata(
|
| 83 |
+
input_tokens=[],
|
| 84 |
+
input_positions=[],
|
| 85 |
+
attn_metadata=None,
|
| 86 |
+
lora_index_mapping=[],
|
| 87 |
+
lora_prompt_mapping=[],
|
| 88 |
+
lora_requests=set(),
|
| 89 |
+
slot_mapping=[],
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# How batches are constructed.
|
| 94 |
+
class BatchType(IntEnum):
|
| 95 |
+
# Every batch is prefill.
|
| 96 |
+
PREFILL = 0
|
| 97 |
+
# Every batch is decode.
|
| 98 |
+
DECODE = 1
|
| 99 |
+
# Batch is a mixture of prefill and decode.
|
| 100 |
+
MIXED = 2
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class ModelRunner:
|
| 104 |
+
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
model_config: ModelConfig,
|
| 108 |
+
parallel_config: ParallelConfig,
|
| 109 |
+
scheduler_config: SchedulerConfig,
|
| 110 |
+
device_config: DeviceConfig,
|
| 111 |
+
load_config: LoadConfig,
|
| 112 |
+
lora_config: Optional[LoRAConfig],
|
| 113 |
+
kv_cache_dtype: Optional[str] = "auto",
|
| 114 |
+
is_driver_worker: bool = False,
|
| 115 |
+
vision_language_config: Optional[VisionLanguageConfig] = None,
|
| 116 |
+
):
|
| 117 |
+
self.model_config = model_config
|
| 118 |
+
self.parallel_config = parallel_config
|
| 119 |
+
self.scheduler_config = scheduler_config
|
| 120 |
+
self.lora_config = lora_config
|
| 121 |
+
self.load_config = load_config
|
| 122 |
+
self.is_driver_worker = is_driver_worker
|
| 123 |
+
|
| 124 |
+
# model_config can be None in tests/samplers/test_sampler.py.
|
| 125 |
+
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
|
| 126 |
+
self.sliding_window = (model_config.get_sliding_window()
|
| 127 |
+
if model_config is not None else None)
|
| 128 |
+
self.device_config = (device_config
|
| 129 |
+
if device_config is not None else DeviceConfig())
|
| 130 |
+
self.device = self.device_config.device
|
| 131 |
+
|
| 132 |
+
# Set after load_model.
|
| 133 |
+
self.lora_manager: LRUCacheWorkerLoRAManager = None
|
| 134 |
+
|
| 135 |
+
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
|
| 136 |
+
self.graph_memory_pool: Optional[Tuple[
|
| 137 |
+
int, int]] = None # Set during graph capture.
|
| 138 |
+
|
| 139 |
+
self.max_context_len_to_capture = (
|
| 140 |
+
self.model_config.max_context_len_to_capture
|
| 141 |
+
if self.model_config is not None else 0)
|
| 142 |
+
|
| 143 |
+
self.pin_memory = is_pin_memory_available()
|
| 144 |
+
self.kv_cache_dtype = kv_cache_dtype
|
| 145 |
+
self.vision_language_config = vision_language_config
|
| 146 |
+
|
| 147 |
+
self.attn_backend = get_attn_backend(
|
| 148 |
+
self.model_config.dtype if model_config is not None else None)
|
| 149 |
+
|
| 150 |
+
# Lazy initialization
|
| 151 |
+
self.model: torch.nn.Module # Set after load_model
|
| 152 |
+
self.block_size: int # Set after initial profiling.
|
| 153 |
+
# When using CUDA graph, the input block tables must be padded to
|
| 154 |
+
# max_context_len_to_capture. However, creating the block table in
|
| 155 |
+
# Python can be expensive. To optimize this, we cache the block table
|
| 156 |
+
# in numpy and only copy the actual input content at every iteration.
|
| 157 |
+
# The shape of the cached block table will be
|
| 158 |
+
# (max batch size to capture, max context len to capture / block size).
|
| 159 |
+
self.graph_block_tables: torch.Tensor # Set after initial profiling.
|
| 160 |
+
|
| 161 |
+
def load_model(self, args) -> None:
|
| 162 |
+
with CudaMemoryProfiler() as m:
|
| 163 |
+
precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
|
| 164 |
+
latent_size = args.image_size // args.downsample_size
|
| 165 |
+
gpt_model = GPT_models[args.gpt_model](
|
| 166 |
+
vocab_size=args.codebook_size,
|
| 167 |
+
block_size=latent_size ** 2,
|
| 168 |
+
num_classes=args.num_classes,
|
| 169 |
+
cls_token_num=args.cls_token_num,
|
| 170 |
+
model_type=args.gpt_type,
|
| 171 |
+
cfg_scale=args.cfg_scale,
|
| 172 |
+
).to(device='cuda', dtype=precision) # TODO: make device configurable
|
| 173 |
+
|
| 174 |
+
checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
|
| 175 |
+
if args.from_fsdp: # fspd
|
| 176 |
+
model_weight = checkpoint
|
| 177 |
+
elif "model" in checkpoint: # ddp
|
| 178 |
+
model_weight = checkpoint["model"]
|
| 179 |
+
elif "state_dict" in checkpoint:
|
| 180 |
+
model_weight = checkpoint["state_dict"]
|
| 181 |
+
else:
|
| 182 |
+
raise Exception("please check model weight")
|
| 183 |
+
gpt_model.custom_load_state_dict(model_weight)
|
| 184 |
+
gpt_model.eval()
|
| 185 |
+
del checkpoint
|
| 186 |
+
self.model = gpt_model
|
| 187 |
+
|
| 188 |
+
self.model_memory_usage = m.consumed_memory
|
| 189 |
+
logger.info(f"Loading model weights took "
|
| 190 |
+
f"{self.model_memory_usage / float(2**30):.4f} GB")
|
| 191 |
+
|
| 192 |
+
if self.lora_config:
|
| 193 |
+
assert hasattr(self.model, "supported_lora_modules"
|
| 194 |
+
) and self.model.supported_lora_modules, (
|
| 195 |
+
"Model does not support LoRA")
|
| 196 |
+
assert hasattr(
|
| 197 |
+
self.model,
|
| 198 |
+
"embedding_modules"), "Model does not have embedding_modules"
|
| 199 |
+
assert hasattr(self.model, "embedding_padding_modules"
|
| 200 |
+
), "Model does not have embedding_padding_modules"
|
| 201 |
+
self.lora_manager = LRUCacheWorkerLoRAManager(
|
| 202 |
+
self.scheduler_config.max_num_seqs,
|
| 203 |
+
self.scheduler_config.max_num_batched_tokens, self.vocab_size,
|
| 204 |
+
self.lora_config, self.device, self.model.embedding_modules,
|
| 205 |
+
self.model.embedding_padding_modules)
|
| 206 |
+
self.model = self.lora_manager.create_lora_manager(self.model)
|
| 207 |
+
|
| 208 |
+
if self.kv_cache_dtype == "fp8" and is_hip():
|
| 209 |
+
# Currently scaled KV cache is only enabled on ROCm
|
| 210 |
+
if self.model_config.quantization_param_path is not None:
|
| 211 |
+
if callable(getattr(self.model, "load_kv_cache_scales", None)):
|
| 212 |
+
self.model.load_kv_cache_scales(
|
| 213 |
+
self.model_config.quantization_param_path)
|
| 214 |
+
else:
|
| 215 |
+
raise RuntimeError("Using FP8 KV cache and scaling "
|
| 216 |
+
"factors provided but model "
|
| 217 |
+
f"{self.model.__class__} does not "
|
| 218 |
+
"support loading scaling factors.")
|
| 219 |
+
else:
|
| 220 |
+
logger.warn("Using FP8 KV cache but no scaling factors "
|
| 221 |
+
"provided. Defaulting to scaling factors of 1.0. "
|
| 222 |
+
"This may lead to less accurate results!")
|
| 223 |
+
elif self.model_config.quantization_param_path is not None:
|
| 224 |
+
logger.warn("KV cache scaling factors provided, "
|
| 225 |
+
"but the KV cache data type is not FP8. "
|
| 226 |
+
"KV cache scaling factors will not be used.")
|
| 227 |
+
|
| 228 |
+
def set_block_size(self, block_size: int) -> None:
|
| 229 |
+
self.block_size = block_size
|
| 230 |
+
|
| 231 |
+
self.graph_block_tables = np.zeros(
|
| 232 |
+
(max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
|
| 233 |
+
dtype=np.int32)
|
| 234 |
+
|
| 235 |
+
def get_max_block_per_batch(self) -> int:
|
| 236 |
+
block_size = self.block_size
|
| 237 |
+
return (self.max_context_len_to_capture + block_size - 1) // block_size
|
| 238 |
+
|
| 239 |
+
def _prepare_prompt(
|
| 240 |
+
self,
|
| 241 |
+
seq_group_metadata_list: List[SequenceGroupMetadata],
|
| 242 |
+
) -> PreparePromptMetadata:
|
| 243 |
+
input_tokens: List[int] = []
|
| 244 |
+
input_positions: List[int] = []
|
| 245 |
+
slot_mapping: List[int] = []
|
| 246 |
+
lora_index_mapping: List[int] = []
|
| 247 |
+
lora_prompt_mapping: List[int] = []
|
| 248 |
+
lora_requests: Set[LoRARequest] = set()
|
| 249 |
+
|
| 250 |
+
prompt_lens: List[int] = []
|
| 251 |
+
context_lens: List[int] = []
|
| 252 |
+
subquery_lens: List[int] = []
|
| 253 |
+
prefix_block_tables: List[List[int]] = []
|
| 254 |
+
multi_modal_input_list: List[torch.Tensor] = []
|
| 255 |
+
|
| 256 |
+
if len(seq_group_metadata_list) == 0:
|
| 257 |
+
return PreparePromptMetadata.empty()
|
| 258 |
+
|
| 259 |
+
for seq_group_metadata in seq_group_metadata_list:
|
| 260 |
+
assert seq_group_metadata.is_prompt
|
| 261 |
+
seq_ids = list(seq_group_metadata.seq_data.keys())
|
| 262 |
+
assert len(seq_ids) == 1
|
| 263 |
+
seq_id = seq_ids[0]
|
| 264 |
+
|
| 265 |
+
computed_block_nums = seq_group_metadata.computed_block_nums
|
| 266 |
+
if (self.scheduler_config is not None
|
| 267 |
+
and self.scheduler_config.chunked_prefill_enabled
|
| 268 |
+
and not (computed_block_nums is None
|
| 269 |
+
or computed_block_nums == [])):
|
| 270 |
+
raise RuntimeError(
|
| 271 |
+
"chunked prefill cannot be used with prefix caching "
|
| 272 |
+
"now.")
|
| 273 |
+
|
| 274 |
+
token_chunk_size = seq_group_metadata.token_chunk_size
|
| 275 |
+
seq_data = seq_group_metadata.seq_data[seq_id]
|
| 276 |
+
computed_len = seq_data.get_num_computed_tokens()
|
| 277 |
+
# We should use get_len here because in case of preemption
|
| 278 |
+
# it contains output tokens.
|
| 279 |
+
prefill_end = min(seq_data.get_len(),
|
| 280 |
+
computed_len + token_chunk_size)
|
| 281 |
+
prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
|
| 282 |
+
prompt_len = prefill_end
|
| 283 |
+
prompt_lens.append(prompt_len)
|
| 284 |
+
|
| 285 |
+
# NOTE: This only works for oooooooxxx style attention.
|
| 286 |
+
if computed_block_nums is not None and len(
|
| 287 |
+
computed_block_nums) > 0 and self.sliding_window is None:
|
| 288 |
+
# Prefix is not supported with sliding_window
|
| 289 |
+
computed_len = len(computed_block_nums) * self.block_size
|
| 290 |
+
prompt_tokens = prompt_tokens[computed_len:]
|
| 291 |
+
prefix_block_tables.append(computed_block_nums)
|
| 292 |
+
elif self.scheduler_config.chunked_prefill_enabled:
|
| 293 |
+
if seq_group_metadata.block_tables is not None:
|
| 294 |
+
# Prefill has chunked before.
|
| 295 |
+
block_table = seq_group_metadata.block_tables[seq_id]
|
| 296 |
+
prefix_block_tables.append(block_table)
|
| 297 |
+
else:
|
| 298 |
+
# The first prefill.
|
| 299 |
+
prefix_block_tables.append([])
|
| 300 |
+
else:
|
| 301 |
+
prefix_block_tables.append([])
|
| 302 |
+
# Right now, prefill start is always 0. However, this
|
| 303 |
+
# assumption can be changed once chunked prefill is introduced.
|
| 304 |
+
assert computed_len == 0
|
| 305 |
+
|
| 306 |
+
# actual prompt lens
|
| 307 |
+
context_lens.append(computed_len)
|
| 308 |
+
subquery_lens.append(prompt_len - computed_len)
|
| 309 |
+
|
| 310 |
+
input_tokens.extend(prompt_tokens)
|
| 311 |
+
# NOTE(woosuk): Here we assume that the first token in the prompt
|
| 312 |
+
# is always the first token in the sequence.
|
| 313 |
+
input_positions.extend(list(range(computed_len, prefill_end)))
|
| 314 |
+
lora_id = seq_group_metadata.lora_int_id
|
| 315 |
+
|
| 316 |
+
if lora_id > 0:
|
| 317 |
+
lora_requests.add(seq_group_metadata.lora_request)
|
| 318 |
+
|
| 319 |
+
lora_index_mapping += [lora_id] * (prompt_len - computed_len)
|
| 320 |
+
lora_prompt_mapping.extend(
|
| 321 |
+
[lora_id] *
|
| 322 |
+
(prompt_len - computed_len
|
| 323 |
+
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
|
| 324 |
+
|
| 325 |
+
if seq_group_metadata.multi_modal_data:
|
| 326 |
+
multi_modal_input_list.append(
|
| 327 |
+
seq_group_metadata.multi_modal_data.data)
|
| 328 |
+
|
| 329 |
+
if seq_group_metadata.block_tables is None:
|
| 330 |
+
# During memory profiling, the block tables are not initialized
|
| 331 |
+
# yet. In this case, we just use a dummy slot mapping.
|
| 332 |
+
slot_mapping.extend([_PAD_SLOT_ID] * prompt_len)
|
| 333 |
+
continue
|
| 334 |
+
|
| 335 |
+
# Compute the slot mapping.
|
| 336 |
+
block_table = seq_group_metadata.block_tables[seq_id]
|
| 337 |
+
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
|
| 338 |
+
# where start_idx is max(0, prompt_len - sliding_window).
|
| 339 |
+
# For example, if the prompt len is 10, sliding window is 8, and
|
| 340 |
+
# block size is 4, the first two tokens are masked and the slot
|
| 341 |
+
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
| 342 |
+
start_idx = 0
|
| 343 |
+
if self.sliding_window is not None:
|
| 344 |
+
assert computed_len == 0, (
|
| 345 |
+
"Prefix caching is currently not supported with "
|
| 346 |
+
"sliding window attention")
|
| 347 |
+
start_idx = max(0, prompt_len - self.sliding_window)
|
| 348 |
+
|
| 349 |
+
for i in range(computed_len, prefill_end):
|
| 350 |
+
if i < start_idx:
|
| 351 |
+
slot_mapping.append(_PAD_SLOT_ID)
|
| 352 |
+
continue
|
| 353 |
+
|
| 354 |
+
block_number = block_table[i // self.block_size]
|
| 355 |
+
block_offset = i % self.block_size
|
| 356 |
+
slot = block_number * self.block_size + block_offset
|
| 357 |
+
slot_mapping.append(slot)
|
| 358 |
+
|
| 359 |
+
max_subquery_len = max(subquery_lens)
|
| 360 |
+
max_prompt_len = max(prompt_lens)
|
| 361 |
+
assert max_subquery_len > 0
|
| 362 |
+
|
| 363 |
+
context_lens_tensor = torch.tensor(context_lens,
|
| 364 |
+
dtype=torch.int,
|
| 365 |
+
device=self.device)
|
| 366 |
+
|
| 367 |
+
if multi_modal_input_list:
|
| 368 |
+
assert self.vision_language_config, (
|
| 369 |
+
"Multi-modal inputs are only supported by "
|
| 370 |
+
"vision language models.")
|
| 371 |
+
multi_modal_input = torch.cat(multi_modal_input_list,
|
| 372 |
+
dim=0).to(self.device)
|
| 373 |
+
else:
|
| 374 |
+
multi_modal_input = None
|
| 375 |
+
|
| 376 |
+
# Prepare prefix block tables
|
| 377 |
+
max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
|
| 378 |
+
block_tables = make_tensor_with_pad(
|
| 379 |
+
prefix_block_tables,
|
| 380 |
+
max_len=max_prompt_block_table_len,
|
| 381 |
+
pad=0,
|
| 382 |
+
dtype=torch.int,
|
| 383 |
+
device=self.device,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
# Query length can be shorter than key (i.e., prompt) when prefill
|
| 387 |
+
# is chunked or prefix cached.
|
| 388 |
+
subquery_lens_tensor = torch.tensor(subquery_lens,
|
| 389 |
+
dtype=torch.long,
|
| 390 |
+
device=self.device)
|
| 391 |
+
subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1,
|
| 392 |
+
dtype=torch.int32,
|
| 393 |
+
device=self.device)
|
| 394 |
+
|
| 395 |
+
prompt_lens_tensor = torch.tensor(prompt_lens,
|
| 396 |
+
dtype=torch.long,
|
| 397 |
+
device=self.device)
|
| 398 |
+
seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1,
|
| 399 |
+
dtype=torch.int32,
|
| 400 |
+
device=self.device)
|
| 401 |
+
|
| 402 |
+
torch.cumsum(subquery_lens_tensor,
|
| 403 |
+
dim=0,
|
| 404 |
+
dtype=subquery_start_loc.dtype,
|
| 405 |
+
out=subquery_start_loc[1:])
|
| 406 |
+
|
| 407 |
+
torch.cumsum(prompt_lens_tensor,
|
| 408 |
+
dim=0,
|
| 409 |
+
dtype=seq_start_loc.dtype,
|
| 410 |
+
out=seq_start_loc[1:])
|
| 411 |
+
|
| 412 |
+
attn_metadata = self.attn_backend.make_metadata(
|
| 413 |
+
is_prompt=True,
|
| 414 |
+
prompt_lens=prompt_lens,
|
| 415 |
+
prompt_lens_tensor=prompt_lens_tensor,
|
| 416 |
+
max_subquery_len=max_subquery_len,
|
| 417 |
+
max_context_len=None,
|
| 418 |
+
max_prompt_len=max_prompt_len,
|
| 419 |
+
subquery_start_loc=subquery_start_loc,
|
| 420 |
+
seq_start_loc=seq_start_loc,
|
| 421 |
+
context_lens=context_lens_tensor,
|
| 422 |
+
block_tables=block_tables,
|
| 423 |
+
use_cuda_graph=False,
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
return PreparePromptMetadata(
|
| 427 |
+
input_tokens=input_tokens,
|
| 428 |
+
input_positions=input_positions,
|
| 429 |
+
attn_metadata=attn_metadata,
|
| 430 |
+
prompt_lens=prompt_lens,
|
| 431 |
+
subquery_lens=subquery_lens,
|
| 432 |
+
lora_index_mapping=lora_index_mapping,
|
| 433 |
+
lora_prompt_mapping=lora_prompt_mapping,
|
| 434 |
+
lora_requests=lora_requests,
|
| 435 |
+
multi_modal_input=multi_modal_input,
|
| 436 |
+
slot_mapping=slot_mapping,
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
def _prepare_decode(
|
| 440 |
+
self,
|
| 441 |
+
seq_group_metadata_list: List[SequenceGroupMetadata],
|
| 442 |
+
) -> PrepareDecodeMetadata:
|
| 443 |
+
input_tokens: List[int] = []
|
| 444 |
+
input_positions: List[int] = []
|
| 445 |
+
slot_mapping: List[int] = []
|
| 446 |
+
context_lens: List[int] = []
|
| 447 |
+
block_tables: List[List[int]] = []
|
| 448 |
+
lora_index_mapping: List[int] = []
|
| 449 |
+
lora_prompt_mapping: List[int] = []
|
| 450 |
+
lora_requests: Set[LoRARequest] = set()
|
| 451 |
+
|
| 452 |
+
if len(seq_group_metadata_list) == 0:
|
| 453 |
+
return PrepareDecodeMetadata.empty()
|
| 454 |
+
|
| 455 |
+
for seq_group_metadata in seq_group_metadata_list:
|
| 456 |
+
assert not seq_group_metadata.is_prompt
|
| 457 |
+
assert seq_group_metadata.token_chunk_size == 1
|
| 458 |
+
|
| 459 |
+
seq_ids = list(seq_group_metadata.seq_data.keys())
|
| 460 |
+
lora_id = seq_group_metadata.lora_int_id
|
| 461 |
+
|
| 462 |
+
if lora_id > 0:
|
| 463 |
+
lora_requests.add(seq_group_metadata.lora_request)
|
| 464 |
+
|
| 465 |
+
for seq_id in seq_ids:
|
| 466 |
+
seq_data = seq_group_metadata.seq_data[seq_id]
|
| 467 |
+
generation_token = seq_data.get_last_token_id()
|
| 468 |
+
input_tokens.append(generation_token)
|
| 469 |
+
|
| 470 |
+
seq_len = seq_data.get_len()
|
| 471 |
+
position = seq_len - 1
|
| 472 |
+
input_positions.append(position)
|
| 473 |
+
|
| 474 |
+
context_len = seq_len if self.sliding_window is None else min(
|
| 475 |
+
seq_len, self.sliding_window)
|
| 476 |
+
context_lens.append(context_len)
|
| 477 |
+
|
| 478 |
+
block_table = seq_group_metadata.block_tables[seq_id]
|
| 479 |
+
block_number = block_table[position // self.block_size]
|
| 480 |
+
block_offset = position % self.block_size
|
| 481 |
+
slot = block_number * self.block_size + block_offset
|
| 482 |
+
slot_mapping.append(slot)
|
| 483 |
+
lora_index_mapping.append(lora_id)
|
| 484 |
+
lora_prompt_mapping.append(lora_id)
|
| 485 |
+
|
| 486 |
+
if self.sliding_window is not None:
|
| 487 |
+
sliding_window_blocks = (self.sliding_window //
|
| 488 |
+
self.block_size)
|
| 489 |
+
block_table = block_table[-sliding_window_blocks:]
|
| 490 |
+
block_tables.append(block_table)
|
| 491 |
+
|
| 492 |
+
# vLLM uses cuda graph only for decoding requests.
|
| 493 |
+
# See `capture_model` API for more details.
|
| 494 |
+
# For decoding requests, batch_size == input_tokens.
|
| 495 |
+
batch_size = len(input_tokens)
|
| 496 |
+
max_context_len = max(context_lens)
|
| 497 |
+
use_captured_graph = (
|
| 498 |
+
not self.model_config.enforce_eager
|
| 499 |
+
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
|
| 500 |
+
and max_context_len <= self.max_context_len_to_capture)
|
| 501 |
+
if use_captured_graph:
|
| 502 |
+
graph_batch_size = _get_graph_batch_size(batch_size)
|
| 503 |
+
assert graph_batch_size >= batch_size
|
| 504 |
+
for _ in range(graph_batch_size - batch_size):
|
| 505 |
+
input_tokens.append(0)
|
| 506 |
+
input_positions.append(0)
|
| 507 |
+
slot_mapping.append(_PAD_SLOT_ID)
|
| 508 |
+
context_lens.append(1)
|
| 509 |
+
block_tables.append([])
|
| 510 |
+
lora_index_mapping.append(0)
|
| 511 |
+
batch_size = graph_batch_size
|
| 512 |
+
|
| 513 |
+
context_lens_tensor = torch.tensor(context_lens,
|
| 514 |
+
dtype=torch.int,
|
| 515 |
+
device=self.device)
|
| 516 |
+
|
| 517 |
+
if use_captured_graph:
|
| 518 |
+
# When using cuda-graph all these tensors should be
|
| 519 |
+
# padded.
|
| 520 |
+
assert context_lens_tensor.shape[0] == len(input_tokens)
|
| 521 |
+
assert context_lens_tensor.shape[0] == len(input_positions)
|
| 522 |
+
assert context_lens_tensor.shape[0] == len(slot_mapping)
|
| 523 |
+
|
| 524 |
+
# The shape of graph_block_tables is
|
| 525 |
+
# [max batch size, max context len // block size].
|
| 526 |
+
input_block_tables = self.graph_block_tables[:batch_size]
|
| 527 |
+
for i, block_table in enumerate(block_tables):
|
| 528 |
+
if block_table:
|
| 529 |
+
input_block_tables[i, :len(block_table)] = block_table
|
| 530 |
+
block_tables = torch.tensor(input_block_tables, device=self.device)
|
| 531 |
+
else:
|
| 532 |
+
max_block_table_len = max(
|
| 533 |
+
len(block_table) for block_table in block_tables)
|
| 534 |
+
block_tables = make_tensor_with_pad(
|
| 535 |
+
block_tables,
|
| 536 |
+
max_len=max_block_table_len,
|
| 537 |
+
pad=0,
|
| 538 |
+
dtype=torch.int,
|
| 539 |
+
device=self.device,
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
attn_metadata = self.attn_backend.make_metadata(
|
| 543 |
+
is_prompt=False,
|
| 544 |
+
prompt_lens=None,
|
| 545 |
+
prompt_lens_tensor=None,
|
| 546 |
+
max_subquery_len=None,
|
| 547 |
+
max_context_len=max_context_len,
|
| 548 |
+
max_prompt_len=None,
|
| 549 |
+
subquery_start_loc=None,
|
| 550 |
+
seq_start_loc=None,
|
| 551 |
+
context_lens=context_lens_tensor,
|
| 552 |
+
block_tables=block_tables,
|
| 553 |
+
use_cuda_graph=use_captured_graph,
|
| 554 |
+
)
|
| 555 |
+
return PrepareDecodeMetadata(
|
| 556 |
+
input_tokens=input_tokens,
|
| 557 |
+
input_positions=input_positions,
|
| 558 |
+
attn_metadata=attn_metadata,
|
| 559 |
+
lora_index_mapping=lora_index_mapping,
|
| 560 |
+
lora_prompt_mapping=lora_prompt_mapping,
|
| 561 |
+
lora_requests=lora_requests,
|
| 562 |
+
slot_mapping=slot_mapping,
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
def _prepare_sample(
|
| 566 |
+
self,
|
| 567 |
+
seq_group_metadata_list: List[SequenceGroupMetadata],
|
| 568 |
+
prompt_lens: List[int],
|
| 569 |
+
subquery_lens: Optional[List[int]],
|
| 570 |
+
) -> SamplingMetadata:
|
| 571 |
+
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
| 572 |
+
selected_token_indices: List[int] = []
|
| 573 |
+
generators: List[torch.Generator] = []
|
| 574 |
+
selected_token_start_idx = 0
|
| 575 |
+
categorized_sample_indices: Dict[SamplingType,
|
| 576 |
+
List[Tuple[int, int]]] = {
|
| 577 |
+
t: []
|
| 578 |
+
for t in SamplingType
|
| 579 |
+
}
|
| 580 |
+
categorized_sample_indices_start_idx = 0
|
| 581 |
+
categorized_sampled_token_indices_start_idx = 0
|
| 582 |
+
|
| 583 |
+
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
| 584 |
+
seq_ids = list(seq_group_metadata.seq_data.keys())
|
| 585 |
+
sampling_params = seq_group_metadata.sampling_params
|
| 586 |
+
seq_groups.append((seq_ids, sampling_params))
|
| 587 |
+
|
| 588 |
+
if seq_group_metadata.is_prompt:
|
| 589 |
+
assert len(seq_ids) == 1
|
| 590 |
+
assert subquery_lens is not None
|
| 591 |
+
subquery_len = subquery_lens[i]
|
| 592 |
+
if sampling_params.prompt_logprobs is not None:
|
| 593 |
+
# NOTE: prompt token positions do not need sample, skip
|
| 594 |
+
categorized_sample_indices_start_idx += subquery_len - 1
|
| 595 |
+
|
| 596 |
+
categorized_sample_indices[
|
| 597 |
+
sampling_params.sampling_type].append(
|
| 598 |
+
(categorized_sample_indices_start_idx,
|
| 599 |
+
categorized_sampled_token_indices_start_idx))
|
| 600 |
+
categorized_sample_indices_start_idx += 1
|
| 601 |
+
categorized_sampled_token_indices_start_idx += 1
|
| 602 |
+
|
| 603 |
+
if sampling_params.prompt_logprobs is not None:
|
| 604 |
+
selected_token_indices.extend(
|
| 605 |
+
range(selected_token_start_idx,
|
| 606 |
+
selected_token_start_idx + subquery_len - 1))
|
| 607 |
+
selected_token_indices.append(selected_token_start_idx +
|
| 608 |
+
subquery_len - 1)
|
| 609 |
+
selected_token_start_idx += subquery_len
|
| 610 |
+
|
| 611 |
+
if sampling_params.seed is not None:
|
| 612 |
+
seq_group_metadata.state.generator = torch.Generator(
|
| 613 |
+
device=self.device).manual_seed(sampling_params.seed)
|
| 614 |
+
else:
|
| 615 |
+
num_seqs = len(seq_ids)
|
| 616 |
+
selected_token_indices.extend(
|
| 617 |
+
range(selected_token_start_idx,
|
| 618 |
+
selected_token_start_idx + num_seqs))
|
| 619 |
+
selected_token_start_idx += num_seqs
|
| 620 |
+
|
| 621 |
+
categorized_sample_indices[
|
| 622 |
+
sampling_params.sampling_type].extend(
|
| 623 |
+
list(
|
| 624 |
+
zip(
|
| 625 |
+
range(
|
| 626 |
+
categorized_sample_indices_start_idx,
|
| 627 |
+
categorized_sample_indices_start_idx +
|
| 628 |
+
num_seqs),
|
| 629 |
+
range(
|
| 630 |
+
categorized_sampled_token_indices_start_idx,
|
| 631 |
+
categorized_sampled_token_indices_start_idx
|
| 632 |
+
+ num_seqs))))
|
| 633 |
+
categorized_sample_indices_start_idx += num_seqs
|
| 634 |
+
categorized_sampled_token_indices_start_idx += num_seqs
|
| 635 |
+
|
| 636 |
+
if sampling_params.seed is not None:
|
| 637 |
+
generators.append(seq_group_metadata.state.generator)
|
| 638 |
+
|
| 639 |
+
selected_token_indices = async_tensor_h2d(selected_token_indices,
|
| 640 |
+
dtype=torch.long,
|
| 641 |
+
target_device=self.device,
|
| 642 |
+
pin_memory=self.pin_memory)
|
| 643 |
+
|
| 644 |
+
categorized_sample_indices = {
|
| 645 |
+
t: maybe_expand_dim(
|
| 646 |
+
async_tensor_h2d(seq_ids,
|
| 647 |
+
dtype=torch.int,
|
| 648 |
+
target_device=self.device,
|
| 649 |
+
pin_memory=self.pin_memory), 2, 2)
|
| 650 |
+
for t, seq_ids in categorized_sample_indices.items()
|
| 651 |
+
}
|
| 652 |
+
|
| 653 |
+
seq_data: Dict[int, SequenceData] = {}
|
| 654 |
+
for seq_group_metadata in seq_group_metadata_list:
|
| 655 |
+
seq_data.update(seq_group_metadata.seq_data)
|
| 656 |
+
|
| 657 |
+
sampling_metadata = SamplingMetadata(
|
| 658 |
+
seq_groups=seq_groups,
|
| 659 |
+
seq_data=seq_data,
|
| 660 |
+
prompt_lens=prompt_lens,
|
| 661 |
+
selected_token_indices=selected_token_indices,
|
| 662 |
+
categorized_sample_indices=categorized_sample_indices,
|
| 663 |
+
generators=generators,
|
| 664 |
+
)
|
| 665 |
+
return sampling_metadata
|
| 666 |
+
|
| 667 |
+
def prepare_input_tensors(
|
| 668 |
+
self,
|
| 669 |
+
seq_group_metadata_list: List[SequenceGroupMetadata],
|
| 670 |
+
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
|
| 671 |
+
Set[LoRARequest], LoRAMapping, torch.Tensor]:
|
| 672 |
+
if self.is_driver_worker:
|
| 673 |
+
prefill_reqs = []
|
| 674 |
+
decode_reqs = []
|
| 675 |
+
for seq_group_meta in seq_group_metadata_list:
|
| 676 |
+
if seq_group_meta.is_prompt:
|
| 677 |
+
prefill_reqs.append(seq_group_meta)
|
| 678 |
+
else:
|
| 679 |
+
decode_reqs.append(seq_group_meta)
|
| 680 |
+
|
| 681 |
+
# Prepare input tensors.
|
| 682 |
+
(
|
| 683 |
+
input_tokens,
|
| 684 |
+
input_positions,
|
| 685 |
+
prefill_attn_metadata,
|
| 686 |
+
prompt_lens,
|
| 687 |
+
subquery_lens,
|
| 688 |
+
lora_index_mapping,
|
| 689 |
+
lora_prompt_mapping,
|
| 690 |
+
lora_requests,
|
| 691 |
+
multi_modal_input,
|
| 692 |
+
slot_mapping,
|
| 693 |
+
) = self._prepare_prompt(prefill_reqs)
|
| 694 |
+
(
|
| 695 |
+
decode_input_tokens,
|
| 696 |
+
decode_input_positions,
|
| 697 |
+
decode_attn_metadata,
|
| 698 |
+
decode_lora_index_mapping,
|
| 699 |
+
decode_lora_prompt_mapping,
|
| 700 |
+
decode_lora_requests,
|
| 701 |
+
decode_slot_mapping,
|
| 702 |
+
) = self._prepare_decode(decode_reqs)
|
| 703 |
+
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
|
| 704 |
+
prompt_lens,
|
| 705 |
+
subquery_lens)
|
| 706 |
+
|
| 707 |
+
if not self.scheduler_config.chunked_prefill_enabled:
|
| 708 |
+
assert (len(prefill_reqs) and len(decode_reqs)) == 0
|
| 709 |
+
|
| 710 |
+
num_prefills = len(prompt_lens)
|
| 711 |
+
num_prefill_tokens = len(input_tokens)
|
| 712 |
+
num_decode_tokens = len(decode_input_tokens)
|
| 713 |
+
|
| 714 |
+
# Coalesce tensors. Note that attn_metadata is currently not
|
| 715 |
+
# coalesced for simplicity.
|
| 716 |
+
input_tokens.extend(decode_input_tokens)
|
| 717 |
+
input_positions.extend(decode_input_positions)
|
| 718 |
+
slot_mapping.extend(decode_slot_mapping)
|
| 719 |
+
lora_index_mapping.extend(decode_lora_index_mapping)
|
| 720 |
+
lora_prompt_mapping.extend(decode_lora_prompt_mapping)
|
| 721 |
+
lora_requests.update(decode_lora_requests)
|
| 722 |
+
|
| 723 |
+
input_tokens = torch.tensor(input_tokens,
|
| 724 |
+
dtype=torch.long,
|
| 725 |
+
device=self.device)
|
| 726 |
+
input_positions = torch.tensor(input_positions,
|
| 727 |
+
dtype=torch.long,
|
| 728 |
+
device=self.device)
|
| 729 |
+
slot_mapping = torch.tensor(slot_mapping,
|
| 730 |
+
dtype=torch.long,
|
| 731 |
+
device=self.device)
|
| 732 |
+
|
| 733 |
+
if self.lora_config:
|
| 734 |
+
lora_mapping = LoRAMapping(
|
| 735 |
+
lora_index_mapping,
|
| 736 |
+
lora_prompt_mapping,
|
| 737 |
+
)
|
| 738 |
+
else:
|
| 739 |
+
lora_mapping = None
|
| 740 |
+
|
| 741 |
+
# Broadcast the metadata.
|
| 742 |
+
# If batch contains both prefill and decode, it sends 2 broadcasts.
|
| 743 |
+
# If it only contains 1 type, it triggers a single broadcast.
|
| 744 |
+
if (prefill_attn_metadata is not None
|
| 745 |
+
and decode_attn_metadata is not None):
|
| 746 |
+
batch_type = BatchType.MIXED
|
| 747 |
+
elif prefill_attn_metadata is not None:
|
| 748 |
+
batch_type = BatchType.PREFILL
|
| 749 |
+
else:
|
| 750 |
+
batch_type = BatchType.DECODE
|
| 751 |
+
|
| 752 |
+
metadata_dict = {
|
| 753 |
+
"input_tokens": input_tokens,
|
| 754 |
+
"input_positions": input_positions,
|
| 755 |
+
"selected_token_indices":
|
| 756 |
+
sampling_metadata.selected_token_indices,
|
| 757 |
+
"lora_requests": lora_requests,
|
| 758 |
+
"lora_mapping": lora_mapping,
|
| 759 |
+
"multi_modal_input": multi_modal_input,
|
| 760 |
+
"num_prefill_tokens": num_prefill_tokens,
|
| 761 |
+
"num_decode_tokens": num_decode_tokens,
|
| 762 |
+
"slot_mapping": slot_mapping,
|
| 763 |
+
"num_prefills": num_prefills,
|
| 764 |
+
"batch_type": batch_type,
|
| 765 |
+
}
|
| 766 |
+
if prefill_attn_metadata is not None:
|
| 767 |
+
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
|
| 768 |
+
else:
|
| 769 |
+
assert decode_attn_metadata is not None
|
| 770 |
+
metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
|
| 771 |
+
broadcast_tensor_dict(metadata_dict, src=0)
|
| 772 |
+
|
| 773 |
+
# Broadcast decode attn metadata for mixed batch type.
|
| 774 |
+
# The additional broadcast costs 300us overhead on 4 A10 GPUs.
|
| 775 |
+
# We can potentially reduce the overhead by coelescing tensors.
|
| 776 |
+
if batch_type == BatchType.MIXED:
|
| 777 |
+
assert decode_attn_metadata is not None
|
| 778 |
+
metadata_dict = decode_attn_metadata.asdict_zerocopy()
|
| 779 |
+
broadcast_tensor_dict(metadata_dict, src=0)
|
| 780 |
+
else:
|
| 781 |
+
metadata_dict = broadcast_tensor_dict(src=0)
|
| 782 |
+
input_tokens = metadata_dict.pop("input_tokens")
|
| 783 |
+
input_positions = metadata_dict.pop("input_positions")
|
| 784 |
+
slot_mapping = metadata_dict.pop("slot_mapping")
|
| 785 |
+
num_prefills = metadata_dict.pop("num_prefills")
|
| 786 |
+
selected_token_indices = metadata_dict.pop(
|
| 787 |
+
"selected_token_indices")
|
| 788 |
+
lora_mapping = metadata_dict.pop("lora_mapping")
|
| 789 |
+
lora_requests = metadata_dict.pop("lora_requests")
|
| 790 |
+
multi_modal_input = metadata_dict.pop("multi_modal_input")
|
| 791 |
+
num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
|
| 792 |
+
num_decode_tokens = metadata_dict.pop("num_decode_tokens")
|
| 793 |
+
batch_type = metadata_dict.pop("batch_type")
|
| 794 |
+
|
| 795 |
+
# Create an attention metadata.
|
| 796 |
+
prefill_attn_metadata = None
|
| 797 |
+
decode_attn_metadata = None
|
| 798 |
+
if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
|
| 799 |
+
prefill_attn_metadata = self.attn_backend.make_metadata(
|
| 800 |
+
**metadata_dict)
|
| 801 |
+
else:
|
| 802 |
+
decode_attn_metadata = self.attn_backend.make_metadata(
|
| 803 |
+
**metadata_dict)
|
| 804 |
+
sampling_metadata = SamplingMetadata(
|
| 805 |
+
seq_groups=None,
|
| 806 |
+
seq_data=None,
|
| 807 |
+
prompt_lens=None,
|
| 808 |
+
selected_token_indices=selected_token_indices,
|
| 809 |
+
categorized_sample_indices=None,
|
| 810 |
+
generators=None,
|
| 811 |
+
perform_sampling=False,
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
# if it is a mixed batch, decode attn_metadata is broadcasted
|
| 815 |
+
# separately.
|
| 816 |
+
if batch_type == BatchType.MIXED:
|
| 817 |
+
metadata_dict = broadcast_tensor_dict(src=0)
|
| 818 |
+
decode_attn_metadata = self.attn_backend.make_metadata(
|
| 819 |
+
**metadata_dict)
|
| 820 |
+
|
| 821 |
+
attn_metadata = AttentionMetadata(
|
| 822 |
+
num_prefills=num_prefills,
|
| 823 |
+
slot_mapping=slot_mapping,
|
| 824 |
+
num_prefill_tokens=num_prefill_tokens,
|
| 825 |
+
num_decode_tokens=num_decode_tokens,
|
| 826 |
+
prefill_metadata=prefill_attn_metadata,
|
| 827 |
+
decode_metadata=decode_attn_metadata,
|
| 828 |
+
kv_cache_dtype=self.kv_cache_dtype,
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
return (input_tokens, input_positions, attn_metadata,
|
| 832 |
+
sampling_metadata, lora_requests, lora_mapping,
|
| 833 |
+
multi_modal_input)
|
| 834 |
+
|
| 835 |
+
@torch.inference_mode()
|
| 836 |
+
def execute_model(
|
| 837 |
+
self,
|
| 838 |
+
seq_group_metadata_list: List[SequenceGroupMetadata],
|
| 839 |
+
kv_caches: List[torch.Tensor],
|
| 840 |
+
) -> Optional[SamplerOutput]:
|
| 841 |
+
(input_tokens, input_positions, attn_metadata, sampling_metadata,
|
| 842 |
+
lora_requests, lora_mapping, multi_modal_input
|
| 843 |
+
) = self.prepare_input_tensors(seq_group_metadata_list)
|
| 844 |
+
if self.lora_config:
|
| 845 |
+
self.set_active_loras(lora_requests, lora_mapping)
|
| 846 |
+
|
| 847 |
+
# Currently cuda graph is only supported by the decode phase.
|
| 848 |
+
prefill_meta = attn_metadata.prefill_metadata
|
| 849 |
+
decode_meta = attn_metadata.decode_metadata
|
| 850 |
+
if prefill_meta is None and decode_meta.use_cuda_graph:
|
| 851 |
+
graph_batch_size = input_tokens.shape[0]
|
| 852 |
+
model_executable = self.graph_runners[graph_batch_size]
|
| 853 |
+
else:
|
| 854 |
+
model_executable = self.model
|
| 855 |
+
execute_model_kwargs = {
|
| 856 |
+
"input_ids": input_tokens,
|
| 857 |
+
"positions": input_positions,
|
| 858 |
+
"kv_caches": kv_caches,
|
| 859 |
+
"attn_metadata": attn_metadata,
|
| 860 |
+
}
|
| 861 |
+
if self.vision_language_config:
|
| 862 |
+
execute_model_kwargs.update({"image_input": multi_modal_input})
|
| 863 |
+
hidden_states = model_executable(**execute_model_kwargs)
|
| 864 |
+
|
| 865 |
+
# Compute the logits.
|
| 866 |
+
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
| 867 |
+
|
| 868 |
+
# Only perform sampling in the driver worker.
|
| 869 |
+
if not sampling_metadata.perform_sampling:
|
| 870 |
+
return None
|
| 871 |
+
|
| 872 |
+
# Sample the next token.
|
| 873 |
+
output = self.model.sample(
|
| 874 |
+
logits=logits,
|
| 875 |
+
sampling_metadata=sampling_metadata,
|
| 876 |
+
)
|
| 877 |
+
return output
|
| 878 |
+
|
| 879 |
+
@torch.inference_mode()
|
| 880 |
+
def profile_run(self) -> None:
|
| 881 |
+
# Enable top-k sampling to reflect the accurate memory usage.
|
| 882 |
+
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
| 883 |
+
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
| 884 |
+
max_num_seqs = self.scheduler_config.max_num_seqs
|
| 885 |
+
|
| 886 |
+
# This represents the maximum number of different requests
|
| 887 |
+
# that will have unique loras, an therefore the max amount of memory
|
| 888 |
+
# consumption create dummy lora request copies from the lora request
|
| 889 |
+
# passed in, which contains a lora from the lora warmup path.
|
| 890 |
+
dummy_lora_requests = []
|
| 891 |
+
dummy_lora_requests_per_seq = []
|
| 892 |
+
if self.lora_config:
|
| 893 |
+
for idx in range(self.lora_config.max_loras):
|
| 894 |
+
lora_id = idx + 1
|
| 895 |
+
dummy_lora_request = LoRARequest(
|
| 896 |
+
lora_name=f"warmup_{lora_id}",
|
| 897 |
+
lora_int_id=lora_id,
|
| 898 |
+
lora_local_path="/not/a/real/path",
|
| 899 |
+
)
|
| 900 |
+
self.lora_manager.add_dummy_lora(dummy_lora_request,
|
| 901 |
+
rank=LORA_WARMUP_RANK)
|
| 902 |
+
dummy_lora_requests.append(dummy_lora_request)
|
| 903 |
+
dummy_lora_requests_per_seq = [
|
| 904 |
+
dummy_lora_requests[idx % len(dummy_lora_requests)]
|
| 905 |
+
for idx in range(max_num_seqs)
|
| 906 |
+
]
|
| 907 |
+
|
| 908 |
+
# Profile memory usage with max_num_sequences sequences and the total
|
| 909 |
+
# number of tokens equal to max_num_batched_tokens.
|
| 910 |
+
seqs: List[SequenceGroupMetadata] = []
|
| 911 |
+
# Additional GPU memory may be needed for vision encoding, which needs
|
| 912 |
+
# to be accounted for when calculating the GPU blocks for
|
| 913 |
+
# vLLM blocker manager.
|
| 914 |
+
# To exercise the worst scenario for GPU memory consumption,
|
| 915 |
+
# the number of seqs (batch_size) is chosen to maximize the number
|
| 916 |
+
# of images processed.
|
| 917 |
+
if self.vision_language_config:
|
| 918 |
+
max_num_seqs = min(
|
| 919 |
+
max_num_seqs,
|
| 920 |
+
int(max_num_batched_tokens /
|
| 921 |
+
self.vision_language_config.image_feature_size))
|
| 922 |
+
for group_id in range(max_num_seqs):
|
| 923 |
+
seq_len = (max_num_batched_tokens // max_num_seqs +
|
| 924 |
+
(group_id < max_num_batched_tokens % max_num_seqs))
|
| 925 |
+
seq_data, fake_multi_modal_input = _prepare_fake_inputs(
|
| 926 |
+
seq_len, self.vision_language_config)
|
| 927 |
+
seq = SequenceGroupMetadata(
|
| 928 |
+
request_id=str(group_id),
|
| 929 |
+
is_prompt=True,
|
| 930 |
+
seq_data={group_id: seq_data},
|
| 931 |
+
sampling_params=sampling_params,
|
| 932 |
+
block_tables=None,
|
| 933 |
+
lora_request=dummy_lora_requests_per_seq[group_id]
|
| 934 |
+
if dummy_lora_requests_per_seq else None,
|
| 935 |
+
multi_modal_data=fake_multi_modal_input,
|
| 936 |
+
)
|
| 937 |
+
seqs.append(seq)
|
| 938 |
+
|
| 939 |
+
# Run the model with the dummy inputs.
|
| 940 |
+
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
| 941 |
+
kv_caches = [None] * num_layers
|
| 942 |
+
self.execute_model(seqs, kv_caches)
|
| 943 |
+
torch.cuda.synchronize()
|
| 944 |
+
return
|
| 945 |
+
|
| 946 |
+
def remove_all_loras(self) -> bool:
|
| 947 |
+
if not self.lora_manager:
|
| 948 |
+
raise RuntimeError("LoRA is not enabled.")
|
| 949 |
+
return self.lora_manager.remove_all_loras()
|
| 950 |
+
|
| 951 |
+
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
| 952 |
+
lora_mapping: LoRAMapping) -> None:
|
| 953 |
+
if not self.lora_manager:
|
| 954 |
+
raise RuntimeError("LoRA is not enabled.")
|
| 955 |
+
self.lora_manager.set_active_loras(lora_requests, lora_mapping)
|
| 956 |
+
|
| 957 |
+
def add_lora(self, lora_request: LoRARequest) -> bool:
|
| 958 |
+
if not self.lora_manager:
|
| 959 |
+
raise RuntimeError("LoRA is not enabled.")
|
| 960 |
+
return self.lora_manager.add_lora(lora_request)
|
| 961 |
+
|
| 962 |
+
def remove_lora(self, lora_id: int) -> bool:
|
| 963 |
+
if not self.lora_manager:
|
| 964 |
+
raise RuntimeError("LoRA is not enabled.")
|
| 965 |
+
return self.lora_manager.remove_lora(lora_id)
|
| 966 |
+
|
| 967 |
+
def list_loras(self) -> Set[int]:
|
| 968 |
+
if not self.lora_manager:
|
| 969 |
+
raise RuntimeError("LoRA is not enabled.")
|
| 970 |
+
return self.lora_manager.list_loras()
|
| 971 |
+
|
| 972 |
+
@torch.inference_mode()
|
| 973 |
+
def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
|
| 974 |
+
"""Cuda graph capture a model.
|
| 975 |
+
|
| 976 |
+
Note that CUDA graph's performance gain is negligible if number
|
| 977 |
+
of batched tokens are larger than 200. And since CUDA graph
|
| 978 |
+
requires fixed sized tensors, supporting large/variable batch
|
| 979 |
+
size requires high GPU memory overhead. Thus, vLLM only captures
|
| 980 |
+
decoding requests. Mixed batch (chunked prefill + decoding) or
|
| 981 |
+
prefill requests are not captured.
|
| 982 |
+
|
| 983 |
+
Since it is used for decoding-only, it assumes there's only 1 token
|
| 984 |
+
per sequence in the batch.
|
| 985 |
+
"""
|
| 986 |
+
# NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
|
| 987 |
+
# deleted before the CUDA graphs.
|
| 988 |
+
self.pynccl_backend = pynccl_utils.get_nccl_backend()
|
| 989 |
+
|
| 990 |
+
assert not self.model_config.enforce_eager
|
| 991 |
+
logger.info("Capturing the model for CUDA graphs. This may lead to "
|
| 992 |
+
"unexpected consequences if the model is not static. To "
|
| 993 |
+
"run the model in eager mode, set 'enforce_eager=True' or "
|
| 994 |
+
"use '--enforce-eager' in the CLI.")
|
| 995 |
+
logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
|
| 996 |
+
"If you are running out of memory, consider decreasing "
|
| 997 |
+
"`gpu_memory_utilization` or enforcing eager mode. "
|
| 998 |
+
"You can also reduce the `max_num_seqs` as needed "
|
| 999 |
+
"to decrease memory usage.")
|
| 1000 |
+
start_time = time.perf_counter()
|
| 1001 |
+
|
| 1002 |
+
# Prepare dummy inputs. These will be reused for all batch sizes.
|
| 1003 |
+
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
|
| 1004 |
+
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
| 1005 |
+
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
| 1006 |
+
slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
|
| 1007 |
+
slot_mapping.fill_(_PAD_SLOT_ID)
|
| 1008 |
+
context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
|
| 1009 |
+
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
| 1010 |
+
|
| 1011 |
+
graph_batch_size = _get_graph_batch_size(
|
| 1012 |
+
self.scheduler_config.max_num_seqs)
|
| 1013 |
+
batch_size_capture_list = [
|
| 1014 |
+
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
|
| 1015 |
+
]
|
| 1016 |
+
|
| 1017 |
+
# NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
|
| 1018 |
+
# kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use
|
| 1019 |
+
# either custom all-reduce kernel or pynccl. When not using CUDA
|
| 1020 |
+
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
|
| 1021 |
+
# We always prioritize using custom all-reduce kernel but fall back
|
| 1022 |
+
# to PyTorch or pynccl if it is disabled or not supported.
|
| 1023 |
+
with custom_all_reduce.capture():
|
| 1024 |
+
# NOTE: Capturing the largest batch size first may help reduce the
|
| 1025 |
+
# memory usage of CUDA graph.
|
| 1026 |
+
for batch_size in reversed(batch_size_capture_list):
|
| 1027 |
+
# Create dummy attn_metadata.
|
| 1028 |
+
decode_metadata = self.attn_backend.make_metadata(
|
| 1029 |
+
is_prompt=False,
|
| 1030 |
+
prompt_lens=None,
|
| 1031 |
+
prompt_lens_tensor=None,
|
| 1032 |
+
max_subquery_len=None,
|
| 1033 |
+
max_context_len=self.max_context_len_to_capture,
|
| 1034 |
+
max_prompt_len=None,
|
| 1035 |
+
subquery_start_loc=None,
|
| 1036 |
+
seq_start_loc=None,
|
| 1037 |
+
context_lens=context_lens[:batch_size],
|
| 1038 |
+
block_tables=block_tables[:batch_size],
|
| 1039 |
+
use_cuda_graph=True,
|
| 1040 |
+
)
|
| 1041 |
+
attn_metadata = AttentionMetadata(
|
| 1042 |
+
num_prefills=0,
|
| 1043 |
+
num_prefill_tokens=0,
|
| 1044 |
+
num_decode_tokens=batch_size,
|
| 1045 |
+
slot_mapping=slot_mapping[:batch_size],
|
| 1046 |
+
prefill_metadata=None,
|
| 1047 |
+
decode_metadata=decode_metadata,
|
| 1048 |
+
kv_cache_dtype=self.kv_cache_dtype,
|
| 1049 |
+
)
|
| 1050 |
+
|
| 1051 |
+
if self.lora_config:
|
| 1052 |
+
lora_mapping = LoRAMapping(
|
| 1053 |
+
[0] * batch_size,
|
| 1054 |
+
[0] * batch_size,
|
| 1055 |
+
)
|
| 1056 |
+
self.set_active_loras(set(), lora_mapping)
|
| 1057 |
+
|
| 1058 |
+
graph_runner = CUDAGraphRunner(self.model)
|
| 1059 |
+
graph_runner.capture(
|
| 1060 |
+
input_tokens[:batch_size],
|
| 1061 |
+
input_positions[:batch_size],
|
| 1062 |
+
kv_caches,
|
| 1063 |
+
attn_metadata,
|
| 1064 |
+
memory_pool=self.graph_memory_pool,
|
| 1065 |
+
)
|
| 1066 |
+
self.graph_memory_pool = graph_runner.graph.pool()
|
| 1067 |
+
self.graph_runners[batch_size] = graph_runner
|
| 1068 |
+
|
| 1069 |
+
end_time = time.perf_counter()
|
| 1070 |
+
elapsed_time = end_time - start_time
|
| 1071 |
+
# This usually takes < 10 seconds.
|
| 1072 |
+
logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
|
| 1073 |
+
|
| 1074 |
+
def __del__(self) -> None:
|
| 1075 |
+
# Delete the CUDA graphs before deleting the pynccl communicator.
|
| 1076 |
+
# NOTE(woosuk): This is necessary because otherwise deadlocks can
|
| 1077 |
+
# happen.
|
| 1078 |
+
# FIXME(woosuk): This is a bit hacky. Find a more robust solution.
|
| 1079 |
+
# TODO(youkaichao): when we get enough user feedback that pynccl is
|
| 1080 |
+
# more stable than cupy, we can remove this, e.g. in v0.4.1.
|
| 1081 |
+
self.graph_runners.clear()
|
| 1082 |
+
self.pynccl_backend = None
|
| 1083 |
+
|
| 1084 |
+
@property
|
| 1085 |
+
def vocab_size(self) -> int:
|
| 1086 |
+
return self.model_config.get_vocab_size()
|
| 1087 |
+
|
| 1088 |
+
|
| 1089 |
+
class CUDAGraphRunner:
|
| 1090 |
+
|
| 1091 |
+
def __init__(self, model: nn.Module):
|
| 1092 |
+
self.model = model
|
| 1093 |
+
self.input_buffers: Dict[str, torch.Tensor] = {}
|
| 1094 |
+
self.output_buffers: Dict[str, torch.Tensor] = {}
|
| 1095 |
+
|
| 1096 |
+
self._graph: Optional[torch.cuda.CUDAGraph] = None
|
| 1097 |
+
|
| 1098 |
+
@property
|
| 1099 |
+
def graph(self):
|
| 1100 |
+
assert self._graph is not None
|
| 1101 |
+
return self._graph
|
| 1102 |
+
|
| 1103 |
+
def capture(
|
| 1104 |
+
self,
|
| 1105 |
+
input_ids: torch.Tensor,
|
| 1106 |
+
positions: torch.Tensor,
|
| 1107 |
+
kv_caches: List[torch.Tensor],
|
| 1108 |
+
attn_metadata: AttentionMetadata,
|
| 1109 |
+
memory_pool,
|
| 1110 |
+
**kwargs,
|
| 1111 |
+
) -> None:
|
| 1112 |
+
assert self._graph is None
|
| 1113 |
+
# Run the model once without capturing the graph.
|
| 1114 |
+
# This is to make sure that the captured graph does not include the
|
| 1115 |
+
# kernel launches for initial benchmarking (e.g., Triton autotune).
|
| 1116 |
+
with _maybe_pynccl():
|
| 1117 |
+
self.model(
|
| 1118 |
+
input_ids,
|
| 1119 |
+
positions,
|
| 1120 |
+
kv_caches,
|
| 1121 |
+
attn_metadata,
|
| 1122 |
+
**kwargs,
|
| 1123 |
+
)
|
| 1124 |
+
torch.cuda.synchronize()
|
| 1125 |
+
|
| 1126 |
+
# Capture the graph.
|
| 1127 |
+
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
|
| 1128 |
+
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
|
| 1129 |
+
self._graph = torch.cuda.CUDAGraph()
|
| 1130 |
+
with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117
|
| 1131 |
+
with _maybe_pynccl():
|
| 1132 |
+
hidden_states = self.model(
|
| 1133 |
+
input_ids,
|
| 1134 |
+
positions,
|
| 1135 |
+
kv_caches,
|
| 1136 |
+
attn_metadata,
|
| 1137 |
+
**kwargs,
|
| 1138 |
+
)
|
| 1139 |
+
torch.cuda.synchronize()
|
| 1140 |
+
|
| 1141 |
+
# Save the input and output buffers.
|
| 1142 |
+
self.input_buffers = {
|
| 1143 |
+
"input_ids": input_ids,
|
| 1144 |
+
"positions": positions,
|
| 1145 |
+
"kv_caches": kv_caches,
|
| 1146 |
+
"slot_mapping": attn_metadata.slot_mapping,
|
| 1147 |
+
"context_lens": attn_metadata.decode_metadata.context_lens,
|
| 1148 |
+
"block_tables": attn_metadata.decode_metadata.block_tables,
|
| 1149 |
+
}
|
| 1150 |
+
self.output_buffers = {"hidden_states": hidden_states}
|
| 1151 |
+
return
|
| 1152 |
+
|
| 1153 |
+
def forward(
|
| 1154 |
+
self,
|
| 1155 |
+
input_ids: torch.Tensor,
|
| 1156 |
+
positions: torch.Tensor,
|
| 1157 |
+
kv_caches: List[torch.Tensor],
|
| 1158 |
+
attn_metadata: AttentionMetadata,
|
| 1159 |
+
**kwargs,
|
| 1160 |
+
) -> torch.Tensor:
|
| 1161 |
+
# KV caches are fixed tensors, so we don't need to copy them.
|
| 1162 |
+
del kv_caches
|
| 1163 |
+
|
| 1164 |
+
# Copy the input tensors to the input buffers.
|
| 1165 |
+
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
|
| 1166 |
+
self.input_buffers["positions"].copy_(positions, non_blocking=True)
|
| 1167 |
+
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
|
| 1168 |
+
non_blocking=True)
|
| 1169 |
+
self.input_buffers["context_lens"].copy_(
|
| 1170 |
+
attn_metadata.decode_metadata.context_lens, non_blocking=True)
|
| 1171 |
+
self.input_buffers["block_tables"].copy_(
|
| 1172 |
+
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
| 1173 |
+
# Run the graph.
|
| 1174 |
+
self.graph.replay()
|
| 1175 |
+
|
| 1176 |
+
# Return the output tensor.
|
| 1177 |
+
return self.output_buffers["hidden_states"]
|
| 1178 |
+
|
| 1179 |
+
def __call__(self, *args, **kwargs):
|
| 1180 |
+
return self.forward(*args, **kwargs)
|
| 1181 |
+
|
| 1182 |
+
|
| 1183 |
+
@contextlib.contextmanager
|
| 1184 |
+
def _maybe_pynccl():
|
| 1185 |
+
if pynccl_utils.is_initialized(
|
| 1186 |
+
) and not custom_all_reduce.is_initialized():
|
| 1187 |
+
with with_pynccl_for_all_reduce():
|
| 1188 |
+
yield
|
| 1189 |
+
else:
|
| 1190 |
+
yield
|
| 1191 |
+
|
| 1192 |
+
|
| 1193 |
+
def _get_graph_batch_size(batch_size: int) -> int:
|
| 1194 |
+
"""Returns the padded batch size given actual batch size.
|
| 1195 |
+
|
| 1196 |
+
Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
|
| 1197 |
+
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
|
| 1198 |
+
"""
|
| 1199 |
+
if batch_size <= 2:
|
| 1200 |
+
return batch_size
|
| 1201 |
+
elif batch_size <= 4:
|
| 1202 |
+
return 4
|
| 1203 |
+
else:
|
| 1204 |
+
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
|
| 1205 |
+
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
|
| 1206 |
+
|
| 1207 |
+
|
| 1208 |
+
def _prepare_fake_inputs(
|
| 1209 |
+
seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
|
| 1210 |
+
"""Prepare fake inputs for profile run."""
|
| 1211 |
+
if vision_language_config:
|
| 1212 |
+
prompt_tokens = [
|
| 1213 |
+
vision_language_config.image_token_id
|
| 1214 |
+
] * vision_language_config.image_feature_size + [0] * (
|
| 1215 |
+
seq_len - vision_language_config.image_feature_size)
|
| 1216 |
+
fake_image_input = MultiModalData(
|
| 1217 |
+
type=MultiModalData.Type.IMAGE,
|
| 1218 |
+
data=torch.zeros(vision_language_config.image_input_shape,
|
| 1219 |
+
dtype=torch.float16))
|
| 1220 |
+
else:
|
| 1221 |
+
prompt_tokens = [0] * seq_len
|
| 1222 |
+
fake_image_input = None
|
| 1223 |
+
return SequenceData(prompt_tokens), fake_image_input
|
serve/sample_c2i.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
from torchvision.utils import save_image
|
| 5 |
+
|
| 6 |
+
from tokenizer.tokenizer_image.vq_model import VQ_models
|
| 7 |
+
from serve.gpt_model import GPT_models
|
| 8 |
+
from serve.llm import LLM
|
| 9 |
+
from vllm import SamplingParams
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def main(args):
|
| 13 |
+
# Setup PyTorch:
|
| 14 |
+
torch.manual_seed(args.seed)
|
| 15 |
+
torch.backends.cudnn.deterministic = True
|
| 16 |
+
torch.backends.cudnn.benchmark = False
|
| 17 |
+
torch.set_grad_enabled(False)
|
| 18 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 19 |
+
|
| 20 |
+
# create and load model
|
| 21 |
+
vq_model = VQ_models[args.vq_model](
|
| 22 |
+
codebook_size=args.codebook_size,
|
| 23 |
+
codebook_embed_dim=args.codebook_embed_dim)
|
| 24 |
+
vq_model.to(device)
|
| 25 |
+
vq_model.eval()
|
| 26 |
+
checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
|
| 27 |
+
vq_model.load_state_dict(checkpoint["model"])
|
| 28 |
+
del checkpoint
|
| 29 |
+
print(f"image tokenizer is loaded")
|
| 30 |
+
|
| 31 |
+
# Labels to condition the model with (feel free to change):
|
| 32 |
+
class_labels = [207, 360, 387, 974, 88, 979, 417, 279]
|
| 33 |
+
latent_size = args.image_size // args.downsample_size
|
| 34 |
+
qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size]
|
| 35 |
+
prompt_token_ids = [[cind] for cind in class_labels]
|
| 36 |
+
if args.cfg_scale > 1.0:
|
| 37 |
+
prompt_token_ids.extend([[args.num_classes] for _ in range(len(prompt_token_ids))])
|
| 38 |
+
# Create an LLM.
|
| 39 |
+
llm = LLM(
|
| 40 |
+
args=args,
|
| 41 |
+
model='autoregressive/serve/fake_json/{}.json'.format(args.gpt_model),
|
| 42 |
+
gpu_memory_utilization=0.9,
|
| 43 |
+
skip_tokenizer_init=True)
|
| 44 |
+
print(f"gpt model is loaded")
|
| 45 |
+
|
| 46 |
+
# Create a sampling params object.
|
| 47 |
+
sampling_params = SamplingParams(
|
| 48 |
+
temperature=args.temperature, top_p=args.top_p, top_k=args.top_k,
|
| 49 |
+
max_tokens=latent_size ** 2)
|
| 50 |
+
|
| 51 |
+
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
| 52 |
+
# that contain the prompt, generated text, and other information.
|
| 53 |
+
t1 = time.time()
|
| 54 |
+
outputs = llm.generate(
|
| 55 |
+
prompt_token_ids=prompt_token_ids,
|
| 56 |
+
sampling_params=sampling_params,
|
| 57 |
+
use_tqdm=False)
|
| 58 |
+
sampling_time = time.time() - t1
|
| 59 |
+
print(f"gpt sampling takes about {sampling_time:.2f} seconds.")
|
| 60 |
+
|
| 61 |
+
# decode to image
|
| 62 |
+
index_sample = torch.tensor([output.outputs[0].token_ids for output in outputs], device=device)
|
| 63 |
+
if args.cfg_scale > 1.0:
|
| 64 |
+
index_sample = index_sample[:len(class_labels)]
|
| 65 |
+
t2 = time.time()
|
| 66 |
+
samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
|
| 67 |
+
decoder_time = time.time() - t2
|
| 68 |
+
print(f"decoder takes about {decoder_time:.2f} seconds.")
|
| 69 |
+
|
| 70 |
+
# Save and display images:
|
| 71 |
+
save_image(samples, "sample_{}.png".format(args.gpt_type), nrow=4, normalize=True, value_range=(-1, 1))
|
| 72 |
+
print(f"image is saved to sample_{args.gpt_type}.png")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
if __name__ == '__main__':
|
| 76 |
+
parser = argparse.ArgumentParser()
|
| 77 |
+
parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-B")
|
| 78 |
+
parser.add_argument("--gpt-ckpt", type=str, required=True, help="ckpt path for gpt model")
|
| 79 |
+
parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
|
| 80 |
+
parser.add_argument("--from-fsdp", action='store_true')
|
| 81 |
+
parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
|
| 82 |
+
parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
|
| 83 |
+
parser.add_argument("--compile", action='store_true', default=False)
|
| 84 |
+
parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
|
| 85 |
+
parser.add_argument("--vq-ckpt", type=str, required=True, help="ckpt path for vq model")
|
| 86 |
+
parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
|
| 87 |
+
parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
|
| 88 |
+
parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=384)
|
| 89 |
+
parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
|
| 90 |
+
parser.add_argument("--num-classes", type=int, default=1000)
|
| 91 |
+
parser.add_argument("--cfg-scale", type=float, default=4.0)
|
| 92 |
+
parser.add_argument("--seed", type=int, default=0)
|
| 93 |
+
parser.add_argument("--top-k", type=int, default=2000,help="top-k value to sample with")
|
| 94 |
+
parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
|
| 95 |
+
parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
|
| 96 |
+
args = parser.parse_args()
|
| 97 |
+
main(args)
|
serve/sampler.py
ADDED
|
@@ -0,0 +1,868 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""A layer that samples the next tokens from the model's outputs."""
|
| 2 |
+
import itertools
|
| 3 |
+
from typing import Dict, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from vllm.model_executor.layers.ops.sample import sample as sample_triton
|
| 9 |
+
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
| 10 |
+
SamplingTensors)
|
| 11 |
+
from vllm.sampling_params import SamplingParams, SamplingType
|
| 12 |
+
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
|
| 13 |
+
SamplerOutput, SequenceData, SequenceGroupOutput,
|
| 14 |
+
SequenceOutput)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Sampler(nn.Module):
|
| 18 |
+
"""Samples the next tokens from the model's outputs.
|
| 19 |
+
|
| 20 |
+
This layer does the following:
|
| 21 |
+
1. Discard the hidden states that are not used for sampling (i.e., all
|
| 22 |
+
tokens except the final one in each prompt).
|
| 23 |
+
2. Compute the logits for the next tokens.
|
| 24 |
+
3. Apply presence, frequency and repetition penalties.
|
| 25 |
+
4. Apply temperature scaling.
|
| 26 |
+
5. Apply top-p and top-k truncation.
|
| 27 |
+
6. Sample the next tokens.
|
| 28 |
+
Here, each sequence group within the batch can have different sampling
|
| 29 |
+
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
|
| 30 |
+
|
| 31 |
+
The structure of the logits tensor is coupled with the seq_groups in
|
| 32 |
+
sampling_metadata. Typically, each sequence in each seq_group has one row in
|
| 33 |
+
logits for the next token to be sampled; however, for a seq_group with a
|
| 34 |
+
prompt request with the prompt_logprobs sampling parameter, there are rows
|
| 35 |
+
in logits for each token in the input prompt.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, cfg_scale=1.0):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.cfg_scale = cfg_scale
|
| 41 |
+
# Whether or not the SamplerOutput should have on-device tensors
|
| 42 |
+
# containing the sampled token ids and probabilities. This is used by
|
| 43 |
+
# speculative decoding.
|
| 44 |
+
self.include_gpu_probs_tensor = False
|
| 45 |
+
|
| 46 |
+
def forward(
|
| 47 |
+
self,
|
| 48 |
+
logits: torch.Tensor,
|
| 49 |
+
sampling_metadata: SamplingMetadata,
|
| 50 |
+
) -> Optional[SamplerOutput]:
|
| 51 |
+
assert logits is not None
|
| 52 |
+
_, vocab_size = logits.shape
|
| 53 |
+
|
| 54 |
+
if self.cfg_scale > 1.0:
|
| 55 |
+
logits_combined = logits
|
| 56 |
+
cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
|
| 57 |
+
logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_scale
|
| 58 |
+
logits = torch.cat([logits, logits], dim=0)
|
| 59 |
+
|
| 60 |
+
# Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
|
| 61 |
+
# have not been generated yet
|
| 62 |
+
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
|
| 63 |
+
|
| 64 |
+
# Prepare sampling tensors with pinned memory to avoid blocking.
|
| 65 |
+
(sampling_tensors, do_penalties, do_top_p_top_k,
|
| 66 |
+
do_min_p) = SamplingTensors.from_sampling_metadata(
|
| 67 |
+
sampling_metadata, vocab_size, logits.device, logits.dtype)
|
| 68 |
+
|
| 69 |
+
# Apply presence and frequency penalties.
|
| 70 |
+
if do_penalties:
|
| 71 |
+
logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
|
| 72 |
+
sampling_tensors.output_tokens,
|
| 73 |
+
sampling_tensors.presence_penalties,
|
| 74 |
+
sampling_tensors.frequency_penalties,
|
| 75 |
+
sampling_tensors.repetition_penalties)
|
| 76 |
+
|
| 77 |
+
# Apply temperature scaling.
|
| 78 |
+
# Use in-place division to avoid creating a new tensor.
|
| 79 |
+
logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))
|
| 80 |
+
|
| 81 |
+
if do_top_p_top_k:
|
| 82 |
+
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
|
| 83 |
+
sampling_tensors.top_ks)
|
| 84 |
+
|
| 85 |
+
if do_min_p:
|
| 86 |
+
logits = _apply_min_p(logits, sampling_tensors.min_ps)
|
| 87 |
+
|
| 88 |
+
# We use float32 for probabilities and log probabilities.
|
| 89 |
+
# Compute the probabilities.
|
| 90 |
+
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
| 91 |
+
# Compute the log probabilities.
|
| 92 |
+
# Use log_softmax to ensure numerical stability.
|
| 93 |
+
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
| 94 |
+
|
| 95 |
+
# Sample the next tokens.
|
| 96 |
+
sample_results, maybe_sampled_tokens_tensor = _sample(
|
| 97 |
+
probs,
|
| 98 |
+
logprobs,
|
| 99 |
+
sampling_metadata,
|
| 100 |
+
sampling_tensors,
|
| 101 |
+
include_gpu_probs_tensor=self.include_gpu_probs_tensor,
|
| 102 |
+
modify_greedy_probs=self._should_modify_greedy_probs_inplace,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
if self.cfg_scale > 1.0:
|
| 107 |
+
cond_result = sample_results[:len(sample_results) // 2]
|
| 108 |
+
sample_results = cond_result + cond_result
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
if self.include_gpu_probs_tensor:
|
| 112 |
+
assert maybe_sampled_tokens_tensor is not None
|
| 113 |
+
sampled_tokens_tensor = maybe_sampled_tokens_tensor
|
| 114 |
+
on_device_tensors = (probs, sampled_tokens_tensor)
|
| 115 |
+
else:
|
| 116 |
+
on_device_tensors = None
|
| 117 |
+
|
| 118 |
+
# Get the logprobs query results.
|
| 119 |
+
prompt_logprobs, sample_logprobs = _get_logprobs(
|
| 120 |
+
logprobs, sampling_metadata, sample_results)
|
| 121 |
+
return _build_sampler_output(sample_results,
|
| 122 |
+
sampling_metadata,
|
| 123 |
+
prompt_logprobs,
|
| 124 |
+
sample_logprobs,
|
| 125 |
+
on_device_tensors=on_device_tensors)
|
| 126 |
+
|
| 127 |
+
@property
|
| 128 |
+
def _should_modify_greedy_probs_inplace(self) -> bool:
|
| 129 |
+
"""Whether or not the sampler should modify the probability distribution
|
| 130 |
+
of greedily-sampled tokens such that multinomial sampling would sample
|
| 131 |
+
the greedily-sampled token.
|
| 132 |
+
|
| 133 |
+
In other words, if True then we set the probability of the greedily-
|
| 134 |
+
sampled token to 1.
|
| 135 |
+
|
| 136 |
+
This is used by speculative decoding, which requires that the sampling
|
| 137 |
+
method be encoded into the probability distribution.
|
| 138 |
+
"""
|
| 139 |
+
# Modify greedy probs if include_gpu_probs_tensor is set.
|
| 140 |
+
return self.include_gpu_probs_tensor
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def _get_bin_counts_and_mask(
|
| 144 |
+
tokens: torch.Tensor,
|
| 145 |
+
vocab_size: int,
|
| 146 |
+
num_seqs: int,
|
| 147 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 148 |
+
# Compute the bin counts for the tokens.
|
| 149 |
+
# vocab_size + 1 for padding.
|
| 150 |
+
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
| 151 |
+
dtype=torch.long,
|
| 152 |
+
device=tokens.device)
|
| 153 |
+
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
|
| 154 |
+
bin_counts = bin_counts[:, :vocab_size]
|
| 155 |
+
mask = bin_counts > 0
|
| 156 |
+
|
| 157 |
+
return bin_counts, mask
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _apply_min_tokens_penalty(
|
| 161 |
+
logits: torch.Tensor,
|
| 162 |
+
sampling_metadata: SamplingMetadata,
|
| 163 |
+
) -> torch.Tensor:
|
| 164 |
+
# list of indices in logits that will be set to -inf
|
| 165 |
+
logits_to_penalize = []
|
| 166 |
+
start_idx = 0
|
| 167 |
+
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
| 168 |
+
seq_ids, sampling_params = seq_group
|
| 169 |
+
|
| 170 |
+
# handle prompt_logprobs by skipping rows in logits added for the prompt
|
| 171 |
+
# tokens (prompt logprobs are not penalized)
|
| 172 |
+
if (i < sampling_metadata.num_prompts
|
| 173 |
+
and sampling_params.prompt_logprobs is not None):
|
| 174 |
+
assert len(seq_ids) == 1
|
| 175 |
+
start_idx += sampling_metadata.prompt_lens[i] - 1
|
| 176 |
+
|
| 177 |
+
min_tokens = sampling_params.min_tokens
|
| 178 |
+
if min_tokens > 0:
|
| 179 |
+
seqs_to_penalize = []
|
| 180 |
+
for i, seq_id in enumerate(seq_ids):
|
| 181 |
+
seq_data = sampling_metadata.seq_data[seq_id]
|
| 182 |
+
if len(seq_data.output_token_ids) < min_tokens:
|
| 183 |
+
seqs_to_penalize.append(i)
|
| 184 |
+
|
| 185 |
+
if seqs_to_penalize:
|
| 186 |
+
# convert to the index into logits
|
| 187 |
+
seqs_to_penalize = [start_idx + i for i in seqs_to_penalize]
|
| 188 |
+
# use set() to remove any duplicates
|
| 189 |
+
token_ids_to_penalize = set(sampling_params.stop_token_ids +
|
| 190 |
+
[sampling_params.eos_token_id])
|
| 191 |
+
# itertools.product pairs each seq index with every token id
|
| 192 |
+
logits_to_penalize.extend(
|
| 193 |
+
itertools.product(seqs_to_penalize, token_ids_to_penalize))
|
| 194 |
+
|
| 195 |
+
start_idx += len(seq_ids)
|
| 196 |
+
|
| 197 |
+
if logits_to_penalize:
|
| 198 |
+
# use zip and * to group indices along each dimension
|
| 199 |
+
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
|
| 200 |
+
logits[tuple(zip(*logits_to_penalize))] = -float("inf")
|
| 201 |
+
|
| 202 |
+
# verifies that no rows in logits were missed unexpectedly
|
| 203 |
+
assert start_idx == logits.shape[0]
|
| 204 |
+
return logits
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
| 208 |
+
output_tokens_tensor: torch.Tensor,
|
| 209 |
+
presence_penalties: torch.Tensor,
|
| 210 |
+
frequency_penalties: torch.Tensor,
|
| 211 |
+
repetition_penalties: torch.Tensor) -> torch.Tensor:
|
| 212 |
+
num_seqs, vocab_size = logits.shape
|
| 213 |
+
_, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
|
| 214 |
+
num_seqs)
|
| 215 |
+
output_bin_counts, output_mask = _get_bin_counts_and_mask(
|
| 216 |
+
output_tokens_tensor, vocab_size, num_seqs)
|
| 217 |
+
|
| 218 |
+
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
|
| 219 |
+
repetition_penalties[~(prompt_mask | output_mask)] = 1.0
|
| 220 |
+
logits = torch.where(logits > 0, logits / repetition_penalties,
|
| 221 |
+
logits * repetition_penalties)
|
| 222 |
+
|
| 223 |
+
# We follow the definition in OpenAI API.
|
| 224 |
+
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
| 225 |
+
logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
|
| 226 |
+
logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
|
| 227 |
+
return logits
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def _apply_top_k_top_p(
|
| 231 |
+
logits: torch.Tensor,
|
| 232 |
+
p: torch.Tensor,
|
| 233 |
+
k: torch.Tensor,
|
| 234 |
+
) -> torch.Tensor:
|
| 235 |
+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
| 236 |
+
|
| 237 |
+
# Apply top-k.
|
| 238 |
+
top_k_mask = logits_sort.size(1) - k.to(torch.long)
|
| 239 |
+
# Get all the top_k values.
|
| 240 |
+
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
| 241 |
+
top_k_mask = logits_sort < top_k_mask
|
| 242 |
+
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
| 243 |
+
|
| 244 |
+
# Apply top-p.
|
| 245 |
+
probs_sort = logits_sort.softmax(dim=-1)
|
| 246 |
+
probs_sum = probs_sort.cumsum(dim=-1)
|
| 247 |
+
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
| 248 |
+
# at least one
|
| 249 |
+
top_p_mask[:, -1] = False
|
| 250 |
+
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
| 251 |
+
|
| 252 |
+
# Re-sort the probabilities.
|
| 253 |
+
src = torch.arange(logits_idx.shape[-1],
|
| 254 |
+
device=logits_idx.device).expand_as(logits_idx)
|
| 255 |
+
logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
|
| 256 |
+
index=logits_idx,
|
| 257 |
+
src=src)
|
| 258 |
+
logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
|
| 259 |
+
return logits
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def _apply_min_p(
|
| 263 |
+
logits: torch.Tensor,
|
| 264 |
+
min_p: torch.Tensor,
|
| 265 |
+
) -> torch.Tensor:
|
| 266 |
+
"""
|
| 267 |
+
Adapted from
|
| 268 |
+
https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
|
| 269 |
+
"""
|
| 270 |
+
probs = torch.softmax(logits, dim=-1)
|
| 271 |
+
top_probs, _ = probs.max(dim=-1, keepdim=True)
|
| 272 |
+
scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
|
| 273 |
+
tokens_to_remove = probs < scaled_min_p
|
| 274 |
+
logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
|
| 275 |
+
|
| 276 |
+
return logits
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def _greedy_sample(
|
| 280 |
+
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
| 281 |
+
samples: torch.Tensor,
|
| 282 |
+
) -> List[Tuple[List[int], List[int]]]:
|
| 283 |
+
samples = samples.tolist()
|
| 284 |
+
sample_idx = 0
|
| 285 |
+
results = []
|
| 286 |
+
for seq_group in selected_seq_groups:
|
| 287 |
+
seq_ids, _ = seq_group
|
| 288 |
+
num_parent_seqs = len(seq_ids)
|
| 289 |
+
assert num_parent_seqs == 1, (
|
| 290 |
+
"Greedy sampling should have only one seq.")
|
| 291 |
+
parent_ids = list(range(num_parent_seqs))
|
| 292 |
+
next_token_ids = [samples[sample_idx]]
|
| 293 |
+
results.append((next_token_ids, parent_ids))
|
| 294 |
+
sample_idx += num_parent_seqs
|
| 295 |
+
return results
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def _random_sample(
|
| 299 |
+
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
| 300 |
+
is_prompts: List[bool],
|
| 301 |
+
random_samples: torch.Tensor,
|
| 302 |
+
) -> List[Tuple[List[int], List[int]]]:
|
| 303 |
+
# Find the maximum best_of value of the prompt phase requests.
|
| 304 |
+
random_samples = random_samples.cpu()
|
| 305 |
+
sample_idx = 0
|
| 306 |
+
results = []
|
| 307 |
+
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
| 308 |
+
seq_ids, sampling_params = seq_group
|
| 309 |
+
num_parent_seqs = len(seq_ids)
|
| 310 |
+
if is_prompt:
|
| 311 |
+
# Prompt phase.
|
| 312 |
+
parent_ids = [0] * sampling_params.best_of
|
| 313 |
+
next_token_ids = random_samples[
|
| 314 |
+
sample_idx, :sampling_params.best_of].tolist()
|
| 315 |
+
else:
|
| 316 |
+
# Generation phase.
|
| 317 |
+
parent_ids = list(range(num_parent_seqs))
|
| 318 |
+
next_token_ids = random_samples[sample_idx:sample_idx +
|
| 319 |
+
num_parent_seqs, 0].tolist()
|
| 320 |
+
results.append((next_token_ids, parent_ids))
|
| 321 |
+
sample_idx += num_parent_seqs
|
| 322 |
+
return results
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def _beam_search_sample(
|
| 326 |
+
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
|
| 327 |
+
is_prompts: List[bool],
|
| 328 |
+
seq_data: Dict[int, SequenceData],
|
| 329 |
+
logprobs: torch.Tensor,
|
| 330 |
+
) -> List[Tuple[List[int], List[int]]]:
|
| 331 |
+
# We sample 2 * beam_width candidates to make sure that with high
|
| 332 |
+
# probability we can get `beam_width` candidates in addition to
|
| 333 |
+
# the finished sequences for the next iteration. See
|
| 334 |
+
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
|
| 335 |
+
# for details. See also HF reference:
|
| 336 |
+
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
|
| 337 |
+
#
|
| 338 |
+
# NOTE: Beam search is not vectorized, so its speed can be slower than
|
| 339 |
+
# other sampling methods.
|
| 340 |
+
sample_idx = 0
|
| 341 |
+
results = []
|
| 342 |
+
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
|
| 343 |
+
seq_ids, sampling_params = seq_group
|
| 344 |
+
num_parent_seqs = len(seq_ids)
|
| 345 |
+
beam_width = sampling_params.best_of
|
| 346 |
+
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
|
| 347 |
+
if is_prompt:
|
| 348 |
+
# Prompt phase.
|
| 349 |
+
assert num_parent_seqs == 1, (
|
| 350 |
+
"Prompt input should have only one seq.")
|
| 351 |
+
parent_ids = [0] * (2 * beam_width)
|
| 352 |
+
_, next_token_ids = torch.topk(seq_group_logprobs[0],
|
| 353 |
+
2 * beam_width)
|
| 354 |
+
next_token_ids = next_token_ids.tolist()
|
| 355 |
+
else:
|
| 356 |
+
# Generation phase.
|
| 357 |
+
cumulative_logprobs = [
|
| 358 |
+
seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
|
| 359 |
+
]
|
| 360 |
+
cumulative_logprobs = torch.tensor(
|
| 361 |
+
cumulative_logprobs,
|
| 362 |
+
dtype=torch.float,
|
| 363 |
+
device=seq_group_logprobs.device)
|
| 364 |
+
seq_group_logprobs = (seq_group_logprobs +
|
| 365 |
+
cumulative_logprobs.unsqueeze(dim=1))
|
| 366 |
+
_, topk_ids = torch.topk(seq_group_logprobs.flatten(),
|
| 367 |
+
2 * beam_width)
|
| 368 |
+
topk_ids = topk_ids.tolist()
|
| 369 |
+
vocab_size = seq_group_logprobs.size(-1)
|
| 370 |
+
parent_ids = [i // vocab_size for i in topk_ids]
|
| 371 |
+
next_token_ids = [i % vocab_size for i in topk_ids]
|
| 372 |
+
results.append((next_token_ids, parent_ids))
|
| 373 |
+
sample_idx += num_parent_seqs
|
| 374 |
+
assert sample_idx == logprobs.size(0)
|
| 375 |
+
return results
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
# torch.multinomial forces a GPU<->CPU sync.
|
| 379 |
+
# Therefore, we use an optimized implementation instead.
|
| 380 |
+
# Note that we always sample with replacement.
|
| 381 |
+
# probs will be modified in place, but this is fine, as we pass
|
| 382 |
+
# in a copy already.
|
| 383 |
+
def _multinomial(
|
| 384 |
+
probs: torch.Tensor,
|
| 385 |
+
num_samples: int,
|
| 386 |
+
seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None,
|
| 387 |
+
generators: Optional[List[torch.Generator]] = None,
|
| 388 |
+
) -> torch.Tensor:
|
| 389 |
+
if num_samples > 1:
|
| 390 |
+
# This is equivalent to torch.repeat_interleaved (which also
|
| 391 |
+
# forces a GPU<->CPU sync).
|
| 392 |
+
# This allows us to do sampling with replacement by creating
|
| 393 |
+
# num_samples copies of each row in the tensor, and then
|
| 394 |
+
# batch sampling the resulting tensor.
|
| 395 |
+
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
|
| 396 |
+
probs.shape[1]).contiguous().view(
|
| 397 |
+
-1, probs.shape[1])
|
| 398 |
+
q = torch.empty_like(probs)
|
| 399 |
+
if seq_groups is None:
|
| 400 |
+
q.exponential_()
|
| 401 |
+
else:
|
| 402 |
+
sample_idx = 0
|
| 403 |
+
for (seq_ids, _), generator in zip(seq_groups, generators):
|
| 404 |
+
next_sample_idx = sample_idx + len(seq_ids) * num_samples
|
| 405 |
+
q[sample_idx:next_sample_idx].exponential_(generator=generator)
|
| 406 |
+
sample_idx = next_sample_idx
|
| 407 |
+
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def _sample_with_torch(
|
| 411 |
+
probs: torch.Tensor,
|
| 412 |
+
logprobs: torch.Tensor,
|
| 413 |
+
sampling_metadata: SamplingMetadata,
|
| 414 |
+
include_gpu_probs_tensor: bool,
|
| 415 |
+
modify_greedy_probs: bool,
|
| 416 |
+
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
|
| 417 |
+
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
| 418 |
+
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
| 419 |
+
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
| 420 |
+
_, sampling_params = seq_group
|
| 421 |
+
sampling_type = sampling_params.sampling_type
|
| 422 |
+
categorized_seq_group_ids[sampling_type].append(i)
|
| 423 |
+
|
| 424 |
+
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
| 425 |
+
sample_metadata = {}
|
| 426 |
+
multinomial_samples = {}
|
| 427 |
+
|
| 428 |
+
# Create output tensor for sampled token ids.
|
| 429 |
+
if include_gpu_probs_tensor:
|
| 430 |
+
sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
|
| 431 |
+
1,
|
| 432 |
+
dtype=torch.long,
|
| 433 |
+
device=logprobs.device)
|
| 434 |
+
else:
|
| 435 |
+
sampled_token_ids_tensor = None
|
| 436 |
+
|
| 437 |
+
# Counterintiutively, having two loops here is actually faster.
|
| 438 |
+
# The first loop can run without waiting on GPU<->CPU sync.
|
| 439 |
+
for sampling_type in SamplingType:
|
| 440 |
+
sample_indices = categorized_sample_indices[sampling_type][:, 0]
|
| 441 |
+
num_tokens = len(sample_indices)
|
| 442 |
+
if num_tokens == 0:
|
| 443 |
+
continue
|
| 444 |
+
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
| 445 |
+
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
|
| 446 |
+
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
|
| 447 |
+
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
|
| 448 |
+
is_prompts, sample_indices)
|
| 449 |
+
long_sample_indices = sample_indices.long()
|
| 450 |
+
|
| 451 |
+
if sampling_type == SamplingType.GREEDY:
|
| 452 |
+
greedy_samples = torch.argmax(logprobs[long_sample_indices],
|
| 453 |
+
dim=-1)
|
| 454 |
+
|
| 455 |
+
if include_gpu_probs_tensor:
|
| 456 |
+
# Store sampled tokens in output tensor.
|
| 457 |
+
sampled_token_ids_tensor[
|
| 458 |
+
long_sample_indices] = greedy_samples.unsqueeze(-1)
|
| 459 |
+
|
| 460 |
+
if modify_greedy_probs:
|
| 461 |
+
# If required, modify the probabilities such that sampling from
|
| 462 |
+
# the modified distribution would always sample the argmax
|
| 463 |
+
# token id.
|
| 464 |
+
_modify_greedy_probs_inplace(logprobs, probs,
|
| 465 |
+
long_sample_indices,
|
| 466 |
+
greedy_samples)
|
| 467 |
+
|
| 468 |
+
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
| 469 |
+
max_best_of_in_batch = 1
|
| 470 |
+
for seq_group, is_prompt in zip(seq_groups, is_prompts):
|
| 471 |
+
if is_prompt:
|
| 472 |
+
_, sampling_params = seq_group
|
| 473 |
+
max_best_of_in_batch = max(max_best_of_in_batch,
|
| 474 |
+
sampling_params.best_of)
|
| 475 |
+
seeded_args = {} if sampling_type == SamplingType.RANDOM else {
|
| 476 |
+
"seq_groups": seq_groups,
|
| 477 |
+
"generators": sampling_metadata.generators,
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
multinomial_samples[sampling_type] = _multinomial(
|
| 481 |
+
probs[long_sample_indices], max_best_of_in_batch,
|
| 482 |
+
**seeded_args)
|
| 483 |
+
|
| 484 |
+
if include_gpu_probs_tensor:
|
| 485 |
+
# Store sampled tokens in output tensor.
|
| 486 |
+
sampled_token_ids_tensor[
|
| 487 |
+
long_sample_indices] = multinomial_samples[sampling_type]
|
| 488 |
+
|
| 489 |
+
elif sampling_type == SamplingType.BEAM:
|
| 490 |
+
beam_search_logprobs = logprobs[sample_indices]
|
| 491 |
+
else:
|
| 492 |
+
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
| 493 |
+
|
| 494 |
+
# GPU<->CPU sync happens in the loop below.
|
| 495 |
+
# This also converts the sample output to Python objects.
|
| 496 |
+
|
| 497 |
+
for sampling_type in SamplingType:
|
| 498 |
+
if sampling_type not in sample_metadata:
|
| 499 |
+
continue
|
| 500 |
+
seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[
|
| 501 |
+
sampling_type]
|
| 502 |
+
if sampling_type == SamplingType.GREEDY:
|
| 503 |
+
sample_results = _greedy_sample(seq_groups, greedy_samples)
|
| 504 |
+
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
| 505 |
+
sample_results = _random_sample(seq_groups, is_prompts,
|
| 506 |
+
multinomial_samples[sampling_type])
|
| 507 |
+
elif sampling_type == SamplingType.BEAM:
|
| 508 |
+
sample_results = _beam_search_sample(seq_groups, is_prompts,
|
| 509 |
+
sampling_metadata.seq_data,
|
| 510 |
+
beam_search_logprobs)
|
| 511 |
+
sample_results_dict.update(zip(seq_group_ids, sample_results))
|
| 512 |
+
|
| 513 |
+
sample_results = [
|
| 514 |
+
sample_results_dict[i]
|
| 515 |
+
for i in range(len(sampling_metadata.seq_groups))
|
| 516 |
+
]
|
| 517 |
+
return sample_results, sampled_token_ids_tensor
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def _sample_with_triton_kernel(
|
| 521 |
+
probs: torch.Tensor,
|
| 522 |
+
logprobs: torch.Tensor,
|
| 523 |
+
sampling_metadata: SamplingMetadata,
|
| 524 |
+
sampling_tensors: SamplingTensors,
|
| 525 |
+
) -> List[Tuple[List[int], List[int]]]:
|
| 526 |
+
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
| 527 |
+
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
| 528 |
+
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
| 529 |
+
_, sampling_params = seq_group
|
| 530 |
+
sampling_type = sampling_params.sampling_type
|
| 531 |
+
categorized_seq_group_ids[sampling_type].append(i)
|
| 532 |
+
|
| 533 |
+
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
| 534 |
+
sample_metadata = {}
|
| 535 |
+
max_best_of_in_batch = 1
|
| 536 |
+
|
| 537 |
+
# Counterintiutively, having two loops here is actually faster.
|
| 538 |
+
# The first loop can run without waiting on GPU<->CPU sync.
|
| 539 |
+
for sampling_type in SamplingType:
|
| 540 |
+
sample_indices = categorized_sample_indices[sampling_type][:, 0]
|
| 541 |
+
sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
|
| 542 |
+
num_tokens = len(sample_indices)
|
| 543 |
+
if num_tokens == 0:
|
| 544 |
+
continue
|
| 545 |
+
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
| 546 |
+
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
|
| 547 |
+
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
|
| 548 |
+
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
|
| 549 |
+
is_prompts, sample_indices,
|
| 550 |
+
sampled_token_indices)
|
| 551 |
+
if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
|
| 552 |
+
SamplingType.RANDOM_SEED):
|
| 553 |
+
for seq_group, is_prompt in zip(seq_groups, is_prompts):
|
| 554 |
+
if is_prompt:
|
| 555 |
+
_, sampling_params = seq_group
|
| 556 |
+
max_best_of_in_batch = max(max_best_of_in_batch,
|
| 557 |
+
sampling_params.best_of)
|
| 558 |
+
elif sampling_type == SamplingType.BEAM:
|
| 559 |
+
beam_search_logprobs = logprobs[sample_indices]
|
| 560 |
+
else:
|
| 561 |
+
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
| 562 |
+
|
| 563 |
+
sampled_tokens, _, _ = sample_triton(
|
| 564 |
+
probs=probs,
|
| 565 |
+
seeds=sampling_tensors.sampling_seeds,
|
| 566 |
+
max_best_of=max_best_of_in_batch,
|
| 567 |
+
sample_indices=sampling_tensors.sample_indices,
|
| 568 |
+
logprobs=logprobs,
|
| 569 |
+
# don't save logprobs because we have logic for that below
|
| 570 |
+
# TODO: use this instead of the CPU-based logic below
|
| 571 |
+
save_logprobs=False,
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
# GPU<->CPU sync happens in the loop below.
|
| 575 |
+
|
| 576 |
+
for sampling_type in SamplingType:
|
| 577 |
+
if sampling_type not in sample_metadata:
|
| 578 |
+
continue
|
| 579 |
+
(seq_group_ids, seq_groups, is_prompts, sample_indices,
|
| 580 |
+
sampled_token_indices) = sample_metadata[sampling_type]
|
| 581 |
+
if sampling_type == SamplingType.GREEDY:
|
| 582 |
+
sample_results = _greedy_sample(
|
| 583 |
+
seq_groups, sampled_tokens[sampled_token_indices][:, 0])
|
| 584 |
+
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
| 585 |
+
sample_results = _random_sample(
|
| 586 |
+
seq_groups, is_prompts, sampled_tokens[sampled_token_indices])
|
| 587 |
+
elif sampling_type == SamplingType.BEAM:
|
| 588 |
+
sample_results = _beam_search_sample(seq_groups, is_prompts,
|
| 589 |
+
sampling_metadata.seq_data,
|
| 590 |
+
beam_search_logprobs)
|
| 591 |
+
sample_results_dict.update(zip(seq_group_ids, sample_results))
|
| 592 |
+
|
| 593 |
+
sample_results = [
|
| 594 |
+
sample_results_dict[i]
|
| 595 |
+
for i in range(len(sampling_metadata.seq_groups))
|
| 596 |
+
]
|
| 597 |
+
return sample_results
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
def _sample(
|
| 601 |
+
probs: torch.Tensor, logprobs: torch.Tensor,
|
| 602 |
+
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
|
| 603 |
+
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
|
| 604 |
+
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
|
| 605 |
+
return _sample_with_torch(
|
| 606 |
+
probs,
|
| 607 |
+
logprobs,
|
| 608 |
+
sampling_metadata,
|
| 609 |
+
include_gpu_probs_tensor=include_gpu_probs_tensor,
|
| 610 |
+
modify_greedy_probs=modify_greedy_probs,
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
# TODO: Enable once Triton kernel & associated code is faster.
|
| 614 |
+
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
|
| 615 |
+
# sampling_tensors)
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
|
| 619 |
+
"""
|
| 620 |
+
This function calculates the ranks of the chosen tokens in a logprob tensor.
|
| 621 |
+
|
| 622 |
+
Args:
|
| 623 |
+
x (torch.Tensor): 2D logprob tensor of shape (N, M)
|
| 624 |
+
where N is the no. of tokens and M is the vocab dim.
|
| 625 |
+
indices (torch.Tensor): List of chosen token indices.
|
| 626 |
+
|
| 627 |
+
Returns:
|
| 628 |
+
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
|
| 629 |
+
Each element in the returned tensor represents the rank
|
| 630 |
+
of the chosen token in the input logprob tensor.
|
| 631 |
+
"""
|
| 632 |
+
vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
|
| 633 |
+
indices]
|
| 634 |
+
return (x > vals[:, None]).long().sum(1).add_(1)
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
def _get_logprobs(
|
| 638 |
+
logprobs: torch.Tensor,
|
| 639 |
+
sampling_metadata: SamplingMetadata,
|
| 640 |
+
sample_results: List[Tuple[List[int], List[int]]],
|
| 641 |
+
) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
|
| 642 |
+
int, float]]]]:
|
| 643 |
+
# Prepare query indices
|
| 644 |
+
batched_logprobs_query_seq_indices: List[int] = []
|
| 645 |
+
batched_logprobs_query_token_indices: List[int] = []
|
| 646 |
+
# at least get one logprob for each token
|
| 647 |
+
largest_num_logprobs = 1
|
| 648 |
+
sample_idx = 0
|
| 649 |
+
for i, (seq_group, sample_result) in enumerate(
|
| 650 |
+
zip(sampling_metadata.seq_groups, sample_results)):
|
| 651 |
+
seq_ids, sampling_params = seq_group
|
| 652 |
+
next_token_ids, parent_ids = sample_result
|
| 653 |
+
num_parent_seqs = len(seq_ids)
|
| 654 |
+
if (i < sampling_metadata.num_prompts
|
| 655 |
+
and sampling_params.prompt_logprobs is not None):
|
| 656 |
+
largest_num_logprobs = max(largest_num_logprobs,
|
| 657 |
+
sampling_params.prompt_logprobs)
|
| 658 |
+
prompt_len = sampling_metadata.prompt_lens[i]
|
| 659 |
+
prompt_tokens = sampling_metadata.seq_data[
|
| 660 |
+
seq_ids[0]].prompt_token_ids
|
| 661 |
+
batched_logprobs_query_seq_indices.extend(
|
| 662 |
+
sample_idx + j for j in range(prompt_len - 1))
|
| 663 |
+
batched_logprobs_query_token_indices.extend(
|
| 664 |
+
token_id for token_id in prompt_tokens[1:])
|
| 665 |
+
sample_idx += prompt_len - 1
|
| 666 |
+
batched_logprobs_query_seq_indices.extend(
|
| 667 |
+
[sample_idx + parent_id for parent_id in parent_ids])
|
| 668 |
+
batched_logprobs_query_token_indices.extend(next_token_ids)
|
| 669 |
+
if sampling_params.logprobs is not None:
|
| 670 |
+
largest_num_logprobs = max(largest_num_logprobs,
|
| 671 |
+
sampling_params.logprobs)
|
| 672 |
+
sample_idx += num_parent_seqs
|
| 673 |
+
assert sample_idx == logprobs.size(0)
|
| 674 |
+
|
| 675 |
+
batched_logprobs_query_seq_indices_gpu = torch.tensor(
|
| 676 |
+
batched_logprobs_query_seq_indices, device=logprobs.device)
|
| 677 |
+
batched_logprobs_query_token_indices_gpu = torch.tensor(
|
| 678 |
+
batched_logprobs_query_token_indices, device=logprobs.device)
|
| 679 |
+
|
| 680 |
+
# Batched query for logprobs of selected token
|
| 681 |
+
batched_logprobs_query_result = logprobs[[
|
| 682 |
+
batched_logprobs_query_seq_indices_gpu,
|
| 683 |
+
batched_logprobs_query_token_indices_gpu
|
| 684 |
+
]]
|
| 685 |
+
|
| 686 |
+
batched_ranks_query_result = _get_ranks(
|
| 687 |
+
logprobs[batched_logprobs_query_seq_indices_gpu],
|
| 688 |
+
batched_logprobs_query_token_indices_gpu)
|
| 689 |
+
|
| 690 |
+
# Batched query for logprobs of topk tokens
|
| 691 |
+
if largest_num_logprobs > 0:
|
| 692 |
+
top_logprobs, top_token_ids = torch.topk(logprobs,
|
| 693 |
+
largest_num_logprobs,
|
| 694 |
+
dim=-1)
|
| 695 |
+
top_logprobs = top_logprobs.cpu()
|
| 696 |
+
top_token_ids = top_token_ids.cpu()
|
| 697 |
+
else:
|
| 698 |
+
top_logprobs, top_token_ids = None, None
|
| 699 |
+
|
| 700 |
+
batched_logprobs_query_result = batched_logprobs_query_result.cpu()
|
| 701 |
+
batched_ranks_query_result = batched_ranks_query_result.cpu()
|
| 702 |
+
|
| 703 |
+
# Gather results
|
| 704 |
+
result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
|
| 705 |
+
result_sample_logprobs: List[SampleLogprobs] = []
|
| 706 |
+
sample_idx = 0
|
| 707 |
+
query_result_idx = 0
|
| 708 |
+
for i, (seq_group, sample_result) in enumerate(
|
| 709 |
+
zip(sampling_metadata.seq_groups, sample_results)):
|
| 710 |
+
seq_ids, sampling_params = seq_group
|
| 711 |
+
next_token_ids, parent_ids = sample_result
|
| 712 |
+
|
| 713 |
+
# Prompt logprobs
|
| 714 |
+
if (i < sampling_metadata.num_prompts
|
| 715 |
+
and sampling_params.prompt_logprobs is not None):
|
| 716 |
+
num_logprobs = sampling_params.prompt_logprobs
|
| 717 |
+
prompt_tokens = sampling_metadata.seq_data[
|
| 718 |
+
seq_ids[0]].prompt_token_ids
|
| 719 |
+
group_prompt_logprobs: PromptLogprobs = [None]
|
| 720 |
+
for token_id in prompt_tokens[1:]:
|
| 721 |
+
prompt_logprobs_dict = {
|
| 722 |
+
token_id:
|
| 723 |
+
(batched_logprobs_query_result[query_result_idx].item(),
|
| 724 |
+
batched_ranks_query_result[query_result_idx].item())
|
| 725 |
+
}
|
| 726 |
+
if num_logprobs > 0:
|
| 727 |
+
prompt_logprobs_dict.update(
|
| 728 |
+
zip(
|
| 729 |
+
top_token_ids[sample_idx, :num_logprobs].tolist(),
|
| 730 |
+
zip(
|
| 731 |
+
top_logprobs[
|
| 732 |
+
sample_idx, :num_logprobs].tolist(),
|
| 733 |
+
range(1, num_logprobs + 1))))
|
| 734 |
+
group_prompt_logprobs.append({
|
| 735 |
+
token_id: Logprob(*logprob_rank)
|
| 736 |
+
for token_id, logprob_rank in prompt_logprobs_dict.items()
|
| 737 |
+
})
|
| 738 |
+
sample_idx += 1
|
| 739 |
+
query_result_idx += 1
|
| 740 |
+
result_prompt_logprobs.append(group_prompt_logprobs)
|
| 741 |
+
else:
|
| 742 |
+
result_prompt_logprobs.append(None)
|
| 743 |
+
|
| 744 |
+
# Sample logprobs
|
| 745 |
+
num_logprobs = sampling_params.logprobs
|
| 746 |
+
if num_logprobs is None:
|
| 747 |
+
num_logprobs = 0
|
| 748 |
+
group_sample_logprobs: SampleLogprobs = []
|
| 749 |
+
for next_token_id, parent_id in zip(next_token_ids, parent_ids):
|
| 750 |
+
sample_logprobs_dict = {
|
| 751 |
+
next_token_id:
|
| 752 |
+
(batched_logprobs_query_result[query_result_idx].item(),
|
| 753 |
+
batched_ranks_query_result[query_result_idx].item())
|
| 754 |
+
}
|
| 755 |
+
query_result_idx += 1
|
| 756 |
+
if num_logprobs >= 0:
|
| 757 |
+
sample_logprobs_dict.update(
|
| 758 |
+
zip(
|
| 759 |
+
top_token_ids[sample_idx +
|
| 760 |
+
parent_id, :num_logprobs].tolist(),
|
| 761 |
+
zip(
|
| 762 |
+
top_logprobs[sample_idx +
|
| 763 |
+
parent_id, :num_logprobs].tolist(),
|
| 764 |
+
range(1, num_logprobs + 1))))
|
| 765 |
+
group_sample_logprobs.append({
|
| 766 |
+
token_id: Logprob(*logprob_rank)
|
| 767 |
+
for token_id, logprob_rank in sample_logprobs_dict.items()
|
| 768 |
+
})
|
| 769 |
+
result_sample_logprobs.append(group_sample_logprobs)
|
| 770 |
+
sample_idx += len(seq_ids)
|
| 771 |
+
|
| 772 |
+
return result_prompt_logprobs, result_sample_logprobs
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
|
| 776 |
+
sample_indices: torch.Tensor,
|
| 777 |
+
greedy_samples: torch.Tensor) -> None:
|
| 778 |
+
"""Modify the probability distributions of the greedily-sampled tokens such
|
| 779 |
+
that each sampled token has a "probability" of 1.0. This is required by
|
| 780 |
+
speculative decoding, which depends on the sampling method being encoded
|
| 781 |
+
within the probability distribution for correctness.
|
| 782 |
+
|
| 783 |
+
# Why do we only need to do this for greedy sampling?
|
| 784 |
+
|
| 785 |
+
vLLM's sampler performs the following steps for greedy or multinomial
|
| 786 |
+
(random) sampling:
|
| 787 |
+
1. Get logits from model.
|
| 788 |
+
2. Modify logits according to per-sequence sampling parameters.
|
| 789 |
+
- Multiply by temperature, top-k and top-p masking, penalize tokens
|
| 790 |
+
according to their frequency, etc.
|
| 791 |
+
3. Sample a token.
|
| 792 |
+
- Random sampling simply samples from the modified probability
|
| 793 |
+
distribution.
|
| 794 |
+
- Greedy sampling performs `argmax` to obtain the token with the
|
| 795 |
+
highest likelihood.
|
| 796 |
+
|
| 797 |
+
Ignoring greedy sampling for a moment, we find that the computed probability
|
| 798 |
+
distribution has the following property: we can sample from it independently
|
| 799 |
+
and find that the token sampled by the Sampler has a frequency corresponding
|
| 800 |
+
to how often we see it in our sampling. In other words, for tokens sampled
|
| 801 |
+
with vLLM's random SamplingType, the computed probability distribution
|
| 802 |
+
encodes the sampling methodology completely.
|
| 803 |
+
|
| 804 |
+
Greedy sampling does not normally have this property. vLLM modifies logits
|
| 805 |
+
according to sampling params, then performs `argmax`, then returns the
|
| 806 |
+
sampled token and the computed probability distribution. If we sample from
|
| 807 |
+
the distribution, we'll find the likelihood of the greedily-sampled token
|
| 808 |
+
is not always 1.0.
|
| 809 |
+
|
| 810 |
+
Since lossless speculative decoding requires that the sampling methodology
|
| 811 |
+
be encoded within the probability distribution, we are motivated to modify
|
| 812 |
+
the probability distribution such that the sampled token has probability 1
|
| 813 |
+
when speculative decoding is used.
|
| 814 |
+
|
| 815 |
+
NOTE: Alternatively, we could use an extremely low temperature to achieve
|
| 816 |
+
greedy sampling using multinomial computation and unite the codepaths. This
|
| 817 |
+
has implications on the overall design of the sampler, e.g. how to record
|
| 818 |
+
accurate logprobs for the user, so this improvement is deferred to later.
|
| 819 |
+
"""
|
| 820 |
+
logprobs[sample_indices, :] = -float('inf')
|
| 821 |
+
logprobs[sample_indices, greedy_samples] = 0.0
|
| 822 |
+
probs[sample_indices, :] = 0
|
| 823 |
+
probs[sample_indices, greedy_samples] = 1.0
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
def _build_sampler_output(
|
| 827 |
+
sample_results: List[Tuple[List[int], List[int]]],
|
| 828 |
+
sampling_metadata: SamplingMetadata,
|
| 829 |
+
prompt_logprobs: List[Optional[PromptLogprobs]],
|
| 830 |
+
sample_logprobs: List[SampleLogprobs],
|
| 831 |
+
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
| 832 |
+
) -> SamplerOutput:
|
| 833 |
+
"""Construct Python objects with the output of sampling.
|
| 834 |
+
|
| 835 |
+
Args:
|
| 836 |
+
on_device_tensors: Tuple containing on-device tensors with the
|
| 837 |
+
probabilities used in sampling and the sampled token ids. This
|
| 838 |
+
allows post-processing without copies to CPU/serialization, e.g. in
|
| 839 |
+
speculative decoding rejection sampling.
|
| 840 |
+
"""
|
| 841 |
+
|
| 842 |
+
sampler_output = []
|
| 843 |
+
for (seq_group, sample_result, group_prompt_logprobs,
|
| 844 |
+
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
|
| 845 |
+
sample_results, prompt_logprobs,
|
| 846 |
+
sample_logprobs):
|
| 847 |
+
seq_ids, _ = seq_group
|
| 848 |
+
next_token_ids, parent_ids = sample_result
|
| 849 |
+
seq_outputs = []
|
| 850 |
+
for parent_id, next_token_id, logprobs in zip(parent_ids,
|
| 851 |
+
next_token_ids,
|
| 852 |
+
group_sample_logprobs):
|
| 853 |
+
seq_outputs.append(
|
| 854 |
+
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
|
| 855 |
+
sampler_output.append(
|
| 856 |
+
SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
|
| 857 |
+
|
| 858 |
+
# If not specified, store None values in SamplerOutput.
|
| 859 |
+
if on_device_tensors is not None:
|
| 860 |
+
sampled_token_probs, sampled_token_ids = on_device_tensors
|
| 861 |
+
else:
|
| 862 |
+
sampled_token_probs, sampled_token_ids = (None, None)
|
| 863 |
+
|
| 864 |
+
return SamplerOutput(
|
| 865 |
+
outputs=sampler_output,
|
| 866 |
+
sampled_token_probs=sampled_token_probs,
|
| 867 |
+
sampled_token_ids=sampled_token_ids,
|
| 868 |
+
)
|
serve/worker.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""A GPU worker class."""
|
| 2 |
+
import gc
|
| 3 |
+
import os
|
| 4 |
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.distributed
|
| 8 |
+
|
| 9 |
+
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
| 10 |
+
ModelConfig, ParallelConfig, SchedulerConfig,
|
| 11 |
+
VisionLanguageConfig)
|
| 12 |
+
from vllm.distributed import (broadcast_tensor_dict,
|
| 13 |
+
ensure_model_parallel_initialized,
|
| 14 |
+
init_distributed_environment)
|
| 15 |
+
from vllm.distributed.device_communicators import pynccl_utils
|
| 16 |
+
from vllm.distributed.device_communicators.custom_all_reduce import (
|
| 17 |
+
init_custom_ar)
|
| 18 |
+
from vllm.lora.request import LoRARequest
|
| 19 |
+
from vllm.model_executor import set_random_seed
|
| 20 |
+
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
| 21 |
+
from vllm.worker.cache_engine import CacheEngine
|
| 22 |
+
# from vllm.worker.model_runner import ModelRunner
|
| 23 |
+
from vllm.worker.worker_base import WorkerBase
|
| 24 |
+
from serve.model_runner import ModelRunner
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Worker(WorkerBase):
|
| 28 |
+
"""A worker class that executes (a partition of) the model on a GPU.
|
| 29 |
+
|
| 30 |
+
Each worker is associated with a single GPU. The worker is responsible for
|
| 31 |
+
maintaining the KV cache and executing the model on the GPU. In case of
|
| 32 |
+
distributed inference, each worker is assigned a partition of the model.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
model_config: ModelConfig,
|
| 38 |
+
parallel_config: ParallelConfig,
|
| 39 |
+
scheduler_config: SchedulerConfig,
|
| 40 |
+
device_config: DeviceConfig,
|
| 41 |
+
cache_config: CacheConfig,
|
| 42 |
+
load_config: LoadConfig,
|
| 43 |
+
local_rank: int,
|
| 44 |
+
rank: int,
|
| 45 |
+
distributed_init_method: str,
|
| 46 |
+
lora_config: Optional[LoRAConfig] = None,
|
| 47 |
+
vision_language_config: Optional[VisionLanguageConfig] = None,
|
| 48 |
+
is_driver_worker: bool = False,
|
| 49 |
+
) -> None:
|
| 50 |
+
self.model_config = model_config
|
| 51 |
+
self.parallel_config = parallel_config
|
| 52 |
+
self.scheduler_config = scheduler_config
|
| 53 |
+
self.device_config = device_config
|
| 54 |
+
self.cache_config = cache_config
|
| 55 |
+
self.local_rank = local_rank
|
| 56 |
+
self.rank = rank
|
| 57 |
+
self.distributed_init_method = distributed_init_method
|
| 58 |
+
self.lora_config = lora_config
|
| 59 |
+
self.load_config = load_config
|
| 60 |
+
self.is_driver_worker = is_driver_worker
|
| 61 |
+
if self.is_driver_worker:
|
| 62 |
+
assert self.rank == 0, "The driver worker must have rank 0."
|
| 63 |
+
|
| 64 |
+
if self.model_config.trust_remote_code:
|
| 65 |
+
# note: lazy import to avoid importing torch before initializing
|
| 66 |
+
from vllm.utils import init_cached_hf_modules
|
| 67 |
+
init_cached_hf_modules()
|
| 68 |
+
self.vision_language_config = vision_language_config
|
| 69 |
+
if self.vision_language_config:
|
| 70 |
+
assert not self.lora_config, (
|
| 71 |
+
"To be tested: vision language model with LoRA settings.")
|
| 72 |
+
|
| 73 |
+
self.model_runner = ModelRunner(
|
| 74 |
+
model_config,
|
| 75 |
+
parallel_config,
|
| 76 |
+
scheduler_config,
|
| 77 |
+
device_config,
|
| 78 |
+
load_config=load_config,
|
| 79 |
+
lora_config=self.lora_config,
|
| 80 |
+
kv_cache_dtype=self.cache_config.cache_dtype,
|
| 81 |
+
is_driver_worker=is_driver_worker,
|
| 82 |
+
vision_language_config=vision_language_config,
|
| 83 |
+
)
|
| 84 |
+
# Uninitialized cache engine. Will be initialized by
|
| 85 |
+
# initialize_cache.
|
| 86 |
+
self.cache_engine: CacheEngine
|
| 87 |
+
self.gpu_cache: List[torch.Tensor]
|
| 88 |
+
|
| 89 |
+
def init_device(self) -> None:
|
| 90 |
+
if self.device_config.device.type == "cuda":
|
| 91 |
+
# torch.distributed.all_reduce does not free the input tensor until
|
| 92 |
+
# the synchronization point. This causes the memory usage to grow
|
| 93 |
+
# as the number of all_reduce calls increases. This env var disables
|
| 94 |
+
# this behavior.
|
| 95 |
+
# Related issue:
|
| 96 |
+
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
|
| 97 |
+
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
| 98 |
+
|
| 99 |
+
# This env var set by Ray causes exceptions with graph building.
|
| 100 |
+
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
| 101 |
+
self.device = torch.device(f"cuda:{self.local_rank}")
|
| 102 |
+
torch.cuda.set_device(self.device)
|
| 103 |
+
|
| 104 |
+
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
| 105 |
+
torch.cuda.empty_cache()
|
| 106 |
+
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
|
| 107 |
+
else:
|
| 108 |
+
raise RuntimeError(
|
| 109 |
+
f"Not support device type: {self.device_config.device}")
|
| 110 |
+
# Initialize the distributed environment.
|
| 111 |
+
init_worker_distributed_environment(self.parallel_config, self.rank,
|
| 112 |
+
self.distributed_init_method,
|
| 113 |
+
self.local_rank)
|
| 114 |
+
# Set random seed.
|
| 115 |
+
set_random_seed(self.model_config.seed)
|
| 116 |
+
|
| 117 |
+
def load_model(self, args):
|
| 118 |
+
self.model_runner.load_model(args)
|
| 119 |
+
|
| 120 |
+
@torch.inference_mode()
|
| 121 |
+
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
| 122 |
+
"""Profiles the peak memory usage of the model to determine how many
|
| 123 |
+
KV blocks may be allocated without OOMs.
|
| 124 |
+
|
| 125 |
+
The engine will first conduct a profiling of the existing memory usage.
|
| 126 |
+
Then, it calculate the maximum possible number of GPU and CPU blocks
|
| 127 |
+
that can be allocated with the remaining free memory.
|
| 128 |
+
|
| 129 |
+
.. tip::
|
| 130 |
+
You may limit the usage of GPU memory
|
| 131 |
+
by adjusting the `gpu_memory_utilization` parameter.
|
| 132 |
+
"""
|
| 133 |
+
# Profile the memory usage of the model and get the maximum number of
|
| 134 |
+
# cache blocks that can be allocated with the remaining free memory.
|
| 135 |
+
torch.cuda.empty_cache()
|
| 136 |
+
|
| 137 |
+
# Execute a forward pass with dummy inputs to profile the memory usage
|
| 138 |
+
# of the model.
|
| 139 |
+
self.model_runner.profile_run()
|
| 140 |
+
|
| 141 |
+
# Calculate the number of blocks that can be allocated with the
|
| 142 |
+
# profiled peak memory.
|
| 143 |
+
torch.cuda.synchronize()
|
| 144 |
+
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
| 145 |
+
# NOTE(woosuk): Here we assume that the other processes using the same
|
| 146 |
+
# GPU did not change their memory usage during the profiling.
|
| 147 |
+
peak_memory = self.init_gpu_memory - free_gpu_memory
|
| 148 |
+
assert peak_memory > 0, (
|
| 149 |
+
"Error in memory profiling. This happens when the GPU memory was "
|
| 150 |
+
"not properly cleaned up before initializing the vLLM instance.")
|
| 151 |
+
|
| 152 |
+
cache_block_size = self.get_cache_block_size_bytes()
|
| 153 |
+
num_gpu_blocks = int(
|
| 154 |
+
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
|
| 155 |
+
peak_memory) // cache_block_size)
|
| 156 |
+
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
|
| 157 |
+
cache_block_size)
|
| 158 |
+
num_gpu_blocks = max(num_gpu_blocks, 0)
|
| 159 |
+
num_cpu_blocks = max(num_cpu_blocks, 0)
|
| 160 |
+
if self.model_runner.lora_manager:
|
| 161 |
+
self.model_runner.remove_all_loras()
|
| 162 |
+
gc.collect()
|
| 163 |
+
torch.cuda.empty_cache()
|
| 164 |
+
return num_gpu_blocks, num_cpu_blocks
|
| 165 |
+
|
| 166 |
+
def initialize_cache(self, num_gpu_blocks: int,
|
| 167 |
+
num_cpu_blocks: int) -> None:
|
| 168 |
+
"""Allocate GPU and CPU KV cache with the specified number of blocks.
|
| 169 |
+
|
| 170 |
+
This also warms up the model, which may record CUDA graphs.
|
| 171 |
+
"""
|
| 172 |
+
raise_if_cache_size_invalid(num_gpu_blocks,
|
| 173 |
+
self.cache_config.block_size,
|
| 174 |
+
self.model_config.max_model_len)
|
| 175 |
+
|
| 176 |
+
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
| 177 |
+
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
| 178 |
+
|
| 179 |
+
self._init_cache_engine()
|
| 180 |
+
self._warm_up_model()
|
| 181 |
+
|
| 182 |
+
def _init_cache_engine(self):
|
| 183 |
+
assert self.cache_config.num_gpu_blocks is not None
|
| 184 |
+
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
|
| 185 |
+
self.parallel_config)
|
| 186 |
+
self.gpu_cache = self.cache_engine.gpu_cache
|
| 187 |
+
self.model_runner.set_block_size(self.cache_engine.block_size)
|
| 188 |
+
|
| 189 |
+
def _warm_up_model(self) -> None:
|
| 190 |
+
if not self.model_config.enforce_eager:
|
| 191 |
+
self.model_runner.capture_model(self.gpu_cache)
|
| 192 |
+
# Reset the seed to ensure that the random state is not affected by
|
| 193 |
+
# the model initialization and profiling.
|
| 194 |
+
set_random_seed(self.model_config.seed)
|
| 195 |
+
|
| 196 |
+
def cache_swap(
|
| 197 |
+
self,
|
| 198 |
+
blocks_to_swap_in: Dict[int, int],
|
| 199 |
+
blocks_to_swap_out: Dict[int, int],
|
| 200 |
+
blocks_to_copy: Dict[int, List[int]],
|
| 201 |
+
) -> None:
|
| 202 |
+
# Issue cache operations.
|
| 203 |
+
# TODO(woosuk): Profile swapping overhead and optimize if needed.
|
| 204 |
+
if blocks_to_swap_in:
|
| 205 |
+
self.cache_engine.swap_in(blocks_to_swap_in)
|
| 206 |
+
if blocks_to_swap_out:
|
| 207 |
+
self.cache_engine.swap_out(blocks_to_swap_out)
|
| 208 |
+
if blocks_to_copy:
|
| 209 |
+
self.cache_engine.copy(blocks_to_copy)
|
| 210 |
+
|
| 211 |
+
@torch.inference_mode()
|
| 212 |
+
def execute_model(
|
| 213 |
+
self,
|
| 214 |
+
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
|
| 215 |
+
blocks_to_swap_in: Optional[Dict[int, int]] = None,
|
| 216 |
+
blocks_to_swap_out: Optional[Dict[int, int]] = None,
|
| 217 |
+
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
|
| 218 |
+
num_lookahead_slots: int = 0,
|
| 219 |
+
) -> List[SamplerOutput]:
|
| 220 |
+
|
| 221 |
+
if self.is_driver_worker:
|
| 222 |
+
assert seq_group_metadata_list is not None
|
| 223 |
+
num_seq_groups = len(seq_group_metadata_list)
|
| 224 |
+
assert blocks_to_swap_in is not None
|
| 225 |
+
assert blocks_to_swap_out is not None
|
| 226 |
+
assert blocks_to_copy is not None
|
| 227 |
+
data: Dict[str, Any] = {
|
| 228 |
+
"num_seq_groups": num_seq_groups,
|
| 229 |
+
"blocks_to_swap_in": blocks_to_swap_in,
|
| 230 |
+
"blocks_to_swap_out": blocks_to_swap_out,
|
| 231 |
+
"blocks_to_copy": blocks_to_copy,
|
| 232 |
+
}
|
| 233 |
+
broadcast_tensor_dict(data, src=0)
|
| 234 |
+
else:
|
| 235 |
+
data = broadcast_tensor_dict(src=0)
|
| 236 |
+
num_seq_groups = data["num_seq_groups"]
|
| 237 |
+
blocks_to_swap_in = data["blocks_to_swap_in"]
|
| 238 |
+
blocks_to_swap_out = data["blocks_to_swap_out"]
|
| 239 |
+
blocks_to_copy = data["blocks_to_copy"]
|
| 240 |
+
|
| 241 |
+
assert blocks_to_swap_in is not None
|
| 242 |
+
assert blocks_to_swap_out is not None
|
| 243 |
+
assert blocks_to_copy is not None
|
| 244 |
+
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
|
| 245 |
+
|
| 246 |
+
# If there is no input, we don't need to execute the model.
|
| 247 |
+
if num_seq_groups == 0:
|
| 248 |
+
return []
|
| 249 |
+
|
| 250 |
+
output = self.model_runner.execute_model(seq_group_metadata_list,
|
| 251 |
+
self.gpu_cache)
|
| 252 |
+
|
| 253 |
+
# Worker only supports single-step execution. Wrap the output in a list
|
| 254 |
+
# to conform to interface.
|
| 255 |
+
return [output]
|
| 256 |
+
|
| 257 |
+
def add_lora(self, lora_request: LoRARequest) -> bool:
|
| 258 |
+
return self.model_runner.add_lora(lora_request)
|
| 259 |
+
|
| 260 |
+
def remove_lora(self, lora_id: int) -> bool:
|
| 261 |
+
return self.model_runner.remove_lora(lora_id)
|
| 262 |
+
|
| 263 |
+
def list_loras(self) -> Set[int]:
|
| 264 |
+
return self.model_runner.list_loras()
|
| 265 |
+
|
| 266 |
+
@property
|
| 267 |
+
def max_model_len(self) -> int:
|
| 268 |
+
return self.model_config.max_model_len
|
| 269 |
+
|
| 270 |
+
@property
|
| 271 |
+
def vocab_size(self) -> int:
|
| 272 |
+
return self.model_runner.vocab_size
|
| 273 |
+
|
| 274 |
+
def get_cache_block_size_bytes(self) -> int:
|
| 275 |
+
"""Get the size of the KV cache block size in bytes.
|
| 276 |
+
"""
|
| 277 |
+
return CacheEngine.get_cache_block_size(self.cache_config,
|
| 278 |
+
self.model_config,
|
| 279 |
+
self.parallel_config)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def init_worker_distributed_environment(
|
| 283 |
+
parallel_config: ParallelConfig,
|
| 284 |
+
rank: int,
|
| 285 |
+
distributed_init_method: Optional[str] = None,
|
| 286 |
+
local_rank: int = -1,
|
| 287 |
+
) -> None:
|
| 288 |
+
"""Initialize the distributed environment."""
|
| 289 |
+
init_distributed_environment(parallel_config.world_size, rank,
|
| 290 |
+
distributed_init_method, local_rank)
|
| 291 |
+
|
| 292 |
+
if pynccl_utils.is_initialized():
|
| 293 |
+
pynccl_world_size = pynccl_utils.get_world_size()
|
| 294 |
+
if pynccl_world_size != parallel_config.world_size:
|
| 295 |
+
raise RuntimeError(
|
| 296 |
+
"pynccl is already initialized but the pynccl world "
|
| 297 |
+
"size does not match parallel_config.world_size "
|
| 298 |
+
f"({pynccl_world_size} vs. {parallel_config.world_size}).")
|
| 299 |
+
elif parallel_config.world_size > 1:
|
| 300 |
+
# NOTE(woosuk): We don't initialize pynccl process group when world size
|
| 301 |
+
# is 1.
|
| 302 |
+
pynccl_utils.init_process_group(
|
| 303 |
+
world_size=parallel_config.world_size,
|
| 304 |
+
local_rank=local_rank,
|
| 305 |
+
rank=rank,
|
| 306 |
+
init_method=distributed_init_method,
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
| 310 |
+
parallel_config.pipeline_parallel_size)
|
| 311 |
+
|
| 312 |
+
# Initialize a custom fast all-reduce implementation.
|
| 313 |
+
if not parallel_config.disable_custom_all_reduce:
|
| 314 |
+
init_custom_ar()
|
| 315 |
+
|
| 316 |
+
# A small all_reduce for warmup.
|
| 317 |
+
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
| 318 |
+
if pynccl_utils.is_initialized():
|
| 319 |
+
pynccl_utils.all_reduce(torch.zeros(1).cuda())
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
| 323 |
+
# Check if the GPU supports the dtype.
|
| 324 |
+
if torch_dtype == torch.bfloat16:
|
| 325 |
+
compute_capability = torch.cuda.get_device_capability()
|
| 326 |
+
if compute_capability[0] < 8:
|
| 327 |
+
gpu_name = torch.cuda.get_device_name()
|
| 328 |
+
raise ValueError(
|
| 329 |
+
"Bfloat16 is only supported on GPUs with compute capability "
|
| 330 |
+
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
|
| 331 |
+
f"{compute_capability[0]}.{compute_capability[1]}. "
|
| 332 |
+
"You can use float16 instead by explicitly setting the"
|
| 333 |
+
"`dtype` flag in CLI, for example: --dtype=half.")
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def raise_if_cache_size_invalid(num_gpu_blocks, block_size,
|
| 337 |
+
max_model_len) -> None:
|
| 338 |
+
if num_gpu_blocks <= 0:
|
| 339 |
+
raise ValueError("No available memory for the cache blocks. "
|
| 340 |
+
"Try increasing `gpu_memory_utilization` when "
|
| 341 |
+
"initializing the engine.")
|
| 342 |
+
max_seq_len = block_size * num_gpu_blocks
|
| 343 |
+
if max_model_len > max_seq_len:
|
| 344 |
+
raise ValueError(
|
| 345 |
+
f"The model's max seq len ({max_model_len}) "
|
| 346 |
+
"is larger than the maximum number of tokens that can be "
|
| 347 |
+
f"stored in KV cache ({max_seq_len}). Try increasing "
|
| 348 |
+
"`gpu_memory_utilization` or decreasing `max_model_len` when "
|
| 349 |
+
"initializing the engine.")
|