|
from .constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN |
|
from . import conversation as conversation_lib |
|
import transformers |
|
import torch |
|
from typing import Dict, Optional, Sequence, List |
|
import copy |
|
|
|
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): |
|
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')] |
|
|
|
def insert_separator(X, sep): |
|
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] |
|
|
|
input_ids = [] |
|
offset = 0 |
|
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: |
|
offset = 1 |
|
input_ids.append(prompt_chunks[0][0]) |
|
|
|
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): |
|
input_ids.extend(x[offset:]) |
|
|
|
if return_tensors is not None: |
|
if return_tensors == 'pt': |
|
return torch.tensor(input_ids, dtype=torch.long) |
|
raise ValueError(f'Unsupported tensor type: {return_tensors}') |
|
return input_ids |
|
|
|
def _add_speaker_and_signal(header, source, get_conversation=True): |
|
"""Add speaker and start/end signal on each round.""" |
|
BEGIN_SIGNAL = "### " |
|
END_SIGNAL = "\n" |
|
conversation = header |
|
for sentence in source: |
|
from_str = sentence["from"] |
|
if from_str.lower() == "human": |
|
from_str = conversation_lib.default_conversation.roles[0] |
|
elif from_str.lower() == "gpt": |
|
from_str = conversation_lib.default_conversation.roles[1] |
|
else: |
|
from_str = 'unknown' |
|
sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + |
|
sentence["value"] + END_SIGNAL) |
|
if get_conversation: |
|
conversation += sentence["value"] |
|
conversation += BEGIN_SIGNAL |
|
return conversation |
|
|
|
def _tokenize_fn(strings: Sequence[str], |
|
tokenizer: transformers.PreTrainedTokenizer) -> Dict: |
|
"""Tokenize a list of strings.""" |
|
tokenized_list = [ |
|
tokenizer( |
|
text, |
|
return_tensors="pt", |
|
padding="longest", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
) for text in strings |
|
] |
|
input_ids = labels = [ |
|
tokenized.input_ids[0] for tokenized in tokenized_list |
|
] |
|
input_ids_lens = labels_lens = [ |
|
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() |
|
for tokenized in tokenized_list |
|
] |
|
return dict( |
|
input_ids=input_ids, |
|
labels=labels, |
|
input_ids_lens=input_ids_lens, |
|
labels_lens=labels_lens, |
|
) |
|
|
|
def _mask_targets(target, tokenized_lens, speakers): |
|
|
|
cur_idx = tokenized_lens[0] |
|
tokenized_lens = tokenized_lens[1:] |
|
target[:cur_idx] = IGNORE_INDEX |
|
for tokenized_len, speaker in zip(tokenized_lens, speakers): |
|
if speaker == "human": |
|
target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX |
|
cur_idx += tokenized_len |
|
|
|
def preprocess_llama_2( |
|
sources, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
has_image: bool = False |
|
) -> Dict: |
|
conv = conversation_lib.default_conversation.copy() |
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
|
|
|
|
|
conversations = [] |
|
for i, source in enumerate(sources): |
|
if roles[source[0]["from"]] != conv.roles[0]: |
|
|
|
source = source[1:] |
|
|
|
conv.messages = [] |
|
for j, sentence in enumerate(source): |
|
role = roles[sentence["from"]] |
|
assert role == conv.roles[j % 2], f"{i}" |
|
conv.append_message(role, sentence["value"]) |
|
conversations.append(conv.get_prompt()) |
|
|
|
|
|
|
|
if has_image: |
|
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) |
|
else: |
|
input_ids = tokenizer( |
|
conversations, |
|
return_tensors="pt", |
|
padding="longest", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
).input_ids |
|
|
|
targets = input_ids.clone() |
|
|
|
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 |
|
|
|
|
|
sep = "[/INST] " |
|
for conversation, target in zip(conversations, targets): |
|
total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
|
rounds = conversation.split(conv.sep2) |
|
cur_len = 1 |
|
target[:cur_len] = IGNORE_INDEX |
|
for i, rou in enumerate(rounds): |
|
if rou == "": |
|
break |
|
|
|
parts = rou.split(sep) |
|
if len(parts) != 2: |
|
break |
|
parts[0] += sep |
|
|
|
if has_image: |
|
round_len = len(tokenizer_image_token(rou, tokenizer)) |
|
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 |
|
else: |
|
round_len = len(tokenizer(rou).input_ids) |
|
instruction_len = len(tokenizer(parts[0]).input_ids) - 2 |
|
|
|
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX |
|
|
|
cur_len += round_len |
|
target[cur_len:] = IGNORE_INDEX |
|
|
|
if cur_len < tokenizer.model_max_length: |
|
if cur_len != total_len: |
|
target[:] = IGNORE_INDEX |
|
print( |
|
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
|
f" (ignored)" |
|
) |
|
|
|
return dict( |
|
input_ids=input_ids, |
|
labels=targets, |
|
) |
|
|
|
|
|
def preprocess_v1( |
|
sources, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
has_image: bool = False, |
|
training_mode: bool =True, |
|
) -> Dict: |
|
conv = conversation_lib.default_conversation.copy() |
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
|
|
|
|
|
conversations = [] |
|
for i, source in enumerate(sources): |
|
if roles[source[0]["from"]] != conv.roles[0]: |
|
|
|
source = source[1:] |
|
|
|
conv.messages = [] |
|
for j, sentence in enumerate(source): |
|
role = roles[sentence["from"]] |
|
assert role == conv.roles[j % 2], f"{i}" |
|
conv.append_message(role, sentence["value"]) |
|
conversations.append(conv.get_prompt()) |
|
|
|
|
|
|
|
if has_image: |
|
if training_mode: |
|
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) |
|
else: |
|
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] |
|
return dict( |
|
input_ids=input_ids, |
|
) |
|
else: |
|
input_ids = tokenizer( |
|
conversations, |
|
return_tensors="pt", |
|
padding="longest", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
).input_ids |
|
|
|
input_ids = input_ids[:, :tokenizer.model_max_length] |
|
|
|
targets = input_ids.clone() |
|
|
|
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO |
|
|
|
|
|
sep = conv.sep + conv.roles[1] + ": " |
|
for conversation, target in zip(conversations, targets): |
|
total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
|
rounds = conversation.split(conv.sep2) |
|
cur_len = 1 |
|
target[:cur_len] = IGNORE_INDEX |
|
for i, rou in enumerate(rounds): |
|
if rou == "": |
|
break |
|
|
|
parts = rou.split(sep) |
|
if len(parts) != 2: |
|
break |
|
parts[0] += sep |
|
|
|
if has_image: |
|
round_len = len(tokenizer_image_token(rou, tokenizer)) |
|
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 |
|
else: |
|
round_len = len(tokenizer(rou).input_ids) |
|
instruction_len = len(tokenizer(parts[0]).input_ids) - 2 |
|
|
|
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX |
|
|
|
cur_len += round_len |
|
target[cur_len:] = IGNORE_INDEX |
|
|
|
if cur_len < tokenizer.model_max_length: |
|
if cur_len != total_len: |
|
target[:] = IGNORE_INDEX |
|
if len(rounds) != 1: |
|
print( |
|
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
|
f" (ignored)" |
|
) |
|
|
|
return dict( |
|
input_ids=input_ids, |
|
labels=targets, |
|
) |
|
|
|
|
|
def preprocess_mpt( |
|
sources, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
) -> Dict: |
|
conv = conversation_lib.default_conversation.copy() |
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
|
|
|
|
|
conversations = [] |
|
for i, source in enumerate(sources): |
|
if roles[source[0]["from"]] != conv.roles[0]: |
|
|
|
source = source[1:] |
|
|
|
conv.messages = [] |
|
for j, sentence in enumerate(source): |
|
role = roles[sentence["from"]] |
|
assert role == conv.roles[j % 2], f"{i}" |
|
conv.append_message(role, sentence["value"]) |
|
conversations.append(conv.get_prompt()) |
|
|
|
|
|
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) |
|
targets = input_ids.clone() |
|
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT |
|
|
|
|
|
sep = conv.sep + conv.roles[1] |
|
for conversation, target in zip(conversations, targets): |
|
total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
|
rounds = conversation.split(conv.sep) |
|
re_rounds = [conv.sep.join(rounds[:3])] |
|
for conv_idx in range(3, len(rounds), 2): |
|
re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) |
|
cur_len = 0 |
|
target[:cur_len] = IGNORE_INDEX |
|
for i, rou in enumerate(re_rounds): |
|
if rou == "": |
|
break |
|
|
|
parts = rou.split(sep) |
|
if len(parts) != 2: |
|
break |
|
parts[0] += sep |
|
round_len = len(tokenizer_image_token(rou, tokenizer)) + len(tokenizer_image_token(conv.sep, tokenizer)) |
|
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) |
|
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX |
|
|
|
cur_len += round_len |
|
target[cur_len:] = IGNORE_INDEX |
|
|
|
if cur_len < tokenizer.model_max_length: |
|
if cur_len != total_len: |
|
target[:] = IGNORE_INDEX |
|
|
|
print( |
|
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
|
f" (ignored)" |
|
) |
|
|
|
return dict( |
|
input_ids=input_ids, |
|
labels=targets, |
|
) |
|
|
|
|
|
def preprocess_plain( |
|
sources: Sequence[str], |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
) -> Dict: |
|
|
|
conversations = [] |
|
for source in sources: |
|
assert len(source) == 2 |
|
assert DEFAULT_IMAGE_TOKEN in source[0]['value'] |
|
source[0]['value'] = DEFAULT_IMAGE_TOKEN |
|
conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep |
|
conversations.append(conversation) |
|
|
|
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] |
|
targets = copy.deepcopy(input_ids) |
|
for target, source in zip(targets, sources): |
|
tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer)) |
|
target[:tokenized_len] = IGNORE_INDEX |
|
|
|
return dict(input_ids=input_ids, labels=targets) |
|
|
|
def preprocess( |
|
sources, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
has_image: bool = False, |
|
training_mode: bool =True, |
|
) -> Dict: |
|
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: |
|
return preprocess_plain(sources, tokenizer) |
|
if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: |
|
return preprocess_llama_2(sources, tokenizer, has_image=has_image) |
|
if conversation_lib.default_conversation.version.startswith("v1"): |
|
return preprocess_v1(sources, tokenizer, has_image=has_image, training_mode=training_mode) |
|
if conversation_lib.default_conversation.version == "mpt": |
|
return preprocess_mpt(sources, tokenizer) |
|
|
|
conversations = [] |
|
for source in sources: |
|
header = f"{conversation_lib.default_conversation.system}\n\n" |
|
conversation = _add_speaker_and_signal(header, source) |
|
conversations.append(conversation) |
|
|
|
def get_tokenize_len(prompts): |
|
return [min(len(tokenizer_image_token(prompt, tokenizer)), tokenizer.model_max_length) for prompt in prompts] |
|
|
|
if has_image: |
|
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt')[:tokenizer.model_max_length] for prompt in conversations] |
|
else: |
|
conversations_tokenized = _tokenize_fn(conversations, tokenizer) |
|
input_ids = conversations_tokenized["input_ids"] |
|
|
|
targets = copy.deepcopy(input_ids) |
|
for target, source in zip(targets, sources): |
|
if has_image: |
|
tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) |
|
else: |
|
tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] |
|
speakers = [sentence["from"] for sentence in source] |
|
_mask_targets(target, tokenized_lens, speakers) |
|
|
|
return dict(input_ids=input_ids, labels=targets) |