"""Inference for FastChat models.""" |
import abc |
import gc |
import json |
import math |
import os |
import sys |
import time |
from typing import Iterable, Optional, Dict |
import warnings |
import psutil |
import torch |
from transformers import ( |
AutoTokenizer, |
AutoModelForCausalLM, |
LlamaTokenizer, |
LlamaForCausalLM, |
AutoModel, |
AutoModelForSeq2SeqLM, |
T5Tokenizer, |
AutoConfig, |
) |
from transformers.generation.logits_process import ( |
LogitsProcessorList, |
RepetitionPenaltyLogitsProcessor, |
TemperatureLogitsWarper, |
TopKLogitsWarper, |
TopPLogitsWarper, |
) |
from fastchat.conversation import get_conv_template, SeparatorStyle |
from fastchat.model.model_adapter import ( |
load_model, |
get_conversation_template, |
get_generate_stream_function, |
) |
from fastchat.modules.awq import AWQConfig |
from fastchat.modules.gptq import GptqConfig |
from fastchat.modules.exllama import ExllamaConfig |
from fastchat.modules.xfastertransformer import XftConfig |
from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length |
def prepare_logits_processor( |
temperature: float, repetition_penalty: float, top_p: float, top_k: int |
) -> LogitsProcessorList: |
processor_list = LogitsProcessorList() |
if temperature >= 1e-5 and temperature != 1.0: |
processor_list.append(TemperatureLogitsWarper(temperature)) |
if repetition_penalty > 1.0: |
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) |
if 1e-8 <= top_p < 1.0: |
processor_list.append(TopPLogitsWarper(top_p)) |
if top_k > 0: |
processor_list.append(TopKLogitsWarper(top_k)) |
return processor_list |
@torch.inference_mode() |
def generate_stream( |
model, |
tokenizer, |
params: Dict, |
device: str, |
context_len: int, |
stream_interval: int = 2, |
judge_sent_end: bool = False, |
): |
if hasattr(model, "device"): |
device = model.device |
prompt = params["prompt"] |
len_prompt = len(prompt) |
temperature = float(params.get("temperature", 1.0)) |
repetition_penalty = float(params.get("repetition_penalty", 1.0)) |
top_p = float(params.get("top_p", 1.0)) |
top_k = int(params.get("top_k", -1)) |
max_new_tokens = int(params.get("max_new_tokens", 256)) |
logprobs = params.get("logprobs", None) |
echo = bool(params.get("echo", True)) |
stop_str = params.get("stop", None) |
stop_token_ids = params.get("stop_token_ids", None) or [] |
if tokenizer.eos_token_id not in stop_token_ids: |
stop_token_ids.append(tokenizer.eos_token_id) |
logits_processor = prepare_logits_processor( |
temperature, repetition_penalty, top_p, top_k |
) |
input_ids = tokenizer(prompt).input_ids |
if model.config.is_encoder_decoder: |
max_src_len = context_len |
else: |
max_src_len = context_len - max_new_tokens - 1 |
input_ids = input_ids[-max_src_len:] |
output_ids = list(input_ids) |
input_echo_len = len(input_ids) |
if model.config.is_encoder_decoder: |
if logprobs is not None: |
raise NotImplementedError |
encoder_output = model.encoder( |
input_ids=torch.as_tensor([input_ids], device=device) |
)[0] |
start_ids = torch.as_tensor( |
[[model.generation_config.decoder_start_token_id]], |
dtype=torch.int64, |
device=device, |
) |
else: |
start_ids = torch.as_tensor([input_ids], device=device) |
past_key_values = out = None |
token_logprobs = [None] |
sent_interrupt = False |
finish_reason = None |
stopped = False |
for i in range(max_new_tokens): |
if i == 0: |
if model.config.is_encoder_decoder: |
out = model.decoder( |
input_ids=start_ids, |
encoder_hidden_states=encoder_output, |
use_cache=True, |
) |
logits = model.lm_head(out[0]) |
else: |
out = model(input_ids=start_ids, use_cache=True) |
logits = out.logits |
past_key_values = out.past_key_values |
if logprobs is not None: |
shift_input_ids = start_ids[..., 1:].contiguous() |
shift_logits = logits[..., :-1, :].contiguous() |
shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist() |
for label_id, logit in zip( |
shift_input_ids[0].tolist(), shift_logits[0] |
): |
token_logprobs.append(logit[label_id]) |
else: |
if model.config.is_encoder_decoder: |
out = model.decoder( |
input_ids=torch.as_tensor( |
[[token] if not sent_interrupt else output_ids], |
device=device, |
), |
encoder_hidden_states=encoder_output, |
use_cache=True, |
past_key_values=past_key_values if not sent_interrupt else None, |
) |
sent_interrupt = False |
logits = model.lm_head(out[0]) |
else: |
out = model( |
input_ids=torch.as_tensor( |
[[token] if not sent_interrupt else output_ids], |
device=device, |
), |
use_cache=True, |
past_key_values=past_key_values if not sent_interrupt else None, |
) |
sent_interrupt = False |
logits = out.logits |
past_key_values = out.past_key_values |
if logits_processor: |
if repetition_penalty > 1.0: |
tmp_output_ids = torch.as_tensor([output_ids], device=logits.device) |
else: |
tmp_output_ids = None |
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] |
else: |
last_token_logits = logits[0, -1, :] |
if device == "mps": |
last_token_logits = last_token_logits.float().to("cpu") |
if temperature < 1e-5 or top_p < 1e-8: |
_, indices = torch.topk(last_token_logits, 2) |
tokens = [int(index) for index in indices.tolist()] |
else: |
probs = torch.softmax(last_token_logits, dim=-1) |
indices = torch.multinomial(probs, num_samples=2) |
tokens = [int(token) for token in indices.tolist()] |
token = tokens[0] |
output_ids.append(token) |
if logprobs is not None: |
token_logprobs.append( |
torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist() |
) |
if token in stop_token_ids: |
stopped = True |
else: |
stopped = False |
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: |
if echo: |
tmp_output_ids = output_ids |
rfind_start = len_prompt |
else: |
tmp_output_ids = output_ids[input_echo_len:] |
rfind_start = 0 |
output = tokenizer.decode( |
tmp_output_ids, |
skip_special_tokens=True, |
spaces_between_special_tokens=False, |
clean_up_tokenization_spaces=True, |
) |
ret_logprobs = None |
if logprobs is not None: |
ret_logprobs = { |
"text_offset": [], |
"tokens": [ |
tokenizer.decode(token) |
for token in ( |
output_ids if echo else output_ids[input_echo_len:] |
) |
], |
"token_logprobs": token_logprobs |
if echo |
else token_logprobs[input_echo_len:], |
"top_logprobs": [{}] |
* len(token_logprobs if echo else token_logprobs[input_echo_len:]), |
} |
curr_pos = 0 |
for text in ret_logprobs["tokens"]: |
ret_logprobs["text_offset"].append(curr_pos) |
curr_pos += len(text) |
if judge_sent_end and stopped and not is_sentence_complete(output): |
if len(tokens) > 1: |
token = tokens[1] |
output_ids[-1] = token |
else: |
output_ids.pop() |
stopped = False |
sent_interrupt = True |
partially_stopped = False |
if stop_str: |
if isinstance(stop_str, str): |
pos = output.rfind(stop_str, rfind_start) |
if pos != -1: |
output = output[:pos] |
stopped = True |
else: |
partially_stopped = is_partial_stop(output, stop_str) |
elif isinstance(stop_str, Iterable): |
for each_stop in stop_str: |
pos = output.rfind(each_stop, rfind_start) |
if pos != -1: |
output = output[:pos] |
stopped = True |
break |
else: |
partially_stopped = is_partial_stop(output, each_stop) |
if partially_stopped: |
break |
else: |
raise ValueError("Invalid stop field type.") |
if not partially_stopped: |
yield { |
"text": output, |
"logprobs": ret_logprobs, |
"usage": { |
"prompt_tokens": input_echo_len, |
"completion_tokens": i, |
"total_tokens": input_echo_len + i, |
}, |
"finish_reason": None, |
} |
if stopped: |
break |
else: |
finish_reason = "length" |
if stopped: |
finish_reason = "stop" |
yield { |
"text": output, |
"logprobs": ret_logprobs, |
"usage": { |
"prompt_tokens": input_echo_len, |
"completion_tokens": i, |
"total_tokens": input_echo_len + i, |
}, |
"finish_reason": finish_reason, |
} |
del past_key_values, out |
gc.collect() |
torch.cuda.empty_cache() |
if device == "xpu": |
torch.xpu.empty_cache() |
if device == "npu": |
torch.npu.empty_cache() |
class ChatIO(abc.ABC): |
@abc.abstractmethod |
def prompt_for_input(self, role: str) -> str: |
"""Prompt for input from a role.""" |
@abc.abstractmethod |
def prompt_for_output(self, role: str): |
"""Prompt for output from a role.""" |
@abc.abstractmethod |
def stream_output(self, output_stream): |
"""Stream output.""" |
@abc.abstractmethod |
def print_output(self, text: str): |
"""Print output.""" |
def chat_loop( |
model_path: str, |
device: str, |
num_gpus: int, |
max_gpu_memory: str, |
dtype: Optional[torch.dtype], |
load_8bit: bool, |
cpu_offloading: bool, |
conv_template: Optional[str], |
conv_system_msg: Optional[str], |
temperature: float, |
repetition_penalty: float, |
max_new_tokens: int, |
chatio: ChatIO, |
gptq_config: Optional[GptqConfig] = None, |
awq_config: Optional[AWQConfig] = None, |
exllama_config: Optional[ExllamaConfig] = None, |
xft_config: Optional[XftConfig] = None, |
revision: str = "main", |
judge_sent_end: bool = True, |
debug: bool = True, |
history: bool = True, |
): |
model, tokenizer = load_model( |
model_path, |
device=device, |
num_gpus=num_gpus, |
max_gpu_memory=max_gpu_memory, |
dtype=dtype, |
load_8bit=load_8bit, |
cpu_offloading=cpu_offloading, |
gptq_config=gptq_config, |
awq_config=awq_config, |
exllama_config=exllama_config, |
xft_config=xft_config, |
revision=revision, |
debug=debug, |
) |
generate_stream_func = get_generate_stream_function(model, model_path) |
model_type = str(type(model)).lower() |
is_t5 = "t5" in model_type |
is_codet5p = "codet5p" in model_type |
is_xft = "xft" in model_type |
if is_t5 and repetition_penalty == 1.0: |
repetition_penalty = 1.2 |
context_len = get_context_length(model.config) |
def new_chat(): |
if conv_template: |
conv = get_conv_template(conv_template) |
else: |
conv = get_conversation_template(model_path) |
if conv_system_msg is not None: |
conv.set_system_message(conv_system_msg) |
return conv |
def reload_conv(conv): |
""" |
Reprints the conversation from the start. |
""" |
for message in conv.messages[conv.offset :]: |
chatio.prompt_for_output(message[0]) |
chatio.print_output(message[1]) |
conv = None |
while True: |
if not history or not conv: |
conv = new_chat() |
try: |
inp = chatio.prompt_for_input(conv.roles[0]) |
except EOFError: |
inp = "" |
if inp == "!!exit" or not inp: |
print("exit...") |
break |
elif inp == "!!reset": |
print("resetting...") |
conv = new_chat() |
continue |
elif inp == "!!remove": |
print("removing last message...") |
if len(conv.messages) > conv.offset: |
if conv.messages[-1][0] == conv.roles[1]: |
conv.messages.pop() |
if conv.messages[-1][0] == conv.roles[0]: |
conv.messages.pop() |
reload_conv(conv) |
else: |
print("No messages to remove.") |
continue |
elif inp == "!!regen": |
print("regenerating last message...") |
if len(conv.messages) > conv.offset: |
if conv.messages[-1][0] == conv.roles[1]: |
conv.messages.pop() |
if conv.messages[-1][0] == conv.roles[0]: |
reload_conv(conv) |
inp = conv.messages.pop()[1] |
else: |
print("No user message to regenerate from.") |
continue |
else: |
print("No messages to regenerate.") |
continue |
elif inp.startswith("!!save"): |
args = inp.split(" ", 1) |
if len(args) != 2: |
print("usage: !!save <filename>") |
continue |
else: |
filename = args[1] |
if not "." in filename: |
filename += ".json" |
print("saving...", filename) |
with open(filename, "w") as outfile: |
json.dump(conv.dict(), outfile) |
continue |
elif inp.startswith("!!load"): |
args = inp.split(" ", 1) |
if len(args) != 2: |
print("usage: !!load <filename>") |
continue |
else: |
filename = args[1] |
if not os.path.exists(filename): |
if (not filename.endswith(".json")) and os.path.exists( |
filename + ".json" |
): |
filename += ".json" |
else: |
print("file not found:", filename) |
continue |
print("loading...", filename) |
with open(filename, "r") as infile: |
new_conv = json.load(infile) |
conv = get_conv_template(new_conv["template_name"]) |
conv.set_system_message(new_conv["system_message"]) |
conv.messages = new_conv["messages"] |
reload_conv(conv) |
continue |
conv.append_message(conv.roles[0], inp) |
conv.append_message(conv.roles[1], None) |
prompt = conv.get_prompt() |
if is_codet5p: |
prompt = inp |
gen_params = { |
"model": model_path, |
"prompt": prompt, |
"temperature": temperature, |
"repetition_penalty": repetition_penalty, |
"max_new_tokens": max_new_tokens, |
"stop": conv.stop_str, |
"stop_token_ids": conv.stop_token_ids, |
"echo": False, |
} |
try: |
chatio.prompt_for_output(conv.roles[1]) |
output_stream = generate_stream_func( |
model, |
tokenizer, |
gen_params, |
device, |
context_len=context_len, |
judge_sent_end=judge_sent_end, |
) |
t = time.time() |
outputs = chatio.stream_output(output_stream) |
duration = time.time() - t |
conv.update_last_message(outputs.strip()) |
if debug: |
num_tokens = len(tokenizer.encode(outputs)) |
msg = { |
"conv_template": conv.name, |
"prompt": prompt, |
"outputs": outputs, |
"speed (token/s)": round(num_tokens / duration, 2), |
} |
print(f"\n{msg}\n") |
except KeyboardInterrupt: |
print("stopped generation.") |
if conv.messages[-1][1] is None: |
conv.messages.pop() |
if conv.messages[-1][0] == conv.roles[0]: |
conv.messages.pop() |
reload_conv(conv) |