|
import logging |
|
import sys |
|
import os |
|
import torch |
|
import json |
|
from typing import Optional, Tuple, Union, List, Callable |
|
from transformers import LlamaForCausalLM |
|
from transformers.generation.logits_process import LogitsProcessor |
|
from transformers.generation.beam_search import BeamSearchScorer |
|
from transformers.deepspeed import is_deepspeed_zero3_enabled |
|
from transformers.generation.utils import ( |
|
LogitsProcessorList, |
|
StoppingCriteriaList, |
|
GenerationConfig, |
|
GenerationMixin, |
|
) |
|
import warnings |
|
from peft import PeftModel, PeftModelForCausalLM, LoraConfig |
|
import peft |
|
import torch.distributed as dist |
|
from torch import nn |
|
import copy |
|
from accelerate.hooks import ( |
|
AlignDevicesHook, |
|
add_hook_to_module, |
|
remove_hook_from_submodules, |
|
) |
|
from accelerate.utils import get_balanced_memory |
|
from huggingface_hub import hf_hub_download |
|
from accelerate import dispatch_model, infer_auto_device_map |
|
from peft.utils import PeftType, set_peft_model_state_dict |
|
|
|
def printf(*args,**kargs): |
|
if os.environ.get('DEBUG',False): |
|
end = '\n' |
|
if 'end' in kargs: |
|
end = kargs['end'] |
|
print(*args, end=end, flush=True) |
|
|
|
class ColorFormatter(logging.Formatter): |
|
|
|
grey = "\x1b[38;20m" |
|
blue = "\x1b[34;20m" |
|
yellow = "\x1b[33;20m" |
|
red = "\x1b[31;20m" |
|
bold_red = "\x1b[31;1m" |
|
reset = "\x1b[0m" |
|
|
|
def __init__(self, fmt): |
|
super().__init__(fmt) |
|
self.FORMATS = { |
|
logging.DEBUG: self.grey + fmt + self.reset, |
|
logging.INFO: self.blue + fmt + self.reset, |
|
logging.WARNING: self.yellow + fmt + self.reset, |
|
logging.ERROR: self.red + fmt + self.reset, |
|
logging.CRITICAL: self.bold_red + fmt + self.reset |
|
} |
|
|
|
def format(self, record): |
|
log_fmt = self.FORMATS.get(record.levelno) |
|
formatter = logging.Formatter(log_fmt) |
|
return formatter.format(record) |
|
|
|
def set_console_logger(name): |
|
logger = logging.getLogger(name) |
|
logger.setLevel(logging.DEBUG) |
|
consoleHandler = logging.StreamHandler(sys.stdout) |
|
consoleHandler.setLevel(logging.INFO) |
|
consoleHandler.setFormatter(ColorFormatter("%(asctime)s | %(levelname)s %(message)s")) |
|
logger.addHandler(consoleHandler) |
|
return logger |
|
|
|
def set_file_logger(name, dir, use_console=False): |
|
logger = logging.getLogger(name) |
|
logger.setLevel(logging.DEBUG) |
|
os.makedirs(dir, exist_ok=True) |
|
|
|
if use_console: |
|
logger.propagate = False |
|
consoleHandler = logging.StreamHandler(sys.stdout) |
|
consoleHandler.setLevel(logging.INFO) |
|
consoleHandler.setFormatter(ColorFormatter("%(asctime)s | %(levelname)s %(message)s")) |
|
logger.addHandler(consoleHandler) |
|
|
|
fileHandler = logging.FileHandler(os.path.join(dir,'session.log'), mode='a') |
|
fileHandler.setLevel(logging.INFO) |
|
fileHandler.setFormatter(logging.Formatter("%(asctime)s | %(levelname)s %(message)s")) |
|
logger.addHandler(fileHandler) |
|
return logger |
|
|
|
def to_jsonl(data, path): |
|
with open(path, 'a') as f: |
|
for line in data: |
|
f.write(json.dumps(line,ensure_ascii=False)+'\n') |
|
|
|
def from_json(path): |
|
return json.load(open(path)) |
|
|
|
def from_jsonl(path): |
|
return [json.loads(line) for line in open(path, 'r') ] |
|
|
|
def to_json(data, path): |
|
json.dump(data, open(path, 'w'), ensure_ascii=False) |
|
|
|
class StreamGenerationMixin(GenerationMixin): |
|
|
|
|
|
@torch.no_grad() |
|
def stream_generate( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
generation_config: Optional[GenerationConfig] = None, |
|
logits_processor: Optional[LogitsProcessorList] = None, |
|
stopping_criteria: Optional[StoppingCriteriaList] = None, |
|
prefix_allowed_tokens_fn: Optional[ |
|
Callable[[int, torch.Tensor], List[int]] |
|
] = None, |
|
**kwargs, |
|
): |
|
if is_deepspeed_zero3_enabled() and dist.world_size() > 1: |
|
synced_gpus = True |
|
else: |
|
synced_gpus = False |
|
|
|
if kwargs.get("attention_mask", None) is not None: |
|
|
|
prefix_attention_mask = torch.ones( |
|
kwargs["input_ids"].shape[0], self.peft_config.num_virtual_tokens |
|
).to(kwargs["input_ids"].device) |
|
kwargs["attention_mask"] = torch.cat( |
|
(prefix_attention_mask, kwargs["attention_mask"]), dim=1 |
|
) |
|
if kwargs.get("position_ids", None) is not None: |
|
warnings.warn( |
|
"Position ids are not supported for parameter efficient tuning. Ignoring position ids." |
|
) |
|
kwargs["position_ids"] = None |
|
if kwargs.get("token_type_ids", None) is not None: |
|
warnings.warn( |
|
"Token type ids are not supported for parameter efficient tuning. Ignoring token type ids" |
|
) |
|
kwargs["token_type_ids"] = None |
|
|
|
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] |
|
|
|
if generation_config is None: |
|
generation_config = self.generation_config |
|
generation_config = copy.deepcopy(generation_config) |
|
model_kwargs = generation_config.update(**kwargs) |
|
|
|
bos_token_id, eos_token_id, pad_token_id = ( |
|
generation_config.bos_token_id, |
|
generation_config.eos_token_id, |
|
generation_config.pad_token_id, |
|
) |
|
|
|
if isinstance(eos_token_id, int): |
|
eos_token_id = [eos_token_id] |
|
|
|
has_default_max_length = ( |
|
kwargs.get("max_length") is None |
|
and generation_config.max_length is not None |
|
) |
|
if has_default_max_length and generation_config.max_new_tokens is None: |
|
warnings.warn( |
|
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " |
|
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" |
|
" recommend using `max_new_tokens` to control the maximum length of the generation.", |
|
UserWarning, |
|
) |
|
elif generation_config.max_new_tokens is not None: |
|
generation_config.max_length = ( |
|
generation_config.max_new_tokens + input_ids_seq_length |
|
) |
|
if generation_config.min_new_tokens is not None: |
|
generation_config.min_length = ( |
|
generation_config.min_new_tokens + input_ids_seq_length |
|
) |
|
|
|
if input_ids_seq_length >= generation_config.max_length: |
|
input_ids_string = ( |
|
"decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" |
|
) |
|
|
|
|
|
logits_processor = ( |
|
logits_processor if logits_processor is not None else LogitsProcessorList() |
|
) |
|
stopping_criteria = ( |
|
stopping_criteria |
|
if stopping_criteria is not None |
|
else StoppingCriteriaList() |
|
) |
|
|
|
is_constraint_gen_mode = ( |
|
generation_config.constraints is not None or generation_config.force_words_ids is not None |
|
) |
|
|
|
is_contrastive_search_gen_mode = ( |
|
generation_config.top_k is not None |
|
and generation_config.top_k > 1 |
|
and generation_config.do_sample is False |
|
and generation_config.penalty_alpha is not None |
|
and generation_config.penalty_alpha > 0 |
|
) |
|
|
|
is_greedy_gen_mode = ( |
|
(generation_config.num_beams == 1) |
|
and (generation_config.num_beam_groups == 1) |
|
and generation_config.do_sample is False |
|
and not is_constraint_gen_mode |
|
and not is_contrastive_search_gen_mode |
|
) |
|
|
|
is_sample_gen_mode = ( |
|
(generation_config.num_beams == 1) |
|
and (generation_config.num_beam_groups == 1) |
|
and generation_config.do_sample is True |
|
and not is_constraint_gen_mode |
|
and not is_contrastive_search_gen_mode |
|
) |
|
is_beam_gen_mode = ( |
|
(generation_config.num_beams > 1) |
|
and (generation_config.num_beam_groups == 1) |
|
and generation_config.do_sample is False |
|
and not is_constraint_gen_mode |
|
and not is_contrastive_search_gen_mode |
|
) |
|
is_beam_sample_gen_mode = ( |
|
(generation_config.num_beams > 1) |
|
and (generation_config.num_beam_groups == 1) |
|
and generation_config.do_sample is True |
|
and not is_constraint_gen_mode |
|
and not is_contrastive_search_gen_mode |
|
) |
|
is_group_beam_gen_mode = ( |
|
(generation_config.num_beams > 1) |
|
and (generation_config.num_beam_groups > 1) |
|
and not is_constraint_gen_mode |
|
and not is_contrastive_search_gen_mode |
|
) |
|
|
|
logits_processor = self._get_logits_processor( |
|
generation_config=generation_config, |
|
input_ids_seq_length=input_ids_seq_length, |
|
encoder_input_ids=input_ids, |
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, |
|
logits_processor=logits_processor, |
|
) |
|
|
|
stopping_criteria = self._get_stopping_criteria( |
|
generation_config=generation_config, stopping_criteria=stopping_criteria |
|
) |
|
logits_warper = self._get_logits_warper(generation_config) |
|
|
|
if is_greedy_gen_mode: |
|
|
|
return self.stream_greedy_search( |
|
input_ids, |
|
logits_processor, |
|
stopping_criteria, |
|
generation_config, |
|
synced_gpus, |
|
**model_kwargs, |
|
) |
|
elif is_sample_gen_mode: |
|
|
|
input_ids, model_kwargs = self._expand_inputs_for_generation( |
|
input_ids=input_ids, |
|
expand_size=generation_config.num_return_sequences, |
|
is_encoder_decoder=self.config.is_encoder_decoder, |
|
**model_kwargs, |
|
) |
|
return self.stream_sample( |
|
generation_config, |
|
input_ids, |
|
logits_processor, |
|
logits_warper, |
|
stopping_criteria, |
|
synced_gpus, |
|
**model_kwargs, |
|
) |
|
elif is_beam_gen_mode: |
|
return self.stream_beam_search( |
|
generation_config, |
|
input_ids, |
|
logits_processor, |
|
stopping_criteria, |
|
synced_gpus, |
|
**model_kwargs, |
|
) |
|
elif is_beam_sample_gen_mode: |
|
|
|
return self.stream_beam_sample( |
|
input_ids, |
|
logits_processor, |
|
logits_warper, |
|
stopping_criteria, |
|
generation_config, |
|
synced_gpus, |
|
**model_kwargs, |
|
) |
|
else: |
|
raise Exception('not implement') |
|
|
|
def stream_sample( |
|
self, |
|
generation_config, |
|
input_ids, |
|
logits_processor, |
|
logits_warper, |
|
stopping_criteria, |
|
synced_gpus, |
|
**model_kwargs, |
|
): |
|
bos_token_id, eos_token_id, pad_token_id = ( |
|
generation_config.bos_token_id, |
|
generation_config.eos_token_id, |
|
generation_config.pad_token_id, |
|
) |
|
if isinstance(eos_token_id, int): |
|
eos_token_id = [eos_token_id] |
|
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None |
|
|
|
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) |
|
this_peer_finished = False |
|
scores=() |
|
|
|
while True: |
|
if synced_gpus: |
|
|
|
|
|
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
|
|
|
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
|
|
|
if this_peer_finished_flag.item() == 0.0: |
|
break |
|
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
|
|
outputs = self( |
|
**model_inputs, |
|
return_dict=True, |
|
) |
|
if synced_gpus and this_peer_finished: |
|
continue |
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
next_token_scores = logits_processor(input_ids, next_token_logits) |
|
next_token_scores = logits_warper(input_ids, next_token_scores) |
|
|
|
|
|
probs = nn.functional.softmax(next_token_scores, dim=-1) |
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) |
|
|
|
|
|
if eos_token_id is not None: |
|
if pad_token_id is None: |
|
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") |
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) |
|
|
|
|
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
|
model_kwargs = self._update_model_kwargs_for_generation( |
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
|
) |
|
yield input_ids |
|
|
|
if eos_token_id_tensor is not None: |
|
unfinished_sequences = unfinished_sequences.mul( |
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) |
|
) |
|
|
|
|
|
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): |
|
if not synced_gpus: |
|
break |
|
else: |
|
this_peer_finished = True |
|
yield input_ids |
|
|
|
def stream_beam_sample( |
|
self, |
|
input_ids, |
|
logits_processor, |
|
logits_warper, |
|
stopping_criteria, |
|
generation_config, |
|
synced_gpus, |
|
**model_kwargs, |
|
): |
|
bos_token_id, eos_token_id, pad_token_id = ( |
|
generation_config.bos_token_id, |
|
generation_config.eos_token_id, |
|
generation_config.pad_token_id, |
|
) |
|
if isinstance(eos_token_id, int): |
|
eos_token_id = [eos_token_id] |
|
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None |
|
num_beams = generation_config.num_beams |
|
batch_size, cur_len = input_ids.shape[0], input_ids.shape[-1] |
|
beam_scorer = BeamSearchScorer( |
|
batch_size=batch_size, |
|
num_beams=generation_config.num_beams, |
|
device=input_ids.device, |
|
length_penalty=generation_config.length_penalty, |
|
do_early_stopping=generation_config.early_stopping, |
|
num_beam_hyps_to_keep=generation_config.num_return_sequences, |
|
max_length=generation_config.max_length, |
|
) |
|
input_ids, model_kwargs = self._expand_inputs_for_generation( |
|
input_ids=input_ids, |
|
expand_size=generation_config.num_beams * generation_config.num_return_sequences, |
|
is_encoder_decoder=self.config.is_encoder_decoder, |
|
**model_kwargs, |
|
) |
|
scores = () |
|
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) |
|
beam_scores = beam_scores.view((batch_size * num_beams,)) |
|
|
|
this_peer_finished = False |
|
while True: |
|
if synced_gpus: |
|
|
|
|
|
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
|
|
|
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
|
|
|
if this_peer_finished_flag.item() == 0.0: |
|
break |
|
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
outputs = self( |
|
**model_inputs, |
|
return_dict=True, |
|
) |
|
|
|
if synced_gpus and this_peer_finished: |
|
cur_len = cur_len + 1 |
|
continue |
|
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
|
|
|
|
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) |
|
next_token_scores = nn.functional.log_softmax( |
|
next_token_logits, dim=-1 |
|
) |
|
|
|
next_token_scores_processed = logits_processor(input_ids, next_token_scores) |
|
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) |
|
|
|
|
|
|
|
next_token_scores = logits_warper(input_ids, next_token_scores) |
|
|
|
|
|
vocab_size = next_token_scores.shape[-1] |
|
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) |
|
|
|
probs = nn.functional.softmax(next_token_scores, dim=-1) |
|
|
|
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) |
|
next_token_scores = torch.gather(next_token_scores, -1, next_tokens) |
|
|
|
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1) |
|
next_tokens = torch.gather(next_tokens, -1, _indices) |
|
|
|
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") |
|
next_tokens = next_tokens % vocab_size |
|
|
|
|
|
beam_outputs = beam_scorer.process( |
|
input_ids, |
|
next_token_scores, |
|
next_tokens, |
|
next_indices, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
beam_indices=None, |
|
) |
|
beam_scores = beam_outputs["next_beam_scores"] |
|
beam_next_tokens = beam_outputs["next_beam_tokens"] |
|
beam_idx = beam_outputs["next_beam_indices"] |
|
|
|
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) |
|
yield input_ids |
|
model_kwargs = self._update_model_kwargs_for_generation( |
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
|
) |
|
if model_kwargs["past_key_values"] is not None: |
|
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) |
|
|
|
|
|
cur_len = cur_len + 1 |
|
|
|
if beam_scorer.is_done or stopping_criteria(input_ids, scores): |
|
if not synced_gpus: |
|
break |
|
else: |
|
this_peer_finished = True |
|
|
|
sequence_outputs = beam_scorer.finalize( |
|
input_ids, |
|
beam_scores, |
|
next_tokens, |
|
next_indices, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
max_length=stopping_criteria.max_length, |
|
beam_indices=None, |
|
) |
|
yield sequence_outputs["sequences"] |
|
|
|
def stream_greedy_search( |
|
self, |
|
input_ids, |
|
logits_processor, |
|
stopping_criteria, |
|
generation_config, |
|
synced_gpus, |
|
**model_kwargs, |
|
): |
|
|
|
bos_token_id, eos_token_id, pad_token_id = ( |
|
generation_config.bos_token_id, |
|
generation_config.eos_token_id, |
|
generation_config.pad_token_id, |
|
) |
|
if isinstance(eos_token_id, int): |
|
eos_token_id = [eos_token_id] |
|
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None |
|
|
|
scores = () |
|
|
|
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) |
|
this_peer_finished = False |
|
while True: |
|
if synced_gpus: |
|
|
|
|
|
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) |
|
|
|
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
|
|
|
if this_peer_finished_flag.item() == 0.0: |
|
break |
|
|
|
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
|
|
outputs = self( |
|
**model_inputs, |
|
return_dict=True, |
|
) |
|
|
|
if synced_gpus and this_peer_finished: |
|
continue |
|
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
next_tokens_scores = logits_processor(input_ids, next_token_logits) |
|
|
|
next_tokens = torch.argmax(next_tokens_scores, dim=-1) |
|
|
|
if eos_token_id is not None: |
|
if pad_token_id is None: |
|
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") |
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) |
|
|
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |
|
model_kwargs = self._update_model_kwargs_for_generation( |
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
|
) |
|
yield input_ids |
|
|
|
if eos_token_id_tensor is not None: |
|
unfinished_sequences = unfinished_sequences.mul( |
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) |
|
) |
|
|
|
|
|
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): |
|
if not synced_gpus: |
|
break |
|
else: |
|
this_peer_finished = True |
|
yield input_ids |
|
|
|
def stream_beam_search( |
|
self, |
|
generation_config, |
|
input_ids, |
|
logits_processor, |
|
stopping_criteria, |
|
synced_gpus, |
|
**model_kwargs, |
|
): |
|
|
|
|
|
|
|
bos_token_id, eos_token_id, pad_token_id = ( |
|
generation_config.bos_token_id, |
|
generation_config.eos_token_id, |
|
generation_config.pad_token_id, |
|
) |
|
if isinstance(eos_token_id, int): |
|
eos_token_id = [eos_token_id] |
|
num_beams = generation_config.num_beams |
|
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] |
|
beam_scorer = BeamSearchScorer( |
|
batch_size=batch_size, |
|
num_beams=generation_config.num_beams, |
|
device=input_ids.device, |
|
length_penalty=generation_config.length_penalty, |
|
do_early_stopping=generation_config.early_stopping, |
|
num_beam_hyps_to_keep=generation_config.num_return_sequences, |
|
max_length=generation_config.max_length, |
|
) |
|
|
|
input_ids, model_kwargs = self._expand_inputs_for_generation( |
|
input_ids=input_ids, |
|
expand_size=generation_config.num_beams, |
|
is_encoder_decoder=self.config.is_encoder_decoder, |
|
**model_kwargs, |
|
) |
|
|
|
batch_beam_size, cur_len = input_ids.shape |
|
if num_beams * batch_size != batch_beam_size: |
|
raise ValueError( |
|
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." |
|
) |
|
beam_scores = torch.zeros( |
|
(batch_size, num_beams), dtype=torch.float, device=input_ids.device |
|
) |
|
beam_scores[:, 1:] = -1e9 |
|
beam_scores = beam_scores.view((batch_size * num_beams,)) |
|
this_peer_finished = False |
|
while True: |
|
|
|
if synced_gpus: |
|
|
|
|
|
this_peer_finished_flag = torch.tensor( |
|
0.0 if this_peer_finished else 1.0 |
|
).to(input_ids.device) |
|
|
|
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) |
|
|
|
if this_peer_finished_flag.item() == 0.0: |
|
break |
|
|
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
|
outputs = self( |
|
**model_inputs, |
|
return_dict=True, |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
) |
|
|
|
if synced_gpus and this_peer_finished: |
|
cur_len = cur_len + 1 |
|
continue |
|
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
next_token_scores = nn.functional.log_softmax( |
|
next_token_logits, dim=-1 |
|
) |
|
next_token_scores_processed = logits_processor(input_ids, next_token_scores) |
|
next_token_scores = next_token_scores_processed + beam_scores[ |
|
:, None |
|
].expand_as(next_token_scores) |
|
|
|
|
|
vocab_size = next_token_scores.shape[-1] |
|
next_token_scores = next_token_scores.view( |
|
batch_size, num_beams * vocab_size |
|
) |
|
|
|
|
|
next_token_scores, next_tokens = torch.topk( |
|
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True |
|
) |
|
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") |
|
next_tokens = next_tokens % vocab_size |
|
|
|
beam_outputs = beam_scorer.process( |
|
input_ids, |
|
next_token_scores, |
|
next_tokens, |
|
next_indices, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
beam_indices=None, |
|
) |
|
beam_scores = beam_outputs["next_beam_scores"] |
|
beam_next_tokens = beam_outputs["next_beam_tokens"] |
|
beam_idx = beam_outputs["next_beam_indices"] |
|
|
|
input_ids = torch.cat( |
|
[input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1 |
|
) |
|
model_kwargs = self._update_model_kwargs_for_generation( |
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder |
|
) |
|
if model_kwargs["past_key_values"] is not None: |
|
model_kwargs["past_key_values"] = self._reorder_cache( |
|
model_kwargs["past_key_values"], beam_idx |
|
) |
|
|
|
|
|
cur_len = cur_len + 1 |
|
|
|
yield input_ids |
|
|
|
if beam_scorer.is_done or stopping_criteria(input_ids, None): |
|
if not synced_gpus: |
|
break |
|
else: |
|
this_peer_finished = True |
|
|
|
final_result = beam_scorer.finalize( |
|
input_ids, |
|
beam_scores, |
|
next_tokens, |
|
next_indices, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
max_length=stopping_criteria.max_length, |
|
beam_indices=None, |
|
) |
|
yield final_result["sequences"] |
|
|
|
class StreamLlamaForCausalLM(LlamaForCausalLM, StreamGenerationMixin): |
|
pass |
|
|
|
class StreamPeftGenerationMixin(PeftModelForCausalLM, StreamGenerationMixin): |
|
|
|
|
|
@classmethod |
|
def from_pretrained(cls, model, model_id, adapter_name="default", is_trainable=False, **kwargs): |
|
|
|
if peft.__version__ >= '0.3.0' and peft.__version__ != '0.3.0.dev0': |
|
|
|
from peft.utils import PromptLearningConfig |
|
config = LoraConfig.from_pretrained(model_id) |
|
|
|
if (getattr(model, "hf_device_map", None) is not None) and len( |
|
set(model.hf_device_map.values()).intersection({"cpu", "disk"}) |
|
) > 0: |
|
remove_hook_from_submodules(model) |
|
|
|
if isinstance(config, PromptLearningConfig) and is_trainable: |
|
raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.") |
|
else: |
|
config.inference_mode = not is_trainable |
|
|
|
|
|
model = cls(model, config, adapter_name) |
|
model.load_adapter(model_id, adapter_name, **kwargs) |
|
|
|
model.base_model_prepare_inputs_for_generation = model.base_model.prepare_inputs_for_generation |
|
model._reorder_cache = model.base_model._reorder_cache |
|
return model |
|
else: |
|
return cls.from_pretrained_old_peft_version(model, model_id, **kwargs) |
|
|
|
|
|
@classmethod |
|
def from_pretrained_old_peft_version(cls, model, model_id, **kwargs): |
|
|
|
|
|
|
|
config = LoraConfig.from_pretrained(model_id) |
|
|
|
if getattr(model, "hf_device_map", None) is not None: |
|
remove_hook_from_submodules(model) |
|
|
|
|
|
model = cls(model, config) |
|
model._reorder_cache = model.base_model._reorder_cache |
|
|
|
if os.path.exists(os.path.join(model_id, "adapter_model.bin")): |
|
filename = os.path.join(model_id, "adapter_model.bin") |
|
else: |
|
try: |
|
filename = hf_hub_download(model_id, "adapter_model.bin") |
|
except: |
|
raise ValueError( |
|
f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub. " |
|
f"Please check that the file {'adapter_model.bin'} is present at {model_id}." |
|
) |
|
|
|
adapters_weights = torch.load( |
|
filename, |
|
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), |
|
) |
|
|
|
model = set_peft_model_state_dict(model, adapters_weights) |
|
if getattr(model, "hf_device_map", None) is not None: |
|
device_map = kwargs.get("device_map", "auto") |
|
max_memory = kwargs.get("max_memory", None) |
|
no_split_module_classes = model._no_split_modules |
|
if device_map != "sequential": |
|
max_memory = get_balanced_memory( |
|
model, |
|
max_memory=max_memory, |
|
no_split_module_classes=no_split_module_classes, |
|
low_zero=(device_map == "balanced_low_0"), |
|
) |
|
if isinstance(device_map, str): |
|
device_map = infer_auto_device_map( |
|
model, |
|
max_memory=max_memory, |
|
no_split_module_classes=no_split_module_classes, |
|
) |
|
model = dispatch_model(model, device_map=device_map) |
|
hook = AlignDevicesHook(io_same_device=True) |
|
if model.peft_config.peft_type == PeftType.LORA: |
|
add_hook_to_module(model.base_model.model, hook) |
|
else: |
|
remove_hook_from_submodules(model.prompt_encoder) |
|
add_hook_to_module(model.base_model, hook) |
|
return model |
|
|