Spaces:
Running
Running
| import os | |
| import time | |
| import random | |
| import functools | |
| from typing import List, Optional, Tuple, Union | |
| from pathlib import Path | |
| from loguru import logger | |
| import torch | |
| import torch.distributed as dist | |
| from hyvideo.constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE | |
| from hyvideo.vae import load_vae | |
| from hyvideo.modules import load_model | |
| from hyvideo.text_encoder import TextEncoder | |
| from hyvideo.utils.data_utils import align_to | |
| from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed | |
| from hyvideo.modules.fp8_optimization import convert_fp8_linear | |
| from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler | |
| from hyvideo.diffusion.pipelines import HunyuanVideoPipeline | |
| try: | |
| import xfuser | |
| from xfuser.core.distributed import ( | |
| get_sequence_parallel_world_size, | |
| get_sequence_parallel_rank, | |
| get_sp_group, | |
| initialize_model_parallel, | |
| init_distributed_environment | |
| ) | |
| except: | |
| xfuser = None | |
| get_sequence_parallel_world_size = None | |
| get_sequence_parallel_rank = None | |
| get_sp_group = None | |
| initialize_model_parallel = None | |
| init_distributed_environment = None | |
| def parallelize_transformer(pipe): | |
| transformer = pipe.transformer | |
| original_forward = transformer.forward | |
| def new_forward( | |
| self, | |
| x: torch.Tensor, | |
| t: torch.Tensor, # Should be in range(0, 1000). | |
| text_states: torch.Tensor = None, | |
| text_mask: torch.Tensor = None, # Now we don't use it. | |
| text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation. | |
| freqs_cos: Optional[torch.Tensor] = None, | |
| freqs_sin: Optional[torch.Tensor] = None, | |
| guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000. | |
| return_dict: bool = True, | |
| ): | |
| if x.shape[-2] // 2 % get_sequence_parallel_world_size() == 0: | |
| # try to split x by height | |
| split_dim = -2 | |
| elif x.shape[-1] // 2 % get_sequence_parallel_world_size() == 0: | |
| # try to split x by width | |
| split_dim = -1 | |
| else: | |
| raise ValueError(f"Cannot split video sequence into ulysses_degree x ring_degree ({get_sequence_parallel_world_size()}) parts evenly") | |
| # patch sizes for the temporal, height, and width dimensions are 1, 2, and 2. | |
| temporal_size, h, w = x.shape[2], x.shape[3] // 2, x.shape[4] // 2 | |
| x = torch.chunk(x, get_sequence_parallel_world_size(),dim=split_dim)[get_sequence_parallel_rank()] | |
| dim_thw = freqs_cos.shape[-1] | |
| freqs_cos = freqs_cos.reshape(temporal_size, h, w, dim_thw) | |
| freqs_cos = torch.chunk(freqs_cos, get_sequence_parallel_world_size(),dim=split_dim - 1)[get_sequence_parallel_rank()] | |
| freqs_cos = freqs_cos.reshape(-1, dim_thw) | |
| dim_thw = freqs_sin.shape[-1] | |
| freqs_sin = freqs_sin.reshape(temporal_size, h, w, dim_thw) | |
| freqs_sin = torch.chunk(freqs_sin, get_sequence_parallel_world_size(),dim=split_dim - 1)[get_sequence_parallel_rank()] | |
| freqs_sin = freqs_sin.reshape(-1, dim_thw) | |
| from xfuser.core.long_ctx_attention import xFuserLongContextAttention | |
| for block in transformer.double_blocks + transformer.single_blocks: | |
| block.hybrid_seq_parallel_attn = xFuserLongContextAttention() | |
| output = original_forward( | |
| x, | |
| t, | |
| text_states, | |
| text_mask, | |
| text_states_2, | |
| freqs_cos, | |
| freqs_sin, | |
| guidance, | |
| return_dict, | |
| ) | |
| return_dict = not isinstance(output, tuple) | |
| sample = output["x"] | |
| sample = get_sp_group().all_gather(sample, dim=split_dim) | |
| output["x"] = sample | |
| return output | |
| new_forward = new_forward.__get__(transformer) | |
| transformer.forward = new_forward | |
| class Inference(object): | |
| def __init__( | |
| self, | |
| args, | |
| vae, | |
| vae_kwargs, | |
| text_encoder, | |
| model, | |
| text_encoder_2=None, | |
| pipeline=None, | |
| use_cpu_offload=False, | |
| device=None, | |
| logger=None, | |
| parallel_args=None, | |
| ): | |
| self.vae = vae | |
| self.vae_kwargs = vae_kwargs | |
| self.text_encoder = text_encoder | |
| self.text_encoder_2 = text_encoder_2 | |
| self.model = model | |
| self.pipeline = pipeline | |
| self.use_cpu_offload = use_cpu_offload | |
| self.args = args | |
| self.device = ( | |
| device | |
| if device is not None | |
| else "cuda" | |
| if torch.cuda.is_available() | |
| else "cpu" | |
| ) | |
| self.logger = logger | |
| self.parallel_args = parallel_args | |
| def from_pretrained(cls, pretrained_model_path, args, device=None, **kwargs): | |
| """ | |
| Initialize the Inference pipeline. | |
| Args: | |
| pretrained_model_path (str or pathlib.Path): The model path, including t2v, text encoder and vae checkpoints. | |
| args (argparse.Namespace): The arguments for the pipeline. | |
| device (int): The device for inference. Default is 0. | |
| """ | |
| # ======================================================================== | |
| logger.info(f"Got text-to-video model root path: {pretrained_model_path}") | |
| # ==================== Initialize Distributed Environment ================ | |
| if args.ulysses_degree > 1 or args.ring_degree > 1: | |
| assert xfuser is not None, \ | |
| "Ulysses Attention and Ring Attention requires xfuser package." | |
| assert args.use_cpu_offload is False, \ | |
| "Cannot enable use_cpu_offload in the distributed environment." | |
| dist.init_process_group("nccl") | |
| assert dist.get_world_size() == args.ring_degree * args.ulysses_degree, \ | |
| "number of GPUs should be equal to ring_degree * ulysses_degree." | |
| init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) | |
| initialize_model_parallel( | |
| sequence_parallel_degree=dist.get_world_size(), | |
| ring_degree=args.ring_degree, | |
| ulysses_degree=args.ulysses_degree, | |
| ) | |
| device = torch.device(f"cuda:{os.environ['LOCAL_RANK']}") | |
| else: | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| parallel_args = {"ulysses_degree": args.ulysses_degree, "ring_degree": args.ring_degree} | |
| # ======================== Get the args path ============================= | |
| # Disable gradient | |
| torch.set_grad_enabled(False) | |
| # =========================== Build main model =========================== | |
| logger.info("Building model...") | |
| factor_kwargs = {"device": device, "dtype": PRECISION_TO_TYPE[args.precision]} | |
| in_channels = args.latent_channels | |
| out_channels = args.latent_channels | |
| model = load_model( | |
| args, | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| factor_kwargs=factor_kwargs, | |
| ) | |
| if args.use_fp8: | |
| convert_fp8_linear(model, args.dit_weight, original_dtype=PRECISION_TO_TYPE[args.precision]) | |
| model = model.to(device) | |
| model = Inference.load_state_dict(args, model, pretrained_model_path) | |
| model.eval() | |
| # ============================= Build extra models ======================== | |
| # VAE | |
| vae, _, s_ratio, t_ratio = load_vae( | |
| args.vae, | |
| args.vae_precision, | |
| logger=logger, | |
| device=device if not args.use_cpu_offload else "cpu", | |
| ) | |
| vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio} | |
| # Text encoder | |
| if args.prompt_template_video is not None: | |
| crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get( | |
| "crop_start", 0 | |
| ) | |
| elif args.prompt_template is not None: | |
| crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0) | |
| else: | |
| crop_start = 0 | |
| max_length = args.text_len + crop_start | |
| # prompt_template | |
| prompt_template = ( | |
| PROMPT_TEMPLATE[args.prompt_template] | |
| if args.prompt_template is not None | |
| else None | |
| ) | |
| # prompt_template_video | |
| prompt_template_video = ( | |
| PROMPT_TEMPLATE[args.prompt_template_video] | |
| if args.prompt_template_video is not None | |
| else None | |
| ) | |
| text_encoder = TextEncoder( | |
| text_encoder_type=args.text_encoder, | |
| max_length=max_length, | |
| text_encoder_precision=args.text_encoder_precision, | |
| tokenizer_type=args.tokenizer, | |
| prompt_template=prompt_template, | |
| prompt_template_video=prompt_template_video, | |
| hidden_state_skip_layer=args.hidden_state_skip_layer, | |
| apply_final_norm=args.apply_final_norm, | |
| reproduce=args.reproduce, | |
| logger=logger, | |
| device=device if not args.use_cpu_offload else "cpu", | |
| ) | |
| text_encoder_2 = None | |
| if args.text_encoder_2 is not None: | |
| text_encoder_2 = TextEncoder( | |
| text_encoder_type=args.text_encoder_2, | |
| max_length=args.text_len_2, | |
| text_encoder_precision=args.text_encoder_precision_2, | |
| tokenizer_type=args.tokenizer_2, | |
| reproduce=args.reproduce, | |
| logger=logger, | |
| device=device if not args.use_cpu_offload else "cpu", | |
| ) | |
| return cls( | |
| args=args, | |
| vae=vae, | |
| vae_kwargs=vae_kwargs, | |
| text_encoder=text_encoder, | |
| text_encoder_2=text_encoder_2, | |
| model=model, | |
| use_cpu_offload=args.use_cpu_offload, | |
| device=device, | |
| logger=logger, | |
| parallel_args=parallel_args | |
| ) | |
| def load_state_dict(args, model, pretrained_model_path): | |
| load_key = args.load_key | |
| dit_weight = Path(args.dit_weight) | |
| if dit_weight is None: | |
| model_dir = pretrained_model_path / f"t2v_{args.model_resolution}" | |
| files = list(model_dir.glob("*.pt")) | |
| if len(files) == 0: | |
| raise ValueError(f"No model weights found in {model_dir}") | |
| if str(files[0]).startswith("pytorch_model_"): | |
| model_path = dit_weight / f"pytorch_model_{load_key}.pt" | |
| bare_model = True | |
| elif any(str(f).endswith("_model_states.pt") for f in files): | |
| files = [f for f in files if str(f).endswith("_model_states.pt")] | |
| model_path = files[0] | |
| if len(files) > 1: | |
| logger.warning( | |
| f"Multiple model weights found in {dit_weight}, using {model_path}" | |
| ) | |
| bare_model = False | |
| else: | |
| raise ValueError( | |
| f"Invalid model path: {dit_weight} with unrecognized weight format: " | |
| f"{list(map(str, files))}. When given a directory as --dit-weight, only " | |
| f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and " | |
| f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a " | |
| f"specific weight file, please provide the full path to the file." | |
| ) | |
| else: | |
| if dit_weight.is_dir(): | |
| files = list(dit_weight.glob("*.pt")) | |
| if len(files) == 0: | |
| raise ValueError(f"No model weights found in {dit_weight}") | |
| if str(files[0]).startswith("pytorch_model_"): | |
| model_path = dit_weight / f"pytorch_model_{load_key}.pt" | |
| bare_model = True | |
| elif any(str(f).endswith("_model_states.pt") for f in files): | |
| files = [f for f in files if str(f).endswith("_model_states.pt")] | |
| model_path = files[0] | |
| if len(files) > 1: | |
| logger.warning( | |
| f"Multiple model weights found in {dit_weight}, using {model_path}" | |
| ) | |
| bare_model = False | |
| else: | |
| raise ValueError( | |
| f"Invalid model path: {dit_weight} with unrecognized weight format: " | |
| f"{list(map(str, files))}. When given a directory as --dit-weight, only " | |
| f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and " | |
| f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a " | |
| f"specific weight file, please provide the full path to the file." | |
| ) | |
| elif dit_weight.is_file(): | |
| model_path = dit_weight | |
| bare_model = "unknown" | |
| else: | |
| raise ValueError(f"Invalid model path: {dit_weight}") | |
| if not model_path.exists(): | |
| raise ValueError(f"model_path not exists: {model_path}") | |
| logger.info(f"Loading torch model {model_path}...") | |
| state_dict = torch.load(model_path, map_location=lambda storage, loc: storage) | |
| if bare_model == "unknown" and ("ema" in state_dict or "module" in state_dict): | |
| bare_model = False | |
| if bare_model is False: | |
| if load_key in state_dict: | |
| state_dict = state_dict[load_key] | |
| else: | |
| raise KeyError( | |
| f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint " | |
| f"are: {list(state_dict.keys())}." | |
| ) | |
| model.load_state_dict(state_dict, strict=True) | |
| return model | |
| def parse_size(size): | |
| if isinstance(size, int): | |
| size = [size] | |
| if not isinstance(size, (list, tuple)): | |
| raise ValueError(f"Size must be an integer or (height, width), got {size}.") | |
| if len(size) == 1: | |
| size = [size[0], size[0]] | |
| if len(size) != 2: | |
| raise ValueError(f"Size must be an integer or (height, width), got {size}.") | |
| return size | |
| class HunyuanVideoSampler(Inference): | |
| def __init__( | |
| self, | |
| args, | |
| vae, | |
| vae_kwargs, | |
| text_encoder, | |
| model, | |
| text_encoder_2=None, | |
| pipeline=None, | |
| use_cpu_offload=False, | |
| device=0, | |
| logger=None, | |
| parallel_args=None | |
| ): | |
| super().__init__( | |
| args, | |
| vae, | |
| vae_kwargs, | |
| text_encoder, | |
| model, | |
| text_encoder_2=text_encoder_2, | |
| pipeline=pipeline, | |
| use_cpu_offload=use_cpu_offload, | |
| device=device, | |
| logger=logger, | |
| parallel_args=parallel_args | |
| ) | |
| self.pipeline = self.load_diffusion_pipeline( | |
| args=args, | |
| vae=self.vae, | |
| text_encoder=self.text_encoder, | |
| text_encoder_2=self.text_encoder_2, | |
| model=self.model, | |
| device=self.device, | |
| ) | |
| self.default_negative_prompt = NEGATIVE_PROMPT | |
| if self.parallel_args['ulysses_degree'] > 1 or self.parallel_args['ring_degree'] > 1: | |
| parallelize_transformer(self.pipeline) | |
| def load_diffusion_pipeline( | |
| self, | |
| args, | |
| vae, | |
| text_encoder, | |
| text_encoder_2, | |
| model, | |
| scheduler=None, | |
| device=None, | |
| progress_bar_config=None, | |
| data_type="video", | |
| ): | |
| """Load the denoising scheduler for inference.""" | |
| if scheduler is None: | |
| if args.denoise_type == "flow": | |
| scheduler = FlowMatchDiscreteScheduler( | |
| shift=args.flow_shift, | |
| reverse=args.flow_reverse, | |
| solver=args.flow_solver, | |
| ) | |
| else: | |
| raise ValueError(f"Invalid denoise type {args.denoise_type}") | |
| pipeline = HunyuanVideoPipeline( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| text_encoder_2=text_encoder_2, | |
| transformer=model, | |
| scheduler=scheduler, | |
| progress_bar_config=progress_bar_config, | |
| args=args, | |
| ) | |
| if self.use_cpu_offload: | |
| pipeline.enable_sequential_cpu_offload() | |
| else: | |
| pipeline = pipeline.to(device) | |
| return pipeline | |
| def get_rotary_pos_embed(self, video_length, height, width): | |
| target_ndim = 3 | |
| ndim = 5 - 2 | |
| # 884 | |
| if "884" in self.args.vae: | |
| latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8] | |
| elif "888" in self.args.vae: | |
| latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8] | |
| else: | |
| latents_size = [video_length, height // 8, width // 8] | |
| if isinstance(self.model.patch_size, int): | |
| assert all(s % self.model.patch_size == 0 for s in latents_size), ( | |
| f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " | |
| f"but got {latents_size}." | |
| ) | |
| rope_sizes = [s // self.model.patch_size for s in latents_size] | |
| elif isinstance(self.model.patch_size, list): | |
| assert all( | |
| s % self.model.patch_size[idx] == 0 | |
| for idx, s in enumerate(latents_size) | |
| ), ( | |
| f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " | |
| f"but got {latents_size}." | |
| ) | |
| rope_sizes = [ | |
| s // self.model.patch_size[idx] for idx, s in enumerate(latents_size) | |
| ] | |
| if len(rope_sizes) != target_ndim: | |
| rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis | |
| head_dim = self.model.hidden_size // self.model.heads_num | |
| rope_dim_list = self.model.rope_dim_list | |
| if rope_dim_list is None: | |
| rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] | |
| assert ( | |
| sum(rope_dim_list) == head_dim | |
| ), "sum(rope_dim_list) should equal to head_dim of attention layer" | |
| freqs_cos, freqs_sin = get_nd_rotary_pos_embed( | |
| rope_dim_list, | |
| rope_sizes, | |
| theta=self.args.rope_theta, | |
| use_real=True, | |
| theta_rescale_factor=1, | |
| ) | |
| return freqs_cos, freqs_sin | |
| def predict( | |
| self, | |
| prompt, | |
| height=192, | |
| width=336, | |
| video_length=129, | |
| seed=None, | |
| negative_prompt=None, | |
| infer_steps=50, | |
| guidance_scale=6, | |
| flow_shift=5.0, | |
| embedded_guidance_scale=None, | |
| batch_size=1, | |
| num_videos_per_prompt=1, | |
| **kwargs, | |
| ): | |
| """ | |
| Predict the image/video from the given text. | |
| Args: | |
| prompt (str or List[str]): The input text. | |
| kwargs: | |
| height (int): The height of the output video. Default is 192. | |
| width (int): The width of the output video. Default is 336. | |
| video_length (int): The frame number of the output video. Default is 129. | |
| seed (int or List[str]): The random seed for the generation. Default is a random integer. | |
| negative_prompt (str or List[str]): The negative text prompt. Default is an empty string. | |
| guidance_scale (float): The guidance scale for the generation. Default is 6.0. | |
| num_images_per_prompt (int): The number of images per prompt. Default is 1. | |
| infer_steps (int): The number of inference steps. Default is 100. | |
| """ | |
| out_dict = dict() | |
| # ======================================================================== | |
| # Arguments: seed | |
| # ======================================================================== | |
| if isinstance(seed, torch.Tensor): | |
| seed = seed.tolist() | |
| if seed is None: | |
| seeds = [ | |
| random.randint(0, 1_000_000) | |
| for _ in range(batch_size * num_videos_per_prompt) | |
| ] | |
| elif isinstance(seed, int): | |
| seeds = [ | |
| seed + i | |
| for _ in range(batch_size) | |
| for i in range(num_videos_per_prompt) | |
| ] | |
| elif isinstance(seed, (list, tuple)): | |
| if len(seed) == batch_size: | |
| seeds = [ | |
| int(seed[i]) + j | |
| for i in range(batch_size) | |
| for j in range(num_videos_per_prompt) | |
| ] | |
| elif len(seed) == batch_size * num_videos_per_prompt: | |
| seeds = [int(s) for s in seed] | |
| else: | |
| raise ValueError( | |
| f"Length of seed must be equal to number of prompt(batch_size) or " | |
| f"batch_size * num_videos_per_prompt ({batch_size} * {num_videos_per_prompt}), got {seed}." | |
| ) | |
| else: | |
| raise ValueError( | |
| f"Seed must be an integer, a list of integers, or None, got {seed}." | |
| ) | |
| generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds] | |
| out_dict["seeds"] = seeds | |
| # ======================================================================== | |
| # Arguments: target_width, target_height, target_video_length | |
| # ======================================================================== | |
| if width <= 0 or height <= 0 or video_length <= 0: | |
| raise ValueError( | |
| f"`height` and `width` and `video_length` must be positive integers, got height={height}, width={width}, video_length={video_length}" | |
| ) | |
| if (video_length - 1) % 4 != 0: | |
| raise ValueError( | |
| f"`video_length-1` must be a multiple of 4, got {video_length}" | |
| ) | |
| logger.info( | |
| f"Input (height, width, video_length) = ({height}, {width}, {video_length})" | |
| ) | |
| target_height = align_to(height, 16) | |
| target_width = align_to(width, 16) | |
| target_video_length = video_length | |
| out_dict["size"] = (target_height, target_width, target_video_length) | |
| # ======================================================================== | |
| # Arguments: prompt, new_prompt, negative_prompt | |
| # ======================================================================== | |
| if not isinstance(prompt, str): | |
| raise TypeError(f"`prompt` must be a string, but got {type(prompt)}") | |
| prompt = [prompt.strip()] | |
| # negative prompt | |
| if negative_prompt is None or negative_prompt == "": | |
| negative_prompt = self.default_negative_prompt | |
| if not isinstance(negative_prompt, str): | |
| raise TypeError( | |
| f"`negative_prompt` must be a string, but got {type(negative_prompt)}" | |
| ) | |
| negative_prompt = [negative_prompt.strip()] | |
| # ======================================================================== | |
| # Scheduler | |
| # ======================================================================== | |
| scheduler = FlowMatchDiscreteScheduler( | |
| shift=flow_shift, | |
| reverse=self.args.flow_reverse, | |
| solver=self.args.flow_solver | |
| ) | |
| self.pipeline.scheduler = scheduler | |
| # ======================================================================== | |
| # Build Rope freqs | |
| # ======================================================================== | |
| freqs_cos, freqs_sin = self.get_rotary_pos_embed( | |
| target_video_length, target_height, target_width | |
| ) | |
| n_tokens = freqs_cos.shape[0] | |
| # ======================================================================== | |
| # Print infer args | |
| # ======================================================================== | |
| debug_str = f""" | |
| height: {target_height} | |
| width: {target_width} | |
| video_length: {target_video_length} | |
| prompt: {prompt} | |
| neg_prompt: {negative_prompt} | |
| seed: {seed} | |
| infer_steps: {infer_steps} | |
| num_videos_per_prompt: {num_videos_per_prompt} | |
| guidance_scale: {guidance_scale} | |
| n_tokens: {n_tokens} | |
| flow_shift: {flow_shift} | |
| embedded_guidance_scale: {embedded_guidance_scale}""" | |
| logger.debug(debug_str) | |
| # ======================================================================== | |
| # Pipeline inference | |
| # ======================================================================== | |
| start_time = time.time() | |
| samples = self.pipeline( | |
| prompt=prompt, | |
| height=target_height, | |
| width=target_width, | |
| video_length=target_video_length, | |
| num_inference_steps=infer_steps, | |
| guidance_scale=guidance_scale, | |
| negative_prompt=negative_prompt, | |
| num_videos_per_prompt=num_videos_per_prompt, | |
| generator=generator, | |
| output_type="pil", | |
| freqs_cis=(freqs_cos, freqs_sin), | |
| n_tokens=n_tokens, | |
| embedded_guidance_scale=embedded_guidance_scale, | |
| data_type="video" if target_video_length > 1 else "image", | |
| is_progress_bar=True, | |
| vae_ver=self.args.vae, | |
| enable_tiling=self.args.vae_tiling, | |
| )[0] | |
| out_dict["samples"] = samples | |
| out_dict["prompts"] = prompt | |
| gen_time = time.time() - start_time | |
| logger.info(f"Success, time: {gen_time}") | |
| return out_dict | |