|
import os |
|
import gc |
|
import copy |
|
import time |
|
|
|
import torch |
|
import warnings |
|
import transformers |
|
|
|
import numpy as np |
|
|
|
from typing import Dict, Optional, Sequence |
|
from omnilmm import conversation as conversation_lib |
|
|
|
IGNORE_INDEX = -100 |
|
DEFAULT_IMAGE_TOKEN = "<image>" |
|
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>" |
|
DEFAULT_IM_START_TOKEN = "<im_start>" |
|
DEFAULT_IM_END_TOKEN = "<im_end>" |
|
|
|
|
|
def _tokenize_fn(strings: Sequence[str], |
|
tokenizer: transformers.PreTrainedTokenizer) -> Dict: |
|
"""Tokenize a list of strings.""" |
|
tokenized_list = [ |
|
tokenizer( |
|
text, |
|
return_tensors="pt", |
|
padding="longest", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
) for text in strings |
|
] |
|
input_ids = labels = [ |
|
tokenized.input_ids[0] for tokenized in tokenized_list |
|
] |
|
input_ids_lens = labels_lens = [ |
|
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() |
|
for tokenized in tokenized_list |
|
] |
|
return dict( |
|
input_ids=input_ids, |
|
labels=labels, |
|
input_ids_lens=input_ids_lens, |
|
labels_lens=labels_lens, |
|
) |
|
|
|
|
|
|
|
def omni_preprocess(sources, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
generation=False): |
|
system_content = 'You are an artificial intelligence assistant, which gives helpful, detailed, and polite answers to the human\'s questions.' |
|
ignore_index = -100 |
|
|
|
response_template = '\n<|assistant|>\n' |
|
instruction_template = '\n<|user|>\n' |
|
response_token_ids = tokenizer.encode( |
|
response_template, add_special_tokens=False) |
|
instruction_token_ids = tokenizer.encode( |
|
instruction_template, add_special_tokens=False) |
|
|
|
batch_input_ids = [] |
|
batch_labels = [] |
|
for i in range(len(sources)): |
|
new_source = [] |
|
prev_role = 'unexpect' |
|
for conv_turn in sources[i]: |
|
role = conv_turn['from'] if 'from' in conv_turn else conv_turn['role'] |
|
content = conv_turn['value'] if 'value' in conv_turn else conv_turn['content'] |
|
|
|
role = 'user' if role == 'human' else role |
|
role = 'assistant' if role == 'gpt' else role |
|
|
|
assert role in ['user', 'assistant'] |
|
assert role != prev_role, f'role={role}, prev_role={prev_role}' |
|
prev_role = role |
|
|
|
new_turn = { |
|
'role': role, |
|
'content': content |
|
} |
|
new_source.append(new_turn) |
|
if new_source[0]['role'] != 'system': |
|
new_source.insert(0, {'role': 'system', 'content': system_content}) |
|
|
|
|
|
res_text = tokenizer.apply_chat_template( |
|
new_source, tokenize=False, add_generation_prompt=generation) |
|
if not generation: |
|
res_text = res_text.strip() |
|
|
|
conversations_tokenized = _tokenize_fn([res_text], tokenizer) |
|
res_input_ids = conversations_tokenized["input_ids"][0] |
|
|
|
|
|
res_labels = copy.deepcopy(conversations_tokenized["labels"][0]) |
|
|
|
response_token_ids_idxs = [] |
|
human_token_ids_idxs = [] |
|
|
|
for assistant_idx in np.where(res_labels == response_token_ids[0])[0]: |
|
|
|
if (response_token_ids == res_labels[assistant_idx: assistant_idx + len( |
|
response_token_ids)].tolist() |
|
): |
|
response_token_ids_idxs.append( |
|
assistant_idx + len(response_token_ids)) |
|
|
|
if len(response_token_ids_idxs) == 0: |
|
warnings.warn( |
|
f"Could not find response key `{response_template}` in the " |
|
f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ ' |
|
f'Raw text is @===>{res_text}<===@' |
|
f'Raw source is @===>{new_source}<===@' |
|
f"This instance will be ignored in loss calculation. " |
|
f"Note, if this happens often, consider increasing the `max_seq_length`." |
|
) |
|
res_labels[:] = ignore_index |
|
|
|
human_token_ids = instruction_token_ids |
|
for human_idx in np.where(res_labels == human_token_ids[0])[0]: |
|
|
|
if human_token_ids == res_labels[human_idx: human_idx + len(human_token_ids)].tolist(): |
|
human_token_ids_idxs.append(human_idx) |
|
|
|
if len(human_token_ids_idxs) == 0: |
|
warnings.warn( |
|
f"Could not find instruction key `{instruction_template}` in the " |
|
f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ ' |
|
f'Raw text is @===>{res_text}<===@' |
|
f'Raw source is @===>{new_source}<===@' |
|
f"This instance will be ignored in loss calculation. " |
|
f"Note, if this happens often, consider increasing the `max_seq_length`." |
|
) |
|
res_labels[:] = ignore_index |
|
|
|
for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)): |
|
|
|
if idx != 0: |
|
res_labels[start:end] = ignore_index |
|
else: |
|
res_labels[:end] = ignore_index |
|
|
|
if len(response_token_ids_idxs) < len(human_token_ids_idxs): |
|
res_labels[human_token_ids_idxs[-1]:] = ignore_index |
|
|
|
batch_input_ids.append(res_input_ids) |
|
batch_labels.append(res_labels) |
|
|
|
return dict(input_ids=batch_input_ids, labels=batch_labels) |
|
|
|
|
|
|