diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..8947beb5c6e003965437ce530e229e9cc597d66e 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +mllm/demo/assets/baseball.png filter=lfs diff=lfs merge=lfs -text diff --git a/mllm/__init__.py b/mllm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mllm/__pycache__/__init__.cpython-310.pyc b/mllm/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a6eef3367b49010c6dcdad9982d8c3453c7df0e Binary files /dev/null and b/mllm/__pycache__/__init__.cpython-310.pyc differ diff --git a/mllm/config/__init__.py b/mllm/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..79fb153aa86f34ac37f9af0078f19c75a25636a4 --- /dev/null +++ b/mllm/config/__init__.py @@ -0,0 +1 @@ +from .config import prepare_args diff --git a/mllm/config/__pycache__/__init__.cpython-310.pyc b/mllm/config/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c2e691cd2e6722a9e8521e183f35f595ab34c96 Binary files /dev/null and b/mllm/config/__pycache__/__init__.cpython-310.pyc differ diff --git a/mllm/config/__pycache__/config.cpython-310.pyc b/mllm/config/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27dc55d83ffc46c0fc59db79cd84a9e754de5ec4 Binary files /dev/null and b/mllm/config/__pycache__/config.cpython-310.pyc differ diff --git a/mllm/config/config.py b/mllm/config/config.py new file mode 100644 index 0000000000000000000000000000000000000000..cc417e0a6df10f2eda6ca9b4e921e7e3c86d48dc --- /dev/null +++ b/mllm/config/config.py @@ -0,0 +1,135 @@ +import os +import sys +import logging +import argparse +from dataclasses import dataclass, field +from typing import List, Tuple +from argparse import SUPPRESS + +import datasets +import transformers +from mmengine.config import Config, DictAction +from transformers import HfArgumentParser, set_seed, add_start_docstrings +from transformers import Seq2SeqTrainingArguments as HFSeq2SeqTrainingArguments +from transformers.trainer_utils import get_last_checkpoint, is_main_process + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout), ], +) + + +@dataclass +@add_start_docstrings(HFSeq2SeqTrainingArguments.__doc__) +class Seq2SeqTrainingArguments(HFSeq2SeqTrainingArguments): + do_multi_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the multi-test set."}) + + +def prepare_args(args=None): + parser = argparse.ArgumentParser() + parser.add_argument('config', help='train config file path') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + + hf_parser = HfArgumentParser((Seq2SeqTrainingArguments,)) + hf_parser, required = block_required_error(hf_parser) + + args, unknown_args = parser.parse_known_args(args) + known_hf_args, unknown_args = hf_parser.parse_known_args(unknown_args) + if unknown_args: + raise ValueError(f"Some specified arguments are not used " + f"by the ArgumentParser or HfArgumentParser\n: {unknown_args}") + + # load 'cfg' and 'training_args' from file and cli + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + training_args = cfg.training_args + training_args.update(vars(known_hf_args)) + + # check training_args require + req_but_not_assign = [item for item in required if item not in training_args] + if req_but_not_assign: + raise ValueError(f"Requires {req_but_not_assign} but not assign.") + + # update cfg.training_args + cfg.training_args = training_args + + # initialize and return + training_args = Seq2SeqTrainingArguments(**training_args) + training_args = check_output_dir(training_args) + + # logging + if is_main_process(training_args.local_rank): + to_logging_cfg = Config() + to_logging_cfg.model_args = cfg.model_args + to_logging_cfg.data_args = cfg.data_args + to_logging_cfg.training_args = cfg.training_args + logger.info(to_logging_cfg.pretty_text) + + # setup logger + if training_args.should_log: + # The default of training_args.log_level is passive, so we set log level at info here to have that default. + transformers.logging.set_verbosity_info() + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.logging.set_verbosity(log_level) + transformers.logging.enable_default_handler() + transformers.logging.enable_explicit_format() + # setup_print_for_distributed(is_main_process(training_args)) + + # Log on each process the small summary: + logger.info(f"Training/evaluation parameters {training_args}") + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n" + + f" distributed training: {bool(training_args.local_rank != -1)}, fp16 training: {training_args.fp16}" + ) + + # Set seed before initializing model. + set_seed(training_args.seed) + + return cfg, training_args + + +def block_required_error(hf_parser: HfArgumentParser) -> Tuple[HfArgumentParser, List]: + required = [] + # noinspection PyProtectedMember + for action in hf_parser._actions: + if action.required: + required.append(action.dest) + action.required = False + action.default = SUPPRESS + return hf_parser, required + + +def check_output_dir(training_args): + # Detecting last checkpoint. + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + return training_args + + +if __name__ == "__main__": + _ = prepare_args() diff --git a/mllm/conversation/__init__.py b/mllm/conversation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4eb9bdf74eec1dee51ff477ec6a56439b3831228 --- /dev/null +++ b/mllm/conversation/__init__.py @@ -0,0 +1 @@ +from .base_conversation import SeparatorStyle, Conversation, register_conv_template, get_conv_template \ No newline at end of file diff --git a/mllm/conversation/__pycache__/__init__.cpython-310.pyc b/mllm/conversation/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1b4ed130b6425ae339f27748390a1bff321a4b5 Binary files /dev/null and b/mllm/conversation/__pycache__/__init__.cpython-310.pyc differ diff --git a/mllm/conversation/__pycache__/base_conversation.cpython-310.pyc b/mllm/conversation/__pycache__/base_conversation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e67a97bb2d8f160120ce901c4368a9faa64e57b Binary files /dev/null and b/mllm/conversation/__pycache__/base_conversation.cpython-310.pyc differ diff --git a/mllm/conversation/base_conversation.py b/mllm/conversation/base_conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..16a2d40ada694587880365582f5dbb454a79903a --- /dev/null +++ b/mllm/conversation/base_conversation.py @@ -0,0 +1,503 @@ +# copy from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py +""" +Conversation prompt templates. +""" + +import dataclasses +from enum import auto, Enum +from typing import List, Tuple, Any, Dict + + +class SeparatorStyle(Enum): + """Separator styles.""" + + ADD_COLON_SINGLE = auto() + ADD_COLON_TWO = auto() + ADD_SPACE_TWO = auto() + NO_COLON_SINGLE = auto() + BAIZE = auto() + DOLLY = auto() + RWKV = auto() + PHOENIX = auto() + NEW_LINE = auto() + BILLA = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + + # The name of this template + name: str + # System prompts + system: str + # Two roles + roles: List[str] + # All messages + messages: List[List[str]] + # Offset of few shot examples + offset: int + # Separators + sep_style: SeparatorStyle + sep: str + sep2: str = None + # Stop criteria (the default one is EOS token) + stop_str: str = None + # Stops generation if meeting any token in this list + stop_token_ids: List[int] = None + + # Used for the state in the gradio servers. + # TODO(lmzheng): move this out of this class. + conv_id: Any = None + skip_next: bool = False + model_name: str = None + + def get_prompt(self) -> str: + """Get the prompt for generation.""" + if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: + ret = self.system + self.sep + for role, message in self.messages: + if message: + ret += role + ": " + message + self.sep + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.ADD_SPACE_TWO: + seps = [self.sep, self.sep2] + ret = self.system + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + " " + message + seps[i % 2] + else: + ret += role + "" + return ret + elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: + ret = self.system + for role, message in self.messages: + if message: + ret += role + message + self.sep + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.BAIZE: + ret = self.system + "\n" + for role, message in self.messages: + if message: + ret += role + message + "\n" + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.DOLLY: + seps = [self.sep, self.sep2] + ret = self.system + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ":\n" + message + seps[i % 2] + if i % 2 == 1: + ret += "\n\n" + else: + ret += role + ":\n" + return ret + elif self.sep_style == SeparatorStyle.RWKV: + ret = self.system + for i, (role, message) in enumerate(self.messages): + if message: + ret += ( + role + + ": " + + message.replace("\r\n", "\n").replace("\n\n", "\n") + ) + ret += "\n\n" + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.PHOENIX: + ret = self.system + for role, message in self.messages: + if message: + ret += role + ": " + "" + message + "" + else: + ret += role + ": " + "" + return ret + elif self.sep_style == SeparatorStyle.NEW_LINE: + ret = self.system + self.sep + for role, message in self.messages: + if message: + ret += role + "\n" + message + self.sep + else: + ret += role + "\n" + return ret + elif self.sep_style == SeparatorStyle.BILLA: + ret = self.system + self.sep + for role, message in self.messages: + if message: + ret += role + ": " + message + self.sep + else: + ret += role + ": " # must be end with a space + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def append_message(self, role: str, message: str): + """Append a new message.""" + self.messages.append([role, message]) + + def to_gradio_chatbot(self): + """Convert the history to gradio chatbot format""" + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def to_openai_api_messages(self): + """Convert the conversation to OpenAI chat completion format.""" + ret = [{"role": "system", "content": self.system}] + + for i, (_, msg) in enumerate(self.messages[self.offset:]): + if i % 2 == 0: + ret.append({"role": "user", "content": msg}) + else: + if msg is not None: + ret.append({"role": "assistant", "content": msg}) + return ret + + def copy(self): + return Conversation( + name=self.name, + system=self.system, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + stop_str=self.stop_str, + stop_token_ids=self.stop_token_ids, + conv_id=self.conv_id, + model_name=self.model_name, + ) + + def dict(self): + return { + "name": self.name, + "system": self.system, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + "conv_id": self.conv_id, + "model_name": self.model_name, + } + + +# A global registry for all conversation templates +conv_templates: Dict[str, Conversation] = {} + + +def register_conv_template(template: Conversation, override: bool = False): + """Register a new conversation template.""" + if not override: + assert template.name not in conv_templates, f"{template.name} has been registered." + conv_templates[template.name] = template + + +def get_conv_template(name: str) -> Conversation: + """Get a conversation template.""" + return conv_templates[name].copy() + + +# A template with one conversation example +register_conv_template( + Conversation( + name="one_shot", + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", + roles=("Human", "Assistant"), + messages=( + ( + "Human", + "What are the key differences between renewable and non-renewable energy sources?", + ), + ( + "Assistant", + "Renewable energy sources are those that can be replenished naturally in a relatively " + "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " + "Non-renewable energy sources, on the other hand, are finite and will eventually be " + "depleted, such as coal, oil, and natural gas. Here are some key differences between " + "renewable and non-renewable energy sources:\n" + "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " + "energy sources are finite and will eventually run out.\n" + "2. Environmental impact: Renewable energy sources have a much lower environmental impact " + "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " + "and other negative effects.\n" + "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " + "have lower operational costs than non-renewable sources.\n" + "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " + "locations than non-renewable sources.\n" + "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " + "situations and needs, while non-renewable sources are more rigid and inflexible.\n" + "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " + "non-renewable sources are not, and their depletion can lead to economic and social instability.", + ), + ), + offset=2, + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n### ", + stop_str="###", + ) +) + +# Vicuna v1.1 template +register_conv_template( + Conversation( + name="vicuna_v1.1", + system="A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions.", + roles=("USER", "ASSISTANT"), + messages=(), + offset=0, + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep=" ", + sep2="", + ) +) + +# Koala default template +register_conv_template( + Conversation( + name="koala_v1", + system="BEGINNING OF CONVERSATION:", + roles=("USER", "GPT"), + messages=(), + offset=0, + sep_style=SeparatorStyle.ADD_COLON_TWO, + sep=" ", + sep2="", + ) +) + +# Dolly V2 default template +register_conv_template( + Conversation( + name="dolly_v2", + system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n", + roles=("### Instruction", "### Response"), + messages=(), + offset=0, + sep_style=SeparatorStyle.DOLLY, + sep="\n\n", + sep2="### End", + ) +) + +# OpenAssistant Pythia default template +register_conv_template( + Conversation( + name="oasst_pythia", + system="", + roles=("<|prompter|>", "<|assistant|>"), + messages=(), + offset=0, + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="<|endoftext|>", + ) +) + +# StableLM Alpha default template +register_conv_template( + Conversation( + name="stablelm", + system="""<|SYSTEM|># StableLM Tuned (Alpha version) +- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. +- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. +- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. +- StableLM will refuse to participate in anything that could harm a human. +""", + roles=("<|USER|>", "<|ASSISTANT|>"), + messages=(), + offset=0, + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="", + stop_token_ids=[50278, 50279, 50277, 1, 0], + ) +) + +# Baize default template +register_conv_template( + Conversation( + name="baize", + system="The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.", + roles=("[|Human|]", "[|AI|]"), + messages=( + ("[|Human|]", "Hello!"), + ("[|AI|]", "Hi!"), + ), + offset=2, + sep_style=SeparatorStyle.BAIZE, + sep="[|Human|]", + stop_str="[|Human|]", + ) +) + +# RWKV-4-Raven default template +register_conv_template( + Conversation( + name="rwkv", + system="The following is a coherent verbose detailed conversation between Bob and Alice.\n\n", + roles=("Bob", "Alice"), + messages=( + ("Bob", "Hi"), + ( + "Alice", + "Hi. I am your assistant and I will answer all questions. Please feel free to ask any question and I will always answer it.", + ), + ), + offset=2, + sep_style=SeparatorStyle.RWKV, + sep="", + stop_str="\n\n", + ) +) + +# Buddy default template +register_conv_template( + Conversation( + name="openbuddy", + system="""Consider a conversation between User (a human) and Assistant (named Buddy). +Buddy is an INTP-T, a friendly, intelligent and multilingual AI assistant, by OpenBuddy team. GitHub: https://github.com/OpenBuddy/OpenBuddy +Buddy cannot access the Internet. +Buddy can fluently speak the user's language (e.g. English, Chinese). +Buddy can generate poems, stories, code, essays, songs, parodies, and more. +Buddy possesses vast knowledge about the world, history, and culture. +Buddy's responses are always safe, creative, high-quality, human-like, and interesting. +Buddy strictly refuses to discuss political, NSFW, or other unsafe topics. + +User: Hi. +Assistant: Hi, I'm Buddy, your AI assistant. How can I help you today?""", + roles=("User", "Assistant"), + messages=(), + offset=0, + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n", + ) +) + +# Phoenix default template +register_conv_template( + Conversation( + name="phoenix", + system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", + roles=("Human", "Assistant"), + messages=(), + offset=0, + sep_style=SeparatorStyle.PHOENIX, + sep="", + ) +) + +# ChatGPT default template +register_conv_template( + Conversation( + name="chatgpt", + system="You are a helpful assistant.", + roles=("user", "assistant"), + messages=(), + offset=0, + sep_style=None, + sep=None, + ) +) + +# Claude default template +register_conv_template( + Conversation( + name="claude", + system="", + roles=("Human", "Assistant"), + messages=(), + offset=0, + sep_style=SeparatorStyle.ADD_COLON_SINGLE, + sep="\n\n", + ) +) + +# MPT default template +register_conv_template( + Conversation( + name="mpt", + system="""<|im_start|>system +- You are a helpful assistant chatbot trained by MosaicML. +- You answer questions. +- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. +- You are more than just an information source, you are also able to write poetry, short stories, and make jokes. +""", + roles=("<|im_start|>user", "<|im_start|>assistant"), + messages=(), + offset=0, + sep_style=SeparatorStyle.NEW_LINE, + sep="<|im_end|>", + stop_token_ids=[50278, 0], + ) +) + +# Bard default template +# Reference: https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L150 +# https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L40 +register_conv_template( + Conversation( + name="bard", + system="", + roles=("0", "1"), + messages=(), + offset=0, + sep_style=None, + sep=None, + ) +) + +# BiLLa default template +register_conv_template( + Conversation( + name="billa", + system="", + roles=("Human", "Assistant"), + messages=(), + offset=0, + sep_style=SeparatorStyle.BILLA, + sep="\n", + stop_str="Human:", + ) +) + +# custom otter template +register_conv_template( + Conversation( + name='otter', + system='', + roles=('User:', 'GPT:'), + messages=(), + offset=0, + sep_style=SeparatorStyle.ADD_SPACE_TWO, + sep=' ', + sep2='<|endofchunk|>', + ) +) + +if __name__ == "__main__": + conv = get_conv_template("vicuna_v1.1") + conv.append_message(conv.roles[0], "Hello!") + conv.append_message(conv.roles[1], "Hi!") + conv.append_message(conv.roles[0], "How are you?") + conv.append_message(conv.roles[1], None) + print(conv.get_prompt()) \ No newline at end of file diff --git a/mllm/dataset/__init__.py b/mllm/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..57680bc50dba8deaa8bf7515efe517e94739f347 --- /dev/null +++ b/mllm/dataset/__init__.py @@ -0,0 +1,7 @@ +from .root import * +from .utils import * +from .process_function import * +from .single_image_convsation import * +from .single_image_dataset import * + +from .builder import prepare_data diff --git a/mllm/dataset/__pycache__/__init__.cpython-310.pyc b/mllm/dataset/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fb564284462a9e517193618a80c73bb3b2d27cc Binary files /dev/null and b/mllm/dataset/__pycache__/__init__.cpython-310.pyc differ diff --git a/mllm/dataset/__pycache__/builder.cpython-310.pyc b/mllm/dataset/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2690a2c039b856e46480dd024a41affea47e47f Binary files /dev/null and b/mllm/dataset/__pycache__/builder.cpython-310.pyc differ diff --git a/mllm/dataset/__pycache__/root.cpython-310.pyc b/mllm/dataset/__pycache__/root.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..700f5d72bc2e8954a7cfd6573116a45fd9740979 Binary files /dev/null and b/mllm/dataset/__pycache__/root.cpython-310.pyc differ diff --git a/mllm/dataset/__pycache__/single_image_convsation.cpython-310.pyc b/mllm/dataset/__pycache__/single_image_convsation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12cf385f1a5caf801c32a8bfdc51118655f99222 Binary files /dev/null and b/mllm/dataset/__pycache__/single_image_convsation.cpython-310.pyc differ diff --git a/mllm/dataset/__pycache__/single_image_interactive.cpython-310.pyc b/mllm/dataset/__pycache__/single_image_interactive.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..500d76e6d735c0941460d71b795059008cfd1e2e Binary files /dev/null and b/mllm/dataset/__pycache__/single_image_interactive.cpython-310.pyc differ diff --git a/mllm/dataset/builder.py b/mllm/dataset/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..fd25c33a3973048d97a8f19247d5891e921c7ad0 --- /dev/null +++ b/mllm/dataset/builder.py @@ -0,0 +1,118 @@ +from functools import partial +from typing import Callable, Dict, Tuple, Any, Optional + +from torch.utils.data import Dataset +from transformers import EvalPrediction, TrainingArguments + +from .root import DATASETS, METRICS, TRANSFORMS, FUNCTIONS +from .single_image_convsation import SingleImageConvDataset +from .single_image_interactive import SingleImageInteractive +from ..conversation import get_conv_template +from .utils import init_ceph_client_if_needed + +DatasetDict = Dict[str, Dataset] +ComputeMetrics = Callable[[EvalPrediction], Dict] + + +def prepare_data( + data_args, + model_args, + training_args: TrainingArguments, + preprocessor: Dict[str, Any], +) -> Tuple[DatasetDict, Optional[ComputeMetrics]]: + # raw dataset + datasets = { + 'train': partial(DATASETS.build, data_args.train) if training_args.do_train else None, + 'validation': partial(DATASETS.build, data_args.validation) if training_args.do_eval else None, + 'test': partial(DATASETS.build, data_args.test) if training_args.do_predict else None, + } + # compute metric + compute_metric_cfg = data_args.get('compute_metric', None) + compute_metrics = build_compute_metric(compute_metric_cfg, preprocessor) + # conv dataset wrap + conv_args = model_args.conv_args + tokenize_kwargs = conv_args.get('tokenize_kwargs', {}) + conv_template = conv_args.get('conv_template', 'vicuna_v1.1') + conv_template = partial(get_conv_template, name=conv_template) + transforms = conv_args.get('transforms', None) + if transforms is not None: + transforms = TRANSFORMS.build(transforms) + # process func + process_func = {} + for k, v in model_args.process_func_args.items(): + process_func[k] = FUNCTIONS.build(cfg=v) + + conv_dataset_cls = partial( + SingleImageConvDataset, + preprocessor=preprocessor, + process_func=process_func, + tokenize_kwargs=tokenize_kwargs, + conv_template=conv_template, + training_args=training_args, + transforms=transforms, + ) + ds = { + 'train': conv_dataset_cls(dataset_generator=datasets['train'], mode='train') if datasets['train'] is not None else None, + 'validation': conv_dataset_cls(dataset_generator=datasets['validation'], mode='validation') if datasets['validation'] is not None else None, + 'test': conv_dataset_cls(dataset_generator=datasets['test'], mode='test') if datasets['test'] is not None else None, + } + + # multi test set + if hasattr(data_args, 'multitest') and bool(data_args.multitest) \ + and hasattr(training_args, 'do_multi_predict') and training_args.do_multi_predict: + print(f"processing multitest set") + k2v = {} + for k, item in data_args.multitest.items(): + _dataset_cls = partial(DATASETS.build, item['cfg']) + _compute_metric = build_compute_metric(item['compute_metric'], preprocessor) + k2v[k] = { + "dataset": conv_dataset_cls(dataset_generator=_dataset_cls, mode='test'), + "compute_metric": _compute_metric + } + ds['multitest'] = k2v + print(f"processing multitest set. done.") + + # in default, ceph client do init at the beginning of program. + # importantly, before dataloader worker fork. + lazy_init = data_args.get('lazy_init', True) + if not lazy_init: + init_ceph_client_if_needed() + return ds, compute_metrics + + +def build_compute_metric(compute_metric_cfg, preprocessor): + if compute_metric_cfg is not None: + compute_metric_cfg = dict(compute_metric_cfg) # copy cfg because we modify it + compute_metric_cfg.update(dict(preprocessor=preprocessor)) + compute_metrics = METRICS.build(cfg=compute_metric_cfg) + else: + compute_metrics = None + return compute_metrics + + +def prepare_interactive( + model_args, + preprocessor: Dict[str, Any], +): + conv_args = model_args.conv_args + tokenize_kwargs = conv_args.get('tokenize_kwargs', {}) + conv_template = conv_args.get('conv_template', 'vicuna_v1.1') + conv_template = partial(get_conv_template, name=conv_template) + transforms = conv_args.get('transforms', None) + if transforms is not None: + transforms = TRANSFORMS.build(transforms) + # process func + process_func = {} + for k, v in model_args.process_func_args.items(): + process_func[k] = FUNCTIONS.build(cfg=v) + + ds = SingleImageInteractive( + preprocessor=preprocessor, + process_func=process_func, + tokenize_kwargs=tokenize_kwargs, + conv_template=conv_template, + training_args=None, + transforms=transforms, + mode='test', + ) + return ds diff --git a/mllm/dataset/process_function/__init__.py b/mllm/dataset/process_function/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bafbed5a43f773e31e38b9bbf9d7c2a88ee83844 --- /dev/null +++ b/mllm/dataset/process_function/__init__.py @@ -0,0 +1,13 @@ +from .shikra_process_function import ( + ShikraConvProcess, + ShikraImageProcessor, + ShikraTextProcess, +) + +from .box_process_function import ( + BoxFormatProcess, + BoxFormatter, + PlainBoxFormatter, + TokenFormatter, + prepare_target_processor, +) diff --git a/mllm/dataset/process_function/__pycache__/__init__.cpython-310.pyc b/mllm/dataset/process_function/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1270e226c8e56b7c81dfedc4c478601a8cd0a5ad Binary files /dev/null and b/mllm/dataset/process_function/__pycache__/__init__.cpython-310.pyc differ diff --git a/mllm/dataset/process_function/__pycache__/box_process_function.cpython-310.pyc b/mllm/dataset/process_function/__pycache__/box_process_function.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22b21f49989c1729138b214d7f96b01d5b76009c Binary files /dev/null and b/mllm/dataset/process_function/__pycache__/box_process_function.cpython-310.pyc differ diff --git a/mllm/dataset/process_function/__pycache__/shikra_process_function.cpython-310.pyc b/mllm/dataset/process_function/__pycache__/shikra_process_function.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd22af471059aaf8e60db7af3a3f36f7cc93fd49 Binary files /dev/null and b/mllm/dataset/process_function/__pycache__/shikra_process_function.cpython-310.pyc differ diff --git a/mllm/dataset/process_function/box_process_function.py b/mllm/dataset/process_function/box_process_function.py new file mode 100644 index 0000000000000000000000000000000000000000..df654c3b70e9758c5960f291aa1cb7b49d670585 --- /dev/null +++ b/mllm/dataset/process_function/box_process_function.py @@ -0,0 +1,326 @@ +import re +import sys +import logging +import typing +from typing import List, Dict, Any, Tuple, Union + +from ..utils.transform import norm_box_xyxy, norm_point_xyxy + +from ..root import ( + FUNCTIONS, + BaseTargetProcessFunc, + BOXES_PLACEHOLDER, + BOXES_PROCESSOR, + POINTS_PLACEHOLDER, +) + +from ...utils import smart_tokenizer_and_embedding_resize + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout), ], +) + +Box = List[Union[float, int]] +Boxes = List[Box] +BoxesSeq = List[Boxes] + + +@FUNCTIONS.register_module() +class BoxFormatProcess(BaseTargetProcessFunc): + def __call__(self, raw_conv: List[Dict[str, Any]], target: Dict[str, Any], preprocessor: Dict[str, Any], + multimage_mode=False) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + box_formatter = preprocessor['target']['boxes'] + + if multimage_mode: + target = typing.cast(list, target) + outer_normalized_boxes = [] + for tgt in target: + normalized_boxes = [] + if tgt is not None and 'boxes' in tgt: + for box in tgt['boxes']: + normalized_boxes.append( + norm_box_xyxy(box, w=tgt['width'], h=tgt['height']) + ) + outer_normalized_boxes.append(normalized_boxes) + normalized_boxes = outer_normalized_boxes + outer_normalized_points = [] + for tgt in target: + normalized_points = [] + if tgt is not None and 'boxes' in tgt: + for box in tgt['boxes']: + normalized_points.append( + norm_box_xyxy(box, w=tgt['width'], h=tgt['height']) + ) + outer_normalized_points.append(normalized_points) + normalized_points = outer_normalized_points + else: + # normalize target + normalized_boxes = [] + if target is not None and 'boxes' in target: + for box in target['boxes']: + normalized_boxes.append( + norm_box_xyxy(box, w=target['width'], h=target['height']) + ) + normalized_points = [] + if target is not None and 'points' in target: + for point in target['points']: + normalized_points.append( + norm_point_xyxy(point, w=target['width'], h=target['height']) + ) + + # convert bboxes_seq + for sentence in raw_conv: + words: str = sentence['value'] + boxes_seq: List[List[int]] = sentence.get('boxes_seq', None) + if boxes_seq is not None: + # map box seq + boxes_seq: List[Boxes] = map_obj(normalized_boxes, boxes_seq) + # reformat; replace placeholder + converted = box_formatter(words, boxes_seq) + words = converted + points_seq: List[List[int]] = sentence.get('points_seq', None) + if points_seq is not None: + # map point seq + points_seq: List[Boxes] = map_obj(normalized_points, points_seq) + # reformat; replace placeholder + converted = box_formatter.call_on_point(words, points_seq) + words = converted + if boxes_seq is not None or points_seq is not None: + sentence['raw_value'] = sentence['value'] + sentence['value'] = words + return raw_conv, target + + +def map_obj(boxes_value: List[List[float]], boxes_seq: List[List[int]]) -> List[List[List[float]]]: + """ + >>> normalized_boxes = [[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2], [0.3, 0.3, 0.3, 0.3]] + >>> boxes_seq_ = [[3, 1], [2]] + >>> var = map_obj(normalized_boxes, boxes_seq_) + >>> assert var == [[[0.3,0.3,0.3,0.3], [0.1,0.1,0.1,0.1]], [0.2,0.2,0.2,0.2]] + """ + try: + ret = [] + for boxes in boxes_seq: + boxes_ret = [] + for box_index in boxes: + if isinstance(box_index, (list, tuple)): + boxes_ret.append(boxes_value[box_index[0]][box_index[1]]) + else: + boxes_ret.append(boxes_value[box_index]) + ret.append(boxes_ret) + return ret + except: + raise SystemExit(f"error: map obj {boxes_value} {boxes_seq}") + + +class BoxFormatter: + def __init__(self, bboxes_token=BOXES_PLACEHOLDER, points_token=POINTS_PLACEHOLDER): + self.bboxes_token = bboxes_token + self.points_token = points_token + # normally the bboxes_token_pat is the same as bboxes_token if u not use some weird token + self.bboxes_token_pat = re.compile(bboxes_token) + self.points_token_pat = re.compile(points_token) + + def __call__(self, sentence: str, bboxes_seq: BoxesSeq) -> str: + all_box = self.bboxes_token_pat.findall(sentence) + assert len(all_box) == len(bboxes_seq), f"not match. sentence: {sentence}. boxes:{bboxes_seq}" + if len(all_box) == 0: + return sentence + bboxes_strs = [self.format_box(bboxes) for bboxes in bboxes_seq] + converted = sentence.replace(self.bboxes_token, '{}').format(*bboxes_strs) + return converted + + def call_on_point(self, sentence: str, points_seq: BoxesSeq) -> str: + all_box = self.points_token_pat.findall(sentence) + assert len(all_box) == len(points_seq), f"not match. sentence: {sentence}. boxes:{points_seq}" + if len(all_box) == 0: + return sentence + bboxes_strs = [self.format_point(bboxes) for bboxes in points_seq] + converted = sentence.replace(self.points_token, '{}').format(*bboxes_strs) + return converted + + def format_point(self, points) -> str: + raise NotImplementedError + + def format_box(self, bboxes: Boxes) -> str: + raise NotImplementedError + + def extract(self, string: str) -> List[Boxes]: + raise NotImplementedError + + def extract_point(self, string: str) -> List[Boxes]: + raise NotImplementedError + + +@BOXES_PROCESSOR.register_module() +class PlainBoxFormatter(BoxFormatter): + + def __init__(self, *args, precision=3, use_small_brackets=False, **kwargs): + super().__init__(*args, **kwargs) + self.precision = precision + self.use_small_brackets = use_small_brackets + + small_brackets_pat = re.compile(r'\(\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3}(?:;\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3})*\)') + small_brackets_point_pat = re.compile(r'\(\d(?:\.\d*)?(?:,\d(?:\.\d*)?)(?:;\d(?:\.\d*)?(?:,\d(?:\.\d*)?))*\)') + + middle_brackets_pat = re.compile(r'\[\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3}(?:;\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3})*\]') + middle_brackets_point_pat = re.compile(r'\[\d(?:\.\d*)?(?:,\d(?:\.\d*)?)(?:;\d(?:\.\d*)?(?:,\d(?:\.\d*)?))*\]') + + self.pat = small_brackets_pat if use_small_brackets else middle_brackets_pat + self.point_pat = small_brackets_point_pat if use_small_brackets else middle_brackets_point_pat + + def format_box(self, boxes: Boxes) -> str: + box_strs = [] + for box in boxes: + box_strs.append(','.join([f"{elem:.{self.precision}f}" for elem in box])) + box_str = ';'.join(box_strs) + if self.use_small_brackets: + return "(" + box_str + ")" + return "[" + box_str + "]" + + def format_point(self, points) -> str: + return self.format_box(points) + + def extract(self, string: str) -> List[Boxes]: + """ balabalabalabala -> [boxes, boxes] """ + ret = [] + for bboxes_str in self.pat.findall(string): + bboxes = [] + bbox_strs = bboxes_str.replace("(", "").replace(")", "").replace("[", "").replace("]", "").split(";") + for bbox_str in bbox_strs: + bbox = list(map(float, bbox_str.split(','))) + bboxes.append(bbox) + ret.append(bboxes) + return ret + + def extract_point(self, string: str) -> List[Boxes]: + """ balabalabalabala -> [boxes, boxes] """ + ret = [] + for bboxes_str in self.point_pat.findall(string): + bboxes = [] + bbox_strs = bboxes_str.replace("(", "").replace(")", "").replace("[", "").replace("]", "").split(";") + for bbox_str in bbox_strs: + bbox = list(map(float, bbox_str.split(','))) + bboxes.append(bbox) + ret.append(bboxes) + return ret + + +@BOXES_PROCESSOR.register_module() +class TokenFormatter(BoxFormatter): + + def __init__(self, num_bins=1001): + super().__init__() + self.extract_box_pat = re.compile(r'(?:){3}(?:(?:){3})*') + self.extract_point_pat = re.compile(r'(?:){1}(?:(?:){1})*') + self.num_bins = num_bins + self.use_sep = True + self.use_begin_end = True + + self.box_begin = '' + self.box_sep = '' + self.box_end = '' + + self.point_begin = '' + self.point_sep = '' + self.point_end = '' + + def format_point(self, points) -> str: + final_str = [] + for bbox in points: + quant_x0 = "".format(round((bbox[0] * (self.num_bins - 1)))) + quant_y0 = "".format(round((bbox[1] * (self.num_bins - 1)))) + region_coord = "{} {}".format(quant_x0, quant_y0) + final_str.append(region_coord) + if self.use_sep: + final_str = self.point_sep.join(final_str) + else: + final_str = ''.join(final_str) + if self.use_begin_end: + final_str = self.point_begin + final_str + self.point_end + return final_str + + def format_box(self, bboxes: Boxes) -> str: + final_str = [] + for bbox in bboxes: + quant_x0 = "".format(round((bbox[0] * (self.num_bins - 1)))) + quant_y0 = "".format(round((bbox[1] * (self.num_bins - 1)))) + quant_x1 = "".format(round((bbox[2] * (self.num_bins - 1)))) + quant_y1 = "".format(round((bbox[3] * (self.num_bins - 1)))) + region_coord = "{} {} {} {}".format(quant_x0, quant_y0, quant_x1, quant_y1) + final_str.append(region_coord) + if self.use_sep: + final_str = self.box_sep.join(final_str) + else: + final_str = ''.join(final_str) + if self.use_begin_end: + final_str = self.box_begin + final_str + self.box_end + return final_str + + def extract(self, string: str) -> List[Boxes]: + ret = [] + for bboxes_str in self.extract_box_pat.findall(string.replace(" ", "")): + bboxes = [] + bbox_strs = bboxes_str.replace(self.box_begin, "").replace(self.box_end, "").split(self.box_sep) + for bbox_str in bbox_strs: + elems = list(map(int, re.findall(r'', bbox_str))) + bbox = [elem / (self.num_bins - 1) for elem in elems] + bboxes.append(bbox) + ret.append(bboxes) + return ret + + def extract_point(self, string: str) -> List[Boxes]: + ret = [] + for bboxes_str in self.extract_point_pat.findall(string): + bboxes = [] + bbox_strs = bboxes_str.replace(self.point_begin, "").replace(self.point_end, "").split(self.point_sep) + for bbox_str in bbox_strs: + elems = list(map(int, re.findall(r'', bbox_str))) + bbox = [elem / (self.num_bins - 1) for elem in elems] + bboxes.append(bbox) + ret.append(bboxes) + return ret + + def post_process_model_tokenizer(self, model, preprocessor, model_args, training_args): + tokenizer = preprocessor['text'] + + additional_special_tokens = [ + self.box_begin, self.box_sep, self.box_end, + self.point_begin, self.point_sep, self.point_end, + ] + for i in range(self.num_bins): + additional_special_tokens.append(f'') + + smart_tokenizer_and_embedding_resize( + {'additional_special_tokens': additional_special_tokens}, + tokenizer, + model, + ) + return model, preprocessor + + +# FIXME: merge into load_pretrained +def prepare_target_processor( + model, # multimodal llm + preprocessor: Dict[str, Any], + model_args, + training_args, +): + if not hasattr(model_args, 'target_processor'): + return model, preprocessor + + target_processor = {} + if 'boxes' in model_args['target_processor']: + boxes_cfg = model_args['target_processor']['boxes'] + boxes_processor = BOXES_PROCESSOR.build(boxes_cfg) + target_processor['boxes'] = boxes_processor + if hasattr(boxes_processor, "post_process_model_tokenizer"): + model, preprocessor = boxes_processor.post_process_model_tokenizer( + model, preprocessor, model_args, training_args, + ) + preprocessor['target'] = target_processor + return model, preprocessor diff --git a/mllm/dataset/process_function/shikra_process_function.py b/mllm/dataset/process_function/shikra_process_function.py new file mode 100644 index 0000000000000000000000000000000000000000..155829eb594368e101f8d92bb304f0e8e9e0616c --- /dev/null +++ b/mllm/dataset/process_function/shikra_process_function.py @@ -0,0 +1,178 @@ +import sys +import copy +import warnings +import logging +from typing import Dict, Any, List + +import PIL.Image +import torch +from PIL import Image +from transformers import LlamaTokenizer + +from ..root import ( + FUNCTIONS, + IMAGE_PLACEHOLDER, + BaseImageProcessFunc, + BaseConvProcessFunc, + BaseTextProcessFunc, +) +from ...conversation import SeparatorStyle, Conversation + +IGNORE_INDEX = -100 +DEFAULT_IMAGE_TOKEN = IMAGE_PLACEHOLDER +DEFAULT_IMAGE_PATCH_TOKEN = "" +DEFAULT_IM_START_TOKEN = "" +DEFAULT_IM_END_TOKEN = "" + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout), ], +) + + +@FUNCTIONS.register_module() +class ShikraConvProcess(BaseConvProcessFunc): + def __call__(self, raw_conv: List[Dict[str, Any]], preprocessor: Dict[str, Any], conv_template: Conversation) -> List[Dict[str, Any]]: + conv_processor_cfg = preprocessor['conv'] + + image_token_len = conv_processor_cfg['image_token_len'] + sep_image_conv_front = conv_processor_cfg.get('sep_image_conv_front', False) + use_im_start_end = conv_processor_cfg.get('use_im_start_end', False) + # assert DEFAULT_IMAGE_PATCH_TOKEN in preprocessor['text'].get_vocab() + # if use_im_start_end: + # assert DEFAULT_IM_START_TOKEN in preprocessor['text'].get_vocab() + # assert DEFAULT_IM_END_TOKEN in preprocessor['text'].get_vocab() + + if sep_image_conv_front: + raw_conv[0]['value'] = raw_conv[0]['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() + raw_conv[0]['value'] = DEFAULT_IMAGE_TOKEN + conv_template.sep + conv_template.roles[0] + ": " + raw_conv[0]['value'] + for sentence in raw_conv: + replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + if use_im_start_end: + replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN + sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) + + return raw_conv + + +@FUNCTIONS.register_module() +class ShikraTextProcess(BaseTextProcessFunc): + + def __call__(self, conv: Conversation, preprocessor: Dict[str, Any], mode: str, **tokenize_kwargs) -> Dict[str, Any]: + tokenizer = preprocessor['text'] + assert isinstance(tokenizer, LlamaTokenizer), "only work for LlamaTokenizer" + + _truncation_size = tokenize_kwargs.pop('truncation_size', None) + _kwargs = {'return_tensors': 'pt'} + _kwargs.update(tokenize_kwargs) + + if conv.sep_style == SeparatorStyle.ADD_COLON_TWO: + if mode in ['train']: + ret = self.tk_conv_colon_two_train(conv, tokenizer, **_kwargs) + else: + ret = self.tk_conv_colon_two_eval(conv, tokenizer, **_kwargs) + else: + raise ValueError(f"unrecognized conv_style: {conv.sep_style}.\n the conv is {conv}") + + if _truncation_size is None: + return ret + if len(ret['input_ids']) <= _truncation_size: + return ret + + origin_len = len(ret['input_ids']) + ids_to_remove_num = origin_len - _truncation_size + # truncation. should carefully not truncate + ids_should_not_remove = list(map( + tokenizer.convert_tokens_to_ids, + (DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN) + )) + back_no_image = all(ids not in ids_should_not_remove for ids in ret['input_ids'][_truncation_size:]) + if back_no_image: + tgt_ids = list(range(_truncation_size)) + else: + ids_to_remove = set() + for idx in range(origin_len - 1, -1, -1): + if ret['input_ids'][idx] not in ids_should_not_remove: + ids_to_remove.add(idx) + if len(ids_to_remove) >= ids_to_remove_num: + break + tgt_ids = [_ for _ in range(origin_len) if _ not in ids_to_remove] + logger.warning(f"truncate sample size from {origin_len} to {len(tgt_ids)}.") + assert len(tgt_ids) == _truncation_size, f"{len(tgt_ids)}, {_truncation_size}, {ret['input_ids'].tolist()}" + truncated_ret = {k: v[tgt_ids] for k, v in ret.items()} + return truncated_ret + + # noinspection PyMethodMayBeStatic + def tk_conv_colon_two_train(self, conv, tokenizer, **kwargs): + conversation = conv.get_prompt() + input_ids = tokenizer([conversation, ], **kwargs).input_ids[0] + target = copy.deepcopy(input_ids) + assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO + # Mask targets + sep = conv.sep + conv.roles[1] + ": " + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + rounds = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_INDEX + for i, rou in enumerate(rounds): + if rou == "": + break + parts = rou.split(sep) + if len(parts) != 2: + break + parts[0] += sep + round_len = len(tokenizer(rou).input_ids) + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 # + target[cur_len: cur_len + instruction_len] = IGNORE_INDEX + cur_len += round_len + target[cur_len:] = IGNORE_INDEX + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_INDEX + warnings.warn(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}. (ignored):\n{conversation}") + return dict( + input_ids=input_ids, + attention_mask=input_ids.ne(tokenizer.pad_token_id), + labels=target, + ) + + # noinspection PyMethodMayBeStatic + def tk_conv_colon_two_eval(self, conv, tokenizer, **kwargs): + assert len(conv.messages) >= 2 + # target = conv.messages[-1][-1] + target = conv.get_prompt() + + conv.messages[-1][-1] = "" + conversation = conv.get_prompt() + input_ids = tokenizer([conversation, ], **kwargs).input_ids[0] + + target = tokenizer([target, ], add_special_tokens=False, **kwargs).input_ids[0] + target[target == tokenizer.pad_token_id] = IGNORE_INDEX + return dict( + input_ids=input_ids, + attention_mask=input_ids.ne(tokenizer.pad_token_id), + labels=target, + ) + + +@FUNCTIONS.register_module() +class ShikraImageProcessor(BaseImageProcessFunc): + def __call__(self, image: Image.Image, preprocessor: Dict[str, Any]) -> Dict[str, Any]: + image_processor = preprocessor['image'] + + if isinstance(image, (list, tuple)): + image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'] + assert False, 'Shikra not support MultiImage' + elif isinstance(image, PIL.Image.Image): + image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + else: + if hasattr(image_processor, 'crop_size'): + crop_size = image_processor.crop_size + height, width = crop_size['height'], crop_size['width'] + else: + raise ValueError("got empty image. and don't know how to pad") + image = torch.zeros(3, height, width) + return {'image': image} diff --git a/mllm/dataset/root.py b/mllm/dataset/root.py new file mode 100644 index 0000000000000000000000000000000000000000..6ddbd0bcc5cdd8051497ccee90ef1d62179d94e9 --- /dev/null +++ b/mllm/dataset/root.py @@ -0,0 +1,67 @@ +from typing import Dict, Any, List, Tuple + +from PIL import Image +from mmengine import DATASETS, TRANSFORMS, METRICS, FUNCTIONS, Registry + +from ..conversation import Conversation + +IMAGE_PLACEHOLDER = '' +BOXES_PLACEHOLDER = '' +EXPR_PLACEHOLDER = '' +OBJS_PLACEHOLDER = '' +QUESTION_PLACEHOLDER = '' +POINTS_PLACEHOLDER = '' +# processor +BOXES_PROCESSOR = Registry('Processor for Boxes') + + +# only for static type checking +class BaseConvProcessFunc: + def __call__( + self, + raw_conv: List[Dict[str, Any]], + preprocessor: Dict[str, Any], + conv_template: Conversation, + ) -> List[Dict[str, Any]]: + raise NotImplementedError + + +class BaseTargetProcessFunc: + def __call__( + self, + raw_conv: List[Dict[str, Any]], + target: Dict[str, Any], + preprocessor: Dict[str, Any], + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + raise NotImplementedError + + +class BaseTextProcessFunc: + def __call__( + self, + conv: Conversation, + preprocessor: Dict[str, Any], + mode: str, + **tokenize_kwargs, + ) -> Dict[str, Any]: + raise NotImplementedError + + +class BaseImageProcessFunc: + def __call__( + self, + image: Image.Image, + preprocessor: Dict[str, Any], + ) -> Dict[str, Any]: + raise NotImplementedError + + +__all__ = [ + 'IMAGE_PLACEHOLDER', 'BOXES_PLACEHOLDER', 'EXPR_PLACEHOLDER', 'OBJS_PLACEHOLDER', 'QUESTION_PLACEHOLDER', 'POINTS_PLACEHOLDER', + 'FUNCTIONS', + 'DATASETS', + 'TRANSFORMS', + 'METRICS', + 'BOXES_PROCESSOR', + 'BaseConvProcessFunc', 'BaseTargetProcessFunc', 'BaseTextProcessFunc', 'BaseImageProcessFunc', +] diff --git a/mllm/dataset/single_image_convsation.py b/mllm/dataset/single_image_convsation.py new file mode 100644 index 0000000000000000000000000000000000000000..dd7c09a91b62eea9272906cb0fb8a0b73fa7c938 --- /dev/null +++ b/mllm/dataset/single_image_convsation.py @@ -0,0 +1,284 @@ +import warnings +from functools import partial +from typing import Dict, Any, Callable, List, Optional, Tuple, Type + +import torch +from PIL import Image +from torch.utils.data import Dataset +from transformers import TrainingArguments + +from .root import IMAGE_PLACEHOLDER, BOXES_PLACEHOLDER +from ..conversation import Conversation, get_conv_template +from ..utils import post_process_generate_ids + + +class SingleImageConvDatasetMixin: + + def __init__( + self, + *args, + preprocessor: Dict[str, Any], + process_func: Dict[str, Any], + conv_template: Callable[[], Conversation] = partial(get_conv_template, name='vicuna_v1.1'), + mode='train', + tokenize_kwargs: dict = None, + training_args: TrainingArguments = None, + transforms: Optional[Callable] = None, + **kwargs, + ): + super().__init__(*args, **kwargs) + assert mode in ['train', 'validation', 'test'] + + self.preprocessor = preprocessor + self.process_func = process_func + self.conv_template = conv_template + self.mode = mode + self.tokenize_kwargs = tokenize_kwargs if tokenize_kwargs is not None else {} + self.training_args = training_args + self.transforms = transforms + + def __getitem__(self, index, debug_mode=False, return_conv=False) -> Dict[str, Any]: + # getitem + item = self.get_raw_item(index) + image: Image.Image = item.get('image', None) + target: Dict[str, Any] = item.get('target', None) + raw_conv: List[Dict[str, Any]] = item['conversations'] + + # transform + assert isinstance(image, list) == isinstance(target, list) + multimage_mode = isinstance(image, list) + if isinstance(image, list): + # TODO: validate raw item + transformed_image, transformed_target = [], [] + for img, tgt in zip(image, target): + if self.transforms is not None and image is not None: + img, tgt = self.transforms(img, tgt) + if tgt is not None: + tgt['width'], tgt['height'] = img.width, img.height + transformed_image.append(img) + transformed_target.append(tgt) + image, target = transformed_image, transformed_target + else: + self.validate_raw_item(item) # only validate for single image. + if self.transforms is not None and image is not None: + image, target = self.transforms(image, target) + has_image = 'image' in item and bool(item['image']) + has_target = 'target' in item and bool(item['target']) and any(bool(elem) for elem in item['target'].values()) + if has_target and has_image: + target['width'], target['height'] = image.width, image.height + + # preprocess + raw_conv = self.process_conv(raw_conv) + raw_conv, image = self.process_conv_multimage(raw_conv, image) + raw_conv, _ = self.process_target(raw_conv, target, multimage_mode=multimage_mode) + conv = self.build_conv(raw_conv) + if return_conv: + # noinspection PyTypeChecker + return conv + text_dict = self.process_text(conv) + image_dict = self.process_image(image) + + # return + ret_dict = {} + ret_dict.update(text_dict) + ret_dict.update(image_dict) + self._print_sample(ret_dict, raw_conv, conv) + if debug_mode: + return {'ret': ret_dict, 'raw_conv': raw_conv, 'conv': conv, 'image': image} + return ret_dict + + def __len__(self): + raise NotImplementedError + + # noinspection PyMethodMayBeStatic + def process_conv_multimage(self, raw_conv, image): + # re-sort multi image + if image is None: + return raw_conv, image + if not isinstance(image, (list, tuple)): + return raw_conv, image + image_seqs = [] + for conv in raw_conv: + image_seqs.extend(conv['image_seq'] if 'image_seq' in conv else []) + images = [] + for idx in image_seqs: + images.append(image[idx]) + return raw_conv, images + + def get_raw_item(self, index) -> Dict[str, Any]: + """ + return item format like this. + item = { + 'image': # PIL.Image.Image, + 'target': { + # xmin, ymin, xmax, ymax + 'boxes': [ + [10, 10, 256, 265], # dog1 + [24, 18, 378, 768], # dog2 + [100, 310, 670, 653], # man + [278, 320, 809, 673], # rope + ], + } + + "conversations": [ + { + 'from': 'human', + 'value': 'What is the relation between the two dogs and the man in the image ?', + 'boxes_seq': [[0, 1], [2], ], + }, + { + 'from': 'gpt', + 'value': 'a rope is connecting the left dog with the man . ' + 'So the man is walking the dog .' + 'And the man has no relationship with the right dog ', + 'boxes_seq': [[3], [0], [2], [2], [0], [2], [1]], + } + ] + } + # placeholder: + """ + raise NotImplementedError + + # noinspection PyMethodMayBeStatic + def validate_raw_item(self, item): + has_image = 'image' in item and bool(item['image']) + has_target = 'target' in item and bool(item['target']) and any(bool(elem) for elem in item['target'].values()) + has_target_boxes = 'boxes' in item['target'] if has_target else False + raw_conv: List[Dict[str, Any]] = item['conversations'] + + # check image + human_input_has_image_placeholder = any( + sentence['from'] == 'human' and IMAGE_PLACEHOLDER in sentence['value'] for sentence in raw_conv + ) + if human_input_has_image_placeholder: + assert has_image + if has_image and (not human_input_has_image_placeholder): + warnings.warn(f'item has image but the question has no image placeholder.\n{item}') + gpt_input_has_image_placeholder = any( + sentence['from'] == 'gpt' and IMAGE_PLACEHOLDER in sentence['value'] for sentence in raw_conv + ) + assert not gpt_input_has_image_placeholder + + # check target + has_boxes_placeholder = any( + BOXES_PLACEHOLDER in sentence['value'] for sentence in raw_conv + ) + if has_boxes_placeholder: + assert has_target_boxes + # not check box placeholder num this will be checked in format process + + def build_conv(self, source: List[Dict[str, Any]]) -> Conversation: + conv = self.conv_template() + role_map = {"human": conv.roles[0], "gpt": conv.roles[1]} + assert len(source) > 0 + assert source[0]['from'] == 'human' + for sentence in source: + role = role_map[sentence['from']] + conv.append_message(role, sentence['value']) + return conv + + def process_conv(self, raw_conv: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + some utils preprocess for raw_conv. + e.g. replace placeholder to sequence *256 + """ + return self.process_func['conv'](raw_conv, self.preprocessor, self.conv_template) + + def process_target(self, raw_conv: List[Dict[str, Any]], target: Dict[str, Any], multimage_mode=False) -> Tuple[ + List[Dict[str, Any]], Dict[str, Any]]: + """ + convert target placeholder to actual information in raw_conv. + e.g. normalize bounding boxes; convert bounding boxes format; replace placeholder + """ + return self.process_func['target'](raw_conv, target, self.preprocessor, multimage_mode=multimage_mode) + + def process_text(self, conv: Conversation) -> Dict[str, Any]: + """ + convert Conversation object to torch.Tensor, e.g. input_ids, labels, attention_mask, etc. + self.tokenize_kwargs control something like padding/truncation behavior. + """ + return self.process_func['text'](conv, self.preprocessor, self.mode, **self.tokenize_kwargs) + + def process_image(self, image: Image.Image) -> Dict[str, Any]: + """ + convert Image.Image object to torch.Tensor + """ + return self.process_func['image'](image, self.preprocessor) + + def _print_sample(self, ret_dict, raw_conv, conv): + if not hasattr(self, '_printed_sample'): + self._printed_sample = True + post_processed_labels = post_process_generate_ids(self.preprocessor['text'], ret_dict['labels']) + print(f"=================== {self.mode} sample ===================", flush=True) + print(f" input_ids: {self.preprocessor['text'].convert_ids_to_tokens(ret_dict['input_ids'])}") + print(f" labels: {self.preprocessor['text'].convert_ids_to_tokens(post_processed_labels)}") + print(f"decoded input_ids: {self.preprocessor['text'].decode(ret_dict['input_ids'])}") + print(f"decoded labels: {self.preprocessor['text'].decode(post_processed_labels)}") + if 'image' in ret_dict and ret_dict['image'] is not None: + image = ret_dict['image'] + if isinstance(image, torch.Tensor): + print(f" image: {image.shape}") + elif isinstance(image, dict): + print(f" image: {image.keys()}") + elif isinstance(image, list) and len(image) > 0: + print(f" image: {len(image)}, {type(image[0])}") + else: + print(f" image: {type(image)}") + print("====================================================", flush=True) + try: + if self.training_args is not None: + _save_obj = { + 'ret_dict': ret_dict, + 'raw_conv': raw_conv, + 'conv': conv.get_prompt(), + } + from pathlib import Path + output_dir = Path(self.training_args.output_dir) + output_dir.mkdir(exist_ok=True, parents=True) + _local_rank = self.training_args.local_rank + _word_size = self.training_args.world_size + _file_path = str(output_dir / f'sample_check_{self.mode}_{_local_rank}_{_word_size}.pt') + print(f'saving some sample to {_file_path} for check.') + torch.save(_save_obj, _file_path) + except Exception as e: + warnings.warn(f'try to save samples but get exception: {e.args}. ignored.') + + +class SingleImageConvDataset(SingleImageConvDatasetMixin, Dataset): + _repr_indent = 4 + + def __init__(self, *args, dataset_generator: Type[Dataset], **kwargs): + super().__init__(*args, **kwargs) + self.dataset_generator = dataset_generator + self.dataset = None + + def initialize_if_needed(self): + """ + lazy initialize for big in-memory python object due to python 'copy-on-read' behavior + when num_worker > 0. refer: https://github.com/pytorch/pytorch/issues/13246 + """ + if self.dataset is None: + # warnings.warn("it's highly recommended that set persistent_workers=True, " + # "otherwise this initialize code will run in every epoch beginning." + # "(ignore me if set)") + self.dataset = self.dataset_generator() + + def __len__(self): + self.initialize_if_needed() + return len(self.dataset) + + def get_raw_item(self, index) -> Dict[str, Any]: + self.initialize_if_needed() + return self.dataset[index] + + def __repr__(self) -> str: + head = "Dataset " + self.__class__.__name__ + body = [ + f"Number of datapoints: {self.__len__()}", + ] + body += self.dataset.__repr__().splitlines() + lines = [head] + [" " * self._repr_indent + line for line in body] + return "\n".join(lines) + + +__all__ = ['SingleImageConvDatasetMixin', 'SingleImageConvDataset'] diff --git a/mllm/dataset/single_image_dataset/__init__.py b/mllm/dataset/single_image_dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10ae774affb5a9b6c34e9915679629e5e13d5abe --- /dev/null +++ b/mllm/dataset/single_image_dataset/__init__.py @@ -0,0 +1,13 @@ +from .flickr import FlickrParser, FlickrDataset +from .rec import RECDataset, RECComputeMetrics +from .reg import REGDataset, GCDataset +from .caption import CaptionDataset +from .instr import InstructDataset +from .gqa import GQADataset, GQAComputeMetrics +from .clevr import ClevrDataset +from .point_qa import Point_QA_local, Point_QA_twice, V7W_POINT, PointQAComputeMetrics +from .gpt_gen import GPT4Gen +from .vcr import VCRDataset, VCRPredDataset +from .vqav2 import VQAv2Dataset +from .vqaex import VQAEXDataset +from .pope import POPEVQADataset diff --git a/mllm/dataset/single_image_dataset/__pycache__/__init__.cpython-310.pyc b/mllm/dataset/single_image_dataset/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cd4ded6ced4fbffa49c8809121bd9f12ceb57a5 Binary files /dev/null and b/mllm/dataset/single_image_dataset/__pycache__/__init__.cpython-310.pyc differ diff --git a/mllm/dataset/single_image_dataset/__pycache__/caption.cpython-310.pyc b/mllm/dataset/single_image_dataset/__pycache__/caption.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ea9674ae10b8c176766e9fd5dff57178b2dede5 Binary files /dev/null and b/mllm/dataset/single_image_dataset/__pycache__/caption.cpython-310.pyc differ diff --git a/mllm/dataset/single_image_dataset/__pycache__/clevr.cpython-310.pyc b/mllm/dataset/single_image_dataset/__pycache__/clevr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c6692aa01082a45afc42cb15608ab524e53e7ac Binary files /dev/null and b/mllm/dataset/single_image_dataset/__pycache__/clevr.cpython-310.pyc differ diff --git a/mllm/dataset/single_image_dataset/__pycache__/flickr.cpython-310.pyc b/mllm/dataset/single_image_dataset/__pycache__/flickr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3715df2cc5b1b19194faf290a6ba5205dbbe51cc Binary files /dev/null and b/mllm/dataset/single_image_dataset/__pycache__/flickr.cpython-310.pyc differ diff --git a/mllm/dataset/single_image_dataset/__pycache__/gpt_gen.cpython-310.pyc b/mllm/dataset/single_image_dataset/__pycache__/gpt_gen.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a857a4ba4f8ad1b32a6d2d6ba94bc032c984a2a Binary files /dev/null and b/mllm/dataset/single_image_dataset/__pycache__/gpt_gen.cpython-310.pyc differ diff --git a/mllm/dataset/single_image_dataset/__pycache__/gqa.cpython-310.pyc b/mllm/dataset/single_image_dataset/__pycache__/gqa.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d6af9d9ced3c298de20d421bb9ffc95e26013c0 Binary files /dev/null and b/mllm/dataset/single_image_dataset/__pycache__/gqa.cpython-310.pyc differ diff --git a/mllm/dataset/single_image_dataset/__pycache__/instr.cpython-310.pyc b/mllm/dataset/single_image_dataset/__pycache__/instr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dda5b1c387b1239695b1850f84e24146eb576082 Binary files /dev/null and b/mllm/dataset/single_image_dataset/__pycache__/instr.cpython-310.pyc differ diff --git a/mllm/dataset/single_image_dataset/__pycache__/point_qa.cpython-310.pyc b/mllm/dataset/single_image_dataset/__pycache__/point_qa.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..109374bdd9fc85c21d481e05f7e91eaa0b514dae Binary files /dev/null and b/mllm/dataset/single_image_dataset/__pycache__/point_qa.cpython-310.pyc differ diff --git a/mllm/dataset/single_image_dataset/__pycache__/pope.cpython-310.pyc b/mllm/dataset/single_image_dataset/__pycache__/pope.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ae52b7a28794ae6c80e62bfd2787f1381d0bf11 Binary files /dev/null and b/mllm/dataset/single_image_dataset/__pycache__/pope.cpython-310.pyc differ diff --git a/mllm/dataset/single_image_dataset/__pycache__/rec.cpython-310.pyc b/mllm/dataset/single_image_dataset/__pycache__/rec.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b94433fdad35da8a631f97362eaf89ce4f7df33b Binary files /dev/null and b/mllm/dataset/single_image_dataset/__pycache__/rec.cpython-310.pyc differ diff --git a/mllm/dataset/single_image_dataset/__pycache__/reg.cpython-310.pyc b/mllm/dataset/single_image_dataset/__pycache__/reg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe3cd538826eec1033ee6bcb41f0eff33a6a9f0b Binary files /dev/null and b/mllm/dataset/single_image_dataset/__pycache__/reg.cpython-310.pyc differ diff --git a/mllm/dataset/single_image_dataset/__pycache__/vcr.cpython-310.pyc b/mllm/dataset/single_image_dataset/__pycache__/vcr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..620bcefd3b1b81c4bf9b9f6fc010881e2a1d6fa9 Binary files /dev/null and b/mllm/dataset/single_image_dataset/__pycache__/vcr.cpython-310.pyc differ diff --git a/mllm/dataset/single_image_dataset/__pycache__/vqaex.cpython-310.pyc b/mllm/dataset/single_image_dataset/__pycache__/vqaex.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a554a969e6e08df135dd9189ff48f78981fcd175 Binary files /dev/null and b/mllm/dataset/single_image_dataset/__pycache__/vqaex.cpython-310.pyc differ diff --git a/mllm/dataset/single_image_dataset/__pycache__/vqav2.cpython-310.pyc b/mllm/dataset/single_image_dataset/__pycache__/vqav2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3eeb0d6d558e419c96ec601ec8b8fb5996c05816 Binary files /dev/null and b/mllm/dataset/single_image_dataset/__pycache__/vqav2.cpython-310.pyc differ diff --git a/mllm/dataset/single_image_dataset/caption.py b/mllm/dataset/single_image_dataset/caption.py new file mode 100644 index 0000000000000000000000000000000000000000..5149c9f457caf04b2878f91c38b9cf9a58fb0e06 --- /dev/null +++ b/mllm/dataset/single_image_dataset/caption.py @@ -0,0 +1,31 @@ +from ..root import DATASETS, IMAGE_PLACEHOLDER +from ..utils import MInstrDataset + + +@DATASETS.register_module() +class CaptionDataset(MInstrDataset): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER,)) + + def __getitem__(self, index): + item = self.get_raw_item(index) + img_path = item['img_path'] + caption = item['caption'] + + image = self.get_image(img_path) + question = self.get_template() + + ret = { + 'image': image, + 'conversations': [ + { + 'from': 'human', + 'value': question, + }, + { + 'from': 'gpt', + 'value': caption, + } + ] + } + return ret diff --git a/mllm/dataset/single_image_dataset/clevr.py b/mllm/dataset/single_image_dataset/clevr.py new file mode 100644 index 0000000000000000000000000000000000000000..cc7609a0bcc7fffd736ad8f0cddd6634e84cc872 --- /dev/null +++ b/mllm/dataset/single_image_dataset/clevr.py @@ -0,0 +1,116 @@ +import json + +from ..root import DATASETS, IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER, POINTS_PLACEHOLDER +from ..utils import MInstrDataset + + +@DATASETS.register_module() +class ClevrDataset(MInstrDataset): + def __init__(self, *args, scene_graph_file, version, **kwargs): + super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER)) + self.scene_graph_file = scene_graph_file + self.version = version + qtype, atype = version.split('-') + assert qtype in ['q'] + assert atype in ['a', 's', 'bs'] + self.qtype = qtype + self.atype = atype + + if scene_graph_file is None: + self.scene_graph = None + else: + self.scene_graph = [line for line in open(scene_graph_file, 'r', encoding='utf8')] + + def get_raw_item(self, index): + question = json.loads(self.data[index]) + if self.scene_graph is None: + scene = None + else: + scene = json.loads(self.scene_graph[question['image_index']]) + return question, scene + + def __getitem__(self, index): + question, scene = self.get_raw_item(index) + img_path = question['image_filename'] + image = self.get_image(img_path) + + if self.atype == 'a': + boxes = [] + answer = f"The answer is {question['answer']}." + answer_boxes_seq = [] + elif self.atype == 's': + answer, boxes, answer_boxes_seq = clevr_ss_cot(obj=question, scene=scene, add_ref=False) + answer += f" The answer is {question['answer']}." + elif self.atype == 'bs': + answer, boxes, answer_boxes_seq = clevr_ss_cot(obj=question, scene=scene, add_ref=True) + answer += f" The answer is {question['answer']}." + else: + assert False + + if self.qtype == 'q': + query_boxes_seq = [] + final_query = self.get_template().replace(QUESTION_PLACEHOLDER, question['question']) + else: + assert False + + ret = { + 'image': image, + 'target': {'points': boxes}, + 'conversations': [ + { + 'from': 'human', + 'value': final_query, + 'points_seq': query_boxes_seq, + }, + { + 'from': 'gpt', + 'value': answer, + 'points_seq': answer_boxes_seq, + } + ] + } + return ret + + +def get_boxes_idx(boxes_list, refs): + def get_idx(boxes_list, box): + if box in boxes_list: + return boxes_list.index(box) + else: + boxes_list.append(box) + return len(boxes_list) - 1 + + idx = [get_idx(boxes_list, box) for box in refs] + return idx + + +def clevr_ss_cot(obj, scene, add_ref=False): + cot = [] + boxes = [] + seq = [] + + def can_add_ref(): + if p['function'] in ['unique', 'union', 'intersect', 'relate', 'same_size', 'same_shape', 'same_material', 'same_color']: + return True + if p['function'] in ['scene', 'filter_color', 'filter_material', 'filter_shape', 'filter_size']: + if idx + 1 < len(obj['program']) and obj['program'][idx + 1]['function'] in ['exist', 'count']: + return True + return False + + for idx, p in enumerate(obj['program']): + func = f"{p['function']}:{p['value_inputs'][0]}" if 'value_inputs' in p and p['value_inputs'] else p['function'] + inputs = f"[{','.join(map(str, p['inputs']))}]" if p['inputs'] else "" + + if add_ref and can_add_ref(): + if p['ans']: + objs = POINTS_PLACEHOLDER + idx = get_boxes_idx(boxes_list=boxes, refs=[scene['objects'][_]['pixel_coords'][:2] for _ in p['ans']]) + seq.append(idx) + else: + objs = f" Found no object." + else: + objs = "" + cot.append(f"{func}{inputs}{objs}") + + ret = " -> ".join(cot) + return ret, boxes, seq diff --git a/mllm/dataset/single_image_dataset/flickr.py b/mllm/dataset/single_image_dataset/flickr.py new file mode 100644 index 0000000000000000000000000000000000000000..859ba841d42451e941030cba417eb3a16671065c --- /dev/null +++ b/mllm/dataset/single_image_dataset/flickr.py @@ -0,0 +1,68 @@ +from torch.utils.data import Dataset + +from ..root import DATASETS, BOXES_PLACEHOLDER, IMAGE_PLACEHOLDER +from ..utils import MInstrDataset +from ..utils.flickr30k_entities_utils import ( + flatten_annotation, + PHRASE_ED_PLACEHOLDER, + PHRASE_ST_PLACEHOLDER, +) + + +class FlickrParser(Dataset): + def __init__(self, filename, annotation_dir): + self.filename = filename + self.annotation_dir = annotation_dir + + self.indexes = [line.strip() for line in open(filename, 'r', encoding='utf8')] + self.data = flatten_annotation(self.annotation_dir, self.indexes) + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + return self.data[index] + + def dump(self, filename): + import json + with open(filename, 'w', encoding='utf8') as f: + for obj in self.data: + obj_str = json.dumps(obj) + f.write(obj_str) + f.write('\n') + + +@DATASETS.register_module() +class FlickrDataset(MInstrDataset): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER,)) + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + item = self.get_raw_item(index) + img_path = f"{item['image_id']}.jpg" + caption = item['sentence'] + + image = self.get_image(img_path) + caption = caption.replace(PHRASE_ST_PLACEHOLDER, "").replace(PHRASE_ED_PLACEHOLDER, BOXES_PLACEHOLDER) + question = self.get_template() + + ret = { + 'image': image, + 'target': {'boxes': item['boxes']}, + 'conversations': [ + { + 'from': 'human', + 'value': question, + }, + { + 'from': 'gpt', + 'value': caption, + 'boxes_seq': item['boxes_seq'], + } + ] + } + return ret diff --git a/mllm/dataset/single_image_dataset/gpt_gen.py b/mllm/dataset/single_image_dataset/gpt_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..a2d46121fe6bdd009549fc9d19edab55d158d1f9 --- /dev/null +++ b/mllm/dataset/single_image_dataset/gpt_gen.py @@ -0,0 +1,58 @@ +from ..root import ( + DATASETS, + QUESTION_PLACEHOLDER, + IMAGE_PLACEHOLDER, + BOXES_PLACEHOLDER, +) +from ..utils import MInstrDataset +from ..utils.flickr30k_entities_utils import PHRASE_ST_PLACEHOLDER, PHRASE_ED_PLACEHOLDER + + +@DATASETS.register_module() +class GPT4Gen(MInstrDataset): + def __init__(self, *args, version, **kwargs): + super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER)) + self.version = version + assert version in ['a', 'c', 'bc'] + + def __getitem__(self, item): + raw = self.get_raw_item(item) + # + image = self.get_image(raw['img_path']) + # + boxes = raw['boxes'] + # + question = raw['question'] + question = question.replace(PHRASE_ST_PLACEHOLDER, '').replace(PHRASE_ED_PLACEHOLDER, BOXES_PLACEHOLDER) + final_question = self.get_template().replace(QUESTION_PLACEHOLDER, question) + query_boxes_seq = raw['question_boxes_seq'] + + if self.version == 'a': + final_answer = raw['answer'] + answer_boxes_seq = None + elif self.version == 'c': + final_answer = raw['cot_with_ans'].replace(PHRASE_ST_PLACEHOLDER, '').replace(PHRASE_ED_PLACEHOLDER, '') + answer_boxes_seq = None + elif self.version == 'bc': + final_answer = raw['cot_with_ans'].replace(PHRASE_ST_PLACEHOLDER, '').replace(PHRASE_ED_PLACEHOLDER, BOXES_PLACEHOLDER) + answer_boxes_seq = raw['answer_boxes_seq'] + else: + assert False + + ret = { + 'image': image, + 'target': {'boxes': boxes}, + 'conversations': [ + { + 'from': 'human', + 'value': final_question, + 'boxes_seq': query_boxes_seq, + }, + { + 'from': 'gpt', + 'value': final_answer, + 'boxes_seq': answer_boxes_seq, + } + ] + } + return ret diff --git a/mllm/dataset/single_image_dataset/gqa.py b/mllm/dataset/single_image_dataset/gqa.py new file mode 100644 index 0000000000000000000000000000000000000000..33be7c0de40d5eb165193d76d3f09ab5fed12a6e --- /dev/null +++ b/mllm/dataset/single_image_dataset/gqa.py @@ -0,0 +1,233 @@ +import json +import re + +from ..root import DATASETS, IMAGE_PLACEHOLDER, BOXES_PLACEHOLDER, QUESTION_PLACEHOLDER, METRICS +from ..utils.flickr30k_entities_utils import PHRASE_ST_PLACEHOLDER, PHRASE_ED_PLACEHOLDER +from ..utils import MInstrDataset, BaseComputeMetrics + +REFID_PAT = re.compile(r'(\s\((?:(?:\d+(?:,\d+)*)|-)\)\s?)') +ANS_EXTRACT_PAT = re.compile(r'(?:(?:(?:(?:(?:So t)|(?:T)|(?:t))he answer is)|(?:Answer:)) (.+))') + + +@DATASETS.register_module() +class GQADataset(MInstrDataset): + def __init__( + self, + *args, + scene_graph_file, + scene_graph_index, + version, + question_box_prob=0.5, + **kwargs + ): + super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER)) + self.scene_graph_file = scene_graph_file + self.scene_graph_index = scene_graph_index + self.version = version + self.question_box_prob = question_box_prob + qtype, atype = version.split('-') + assert qtype in ['q', 'qb', 'qbp'] + assert atype in ['a', 'c', 'bc', 's', 'bs', 'l', 'bl'] + self.qtype = qtype + self.atype = atype + + assert bool(scene_graph_file) == bool(scene_graph_index) + if scene_graph_file is not None and scene_graph_index is not None: + self.scene_graph = [line for line in open(scene_graph_file, 'r', encoding='utf8')] + self.scene_index = json.load(open(scene_graph_index, 'r', encoding='utf8')) + else: + self.scene_graph = None + self.scene_index = None + + def get_raw_item(self, index): + question = json.loads(self.data[index]) + if self.scene_graph is None: + return question, None + scene = json.loads(self.scene_graph[self.scene_index[question['imageId']]]) + return question, scene + + def __getitem__(self, index): + question, scene = self.get_raw_item(index) + img_path = f"{question['imageId']}.jpg" + image = self.get_image(img_path) + + # answer + if self.atype == 'bc': + boxes = question['cot']['boxes'] + answer = question['cot']['value'].replace(PHRASE_ST_PLACEHOLDER, "").replace(PHRASE_ED_PLACEHOLDER, BOXES_PLACEHOLDER) + answer_boxes_seq = question['cot']['seq'] + elif self.atype == 'c': + boxes = [] + answer = question['cot']['value'].replace(PHRASE_ST_PLACEHOLDER, "").replace(PHRASE_ED_PLACEHOLDER, "") + answer_boxes_seq = [] + elif self.atype == 'bs': + boxes, bss, answer_boxes_seq = get_bss_example(question, scene) + answer = f"{bss}. The answer is {question['answer']}." + elif self.atype == 's': + boxes = [] + ss = REFID_PAT.sub('', question['semanticStr']) + answer = f"{ss}. The answer is {question['answer']}." + answer_boxes_seq = [] + elif self.atype == 'bl': + boxes, answer, answer_boxes_seq = get_bl_example(question, scene) + elif self.atype == 'l': + boxes = [] + _, answer, _ = get_bl_example(question, scene) + answer = answer.replace(BOXES_PLACEHOLDER, "") + answer_boxes_seq = [] + elif self.atype == 'a': + boxes = [] + answer = f"The answer is {question['answer']}." + answer_boxes_seq = [] + else: + assert False + + # question + if self.qtype == 'q': + boxes, query, query_boxes_seq = prepare_query_dummy(boxes, question, scene) + elif self.qtype == 'qb': + boxes, query, query_boxes_seq = prepare_query_box(boxes, question, scene) + elif self.qtype == 'qbp': + if self.rng.uniform() > self.question_box_prob: + boxes, query, query_boxes_seq = prepare_query_dummy(boxes, question, scene) + else: + boxes, query, query_boxes_seq = prepare_query_box(boxes, question, scene) + else: + assert False + + final_query = self.get_template().replace(QUESTION_PLACEHOLDER, query) + + ret = { + 'image': image, + 'target': {'boxes': boxes}, + 'conversations': [ + { + 'from': 'human', + 'value': final_query, + 'boxes_seq': query_boxes_seq, + }, + { + 'from': 'gpt', + 'value': answer, + 'boxes_seq': answer_boxes_seq, + } + ] + } + return ret + + +def prepare_query_dummy(boxes_list, q, scene): + return boxes_list, q['question'], [] + + +def prepare_query_box(boxes_list, q, scene): + def get_boxes_idx(box): + if box in boxes_list: + return boxes_list.index(box) + else: + boxes_list.append(box) + return len(boxes_list) - 1 + + def add_boxes_by_rids(rids): + def get_box_xyxy(obj): + x, y, w, h = obj['x'], obj['y'], obj['w'], obj['h'] + return x, y, x + w, y + h + + boxes_idx = [] + for rid in rids: + ref = scene['objects'][rid] + ref_box = list(get_box_xyxy(ref)) + boxes_idx.append(get_boxes_idx(ref_box)) + return boxes_idx + + sent = list(q['question'].split()) + query_boxes_seq = [] + for span, rids_str in q['annotations']['question'].items(): + span = tuple(map(int, span.split(':'))) + if len(span) == 1: + span = [span[0], span[0] + 1] + sent[span[1] - 1] = f"{sent[span[1] - 1]}{BOXES_PLACEHOLDER}" + boxes_idx = add_boxes_by_rids(rids_str.split(',')) + query_boxes_seq.append(boxes_idx) + sent_converted = " ".join(sent).strip() + return boxes_list, sent_converted, query_boxes_seq + + +def add_boxes_by_rids(boxes_list, rids, scene): + def get_boxes_idx(boxes_list, box): + if box in boxes_list: + return boxes_list.index(box) + else: + boxes_list.append(box) + return len(boxes_list) - 1 + + def get_box_xyxy(obj): + x, y, w, h = obj['x'], obj['y'], obj['w'], obj['h'] + return x, y, x + w, y + h + + boxes_idx = [] + for rid in rids: + ref = scene['objects'][rid] + ref_box = list(get_box_xyxy(ref)) + boxes_idx.append(get_boxes_idx(boxes_list, ref_box)) + return boxes_idx + + +def get_bss_example(question, scene): + def format_refids(item): + item = item.strip()[1:-1] + return item.split(',') + + s = question['semanticStr'] + print(REFID_PAT.findall(s)) + formats = [] + boxes = [] + seqs = [] + + for item in REFID_PAT.findall(s): + if '-' in item: + formats.append('') + else: + formats.append('') + refids = format_refids(item) + idx = add_boxes_by_rids(boxes, refids, scene) + seqs.append(idx) + answer = REFID_PAT.sub('{}', s).format(*formats) + + print(answer) + print(boxes) + print(seqs) + return boxes, answer, seqs + + +def get_bl_example(ann, scene): + boxes = [] + boxes_seq = [] + + origin_sent = ann['fullAnswer'] + origin_sent = re.sub('(?:^Yes,)|(?:^No,)', '', origin_sent).strip() + sent = list(origin_sent.split()) + for span, rids_str in ann['annotations']['fullAnswer'].items(): + span = tuple(map(int, span.split(':'))) + if len(span) == 1: + span = [span[0], span[0] + 1] + sent[span[1] - 1] = f"{sent[span[1] - 1]}{BOXES_PLACEHOLDER}" + rids = rids_str.split(',') + boxes_idx = add_boxes_by_rids(boxes, rids, scene) + boxes_seq.append(boxes_idx) + + answer = "".join(sent) + answer += f"The answer is {ann['answer']}." + return boxes, answer, boxes_seq + + +@METRICS.register_module() +class GQAComputeMetrics(BaseComputeMetrics): + def extract_ans(self, string: str): + try: + found = ANS_EXTRACT_PAT.findall(string.strip()) + if len(found) != 1: + return None + return found[0].strip().rstrip('.').strip() + except (IndexError, AttributeError): + return None diff --git a/mllm/dataset/single_image_dataset/instr.py b/mllm/dataset/single_image_dataset/instr.py new file mode 100644 index 0000000000000000000000000000000000000000..1954665d14175d0f2331552d70ef5038f838dbae --- /dev/null +++ b/mllm/dataset/single_image_dataset/instr.py @@ -0,0 +1,24 @@ +from ..root import DATASETS +from ..utils import MInstrDataset + + +@DATASETS.register_module() +class InstructDataset(MInstrDataset): + def __init__(self, *args, add_coco_prefix=False, **kwargs): + super().__init__(*args, **kwargs, placeholders=(), template_string='', template_file=None) + self.add_coco_prefix = add_coco_prefix + + def __getitem__(self, index): + item = self.get_raw_item(index) + if self.add_coco_prefix: + img_path = f"COCO_train2014_{item['image']}" + else: + img_path = item['image'] + conversations = item['conversations'] + + image = self.get_image(img_path) + ret = { + 'image': image, + 'conversations': conversations, + } + return ret diff --git a/mllm/dataset/single_image_dataset/point_qa.py b/mllm/dataset/single_image_dataset/point_qa.py new file mode 100644 index 0000000000000000000000000000000000000000..7150099eb060a06a27c062b28f21dd73786f35f2 --- /dev/null +++ b/mllm/dataset/single_image_dataset/point_qa.py @@ -0,0 +1,247 @@ +import re + +from .. import BaseComputeMetrics +from ..root import ( + DATASETS, + METRICS, + QUESTION_PLACEHOLDER, + IMAGE_PLACEHOLDER, + BOXES_PLACEHOLDER, + POINTS_PLACEHOLDER, +) +from ..utils import MInstrDataset + + +# noinspection PyPep8Naming +@DATASETS.register_module() +class Point_QA_local(MInstrDataset): + def __init__(self, *args, version='p', qbp_p_prob=0.5, **kwargs): + super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER)) + assert version in ['b', 'p', 'bp'] + self.version = version + self.qbp_p_prob = qbp_p_prob + + def __getitem__(self, index): + item = self.get_raw_item(index) + # image + img_path = item['file_path'] + image = self.get_image(img_path) + # answer + answer = item['answer'] + # question + question = item['question'] + bbox = item['bbox'] + point = item['point'] + + version = self.version + if version == 'bp': + version = 'p' if self.rng.random() < self.qbp_p_prob else 'b' + if version == 'b': + question = question + BOXES_PLACEHOLDER + query_boxes_seq = [[0]] + query_points_seq = None + elif version == 'p': + question = question + POINTS_PLACEHOLDER + query_boxes_seq = None + query_points_seq = [[0]] + else: + assert False + final_question = self.get_template().replace(QUESTION_PLACEHOLDER, question) + + ret = { + 'image': image, + 'target': { + 'boxes': [bbox], + 'points': [point], + }, + 'conversations': [ + { + 'from': 'human', + 'value': final_question, + 'boxes_seq': query_boxes_seq, + 'points_seq': query_points_seq, + }, + { + 'from': 'gpt', + 'value': f'The answer is {answer} .', + } + ] + } + return ret + + +# noinspection PyPep8Naming +@DATASETS.register_module() +class Point_QA_twice(MInstrDataset): + def __init__(self, *args, version='gq-p', bp_p_prob=0.5, **kwargs): + super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER)) + self.version = version + self.bp_p_prob = bp_p_prob + qtype, rtype = version.split('-') + assert qtype in ['oq', 'sq', 'gq'] + assert rtype in ['b', 'p', 'bp'] + self.qtype = qtype + self.rtype = rtype + + def __getitem__(self, index): + item = self.get_raw_item(index) + # image + img_path = item['file_path'] + image = self.get_image(img_path) + # answer + answer = item['answer'] + # question + bbox = item['bbox'] + point = item['point'] + if self.qtype == 'oq': + question = item['obj_question'] + elif self.qtype == 'sq': + question = item['super_question'] + elif self.qtype == 'gq': + question = item['general_question'] + else: + assert False + rtype = self.rtype + if rtype == 'bp': + rtype = 'p' if self.rng.random() < self.bp_p_prob else 'b' + if rtype == 'p': + question = question + POINTS_PLACEHOLDER + query_boxes_seq = None + query_points_seq = [[0]] + elif rtype == 'b': + question = question + BOXES_PLACEHOLDER + query_boxes_seq = [[0]] + query_points_seq = None + else: + assert False + final_question = self.get_template().replace(QUESTION_PLACEHOLDER, question) + + ret = { + 'image': image, + 'target': { + 'boxes': [bbox], + 'points': [point], + }, + 'conversations': [ + { + 'from': 'human', + 'value': final_question, + 'boxes_seq': query_boxes_seq, + 'points_seq': query_points_seq, + }, + { + 'from': 'gpt', + 'value': f'The answer is {answer} .', + } + ] + } + return ret + + +# noinspection PyPep8Naming +@DATASETS.register_module() +class V7W_POINT(MInstrDataset): + def __init__(self, *args, version, do_shuffle_choice=True, **kwargs): + super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER)) + self.version = version + self.do_shuffle_choice = do_shuffle_choice + assert version in ['p', 'b'] + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + item = self.get_raw_item(index) + # image + img_path = item['file_path'] + image = self.get_image(img_path) + # question + bboxes = item['candidates'] + points = [] + final_question = item['question'] + ' Candidates: ' + " ".join([BOXES_PLACEHOLDER for _ in range(len(bboxes))]) + query_boxes_seq = [] + for _ in range(len(bboxes)): + query_boxes_seq.append([_]) + # answer + if self.version == 'p': + final_question += f" answer in point format." + points.append(item['point']) + final_answer = f"The answer is {POINTS_PLACEHOLDER} ." + answer_boxes_seq = None + answer_points_seq = [[0]] + elif self.version == 'b': + final_question += f" answer in box format." + idx = bboxes.index(item['answer']) + final_answer = f"The answer is {BOXES_PLACEHOLDER} ." + answer_boxes_seq = [[idx]] + answer_points_seq = None + else: + assert False + final_question = self.get_template().replace(QUESTION_PLACEHOLDER, final_question) + if self.do_shuffle_choice: + self.rng.shuffle(query_boxes_seq) + # bboxes, query_boxes_seq, answer_boxes_seq = self.shuffle_boxes(bboxes, query_boxes_seq, answer_boxes_seq) + + ret = { + 'image': image, + 'target': { + 'boxes': bboxes, + 'points': points, + }, + 'conversations': [ + { + 'from': 'human', + 'value': final_question, + 'boxes_seq': query_boxes_seq, + }, + { + 'from': 'gpt', + 'value': final_answer, + 'boxes_seq': answer_boxes_seq, + 'points_seq': answer_points_seq, + + } + ] + } + return ret + + # def shuffle_boxes(self, bboxes, query_boxes_seq, answer_boxes_seq): + # idx_mapping = list(range(len(bboxes))) + # self.rng.shuffle(idx_mapping) + # + # new_bboxes = [None for _ in range(len(bboxes))] + # for idx_old, idx_new in enumerate(idx_mapping): + # new_bboxes[idx_new] = bboxes[idx_old] + # + # if query_boxes_seq is None: + # new_query_boxes_seq = None + # else: + # new_query_boxes_seq = [] + # for boxes in query_boxes_seq: + # new_boxes = [idx_mapping[box_idx] for box_idx in boxes] + # new_query_boxes_seq.append(new_boxes) + # + # if answer_boxes_seq is None: + # new_answer_boxes_seq = None + # else: + # new_answer_boxes_seq = [] + # for boxes in answer_boxes_seq: + # new_boxes = [idx_mapping[box_idx] for box_idx in boxes] + # new_answer_boxes_seq.append(new_boxes) + # + # return new_bboxes, new_query_boxes_seq, new_answer_boxes_seq + + +ANS_EXTRACT_PAT = re.compile(r'(?:The answer is (.+?)\.)') + + +@METRICS.register_module() +class PointQAComputeMetrics(BaseComputeMetrics): + def extract_ans(self, string: str): + try: + found = ANS_EXTRACT_PAT.findall(string.strip()) + if len(found) != 1: + return None + return found[0].strip() + except (IndexError, AttributeError): + return None diff --git a/mllm/dataset/single_image_dataset/pope.py b/mllm/dataset/single_image_dataset/pope.py new file mode 100644 index 0000000000000000000000000000000000000000..648fcfd18c6d9455b5edc676f321a0101e84d328 --- /dev/null +++ b/mllm/dataset/single_image_dataset/pope.py @@ -0,0 +1,36 @@ +from ..root import ( + DATASETS, + QUESTION_PLACEHOLDER, + IMAGE_PLACEHOLDER, +) +from ..utils import MInstrDataset + + +@DATASETS.register_module() +class POPEVQADataset(MInstrDataset): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER)) + + def __getitem__(self, index): + item = self.get_raw_item(index) + image = self.get_image(image_path=item['image']) + + question = item['text'] + final_question = self.get_template().replace(QUESTION_PLACEHOLDER, question) + + label = str(item['label']).lower() + + ret = { + 'image': image, + 'conversations': [ + { + 'from': 'human', + 'value': final_question, + }, + { + 'from': 'gpt', + 'value': f"The answer is {label} .", + }, + ] + } + return ret diff --git a/mllm/dataset/single_image_dataset/rec.py b/mllm/dataset/single_image_dataset/rec.py new file mode 100644 index 0000000000000000000000000000000000000000..4745842ba4121989c8bca85724870222da18de8f --- /dev/null +++ b/mllm/dataset/single_image_dataset/rec.py @@ -0,0 +1,128 @@ +import sys +import logging +import warnings +from typing import Dict, Any, Sequence + +import torch +from torchvision.ops import box_iou + +from ..utils import ( + MInstrDataset, + BaseComputeMetrics, +) + +from ..process_function import ( + BoxFormatter, +) + +from ..root import ( + DATASETS, + METRICS, + IMAGE_PLACEHOLDER, + BOXES_PLACEHOLDER, + EXPR_PLACEHOLDER, +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout), ], +) + + +@DATASETS.register_module() +class RECDataset(MInstrDataset): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, EXPR_PLACEHOLDER)) + + def __getitem__(self, index): + item = self.get_raw_item(index) + img_path = item['img_path'] + expr = item['expression'] + bbox = item['bbox'] + + image = self.get_image(img_path) + question = self.get_template().replace(EXPR_PLACEHOLDER, expr) + + ret = { + 'image': image, + 'target': { + 'boxes': [bbox], + }, + 'conversations': [ + { + 'from': 'human', + 'value': question, + }, + { + 'from': 'gpt', + 'value': f'Answer: {BOXES_PLACEHOLDER} .', + 'boxes_seq': [[0]], + } + ] + } + return ret + + +@METRICS.register_module() +class RECComputeMetrics(BaseComputeMetrics): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.box_formatter: BoxFormatter = self.preprocessor['target']['boxes'] + + def calculate_metric(self, preds: Sequence[str], targets: Sequence[str]) -> Dict[str, Any]: + failed = 0 + target_failed = 0 + + pred_boxes, target_boxes = [], [] + for pred, target in zip(preds, targets): + extract_pred = self.extract_ans(pred) + extract_target = self.extract_ans(target) + if extract_target is None: + target_failed += 1 + logger.warning(f"failed to extract ans for target: {target}") + continue + if extract_pred is None: + failed += 1 + logger.warning(f"failed to extract ans for pred: {pred}") + extract_pred = [0, 0, 0, 0] + target_boxes.append(extract_target) + pred_boxes.append(extract_pred) + + with torch.no_grad(): + target_boxes = torch.tensor(target_boxes) + pred_boxes = torch.tensor(pred_boxes) + # normalized box value is too small, so that the area is 0. + ious = box_iou(pred_boxes * 1000, target_boxes * 1000) + ious = torch.einsum('i i -> i', ious) # take diag elem + # NOTE: please note iou only calculate for success target + iou = ious.mean().item() + correct = (ious > 0.5).sum().item() + + # HACK: currently we expand image to square. so this iou is the real iou. + warn_message = "this iou is calculate on normalized box. just for non-rigorous training progress checking." \ + "the value is consistent with real iou only if image.width == image.height." + warnings.warn(warn_message) + + return { + 'accuracy': 1.0 * correct / len(targets), + 'target_failed': target_failed, + 'failed': failed, + 'iou': iou, + 'warning': warn_message, + } + + def extract_ans(self, string: str): + try: + list_of_boxes = self.box_formatter.extract(string) + if len(list_of_boxes) != 1 or len(list_of_boxes[0]) != 1: + return None + box = list_of_boxes[0][0] + if len(box) != 4: + return None + return box + except Exception as e: + logger.warning(f"extract_ans for {string} but get exception: {e}") + return None diff --git a/mllm/dataset/single_image_dataset/reg.py b/mllm/dataset/single_image_dataset/reg.py new file mode 100644 index 0000000000000000000000000000000000000000..d5886f8795d769729f6a4f2a64f8cd7c93f21b1e --- /dev/null +++ b/mllm/dataset/single_image_dataset/reg.py @@ -0,0 +1,50 @@ +from ..utils import ( + MInstrDataset, +) + +from ..root import ( + DATASETS, + IMAGE_PLACEHOLDER, + BOXES_PLACEHOLDER, + OBJS_PLACEHOLDER, +) + + +@DATASETS.register_module() +class REGDataset(MInstrDataset): + def __init__(self, *args, **kwargs): + super().__init__(*args, placeholders=(IMAGE_PLACEHOLDER, OBJS_PLACEHOLDER), **kwargs) + + def __getitem__(self, index): + item = self.get_raw_item(index) + img_path = item['img_path'] + expr = item['expression'] + bbox = item['bbox'] + + image = self.get_image(img_path) + question = self.get_template().replace(OBJS_PLACEHOLDER, BOXES_PLACEHOLDER) + caption = expr + + ret = { + 'image': image, + 'target': { + 'boxes': [bbox], + }, + 'conversations': [ + { + 'from': 'human', + 'value': question, + 'boxes_seq': [[0]], + }, + { + 'from': 'gpt', + 'value': f'{caption}', + } + ] + } + return ret + + +@DATASETS.register_module() +class GCDataset(REGDataset): + pass diff --git a/mllm/dataset/single_image_dataset/vcr.py b/mllm/dataset/single_image_dataset/vcr.py new file mode 100644 index 0000000000000000000000000000000000000000..e7a619edfd7e8ac4971c727ed1c0a3950e58d444 --- /dev/null +++ b/mllm/dataset/single_image_dataset/vcr.py @@ -0,0 +1,190 @@ +from ..root import ( + DATASETS, + QUESTION_PLACEHOLDER, + IMAGE_PLACEHOLDER, + BOXES_PLACEHOLDER, +) +from ..utils import MInstrDataset + + +def prepare_sentence(sent): + ret_str = [] + ret_box_seq = [] + for word in sent: + if isinstance(word, list): + ret_str.append(BOXES_PLACEHOLDER) + ret_box_seq.append(word) + else: + ret_str.append(word) + return " ".join(ret_str), ret_box_seq + + +def prepare_choice(pack_choices, label_index, *, options='ABCDEFG'): + ret_str = [] + ret_box_seq = [] + for pack, op in zip(pack_choices, options): + ret_str.append(f"({op}) {pack[0]}") + ret_box_seq.extend(pack[1]) + ret_pack = (" ".join(ret_str), ret_box_seq) + label_choice = f"The answer is ({options[label_index]})." + return ret_pack, (label_choice, []) + + +def merge(packs, *, prefixs, postfixs=None): + if postfixs is None: + postfixs = ['' for _ in range(len(packs))] + assert len(packs) == len(prefixs) == len(postfixs), f"{len(packs)},{len(prefixs)},{len(postfixs)}" + ret_str = [] + ret_box_seq = [] + for pack, prefix, postfix in zip(packs, prefixs, postfixs): + if prefix: + ret_str.append(prefix) + ret_str.append(pack[0]) + if postfix: + ret_str.append(postfix) + ret_box_seq.extend(pack[1]) + return " ".join(ret_str), ret_box_seq + + +@DATASETS.register_module() +class VCRDataset(MInstrDataset): + def __init__(self, *args, version, **kwargs): + super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER)) + self.version = version + assert version in [ + 'q-a', 'q-ra', + 'qc-a', 'qc-ra', 'qc-rac', # for evaluation: A + 'qa-r', 'q-a-q-r', + 'qac-r', 'qc-a-qc-r', # for evaluation: R + ] + # for evaluation: + # A: 'qc-a' 'qc-ra' 'qc-rac' + # R: 'qac-r' 'qc-a-qc-r' + + def __getitem__(self, index, force_answer_label=None, force_rationale_label=None): + item = self.get_raw_item(index) + image = self.get_image(item['img_fn']) + + boxes_with_prob = item['boxes'] + boxes = [box[:4] for box in boxes_with_prob] + + question = item['question'] + answer_choices = item['answer_choices'] + rationale_choices = item['rationale_choices'] + if force_answer_label is not None: + answer_label = force_answer_label + else: + answer_label = item['answer_label'] + if force_rationale_label is not None: + rationale_label = force_rationale_label + else: + rationale_label = item['rationale_label'] + + question_pack = prepare_sentence(question) + answer_pack_choices = [prepare_sentence(_) for _ in answer_choices] + rationale_pack_choices = [prepare_sentence(_) for _ in rationale_choices] + + answer_choices_pack, answer_choice = prepare_choice(answer_pack_choices, answer_label) + rationale_choices_pack, rationale_choice = prepare_choice(rationale_pack_choices, rationale_label) + answer_gold_pack = answer_pack_choices[answer_label] + rationale_gold_pack = rationale_pack_choices[rationale_label] + + version = self.version + if version == 'q-a': + final_packs = [ + merge([question_pack], prefixs=['QUESTION:'], ), + answer_gold_pack, + ] + elif version == 'q-ra': + final_packs = [ + merge([question_pack], prefixs=['QUESTION:'], ), + merge([rationale_gold_pack, answer_gold_pack], prefixs=['', '']), + ] + elif version == 'qc-a': + final_packs = [ + merge([question_pack, answer_choices_pack], prefixs=['QUESTION:', '\nOPTIONS:'], postfixs=['', 'You should decide on the best choice and output the corresponding letter.']), + answer_choice, + ] + elif version == 'qc-ra': + final_packs = [ + merge([question_pack, answer_choices_pack], prefixs=['QUESTION:', '\nOPTIONS:'], postfixs=['', 'You should decide on the best choice and output the corresponding letter.']), + merge([rationale_gold_pack, answer_choice], prefixs=['', '']), + ] + elif version == 'qc-rac': + final_packs = [ + merge([question_pack, answer_choices_pack], prefixs=['QUESTION:', '\nOPTIONS:'], postfixs=['', 'You should decide on the best choice and output the corresponding letter.']), + merge([rationale_gold_pack, answer_gold_pack, answer_choice], prefixs=['', '', '']), + ] + elif version == 'qa-r': + final_packs = [ + merge([question_pack, answer_gold_pack], prefixs=['QUESTION:', '\nANSWER:'], postfixs=['', 'You should explain the reason for the above answer.']), + rationale_gold_pack, + ] + elif version == 'qac-r': + final_packs = [ + merge([question_pack, answer_gold_pack, rationale_choices_pack], prefixs=['QUESTION:', '\nANSWER:', '\nRATIONALE OPTIONS:'], postfixs=['', '', 'You should decide on the best choice that explains the above answer and output the corresponding letter.']), + rationale_choice, + ] + elif version == 'q-a-q-r': + final_packs = [ + merge([question_pack], prefixs=['QUESTION:'], ), + answer_gold_pack, + ('You should explain the reason for the above answer.', ()), + rationale_gold_pack, + ] + elif version == 'qc-a-qc-r': + final_packs = [ + merge([question_pack, answer_choices_pack], prefixs=['QUESTION:', '\nOPTIONS:'], postfixs=['', 'You should decide on the best choice and output the corresponding letter.']), + answer_choice, + merge([rationale_choices_pack], prefixs=['RATIONALE OPTIONS:'], postfixs=['You should decide on the best choice that explains the above answer and output the corresponding letter.']), + rationale_choice, + ] + else: + assert False + + conversations = [] + roles = ['human', 'gpt'] + for idx, pack in enumerate(final_packs): + conversations.append({ + 'from': roles[idx % 2], + 'value': pack[0], + 'boxes_seq': pack[1], + }) + conversations[0]['value'] = self.get_template().replace(QUESTION_PLACEHOLDER, conversations[0]['value']) + + ret = { + 'image': image, + 'target': {'boxes': boxes}, + 'conversations': conversations, + } + return ret + + +@DATASETS.register_module() +class VCRPredDataset(VCRDataset): + def __init__(self, *args, version, **kwargs): + super().__init__(*args, version=version, **kwargs) + assert version in [ + 'qc-a', 'qc-ra', 'qc-rac', # for evaluation: A + 'qac-r', 'qc-a-qc-r', # for evaluation: R + ] + self.is_pred_for_r = version in [ + 'qac-r', 'qc-a-qc-r', # for evaluation: R + ] + + def __len__(self): + if self.is_pred_for_r: + return super().__len__() * 4 + else: + return super().__len__() + + # noinspection PyMethodOverriding + def __getitem__(self, index): + if self.is_pred_for_r: + item_index = index // 4 + answer_index = index % 4 + ret = super().__getitem__(item_index, force_answer_label=answer_index, force_rationale_label=0) + else: + ret = super().__getitem__(index, force_answer_label=0, force_rationale_label=0) + ret['conversations'][-1]['value'] += "WARNING: answer and rationale here are just placeholders. we have no real anno." + return ret diff --git a/mllm/dataset/single_image_dataset/vqaex.py b/mllm/dataset/single_image_dataset/vqaex.py new file mode 100644 index 0000000000000000000000000000000000000000..1e74d78583dc8e8e0ce809314af82d4faaf0691a --- /dev/null +++ b/mllm/dataset/single_image_dataset/vqaex.py @@ -0,0 +1,48 @@ +from ..root import ( + DATASETS, + QUESTION_PLACEHOLDER, + IMAGE_PLACEHOLDER, +) +from ..utils import MInstrDataset + + +@DATASETS.register_module() +class VQAEXDataset(MInstrDataset): + def __init__(self, *args, is_e_dataset: bool, has_annotation=True, **kwargs): + super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER)) + self.has_annotation = has_annotation + self.is_e_dataset = is_e_dataset + + def __getitem__(self, index): + item = self.get_raw_item(index) + image = self.get_image(image_path=item['file_path']) + + question = item['question'] + final_question = self.get_template().replace(QUESTION_PLACEHOLDER, question) + + if self.has_annotation: + if self.is_e_dataset: + final_answer = "" + final_answer += item['explanation'][0] + final_answer += f" So the answer is {item['multiple_answers']}." + else: + final_answer = "" + final_answer += "".join(item['justification']) + final_answer += f" So the answer is {item['multiple_choice_answer']}." + else: + final_answer = 'UNKNOWN' + + ret = { + 'image': image, + 'conversations': [ + { + 'from': 'human', + 'value': final_question, + }, + { + 'from': 'gpt', + 'value': final_answer, + }, + ] + } + return ret diff --git a/mllm/dataset/single_image_dataset/vqav2.py b/mllm/dataset/single_image_dataset/vqav2.py new file mode 100644 index 0000000000000000000000000000000000000000..6aee67f5e21e2e972e62eafd652cff2401234ace --- /dev/null +++ b/mllm/dataset/single_image_dataset/vqav2.py @@ -0,0 +1,40 @@ +from ..root import ( + DATASETS, + QUESTION_PLACEHOLDER, + IMAGE_PLACEHOLDER, +) +from ..utils import MInstrDataset + + +@DATASETS.register_module() +class VQAv2Dataset(MInstrDataset): + def __init__(self, *args, has_annotation=True, **kwargs): + super().__init__(*args, **kwargs, placeholders=(IMAGE_PLACEHOLDER, QUESTION_PLACEHOLDER)) + self.has_annotation = has_annotation + + def __getitem__(self, index): + item = self.get_raw_item(index) + image = self.get_image(image_path=item['image_path']) + + question = item['question'] + final_question = self.get_template().replace(QUESTION_PLACEHOLDER, question) + + if self.has_annotation: + final_answer = item['annotation']['multiple_choice_answer'] + else: + final_answer = 'UNKNOWN' + + ret = { + 'image': image, + 'conversations': [ + { + 'from': 'human', + 'value': final_question, + }, + { + 'from': 'gpt', + 'value': f"The answer is {final_answer}.", + }, + ] + } + return ret diff --git a/mllm/dataset/single_image_interactive.py b/mllm/dataset/single_image_interactive.py new file mode 100644 index 0000000000000000000000000000000000000000..22256debe5748a364fa3a489fbc482192b301247 --- /dev/null +++ b/mllm/dataset/single_image_interactive.py @@ -0,0 +1,122 @@ +import copy +from typing import Optional + +from PIL import Image + +from .single_image_convsation import SingleImageConvDatasetMixin + + +class SingleImageInteractive(SingleImageConvDatasetMixin): + _printed_sample = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.image: Optional[Image.Image] = None + self.roles = ('human', 'gpt') + self.boxes = [] + self.points = [] + self.raw_conv = [] + self.conversations = [] + + def set_image(self, image: Image.Image): + assert self.image is None, f"{image}" + self.image = image + + def append_message(self, role: str, message: str, *, boxes=None, points=None, boxes_seq=None, points_seq=None): + """Append a new message.""" + assert role in self.roles + + def convert_idx(objs_seq, objs_value, get_obj_idx_func): + if objs_seq is None: + return None + ret = [] + for objs_idx in objs_seq: + new_objs_idx = [] + for idx in objs_idx: + new_idx = get_obj_idx_func(objs_value[idx]) + new_objs_idx.append(new_idx) + ret.append(tuple(new_objs_idx)) + return tuple(ret) + + boxes_seq = convert_idx(boxes_seq, boxes, self._get_box_idx) + points_seq = convert_idx(points_seq, points, self._get_point_idx) + + if self.image is not None: + previous_message_has_image_placeholder = any( + '' in item['value'] for item in self.conversations + ) + if not previous_message_has_image_placeholder and '' not in message: + message = ' ' + message + if previous_message_has_image_placeholder and '' in message: + message = message.replace('', '') + + self.conversations.append( + { + 'from': role, + 'value': message, + 'boxes_seq': copy.deepcopy(boxes_seq), + 'points_seq': copy.deepcopy(points_seq), + } + ) + + def get_raw_item(self, index=None): + ret = copy.deepcopy({ + 'image': self.image, + 'target': { + 'boxes': self.boxes, + 'points': self.points, + }, + 'conversations': self.conversations, + }) + assert ret['conversations'][0]['from'] == self.roles[0] + if ret['conversations'][-1]['from'] == self.roles[0]: + ret['conversations'].append( + { + 'from': self.roles[1], + 'value': '', + } + ) + return ret + + def to_model_input(self): + item = self.__getitem__(0) + ret = {'input_ids': item['input_ids'].unsqueeze(0).cuda()} + if 'image' in item and item['image'] is not None: + ret['images'] = item['image'].unsqueeze(0).cuda() + else: + ret['images'] = None + return ret + + def to_gradio_chatbot_new_messages(self): + conv = self.__getitem__(0, return_conv=True) + new_messages = conv.messages[-2:] + ret_messages = [] + for r, m in new_messages: + nm = m.replace('', '').replace('', '').replace('', '') + ret_messages.append((r, nm)) + return ret_messages + + def _get_box_idx(self, box): + assert isinstance(box, (tuple, list)), f"{type(box)}" + assert isinstance(box[0], (int, float)), f"{type(box[0])}" + assert len(box) == 4 + box = tuple(box) + if box not in self.boxes: + self.boxes.append(box) + return len(self.boxes) - 1 + else: + return self.boxes.index(box) + + def _get_point_idx(self, point): + assert isinstance(point, (tuple, list)) + assert isinstance(point[0], (int, float)) + assert len(point) == 2 + point = tuple(point) + if point not in self.points: + self.points.append(tuple(point)) + return len(self.points) - 1 + else: + return self.points.index(point) + + def __len__(self): + return 1 diff --git a/mllm/dataset/utils/__init__.py b/mllm/dataset/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..29286a40b81b9faa27893f7245a3ba3e83d1000a --- /dev/null +++ b/mllm/dataset/utils/__init__.py @@ -0,0 +1,5 @@ +from .io import read_img_general, init_ceph_client_if_needed +from .transform import Expand2square, de_norm_box_xyxy, norm_box_xyxy, expand2square, box_xywh_to_xyxy +from .compute_metrics import BaseComputeMetrics +from .mixin import QuestionTemplateMixin, MInstrDataset +from .concatenate_dataset import ConcatDataset, InterleaveDateset, SubSet, ConcatDatasetWithShuffle diff --git a/mllm/dataset/utils/__pycache__/__init__.cpython-310.pyc b/mllm/dataset/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fecb1aeb33fedf927bef987799ea8969ba30bcf7 Binary files /dev/null and b/mllm/dataset/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/mllm/dataset/utils/__pycache__/compute_metrics.cpython-310.pyc b/mllm/dataset/utils/__pycache__/compute_metrics.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60ae90a197c09c479d236b16b681e77c4fd6c4f1 Binary files /dev/null and b/mllm/dataset/utils/__pycache__/compute_metrics.cpython-310.pyc differ diff --git a/mllm/dataset/utils/__pycache__/concatenate_dataset.cpython-310.pyc b/mllm/dataset/utils/__pycache__/concatenate_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ee636a20467c479f563bd00a231677c90e82248 Binary files /dev/null and b/mllm/dataset/utils/__pycache__/concatenate_dataset.cpython-310.pyc differ diff --git a/mllm/dataset/utils/__pycache__/flickr30k_entities_utils.cpython-310.pyc b/mllm/dataset/utils/__pycache__/flickr30k_entities_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c531294a53db1bbdb495eec109f4cb1d51b1d82a Binary files /dev/null and b/mllm/dataset/utils/__pycache__/flickr30k_entities_utils.cpython-310.pyc differ diff --git a/mllm/dataset/utils/__pycache__/io.cpython-310.pyc b/mllm/dataset/utils/__pycache__/io.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d63ba3ed1109c9547a6188548ba5584ef97789d8 Binary files /dev/null and b/mllm/dataset/utils/__pycache__/io.cpython-310.pyc differ diff --git a/mllm/dataset/utils/__pycache__/mixin.cpython-310.pyc b/mllm/dataset/utils/__pycache__/mixin.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64421127f996816dd450325f4081e91e1b8921bd Binary files /dev/null and b/mllm/dataset/utils/__pycache__/mixin.cpython-310.pyc differ diff --git a/mllm/dataset/utils/__pycache__/transform.cpython-310.pyc b/mllm/dataset/utils/__pycache__/transform.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2a70684c6c4a1dddc9c9df4e799f6bed24c91a5 Binary files /dev/null and b/mllm/dataset/utils/__pycache__/transform.cpython-310.pyc differ diff --git a/mllm/dataset/utils/compute_metrics.py b/mllm/dataset/utils/compute_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..dc67a0d9efac0a24fbb11a7f97117c1d495b58b5 --- /dev/null +++ b/mllm/dataset/utils/compute_metrics.py @@ -0,0 +1,53 @@ +import sys +import logging +from typing import Dict, Any, Sequence + +from transformers import EvalPrediction + +from ...utils import decode_generate_ids + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout), ], +) + + +class BaseComputeMetrics: + def __init__(self, preprocessor: Dict[str, Any]): + self.preprocessor = preprocessor + self.tokenizer = self.preprocessor['text'] + + def __call__(self, eval_preds: EvalPrediction) -> Dict[str, Any]: + preds, targets = eval_preds + logger.warning(f"preds shape: {preds.shape}. targets shape: {targets.shape}") + preds = decode_generate_ids(self.tokenizer, preds) + targets = decode_generate_ids(self.tokenizer, targets) + assert len(preds) == len(targets) + return self.calculate_metric(preds, targets) + + def calculate_metric(self, preds: Sequence[str], targets: Sequence[str]) -> Dict[str, Any]: + correct = 0 + failed = 0 + target_failed = 0 + for pred, target in zip(preds, targets): + extract_pred = self.extract_ans(pred) + extract_target = self.extract_ans(target) + if extract_target is None: + target_failed += 1 + logger.warning(f"failed to extract ans from target. maybe the response string is truncated: {target}.") + continue + if extract_pred is None: + failed += 1 + if extract_pred == extract_target: + correct += 1 + return { + 'accuracy': 1.0 * correct / len(targets), + 'target_failed': target_failed, + 'failed': failed, + } + + def extract_ans(self, string: str): + raise NotImplementedError diff --git a/mllm/dataset/utils/concatenate_dataset.py b/mllm/dataset/utils/concatenate_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..739e3373bc05e687195ea4a74c75853fc644d4b3 --- /dev/null +++ b/mllm/dataset/utils/concatenate_dataset.py @@ -0,0 +1,192 @@ +from typing import List, Optional, Literal + +import numpy as np +from torch.utils.data import Dataset +from torch.utils.data import ConcatDataset as TorchConcatDataset +from torch.utils.data import Subset as TorchSubset + +from ..root import DATASETS + + +@DATASETS.register_module() +class ConcatDataset(Dataset): + _repr_indent = 4 + + def __init__(self, cfgs): + self.cfgs = cfgs + datasets = [DATASETS.build(cfg) for cfg in cfgs] + self.concat_dataset = TorchConcatDataset(datasets) + + def __len__(self): + return len(self.concat_dataset) + + def __getitem__(self, index): + return self.concat_dataset[index] + + def __repr__(self) -> str: + head = "Dataset " + self.__class__.__name__ + body = [ + f"Number of datapoints: {self.__len__()}", + ] + for i, ds in enumerate(self.concat_dataset.datasets): + body.append(f"Subset {i + 1}/{len(self.concat_dataset.datasets)}") + body += ds.__repr__().splitlines() + lines = [head] + [" " * self._repr_indent + line for line in body] + return "\n".join(lines) + + +@DATASETS.register_module() +class InterleaveDateset(Dataset): + _repr_indent = 4 + + def __init__( + self, + cfgs, + probabilities: Optional[List[float]] = None, + seed: Optional[int] = 42, + stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted", + ): + self.cfgs = cfgs + self.probabilities = probabilities + self.seed = seed + self.stopping_strategy = stopping_strategy + + datasets = [DATASETS.build(cfg) for cfg in cfgs] + self.concat_dataset = TorchConcatDataset(datasets) + + self.index_mapping = _interleave_dataset_index( + lengths=[len(ds) for ds in datasets], + probabilities=probabilities, + seed=seed, + stopping_strategy=stopping_strategy, + ) + + def __len__(self): + return len(self.index_mapping) + + def __getitem__(self, index): + return self.concat_dataset[self.index_mapping[index]] + + def __repr__(self) -> str: + head = "Dataset " + self.__class__.__name__ + body = [ + f"Number of datapoints: {self.__len__()}", + f"Probabilities: {self.probabilities}", + f"stopping_strategy: {self.stopping_strategy}", + f"seed: {self.seed}", + ] + for i, ds in enumerate(self.concat_dataset.datasets): + body.append(f"Subset {i + 1}/{len(self.concat_dataset.datasets)}") + body += ds.__repr__().splitlines() + lines = [head] + [" " * self._repr_indent + line for line in body] + return "\n".join(lines) + + +# stolen from huggingface/datasets +# https://github.com/huggingface/datasets/blob/074925b9b7c1dfd33b8675aa99c07cc26375665c/src/datasets/arrow_dataset.py#L5987 +def _interleave_dataset_index( + *, + lengths: List[int], + probabilities: Optional[List[float]] = None, + seed: Optional[int] = None, + stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted", +): + if probabilities is not None and 0 in probabilities: + assert stopping_strategy == 'first_exhausted', "you will meet a Infinite loop" + # Let's now build the indices to pass to .select() + offsets = np.cumsum([0] + lengths[:-1]) + + # if stopping_strategy is "first_exhausted", it is an undersampling situation whereas it is an oversampling situation if it is "all_exhausted" + oversampling = stopping_strategy == "all_exhausted" + + if probabilities is None and not oversampling: + # Undersampling situation with cycling between each sources + # Example:: If lengths of the datasets are [3, 4, 5] + # Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 6, 9] + # Note that we only have 3 examples per dataset since the first dataset ran out of examples + + # Reasoning behind the following operation: keeping the min_length first indices of each dataset + # while offsetting in order to correspond to the right indices of the concatenated dataset + # and flattening to effectively interleave the datasets + indices = (offsets.reshape(1, -1) + np.arange(min(lengths)).reshape(-1, 1)).flatten().tolist() + elif probabilities is None: + # Oversampling situation with cycling between each sources + # Then the resulting indices should be [0, 3, 7, 1, 4, 8, 2, 5, 9, 0, 6, 10, 1, 3, 11] + # Note that we have 5 examples per dataset with a rolling window since the longest dataset has 5 samples + + # Reasoning behind the following operation: for each dataset indices (i.e column) repeat the indices to have max_length indices per dataset + # For example, if the max_length is 5 and the i-th dataset has 3 samples, the i-th column will be [0,1,2,0,1] + indices = np.mod(np.arange(max(lengths)).reshape(-1, 1), np.array(lengths).reshape(1, -1)) + + # We have to keep the indices to their respective dataset offsets and to flatten to effectively interleave the datasets + indices = (indices + offsets).flatten().tolist() + + else: + # boolean array indicating if at index i if the dataset_i has been fully exhausted + is_exhausted = np.full(len(lengths), False) + + # if undersampling ("first_exhausted"), we stop as soon as one dataset is exhausted + # if oversampling ("all_exhausted"), we stop as soons as every dataset is exhausted, i.e as soon as every samples of every dataset has been visited at least once + bool_strategy_func = np.all if oversampling else np.any + + def iter_random_indices(): + """Get an infinite iterator that randomly samples the index of the source to pick examples from.""" + rng = np.random.default_rng(seed) + while True: + yield from (int(i) for i in rng.choice(len(lengths), size=1000, p=probabilities)) + + current_index = [0] * len(lengths) + indices = [] + for source_idx in iter_random_indices(): + # If no oversampling, we stop as soon as a dataset has ran out of examples (np.any) + # Otherwise, we stop as soon as every dataset has ran out of examples (np.all) + if bool_strategy_func(is_exhausted): + # the stopping condition was reached, let's stop + break + + # let's add the example at the current index of the `source_idx`-th dataset + indices.append(current_index[source_idx] + offsets[source_idx]) + current_index[source_idx] += 1 + + # we've ran out of examples for the current dataset, let's update our boolean array and bring the current_index back to 0 + if current_index[source_idx] >= lengths[source_idx]: + is_exhausted[source_idx] = True + current_index[source_idx] = 0 + return indices + + +@DATASETS.register_module() +class SubSet(TorchSubset): + def __init__(self, cfg, portion, do_shuffle=True, seed=42): + assert 0 < portion <= 1 + dataset = DATASETS.build(cfg=cfg) + target_len = int(len(dataset) * portion) + if do_shuffle: + rng = np.random.default_rng(seed) + indices = list(range(len(dataset))) + rng.shuffle(indices) + indices = indices[:target_len] + else: + indices = list(range(target_len)) + super().__init__(dataset, indices) + + +@DATASETS.register_module() +class ConcatDatasetWithShuffle(TorchSubset): + _repr_indent = 4 + + def __init__(self, cfgs, seed=42, portion=1): + self.cfgs = cfgs + self.seed = seed + self.portion = portion + dataset = TorchConcatDataset([DATASETS.build(cfg) for cfg in cfgs]) + + target_len = int(len(dataset) * portion) + indices = list(range(len(dataset))) * int(np.ceil(portion)) + rng = np.random.default_rng(seed) + rng.shuffle(indices) + indices = indices[:target_len] + super().__init__(dataset, indices) + + +__all__ = ['ConcatDataset', 'InterleaveDateset', 'SubSet', 'ConcatDatasetWithShuffle'] diff --git a/mllm/dataset/utils/flickr30k_entities_utils.py b/mllm/dataset/utils/flickr30k_entities_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0d49979fdf5688770bab532486f3c96f445ffb42 --- /dev/null +++ b/mllm/dataset/utils/flickr30k_entities_utils.py @@ -0,0 +1,205 @@ +import os +import re +import xml.etree.ElementTree as ET +from typing import Dict, List + +from tqdm import tqdm + + +def get_sentence_data(fn): + """ + Parses a sentence file from the Flickr30K Entities dataset + + input: + fn - full file path to the sentence file to parse + + output: + a list of dictionaries for each sentence with the following fields: + sentence - the original sentence + phrases - a list of dictionaries for each phrase with the + following fields: + phrase - the text of the annotated phrase + first_word_index - the position of the first word of + the phrase in the sentence + phrase_id - an identifier for this phrase + phrase_type - a list of the coarse categories this + phrase belongs to + + """ + with open(fn, 'r', encoding='utf8') as f: + sentences = f.read().split('\n') + + annotations = [] + for sentence in sentences: + if not sentence: + continue + + first_word = [] + phrases = [] + phrase_id = [] + phrase_type = [] + words = [] + current_phrase = [] + add_to_phrase = False + for token in sentence.split(): + if add_to_phrase: + if token[-1] == ']': + add_to_phrase = False + token = token[:-1] + current_phrase.append(token) + phrases.append(' '.join(current_phrase)) + current_phrase = [] + else: + current_phrase.append(token) + + words.append(token) + else: + if token[0] == '[': + add_to_phrase = True + first_word.append(len(words)) + parts = token.split('/') + phrase_id.append(parts[1][3:]) + phrase_type.append(parts[2:]) + else: + words.append(token) + + sentence_data = {'sentence': ' '.join(words), 'phrases': []} + for index, phrase, p_id, p_type in zip(first_word, phrases, phrase_id, phrase_type): + sentence_data['phrases'].append({'first_word_index': index, + 'phrase': phrase, + 'phrase_id': p_id, + 'phrase_type': p_type}) + + annotations.append(sentence_data) + + return annotations + + +def get_annotations(fn): + """ + Parses the xml files in the Flickr30K Entities dataset + + input: + fn - full file path to the annotations file to parse + + output: + dictionary with the following fields: + scene - list of identifiers which were annotated as + pertaining to the whole scene + nobox - list of identifiers which were annotated as + not being visible in the image + boxes - a dictionary where the fields are identifiers + and the values are its list of boxes in the + [xmin ymin xmax ymax] format + """ + tree = ET.parse(fn) + root = tree.getroot() + size_container = root.findall('size')[0] + anno_info = {'boxes': {}, 'scene': [], 'nobox': []} + for size_element in size_container: + anno_info[size_element.tag] = int(size_element.text) + + for object_container in root.findall('object'): + for names in object_container.findall('name'): + box_id = names.text + box_container = object_container.findall('bndbox') + if len(box_container) > 0: + if box_id not in anno_info['boxes']: + anno_info['boxes'][box_id] = [] + xmin = int(box_container[0].findall('xmin')[0].text) - 1 + ymin = int(box_container[0].findall('ymin')[0].text) - 1 + xmax = int(box_container[0].findall('xmax')[0].text) - 1 + ymax = int(box_container[0].findall('ymax')[0].text) - 1 + anno_info['boxes'][box_id].append([xmin, ymin, xmax, ymax]) + else: + nobndbox = int(object_container.findall('nobndbox')[0].text) + if nobndbox > 0: + anno_info['nobox'].append(box_id) + + scene = int(object_container.findall('scene')[0].text) + if scene > 0: + anno_info['scene'].append(box_id) + + return anno_info + + +def get_ann_path(idx, *, annotation_dir=""): + return os.path.join(annotation_dir, rf'Annotations/{idx}.xml') + + +def get_sen_path(idx, *, annotation_dir=""): + return os.path.join(annotation_dir, rf"Sentences/{idx}.txt") + + +def get_img_path(idx, *, image_dir=""): + return os.path.join(image_dir, rf'{idx}.jpg') + + +PHRASE_ST_PLACEHOLDER = '' +PHRASE_ED_PLACEHOLDER = '' + + +def flatten_annotation(annotation_dir, indexes): + data = [] + + for index in tqdm(indexes): + image_id = index + ann_path = get_ann_path(index, annotation_dir=annotation_dir) + sen_path = get_sen_path(index, annotation_dir=annotation_dir) + anns = get_annotations(ann_path) + sens = get_sentence_data(sen_path) + + for sen in sens: + pids = list(set(phrase['phrase_id'] for phrase in sen['phrases'] if phrase['phrase_id'] in anns['boxes'])) + boxes_mapping: Dict[str, List[int]] = {} + boxes_filtered: List[List[int]] = [] + for pid in pids: + v = anns['boxes'][pid] + mapping = [] + for box in v: + mapping.append(len(boxes_filtered)) + boxes_filtered.append(box) + boxes_mapping[pid] = mapping + + boxes_seq: List[List[int]] = [] + for phrase in sen['phrases']: + if not phrase['phrase_id'] in anns['boxes']: + continue + pid = phrase['phrase_id'] + boxes_seq.append(boxes_mapping[pid]) + + sent = list(sen['sentence'].split()) + for phrase in sen['phrases'][::-1]: + if not phrase['phrase_id'] in anns['boxes']: + continue + span = [phrase['first_word_index'], phrase['first_word_index'] + len(phrase['phrase'].split())] + sent[span[0]:span[1]] = [f"{PHRASE_ST_PLACEHOLDER}{' '.join(sent[span[0]:span[1]])}{PHRASE_ED_PLACEHOLDER}"] + sent_converted = " ".join(sent) + + assert len(re.findall(PHRASE_ST_PLACEHOLDER, sent_converted)) \ + == len(re.findall(PHRASE_ED_PLACEHOLDER, sent_converted)) \ + == len(boxes_seq), f"error when parse: {sent_converted}, {boxes_seq}, {sen}, {anns}" + assert sent_converted.replace(PHRASE_ST_PLACEHOLDER, "").replace(PHRASE_ED_PLACEHOLDER, "") == sen['sentence'] + + item = { + 'id': len(data), + 'image_id': image_id, + 'boxes': boxes_filtered, + 'sentence': sent_converted, + 'boxes_seq': boxes_seq, + } + data.append(item) + + return data + + +if __name__ == '__main__': + filenames = [ + r'D:\home\dataset\flickr30kentities\train.txt', + r'D:\home\dataset\flickr30kentities\val.txt', + r'D:\home\dataset\flickr30kentities\test.txt', + ] + for filename in filenames: + annotation_dir = r'D:\home\dataset\flickr30kentities' + indexes = [line.strip() for line in open(filename, 'r', encoding='utf8')] + flatten_annotation(annotation_dir, indexes) diff --git a/mllm/dataset/utils/io.py b/mllm/dataset/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..b28d9a08f0833582a3f3fa9e78503884a1e06d87 --- /dev/null +++ b/mllm/dataset/utils/io.py @@ -0,0 +1,49 @@ +import sys +import time +import logging + +import cv2 +import numpy as np +from PIL import Image + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout), ], +) + + +def read_img_general(img_path): + if "s3://" in img_path: + cv_img = read_img_ceph(img_path) + # noinspection PyUnresolvedReferences + return Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)) + else: + return Image.open(img_path).convert('RGB') + + +client = None + + +def read_img_ceph(img_path): + init_ceph_client_if_needed() + img_bytes = client.get(img_path) + assert img_bytes is not None, f"Please check image at {img_path}" + img_mem_view = memoryview(img_bytes) + img_array = np.frombuffer(img_mem_view, np.uint8) + # noinspection PyUnresolvedReferences + img = cv2.imdecode(img_array, cv2.IMREAD_COLOR) + return img + + +def init_ceph_client_if_needed(): + global client + if client is None: + logger.info(f"initializing ceph client ...") + st = time.time() + from petrel_client.client import Client # noqa + client = Client(enable_mc=True) + ed = time.time() + logger.info(f"initialize client cost {ed - st:.2f} s") \ No newline at end of file diff --git a/mllm/dataset/utils/mixin.py b/mllm/dataset/utils/mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..6814ee09884aad4a775b26d011f5f7b3cc89801e --- /dev/null +++ b/mllm/dataset/utils/mixin.py @@ -0,0 +1,98 @@ +import json +import os + +import numpy as np +from torch.utils.data import Dataset + +from .io import read_img_general + + +class QuestionTemplateMixin: + def __init__( + self, + *args, + template_string=None, + template_file=None, + max_dynamic_size=None, + placeholders=None, + **kwargs + ): + super().__init__(*args, **kwargs) + self.template_string = template_string + self.template_file = template_file + self.max_dynamic_size = max_dynamic_size + self.placeholders = placeholders + if template_string is None and template_file is None: + raise ValueError("assign either template_string or template_file") + if template_string is not None and template_file is not None: + raise ValueError(f"assign both template_string and template_file:\nstring:{template_string}\nfile:{template_file}") + if template_string is not None: + self.templates = [self.template_string] + else: + assert template_file is not None + self.templates = json.load(open(template_file, 'r', encoding='utf8')) + if self.max_dynamic_size is not None: + self.templates = self.templates[: self.max_dynamic_size] + + # sanity check + assert self.placeholders is not None + for template in self.templates: + for placeholder in placeholders: + assert str(template).count(placeholder) == 1, f"template: {template}\nplaceholder:{placeholder}" + + def get_template(self): + import random + return random.choice(self.templates) + + def template_nums(self): + return len(self.templates) + + +class MInstrDataset(QuestionTemplateMixin, Dataset): + _repr_indent = 4 + + def __init__(self, filename, image_folder=None, seed=None, **kwargs): + super().__init__(**kwargs) + self.filename = filename + self.image_folder = image_folder + self.rng = np.random.default_rng(seed) + + self.data = [] + with open(filename, 'r', encoding='utf8') as f: + # for line in tqdm(f, desc=f'{self.__class__.__name__} loading ann {self.filename}'): + for line in f: + self.data.append(line) + + def get_raw_item(self, index): + return json.loads(self.data[index]) + + def get_image(self, image_path): + if self.image_folder is not None: + image_path = os.path.join(self.image_folder, image_path) + image = read_img_general(image_path) + return image + + def get_template(self): + return self.rng.choice(self.templates) + + def __getitem__(self, index): + raise NotImplementedError + + def __len__(self): + return len(self.data) + + def __repr__(self) -> str: + head = "Dataset " + self.__class__.__name__ + body = [ + f"Number of datapoints: {self.__len__()}", + f"ann file: {self.filename}" + ] + if self.image_folder is not None: + body.append(f"image folder: {self.image_folder}") + body += self.extra_repr().splitlines() + lines = [head] + [" " * self._repr_indent + line for line in body] + return "\n".join(lines) + + # noinspection PyMethodMayBeStatic + def extra_repr(self) -> str: + return "" diff --git a/mllm/dataset/utils/transform.py b/mllm/dataset/utils/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..8a3b3dacaf171faf8ceffb617ea02bda4a023e07 --- /dev/null +++ b/mllm/dataset/utils/transform.py @@ -0,0 +1,106 @@ +from typing import Dict, Any, Tuple, Optional + +from PIL import Image + +from ..root import TRANSFORMS + + +def de_norm_box_xyxy(box, *, w, h): + x1, y1, x2, y2 = box + x1 = x1 * w + x2 = x2 * w + y1 = y1 * h + y2 = y2 * h + box = x1, y1, x2, y2 + return box + + +def box_xywh_to_xyxy(box, *, w=None, h=None): + x, y, bw, bh = box + x2 = x + bw + y2 = y + bh + if w is not None: + x2 = min(x2, w) + if h is not None: + y2 = min(y2, h) + box = x, y, x2, y2 + return box + + +def norm_box_xyxy(box, *, w, h): + x1, y1, x2, y2 = box + + # Calculate the normalized coordinates with min-max clamping + norm_x1 = max(0.0, min(x1 / w, 1.0)) + norm_y1 = max(0.0, min(y1 / h, 1.0)) + norm_x2 = max(0.0, min(x2 / w, 1.0)) + norm_y2 = max(0.0, min(y2 / h, 1.0)) + + # Return the normalized box coordinates + normalized_box = (round(norm_x1, 3), round(norm_y1, 3), round(norm_x2, 3), round(norm_y2, 3)) + return normalized_box + + +def norm_point_xyxy(point, *, w, h): + x, y = point + norm_x = max(0.0, min(x / w, 1.0)) + norm_y = max(0.0, min(y / h, 1.0)) + point = norm_x, norm_y + return point + + +def expand2square(pil_img, background_color=(255, 255, 255)): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +def box_xyxy_expand2square(box, *, w, h): + if w == h: + return box + if w > h: + x1, y1, x2, y2 = box + y1 += (w - h) // 2 + y2 += (w - h) // 2 + box = x1, y1, x2, y2 + return box + assert w < h + x1, y1, x2, y2 = box + x1 += (h - w) // 2 + x2 += (h - w) // 2 + box = x1, y1, x2, y2 + return box + + +def point_xy_expand2square(point, *, w, h): + pseudo_box = (point[0], point[1], point[0], point[1]) + expanded_box = box_xyxy_expand2square(box=pseudo_box, w=w, h=h) + expanded_point = (expanded_box[0], expanded_box[1]) + return expanded_point + + +@TRANSFORMS.register_module() +class Expand2square: + def __init__(self, background_color=(255, 255, 255)): + self.background_color = background_color + + def __call__(self, image: Image.Image, labels: Dict[str, Any] = None) -> Tuple[Image.Image, Optional[Dict[str, Any]]]: + width, height = image.size + processed_image = expand2square(image, background_color=self.background_color) + if labels is None: + return processed_image, labels + if 'boxes' in labels: + bboxes = [box_xyxy_expand2square(bbox, w=width, h=height) for bbox in labels['boxes']] + labels['boxes'] = bboxes + if 'points' in labels: + points = [point_xy_expand2square(point, w=width, h=height) for point in labels['points']] + labels['points'] = points + return processed_image, labels diff --git a/mllm/demo/__init__.py b/mllm/demo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mllm/demo/assets/DejaVuSansMono.ttf b/mllm/demo/assets/DejaVuSansMono.ttf new file mode 100644 index 0000000000000000000000000000000000000000..f5786022f18216b4c59c6fb0c634b52c8b6e7990 Binary files /dev/null and b/mllm/demo/assets/DejaVuSansMono.ttf differ diff --git a/mllm/demo/assets/airplane.jpg b/mllm/demo/assets/airplane.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d0de27002359327413f4e879855c8fe0a90ae7d9 Binary files /dev/null and b/mllm/demo/assets/airplane.jpg differ diff --git a/mllm/demo/assets/ball.jpg b/mllm/demo/assets/ball.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2bc4aa295453102f537d63ee0270068333510fd6 Binary files /dev/null and b/mllm/demo/assets/ball.jpg differ diff --git a/mllm/demo/assets/banana_phone.png b/mllm/demo/assets/banana_phone.png new file mode 100644 index 0000000000000000000000000000000000000000..edd544456930bca18fff95ee1d71b33d05fbbc96 Binary files /dev/null and b/mllm/demo/assets/banana_phone.png differ diff --git a/mllm/demo/assets/baseball.png b/mllm/demo/assets/baseball.png new file mode 100644 index 0000000000000000000000000000000000000000..7ce0cb54b992c19c4b751ff9d5d1d6322f681769 --- /dev/null +++ b/mllm/demo/assets/baseball.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:de707222f8ede912fe9c07b2a0849321f9ea341aaf078a995c1228279497afc3 +size 1224375 diff --git a/mllm/demo/assets/bear-792466_1280.jpg b/mllm/demo/assets/bear-792466_1280.jpg new file mode 100644 index 0000000000000000000000000000000000000000..47c46660bea26e4c4320579bacef3a7c590c8ddd Binary files /dev/null and b/mllm/demo/assets/bear-792466_1280.jpg differ diff --git a/mllm/demo/assets/bearhat.png b/mllm/demo/assets/bearhat.png new file mode 100644 index 0000000000000000000000000000000000000000..666ef7e50359ef5767f96efcf4c32846e4fffae0 Binary files /dev/null and b/mllm/demo/assets/bearhat.png differ diff --git a/mllm/demo/assets/boxes_seq_explanation.jpg b/mllm/demo/assets/boxes_seq_explanation.jpg new file mode 100644 index 0000000000000000000000000000000000000000..33710814b196d356088af8b128034ef1d407f43f Binary files /dev/null and b/mllm/demo/assets/boxes_seq_explanation.jpg differ diff --git a/mllm/demo/assets/dog_rabbit.jpg b/mllm/demo/assets/dog_rabbit.jpg new file mode 100644 index 0000000000000000000000000000000000000000..62dd0e69b33997e78b816998cebb419afb6f0b46 Binary files /dev/null and b/mllm/demo/assets/dog_rabbit.jpg differ diff --git a/mllm/demo/assets/dog_selfcontrol.jpg b/mllm/demo/assets/dog_selfcontrol.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f9e81d72cac22ae87c9c16f239475cb22b3cad31 Binary files /dev/null and b/mllm/demo/assets/dog_selfcontrol.jpg differ diff --git a/mllm/demo/assets/fishing.jpg b/mllm/demo/assets/fishing.jpg new file mode 100644 index 0000000000000000000000000000000000000000..afda179a06d1e0a417d233b54fed80ed657f7f61 Binary files /dev/null and b/mllm/demo/assets/fishing.jpg differ diff --git a/mllm/demo/assets/food-1898194_640.jpg b/mllm/demo/assets/food-1898194_640.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fb6bf167e47238e949e1f65ac8a290142dceba9d Binary files /dev/null and b/mllm/demo/assets/food-1898194_640.jpg differ diff --git a/mllm/demo/assets/fruits.jpg b/mllm/demo/assets/fruits.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9a61259d2be343940cb54e26c5d16109e4d3271d Binary files /dev/null and b/mllm/demo/assets/fruits.jpg differ diff --git a/mllm/demo/assets/g2.jpg b/mllm/demo/assets/g2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f3bb120c6a7b60646c4b5c46da881eb22e879798 Binary files /dev/null and b/mllm/demo/assets/g2.jpg differ diff --git a/mllm/demo/assets/giraffes.jpg b/mllm/demo/assets/giraffes.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2fbf7c2b0e7a621e9f9a9defae2baa425e9010f1 Binary files /dev/null and b/mllm/demo/assets/giraffes.jpg differ diff --git a/mllm/demo/assets/logo.png b/mllm/demo/assets/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..9fd1ebc1015857acd807e2a8d963febc76f50ad3 Binary files /dev/null and b/mllm/demo/assets/logo.png differ diff --git a/mllm/demo/assets/man.jpg b/mllm/demo/assets/man.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f520a848eece94b103d0addc368ab9048d0f75a0 Binary files /dev/null and b/mllm/demo/assets/man.jpg differ diff --git a/mllm/demo/assets/oven.jpg b/mllm/demo/assets/oven.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1aa964cecbfb93e31413c3d5cc1a98f539337c5e Binary files /dev/null and b/mllm/demo/assets/oven.jpg differ diff --git a/mllm/demo/assets/petal_20230711_153216_Compressed.mp4 b/mllm/demo/assets/petal_20230711_153216_Compressed.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..7e874c29d2a06bf1d391894535a9068846516fbf Binary files /dev/null and b/mllm/demo/assets/petal_20230711_153216_Compressed.mp4 differ diff --git a/mllm/demo/assets/potato.jpg b/mllm/demo/assets/potato.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a0b7a30de654b3935b2ba2500d13e7cae61cb909 Binary files /dev/null and b/mllm/demo/assets/potato.jpg differ diff --git a/mllm/demo/assets/proposal.jpg b/mllm/demo/assets/proposal.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6915cf88f7d4d13451b5ab06d5a500ad46435856 Binary files /dev/null and b/mllm/demo/assets/proposal.jpg differ diff --git a/mllm/demo/assets/puzzle.jpg b/mllm/demo/assets/puzzle.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f9630e030ab93b63e7f6c6315d8891052c612508 Binary files /dev/null and b/mllm/demo/assets/puzzle.jpg differ diff --git a/mllm/demo/assets/reaction1.png b/mllm/demo/assets/reaction1.png new file mode 100644 index 0000000000000000000000000000000000000000..65f6291e5d0dabee694b667433bca915a830990e Binary files /dev/null and b/mllm/demo/assets/reaction1.png differ diff --git a/mllm/demo/assets/reaction2.png b/mllm/demo/assets/reaction2.png new file mode 100644 index 0000000000000000000000000000000000000000..111258a3502714b5205bf68e9a98f10605bb7fde Binary files /dev/null and b/mllm/demo/assets/reaction2.png differ diff --git a/mllm/demo/assets/reaction3.png b/mllm/demo/assets/reaction3.png new file mode 100644 index 0000000000000000000000000000000000000000..8b8d77fe1cd7b730c133baa4369002fce7884cfe Binary files /dev/null and b/mllm/demo/assets/reaction3.png differ diff --git a/mllm/demo/assets/rec_bear.png b/mllm/demo/assets/rec_bear.png new file mode 100644 index 0000000000000000000000000000000000000000..e644502c26816179d2cde4251ae3550528fcdbb7 Binary files /dev/null and b/mllm/demo/assets/rec_bear.png differ diff --git a/mllm/demo/assets/relogo.png b/mllm/demo/assets/relogo.png new file mode 100644 index 0000000000000000000000000000000000000000..f7ac4f5a4bcd5e5b131552c7dbebd5ead1774aa0 Binary files /dev/null and b/mllm/demo/assets/relogo.png differ diff --git a/mllm/demo/assets/staircase-274614_640.jpg b/mllm/demo/assets/staircase-274614_640.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6a62d87df9e1875f30b56818273d046b0c1e5961 Binary files /dev/null and b/mllm/demo/assets/staircase-274614_640.jpg differ diff --git a/mllm/demo/assets/water_question.jpg b/mllm/demo/assets/water_question.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f3a0307386d539d3c2614b3d30b8fafa707532b9 Binary files /dev/null and b/mllm/demo/assets/water_question.jpg differ diff --git a/mllm/demo/assets/wet_paint1.jpg b/mllm/demo/assets/wet_paint1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e1d57781b017ce1dc746d6af95288ee5eadcd39d Binary files /dev/null and b/mllm/demo/assets/wet_paint1.jpg differ diff --git a/mllm/demo/assets/woman_door.jpg b/mllm/demo/assets/woman_door.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d04197e527e97f4e32a822f71d574a2a7bec2d2f Binary files /dev/null and b/mllm/demo/assets/woman_door.jpg differ diff --git a/mllm/demo/client.py b/mllm/demo/client.py new file mode 100644 index 0000000000000000000000000000000000000000..4c61bcd3fd9f50c98669c0a7c135b52a7dc04e1a --- /dev/null +++ b/mllm/demo/client.py @@ -0,0 +1,169 @@ +2111# client 端 api 调用案例--------------------------------- +import os +import re +import base64 +from io import BytesIO +from typing import Union + +import torch +import requests +from PIL import Image +from torchvision.transforms import ToPILImage, PILToTensor +from torchvision.utils import draw_bounding_boxes as _draw_bounding_boxes + + +######################################## +# helper +######################################## + +def pil_to_base64(pil_img): + output_buffer = BytesIO() + pil_img.save(output_buffer, format="PNG") + byte_data = output_buffer.getvalue() + encode_img = base64.b64encode(byte_data) + return str(encode_img, encoding='utf-8') + + +def de_norm_box_xyxy(box, *, w, h): + x1, y1, x2, y2 = box + x1 = x1 * w + x2 = x2 * w + y1 = y1 * h + y2 = y2 * h + box = x1, y1, x2, y2 + return box + + +def draw_bounding_boxes( + image, + boxes, + **kwargs, +): + if isinstance(image, Image.Image): + image = PILToTensor()(image) + assert isinstance(image, torch.Tensor), "" + + if not isinstance(boxes, torch.Tensor): + boxes = torch.as_tensor(boxes) + assert isinstance(boxes, torch.Tensor) + + return _draw_bounding_boxes(image, boxes, **kwargs) + + +def expand2square(pil_img, background_color=(255, 255, 255)): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +######################################## +# +######################################## + +def query(image: Union[Image.Image, str], text: str, boxes_value: list, boxes_seq: list, server_url='http://127.0.0.1:12345/shikra'): + if isinstance(image, str): + image = Image.open(image) + pload = { + "img_base64": pil_to_base64(image), + "text": text, + "boxes_value": boxes_value, + "boxes_seq": boxes_seq, + } + resp = requests.post(server_url, json=pload) + if resp.status_code != 200: + raise ValueError(resp.reason) + ret = resp.json() + return ret + + +def postprocess(text, image): + if image is None: + return text, None + + image = expand2square(image) + + colors = ['#ed7d31', '#5b9bd5', '#70ad47', '#7030a0', '#c00000', '#ffff00', "olive", "brown", "cyan"] + pat = re.compile(r'\[\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3}(?:;\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3})*\]') + + def extract_boxes(string): + ret = [] + for bboxes_str in pat.findall(string): + bboxes = [] + bbox_strs = bboxes_str.replace("(", "").replace(")", "").replace("[", "").replace("]", "").split(";") + for bbox_str in bbox_strs: + bbox = list(map(float, bbox_str.split(','))) + bboxes.append(bbox) + ret.append(bboxes) + return ret + + extract_pred = extract_boxes(text) + boxes_to_draw = [] + color_to_draw = [] + for idx, boxes in enumerate(extract_pred): + color = colors[idx % len(colors)] + for box in boxes: + boxes_to_draw.append(de_norm_box_xyxy(box, w=image.width, h=image.height)) + color_to_draw.append(color) + if not boxes_to_draw: + return text, None + res = draw_bounding_boxes(image=image, boxes=boxes_to_draw, colors=color_to_draw, width=8) + res = ToPILImage()(res) + + # post process text color + location_text = text + edit_text = list(text) + bboxes_str = pat.findall(text) + for idx in range(len(bboxes_str) - 1, -1, -1): + color = colors[idx % len(colors)] + boxes = bboxes_str[idx] + span = location_text.rfind(boxes), location_text.rfind(boxes) + len(boxes) + location_text = location_text[:span[0]] + edit_text[span[0]:span[1]] = f'{boxes}' + text = "".join(edit_text) + return text, res + + +if __name__ == '__main__': + server_url = 'http://127.0.0.1:12345' + "/shikra" + + + def example1(): + image_path = os.path.join(os.path.dirname(__file__), 'assets/rec_bear.png') + text = 'Can you point out a brown teddy bear with a blue bow in the image and provide the coordinates of its location?' + boxes_value = [] + boxes_seq = [] + + response = query(image_path, text, boxes_value, boxes_seq, server_url) + print(response) + + _, image = postprocess(response['response'], image=Image.open(image_path)) + print(_) + if image is not None: + image.show() + + + def example2(): + image_path = os.path.join(os.path.dirname(__file__), 'assets/man.jpg') + text = "What is the person scared of?" + boxes_value = [[148, 99, 576, 497]] + boxes_seq = [[0]] + + response = query(image_path, text, boxes_value, boxes_seq, server_url) + print(response) + + _, image = postprocess(response['response'], image=Image.open(image_path)) + print(_) + if image is not None: + image.show() + + + example1() + example2() diff --git a/mllm/demo/server.py b/mllm/demo/server.py new file mode 100644 index 0000000000000000000000000000000000000000..bb6297c05c1d8f7c82675d9cae65fe163c78efca --- /dev/null +++ b/mllm/demo/server.py @@ -0,0 +1,198 @@ +# server 端--------------------------------------- +import argparse +import os +import sys +import base64 +import logging +import time +from pathlib import Path +from io import BytesIO + +import torch +import uvicorn +import transformers +from PIL import Image +from mmengine import Config +from transformers import BitsAndBytesConfig +from fastapi import FastAPI, Request, HTTPException + +sys.path.append(str(Path(__file__).parent.parent.parent)) + +from mllm.dataset.process_function import PlainBoxFormatter +from mllm.dataset.builder import prepare_interactive +from mllm.models.builder.build_shikra import load_pretrained_shikra +from mllm.dataset.utils.transform import expand2square, box_xyxy_expand2square + +log_level = logging.DEBUG +transformers.logging.set_verbosity(log_level) +transformers.logging.enable_default_handler() +transformers.logging.enable_explicit_format() + +######################################### +# mllm model init +######################################### +parser = argparse.ArgumentParser("Shikra Server Demo") +parser.add_argument('--model_path', required=True) +parser.add_argument('--load_in_8bit', action='store_true') +parser.add_argument('--server_name', default='127.0.0.1') +parser.add_argument('--server_port', type=int, default=12345) + +args = parser.parse_args() +print(args) + +model_name_or_path = args.model_path + +model_args = Config(dict( + type='shikra', + version='v1', + + # checkpoint config + cache_dir=None, + model_name_or_path=model_name_or_path, + vision_tower=r'vit-h', + pretrain_mm_mlp_adapter=None, + + # model config + mm_vision_select_layer=-2, + model_max_length=3072, + + # finetune config + freeze_backbone=False, + tune_mm_mlp_adapter=False, + freeze_mm_mlp_adapter=False, + + # data process config + is_multimodal=True, + sep_image_conv_front=False, + image_token_len=256, + mm_use_im_start_end=True, + + target_processor=dict( + boxes=dict(type='PlainBoxFormatter'), + ), + + process_func_args=dict( + conv=dict(type='ShikraConvProcess'), + target=dict(type='BoxFormatProcess'), + text=dict(type='ShikraTextProcess'), + image=dict(type='ShikraImageProcessor'), + ), + + conv_args=dict( + conv_template='vicuna_v1.1', + transforms=dict(type='Expand2square'), + tokenize_kwargs=dict(truncation_size=None), + ), + + gen_kwargs_set_pad_token_id=True, + gen_kwargs_set_bos_token_id=True, + gen_kwargs_set_eos_token_id=True, +)) +training_args = Config(dict( + bf16=False, + fp16=True, + device='cuda', + fsdp=None, +)) + +if args.load_in_8bit: + quantization_kwargs = dict( + quantization_config=BitsAndBytesConfig( + load_in_8bit=True, + ) + ) +else: + quantization_kwargs = dict() + +model, preprocessor = load_pretrained_shikra(model_args, training_args, **quantization_kwargs) +if not getattr(model, 'is_quantized', False): + model.to(dtype=torch.float16, device=torch.device('cuda')) +if not getattr(model.model.vision_tower[0], 'is_quantized', False): + model.model.vision_tower[0].to(dtype=torch.float16, device=torch.device('cuda')) +print( + f"LLM device: {model.device}, is_quantized: {getattr(model, 'is_quantized', False)}, is_loaded_in_4bit: {getattr(model, 'is_loaded_in_4bit', False)}, is_loaded_in_8bit: {getattr(model, 'is_loaded_in_8bit', False)}") +print( + f"vision device: {model.model.vision_tower[0].device}, is_quantized: {getattr(model.model.vision_tower[0], 'is_quantized', False)}, is_loaded_in_4bit: {getattr(model, 'is_loaded_in_4bit', False)}, is_loaded_in_8bit: {getattr(model, 'is_loaded_in_8bit', False)}") + +preprocessor['target'] = {'boxes': PlainBoxFormatter()} +tokenizer = preprocessor['text'] + +######################################### +# fast api +######################################### +app = FastAPI() + + +@app.post("/shikra") +async def shikra(request: Request): + try: + # receive parameters + para = await request.json() + img_base64 = para["img_base64"] + user_input = para["text"] + boxes_value = para.get('boxes_value', []) + boxes_seq = para.get('boxes_seq', []) + + do_sample = para.get('do_sample', False) + max_length = para.get('max_length', 512) + top_p = para.get('top_p', 1.0) + temperature = para.get('temperature', 1.0) + + # parameters preprocess + pil_image = Image.open(BytesIO(base64.b64decode(img_base64))).convert("RGB") + ds = prepare_interactive(model_args, preprocessor) + + image = expand2square(pil_image) + boxes_value = [box_xyxy_expand2square(box, w=pil_image.width, h=pil_image.height) for box in boxes_value] + + ds.set_image(image) + ds.append_message(role=ds.roles[0], message=user_input, boxes=boxes_value, boxes_seq=boxes_seq) + model_inputs = ds.to_model_input() + model_inputs['images'] = model_inputs['images'].to(torch.float16) + print(f"model_inputs: {model_inputs}") + + # generate + if do_sample: + gen_kwargs = dict( + use_cache=True, + do_sample=do_sample, + pad_token_id=tokenizer.pad_token_id, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + max_new_tokens=max_length, + top_p=top_p, + temperature=float(temperature), + ) + else: + gen_kwargs = dict( + use_cache=True, + do_sample=do_sample, + pad_token_id=tokenizer.pad_token_id, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + max_new_tokens=max_length, + ) + print(gen_kwargs) + input_ids = model_inputs['input_ids'] + st_time = time.time() + with torch.inference_mode(): + with torch.autocast(dtype=torch.float16, device_type='cuda'): + output_ids = model.generate(**model_inputs, **gen_kwargs) + print(f"done generated in {time.time() - st_time} seconds") + input_token_len = input_ids.shape[-1] + response = tokenizer.batch_decode(output_ids[:, input_token_len:])[0] + print(f"response: {response}") + + input_text = tokenizer.batch_decode(input_ids)[0] + return { + "input": input_text, + "response": response, + } + + except Exception as e: + logging.exception(str(e)) + raise HTTPException(status_code=500, detail=str(e)) + + +if __name__ == "__main__": + uvicorn.run(app, host=args.server_name, port=args.server_port, log_level="info") diff --git a/mllm/demo/temp/tmp06cxea62.jpg b/mllm/demo/temp/tmp06cxea62.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5ea8b60eef1abb8a3839775efd6478447daf67c9 Binary files /dev/null and b/mllm/demo/temp/tmp06cxea62.jpg differ diff --git a/mllm/demo/temp/tmp11amgkps.jpg b/mllm/demo/temp/tmp11amgkps.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ddcc7cdd9256f1a96e4042c037a7315433340593 Binary files /dev/null and b/mllm/demo/temp/tmp11amgkps.jpg differ diff --git a/mllm/demo/temp/tmp1pxg1mrf.jpg b/mllm/demo/temp/tmp1pxg1mrf.jpg new file mode 100644 index 0000000000000000000000000000000000000000..63494f9913728e12b09482d24a19a00529ac5c13 Binary files /dev/null and b/mllm/demo/temp/tmp1pxg1mrf.jpg differ diff --git a/mllm/demo/temp/tmp1xhkhofh.jpg b/mllm/demo/temp/tmp1xhkhofh.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f1e534f0d269915e5ff80205f7a9f7c49800d450 Binary files /dev/null and b/mllm/demo/temp/tmp1xhkhofh.jpg differ diff --git a/mllm/demo/temp/tmp2l27ky23.jpg b/mllm/demo/temp/tmp2l27ky23.jpg new file mode 100644 index 0000000000000000000000000000000000000000..28d100e3b277791679bd81cb6c9a550b01fd8bf8 Binary files /dev/null and b/mllm/demo/temp/tmp2l27ky23.jpg differ diff --git a/mllm/demo/temp/tmp2ly3ve3l.jpg b/mllm/demo/temp/tmp2ly3ve3l.jpg new file mode 100644 index 0000000000000000000000000000000000000000..315bc621421745f361bef9a32ddb47dd1d4b514b Binary files /dev/null and b/mllm/demo/temp/tmp2ly3ve3l.jpg differ diff --git a/mllm/demo/temp/tmp2zavyam4.jpg b/mllm/demo/temp/tmp2zavyam4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..579244b4c2614f288f2f675a1416c74ed0697b04 Binary files /dev/null and b/mllm/demo/temp/tmp2zavyam4.jpg differ diff --git a/mllm/demo/temp/tmp32q_tkqx.jpg b/mllm/demo/temp/tmp32q_tkqx.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9168659f319bba7a82724c713364919a30cba794 Binary files /dev/null and b/mllm/demo/temp/tmp32q_tkqx.jpg differ diff --git a/mllm/demo/temp/tmp44_6k1al.jpg b/mllm/demo/temp/tmp44_6k1al.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ab211902fbbb82df91222781987cfa887e0d7b9f Binary files /dev/null and b/mllm/demo/temp/tmp44_6k1al.jpg differ diff --git a/mllm/demo/temp/tmp4dy0uabc.jpg b/mllm/demo/temp/tmp4dy0uabc.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7856aff454f95737a551d1ebb7f78a2a9a90e4eb Binary files /dev/null and b/mllm/demo/temp/tmp4dy0uabc.jpg differ diff --git a/mllm/demo/temp/tmp4f4tcqgl.jpg b/mllm/demo/temp/tmp4f4tcqgl.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d0284af218354ce13902efeab6bb7deb32b7b839 Binary files /dev/null and b/mllm/demo/temp/tmp4f4tcqgl.jpg differ diff --git a/mllm/demo/temp/tmp4m9dwpm7.jpg b/mllm/demo/temp/tmp4m9dwpm7.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d0336b74680646bb3e3dbca304bf496371342b55 Binary files /dev/null and b/mllm/demo/temp/tmp4m9dwpm7.jpg differ diff --git a/mllm/demo/temp/tmp52apz2o4.jpg b/mllm/demo/temp/tmp52apz2o4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e6e1e725802c8ef7c51db9fd1ca860bd83228b62 Binary files /dev/null and b/mllm/demo/temp/tmp52apz2o4.jpg differ diff --git a/mllm/demo/temp/tmp5e8s9q2u.jpg b/mllm/demo/temp/tmp5e8s9q2u.jpg new file mode 100644 index 0000000000000000000000000000000000000000..587aec0099f73aeb5a36a911a4670c58d346c0f7 Binary files /dev/null and b/mllm/demo/temp/tmp5e8s9q2u.jpg differ diff --git a/mllm/demo/temp/tmp5kbdvyt2.jpg b/mllm/demo/temp/tmp5kbdvyt2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5d4886efec530a20b11877cee9e1a6883102788f Binary files /dev/null and b/mllm/demo/temp/tmp5kbdvyt2.jpg differ diff --git a/mllm/demo/temp/tmp5qu9qw7m.jpg b/mllm/demo/temp/tmp5qu9qw7m.jpg new file mode 100644 index 0000000000000000000000000000000000000000..63e4a9c41e4348f9a3ee6cf826b567ed210b844d Binary files /dev/null and b/mllm/demo/temp/tmp5qu9qw7m.jpg differ diff --git a/mllm/demo/temp/tmp6h20dp17.jpg b/mllm/demo/temp/tmp6h20dp17.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ef76e73175e6600179d0d0dc009af8f61d7585f8 Binary files /dev/null and b/mllm/demo/temp/tmp6h20dp17.jpg differ diff --git a/mllm/demo/temp/tmp6qt5g3uu.jpg b/mllm/demo/temp/tmp6qt5g3uu.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ee3a48e541b50035a7b58533dbb23eb4751e3b76 Binary files /dev/null and b/mllm/demo/temp/tmp6qt5g3uu.jpg differ diff --git a/mllm/demo/temp/tmp8p_4smpq.jpg b/mllm/demo/temp/tmp8p_4smpq.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8806b6ee14373e56ef4dd361bd3ac16865d91564 Binary files /dev/null and b/mllm/demo/temp/tmp8p_4smpq.jpg differ diff --git a/mllm/demo/temp/tmp93twa4d7.jpg b/mllm/demo/temp/tmp93twa4d7.jpg new file mode 100644 index 0000000000000000000000000000000000000000..12671c492c11a53233a7c223026f5180e69e28f3 Binary files /dev/null and b/mllm/demo/temp/tmp93twa4d7.jpg differ diff --git a/mllm/demo/temp/tmp98upp3ol.jpg b/mllm/demo/temp/tmp98upp3ol.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f4f1ef642b8dac2c1041efae82c5fb07d2c63e91 Binary files /dev/null and b/mllm/demo/temp/tmp98upp3ol.jpg differ diff --git a/mllm/demo/temp/tmp99jeit_w.jpg b/mllm/demo/temp/tmp99jeit_w.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f88c62df5b1432932a00eeeb512a4c2fdc97858a Binary files /dev/null and b/mllm/demo/temp/tmp99jeit_w.jpg differ diff --git a/mllm/demo/temp/tmp_1vm_1d5.jpg b/mllm/demo/temp/tmp_1vm_1d5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e4ed8905c6478699e75345307675524256caa266 Binary files /dev/null and b/mllm/demo/temp/tmp_1vm_1d5.jpg differ diff --git a/mllm/demo/temp/tmp_1wpbu5h.jpg b/mllm/demo/temp/tmp_1wpbu5h.jpg new file mode 100644 index 0000000000000000000000000000000000000000..316eb8dbea6b8c77a509fa81b414b42c830ed26c Binary files /dev/null and b/mllm/demo/temp/tmp_1wpbu5h.jpg differ diff --git a/mllm/demo/temp/tmp_o3fvmxe.jpg b/mllm/demo/temp/tmp_o3fvmxe.jpg new file mode 100644 index 0000000000000000000000000000000000000000..892e0a171289a6eb9125325751ae4129b27070f2 Binary files /dev/null and b/mllm/demo/temp/tmp_o3fvmxe.jpg differ diff --git a/mllm/demo/temp/tmp_q_o1dk2.jpg b/mllm/demo/temp/tmp_q_o1dk2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9b02954b6b2a664a239e69c473374f9e7a16fbd9 Binary files /dev/null and b/mllm/demo/temp/tmp_q_o1dk2.jpg differ diff --git a/mllm/demo/temp/tmpa6tns0jv.jpg b/mllm/demo/temp/tmpa6tns0jv.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6dbf7e3ba3ff22a845ecb38239eff5a5068dd416 Binary files /dev/null and b/mllm/demo/temp/tmpa6tns0jv.jpg differ diff --git a/mllm/demo/temp/tmpafz_mg0y.jpg b/mllm/demo/temp/tmpafz_mg0y.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3ba18266776f406187ef4a8762ac439b16a4fd75 Binary files /dev/null and b/mllm/demo/temp/tmpafz_mg0y.jpg differ diff --git a/mllm/demo/temp/tmpbe4gxvof.jpg b/mllm/demo/temp/tmpbe4gxvof.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8cc07b298696e41ddb52cb6b99250b160363e862 Binary files /dev/null and b/mllm/demo/temp/tmpbe4gxvof.jpg differ diff --git a/mllm/demo/temp/tmpbnnkoeqa.jpg b/mllm/demo/temp/tmpbnnkoeqa.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4e7fb70f2d5d8663891691216f092800f250e005 Binary files /dev/null and b/mllm/demo/temp/tmpbnnkoeqa.jpg differ diff --git a/mllm/demo/temp/tmpbo_8fovv.jpg b/mllm/demo/temp/tmpbo_8fovv.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a710325fbd0af7f5a28f4a95c03ad4b492bd1b19 Binary files /dev/null and b/mllm/demo/temp/tmpbo_8fovv.jpg differ diff --git a/mllm/demo/temp/tmpboyb8wgk.jpg b/mllm/demo/temp/tmpboyb8wgk.jpg new file mode 100644 index 0000000000000000000000000000000000000000..874ec4000e0d60de96beeca759c81bd70310fbd8 Binary files /dev/null and b/mllm/demo/temp/tmpboyb8wgk.jpg differ diff --git a/mllm/demo/temp/tmpbqf2tomv.jpg b/mllm/demo/temp/tmpbqf2tomv.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6dbf7e3ba3ff22a845ecb38239eff5a5068dd416 Binary files /dev/null and b/mllm/demo/temp/tmpbqf2tomv.jpg differ diff --git a/mllm/demo/temp/tmpbqnkjcmh.jpg b/mllm/demo/temp/tmpbqnkjcmh.jpg new file mode 100644 index 0000000000000000000000000000000000000000..757b6a173bebf10fa98d61f5f30e5923e46ae8d2 Binary files /dev/null and b/mllm/demo/temp/tmpbqnkjcmh.jpg differ diff --git a/mllm/demo/temp/tmpcahuh4_u.jpg b/mllm/demo/temp/tmpcahuh4_u.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1d7446a4bd8af7d1111e529a837cfcb7f7f3c4a7 Binary files /dev/null and b/mllm/demo/temp/tmpcahuh4_u.jpg differ diff --git a/mllm/demo/temp/tmpcqygcmum.jpg b/mllm/demo/temp/tmpcqygcmum.jpg new file mode 100644 index 0000000000000000000000000000000000000000..376a60548360236f80463049a76a862085583e89 Binary files /dev/null and b/mllm/demo/temp/tmpcqygcmum.jpg differ diff --git a/mllm/demo/temp/tmpd8wd8p71.jpg b/mllm/demo/temp/tmpd8wd8p71.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1be488b8c86dee9a5234980a68ac864ecf33acce Binary files /dev/null and b/mllm/demo/temp/tmpd8wd8p71.jpg differ diff --git a/mllm/demo/temp/tmpda_sgg63.jpg b/mllm/demo/temp/tmpda_sgg63.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e6eb14abe52cc2a5af06f9d4a4a85d7b276bbd5b Binary files /dev/null and b/mllm/demo/temp/tmpda_sgg63.jpg differ diff --git a/mllm/demo/temp/tmpdbp6zekd.jpg b/mllm/demo/temp/tmpdbp6zekd.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7aabbc793d771f45051312d9c5b255acb6a7f4e2 Binary files /dev/null and b/mllm/demo/temp/tmpdbp6zekd.jpg differ diff --git a/mllm/demo/temp/tmpeixux2wk.jpg b/mllm/demo/temp/tmpeixux2wk.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8dd8b9eb7f1da1251150780366932e17ff80368d Binary files /dev/null and b/mllm/demo/temp/tmpeixux2wk.jpg differ diff --git a/mllm/demo/temp/tmpf83m4lbn.jpg b/mllm/demo/temp/tmpf83m4lbn.jpg new file mode 100644 index 0000000000000000000000000000000000000000..007a27c146218bf9af4a41cf02defa28682116aa Binary files /dev/null and b/mllm/demo/temp/tmpf83m4lbn.jpg differ diff --git a/mllm/demo/temp/tmpfilsb5rn.jpg b/mllm/demo/temp/tmpfilsb5rn.jpg new file mode 100644 index 0000000000000000000000000000000000000000..63876f5bd342f042c9d78bdf575c904898a390b3 Binary files /dev/null and b/mllm/demo/temp/tmpfilsb5rn.jpg differ diff --git a/mllm/demo/temp/tmpgcvw07po.jpg b/mllm/demo/temp/tmpgcvw07po.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e3972bcaf1a61913b29919a841c19d66af090c14 Binary files /dev/null and b/mllm/demo/temp/tmpgcvw07po.jpg differ diff --git a/mllm/demo/temp/tmpglu63rin.jpg b/mllm/demo/temp/tmpglu63rin.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b42faf7f97b44bfde937d5bcc609429476c04e7c Binary files /dev/null and b/mllm/demo/temp/tmpglu63rin.jpg differ diff --git a/mllm/demo/temp/tmpgtpavzna.jpg b/mllm/demo/temp/tmpgtpavzna.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f1e251bb21ac75f71568a90ae321df6983978811 Binary files /dev/null and b/mllm/demo/temp/tmpgtpavzna.jpg differ diff --git a/mllm/demo/temp/tmpi05uhxzc.jpg b/mllm/demo/temp/tmpi05uhxzc.jpg new file mode 100644 index 0000000000000000000000000000000000000000..04c284dc0037b87400d554a0a9c38a5e2998c6ad Binary files /dev/null and b/mllm/demo/temp/tmpi05uhxzc.jpg differ diff --git a/mllm/demo/temp/tmpi71dddas.jpg b/mllm/demo/temp/tmpi71dddas.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a25547fedb1fa566566ab4a6eed4b63c6b4034ad Binary files /dev/null and b/mllm/demo/temp/tmpi71dddas.jpg differ diff --git a/mllm/demo/temp/tmpirij6_40.jpg b/mllm/demo/temp/tmpirij6_40.jpg new file mode 100644 index 0000000000000000000000000000000000000000..62f35361ca845afec898076778d761173926f808 Binary files /dev/null and b/mllm/demo/temp/tmpirij6_40.jpg differ diff --git a/mllm/demo/temp/tmpitfhflpn.jpg b/mllm/demo/temp/tmpitfhflpn.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5799b9dcc22ad2a58e5dea600d353db906e5a2dc Binary files /dev/null and b/mllm/demo/temp/tmpitfhflpn.jpg differ diff --git a/mllm/demo/temp/tmpiyd65r70.jpg b/mllm/demo/temp/tmpiyd65r70.jpg new file mode 100644 index 0000000000000000000000000000000000000000..21204f11a2e6d35524c65e67f9c2d3ed8e01394a Binary files /dev/null and b/mllm/demo/temp/tmpiyd65r70.jpg differ diff --git a/mllm/demo/temp/tmpj8nkxnh4.jpg b/mllm/demo/temp/tmpj8nkxnh4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..beede0caafc971c8d806514c7eb7b5356e5e6701 Binary files /dev/null and b/mllm/demo/temp/tmpj8nkxnh4.jpg differ diff --git a/mllm/demo/temp/tmpjr4s1rd0.jpg b/mllm/demo/temp/tmpjr4s1rd0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..03101a76bf4e2bd3751780b348689c442b4cbec8 Binary files /dev/null and b/mllm/demo/temp/tmpjr4s1rd0.jpg differ diff --git a/mllm/demo/temp/tmpk68b8tti.jpg b/mllm/demo/temp/tmpk68b8tti.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cf3ee1d86bb6ab2e0d2ea3acedf89716a019a87d Binary files /dev/null and b/mllm/demo/temp/tmpk68b8tti.jpg differ diff --git a/mllm/demo/temp/tmpkai6shkm.jpg b/mllm/demo/temp/tmpkai6shkm.jpg new file mode 100644 index 0000000000000000000000000000000000000000..37a8b9a35f040e080c130f51ca7328be728ceee2 Binary files /dev/null and b/mllm/demo/temp/tmpkai6shkm.jpg differ diff --git a/mllm/demo/temp/tmpkky8gapr.jpg b/mllm/demo/temp/tmpkky8gapr.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f1b6aa176116041bd6067da0c2a686c80234f3c7 Binary files /dev/null and b/mllm/demo/temp/tmpkky8gapr.jpg differ diff --git a/mllm/demo/temp/tmpkwofl2uh.jpg b/mllm/demo/temp/tmpkwofl2uh.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1fa1e4410974e7dafba8e91c5ba50b15742b555c Binary files /dev/null and b/mllm/demo/temp/tmpkwofl2uh.jpg differ diff --git a/mllm/demo/temp/tmpkybfxuh_.jpg b/mllm/demo/temp/tmpkybfxuh_.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1f22d6ccaddc964fae8f149961ff46277cc584ac Binary files /dev/null and b/mllm/demo/temp/tmpkybfxuh_.jpg differ diff --git a/mllm/demo/temp/tmpkzo1n16l.jpg b/mllm/demo/temp/tmpkzo1n16l.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8f9882dbdddb9121520446211cb237715b6a310b Binary files /dev/null and b/mllm/demo/temp/tmpkzo1n16l.jpg differ diff --git a/mllm/demo/temp/tmpkzpo_7qd.jpg b/mllm/demo/temp/tmpkzpo_7qd.jpg new file mode 100644 index 0000000000000000000000000000000000000000..579244b4c2614f288f2f675a1416c74ed0697b04 Binary files /dev/null and b/mllm/demo/temp/tmpkzpo_7qd.jpg differ diff --git a/mllm/demo/temp/tmpl2nw89e0.jpg b/mllm/demo/temp/tmpl2nw89e0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..874ec4000e0d60de96beeca759c81bd70310fbd8 Binary files /dev/null and b/mllm/demo/temp/tmpl2nw89e0.jpg differ diff --git a/mllm/demo/temp/tmple4xphw3.jpg b/mllm/demo/temp/tmple4xphw3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6a49de1aac6211f7614d50640452f9020dac067c Binary files /dev/null and b/mllm/demo/temp/tmple4xphw3.jpg differ diff --git a/mllm/demo/temp/tmplf4gvocg.jpg b/mllm/demo/temp/tmplf4gvocg.jpg new file mode 100644 index 0000000000000000000000000000000000000000..246512cd42c675185cbcfef28266019b2f5f40c2 Binary files /dev/null and b/mllm/demo/temp/tmplf4gvocg.jpg differ diff --git a/mllm/demo/temp/tmpm1x780i5.jpg b/mllm/demo/temp/tmpm1x780i5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..42b1983b8977e30f9a7edb411d417e6d47975e23 Binary files /dev/null and b/mllm/demo/temp/tmpm1x780i5.jpg differ diff --git a/mllm/demo/temp/tmpm__tuo_e.jpg b/mllm/demo/temp/tmpm__tuo_e.jpg new file mode 100644 index 0000000000000000000000000000000000000000..458a8e6586ffe64c9ad43cf3b0bf1e8287fcd651 Binary files /dev/null and b/mllm/demo/temp/tmpm__tuo_e.jpg differ diff --git a/mllm/demo/temp/tmpmbjkmmsp.jpg b/mllm/demo/temp/tmpmbjkmmsp.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0498f9555fc27a871b9f967c3adac5a9bc0182e7 Binary files /dev/null and b/mllm/demo/temp/tmpmbjkmmsp.jpg differ diff --git a/mllm/demo/temp/tmpni9db1y9.jpg b/mllm/demo/temp/tmpni9db1y9.jpg new file mode 100644 index 0000000000000000000000000000000000000000..50992dc6ad02efdf48386b12d197f60208d6557e Binary files /dev/null and b/mllm/demo/temp/tmpni9db1y9.jpg differ diff --git a/mllm/demo/temp/tmpnijhiand.jpg b/mllm/demo/temp/tmpnijhiand.jpg new file mode 100644 index 0000000000000000000000000000000000000000..931b9a0bc4ba94dd396485012644266c6060480e Binary files /dev/null and b/mllm/demo/temp/tmpnijhiand.jpg differ diff --git a/mllm/demo/temp/tmpnt0dhynl.jpg b/mllm/demo/temp/tmpnt0dhynl.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6437e80fecac9714550c26cca027d0c69c92b920 Binary files /dev/null and b/mllm/demo/temp/tmpnt0dhynl.jpg differ diff --git a/mllm/demo/temp/tmpp0vg0f5h.jpg b/mllm/demo/temp/tmpp0vg0f5h.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9ddaf98600f194d50131b34c82ec65e9f19b20e7 Binary files /dev/null and b/mllm/demo/temp/tmpp0vg0f5h.jpg differ diff --git a/mllm/demo/temp/tmppc1b_fw2.jpg b/mllm/demo/temp/tmppc1b_fw2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c2fb81f1b6cf22852c661d87ad830e68eae928fb Binary files /dev/null and b/mllm/demo/temp/tmppc1b_fw2.jpg differ diff --git a/mllm/demo/temp/tmpps6azmre.jpg b/mllm/demo/temp/tmpps6azmre.jpg new file mode 100644 index 0000000000000000000000000000000000000000..086fb7066228888314d619aa2fa420f43f82eecb Binary files /dev/null and b/mllm/demo/temp/tmpps6azmre.jpg differ diff --git a/mllm/demo/temp/tmpqhcdtdt_.jpg b/mllm/demo/temp/tmpqhcdtdt_.jpg new file mode 100644 index 0000000000000000000000000000000000000000..656411e15390f661fa6796f42521223850998ca4 Binary files /dev/null and b/mllm/demo/temp/tmpqhcdtdt_.jpg differ diff --git a/mllm/demo/temp/tmpqhdik7hp.jpg b/mllm/demo/temp/tmpqhdik7hp.jpg new file mode 100644 index 0000000000000000000000000000000000000000..393f30e61ce12a82e8bbe90fb1ef0aff02ccb4fc Binary files /dev/null and b/mllm/demo/temp/tmpqhdik7hp.jpg differ diff --git a/mllm/demo/temp/tmprb9r1jac.jpg b/mllm/demo/temp/tmprb9r1jac.jpg new file mode 100644 index 0000000000000000000000000000000000000000..00e9968e32cefb2e19658c49873769bb0430b7d1 Binary files /dev/null and b/mllm/demo/temp/tmprb9r1jac.jpg differ diff --git a/mllm/demo/temp/tmprk8h2o7p.jpg b/mllm/demo/temp/tmprk8h2o7p.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3b708af32c247e443587f03632e8560d78dd882a Binary files /dev/null and b/mllm/demo/temp/tmprk8h2o7p.jpg differ diff --git a/mllm/demo/temp/tmps3o5ntof.jpg b/mllm/demo/temp/tmps3o5ntof.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1b554469f0cd6a618d6a13f672ebf96d1286cb70 Binary files /dev/null and b/mllm/demo/temp/tmps3o5ntof.jpg differ diff --git a/mllm/demo/temp/tmpsevpisyi.jpg b/mllm/demo/temp/tmpsevpisyi.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c0f93000509a078367cad254192f6c60b12d4a71 Binary files /dev/null and b/mllm/demo/temp/tmpsevpisyi.jpg differ diff --git a/mllm/demo/temp/tmpshwpfaob.jpg b/mllm/demo/temp/tmpshwpfaob.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7d8f9446505200f3b227d61eb5793530c39cc846 Binary files /dev/null and b/mllm/demo/temp/tmpshwpfaob.jpg differ diff --git a/mllm/demo/temp/tmptfsz7ne0.jpg b/mllm/demo/temp/tmptfsz7ne0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..393a32c79a6b3e924624e47d037e71ca3a65cb94 Binary files /dev/null and b/mllm/demo/temp/tmptfsz7ne0.jpg differ diff --git a/mllm/demo/temp/tmpug0dhmwi.jpg b/mllm/demo/temp/tmpug0dhmwi.jpg new file mode 100644 index 0000000000000000000000000000000000000000..55834f47dd365a22f1466e1d5408886468e0180b Binary files /dev/null and b/mllm/demo/temp/tmpug0dhmwi.jpg differ diff --git a/mllm/demo/temp/tmpurg40rt8.jpg b/mllm/demo/temp/tmpurg40rt8.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c6f69d8a25159d7d896b82b1d32b55da70c39a8e Binary files /dev/null and b/mllm/demo/temp/tmpurg40rt8.jpg differ diff --git a/mllm/demo/temp/tmpv5bxbca5.jpg b/mllm/demo/temp/tmpv5bxbca5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1fa1e4410974e7dafba8e91c5ba50b15742b555c Binary files /dev/null and b/mllm/demo/temp/tmpv5bxbca5.jpg differ diff --git a/mllm/demo/temp/tmpv9fegjfs.jpg b/mllm/demo/temp/tmpv9fegjfs.jpg new file mode 100644 index 0000000000000000000000000000000000000000..393a32c79a6b3e924624e47d037e71ca3a65cb94 Binary files /dev/null and b/mllm/demo/temp/tmpv9fegjfs.jpg differ diff --git a/mllm/demo/temp/tmpvd3oh0_5.jpg b/mllm/demo/temp/tmpvd3oh0_5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b24a497ddc450bcdb3b1a9c258b84e3f2106f64c Binary files /dev/null and b/mllm/demo/temp/tmpvd3oh0_5.jpg differ diff --git a/mllm/demo/temp/tmpvreygw_i.jpg b/mllm/demo/temp/tmpvreygw_i.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7dc28b878d1f72e518fec25fdb92747af9d784a7 Binary files /dev/null and b/mllm/demo/temp/tmpvreygw_i.jpg differ diff --git a/mllm/demo/temp/tmpvtb5l4pd.jpg b/mllm/demo/temp/tmpvtb5l4pd.jpg new file mode 100644 index 0000000000000000000000000000000000000000..edd77abaeeac96ea7426bd863e30af054ae2e248 Binary files /dev/null and b/mllm/demo/temp/tmpvtb5l4pd.jpg differ diff --git a/mllm/demo/temp/tmpvu62fu4r.jpg b/mllm/demo/temp/tmpvu62fu4r.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4429ff94d1565c40cae4983b4b68d19a78f18d8c Binary files /dev/null and b/mllm/demo/temp/tmpvu62fu4r.jpg differ diff --git a/mllm/demo/temp/tmpwbscy95i.jpg b/mllm/demo/temp/tmpwbscy95i.jpg new file mode 100644 index 0000000000000000000000000000000000000000..84b250e8650de6190bcf6552e99648d93c401def Binary files /dev/null and b/mllm/demo/temp/tmpwbscy95i.jpg differ diff --git a/mllm/demo/temp/tmpxak2h7lc.jpg b/mllm/demo/temp/tmpxak2h7lc.jpg new file mode 100644 index 0000000000000000000000000000000000000000..eca597e49c91e717767d78e4e1f28275bd7911cb Binary files /dev/null and b/mllm/demo/temp/tmpxak2h7lc.jpg differ diff --git a/mllm/demo/temp/tmpxasfuuga.jpg b/mllm/demo/temp/tmpxasfuuga.jpg new file mode 100644 index 0000000000000000000000000000000000000000..04665645b4182761de359517cfa5e72ab5085f87 Binary files /dev/null and b/mllm/demo/temp/tmpxasfuuga.jpg differ diff --git a/mllm/demo/temp/tmpxeevsjf8.jpg b/mllm/demo/temp/tmpxeevsjf8.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ab9b52197b655857cd12623e64b18efe16a21799 Binary files /dev/null and b/mllm/demo/temp/tmpxeevsjf8.jpg differ diff --git a/mllm/demo/temp/tmpy2xnuimz.jpg b/mllm/demo/temp/tmpy2xnuimz.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0ed951a1f5fe746ee2231c2efc68418de3717566 Binary files /dev/null and b/mllm/demo/temp/tmpy2xnuimz.jpg differ diff --git a/mllm/demo/temp/tmpycavp8gt.jpg b/mllm/demo/temp/tmpycavp8gt.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b72b61b9b95ac6aca821803c0bcd2a6b7e732711 Binary files /dev/null and b/mllm/demo/temp/tmpycavp8gt.jpg differ diff --git a/mllm/demo/temp/tmpyjy2p7vg.jpg b/mllm/demo/temp/tmpyjy2p7vg.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6f10dace7aa143ffe5c0a3e04bbb740926cd3340 Binary files /dev/null and b/mllm/demo/temp/tmpyjy2p7vg.jpg differ diff --git a/mllm/demo/temp/tmpyk2nqt5h.jpg b/mllm/demo/temp/tmpyk2nqt5h.jpg new file mode 100644 index 0000000000000000000000000000000000000000..777edaf7e7dd39bc74e549f9e25ad7cccfcb5bf1 Binary files /dev/null and b/mllm/demo/temp/tmpyk2nqt5h.jpg differ diff --git a/mllm/demo/temp/tmpz6_sgo66.jpg b/mllm/demo/temp/tmpz6_sgo66.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f08c05e58f6a8e5fddc6b188d01af1029c067911 Binary files /dev/null and b/mllm/demo/temp/tmpz6_sgo66.jpg differ diff --git a/mllm/demo/temp/tmpz_a8zka_.jpg b/mllm/demo/temp/tmpz_a8zka_.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1fa1e4410974e7dafba8e91c5ba50b15742b555c Binary files /dev/null and b/mllm/demo/temp/tmpz_a8zka_.jpg differ diff --git a/mllm/demo/temp/tmpzik1pdtm.jpg b/mllm/demo/temp/tmpzik1pdtm.jpg new file mode 100644 index 0000000000000000000000000000000000000000..03101a76bf4e2bd3751780b348689c442b4cbec8 Binary files /dev/null and b/mllm/demo/temp/tmpzik1pdtm.jpg differ diff --git a/mllm/demo/temp/tmpzvuz8tsx.jpg b/mllm/demo/temp/tmpzvuz8tsx.jpg new file mode 100644 index 0000000000000000000000000000000000000000..46f51b6edd406f15bc0f49d6e722f60b7855cc56 Binary files /dev/null and b/mllm/demo/temp/tmpzvuz8tsx.jpg differ diff --git a/mllm/demo/temp/tmpzyhf603b.jpg b/mllm/demo/temp/tmpzyhf603b.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ede52675a51b13dfd006c2244749635887c186e6 Binary files /dev/null and b/mllm/demo/temp/tmpzyhf603b.jpg differ diff --git a/mllm/demo/webdemo.py b/mllm/demo/webdemo.py new file mode 100644 index 0000000000000000000000000000000000000000..795fde281e697ca4127357265b57cc24fc61a73b --- /dev/null +++ b/mllm/demo/webdemo.py @@ -0,0 +1,1013 @@ +import os +import sys +import logging +import time +import argparse +import tempfile +from pathlib import Path +from typing import List, Any, Union + +import torch +import numpy as np +import gradio as gr +from PIL import Image +from PIL import ImageDraw, ImageFont +from mmengine import Config +import transformers +from transformers import BitsAndBytesConfig + +sys.path.append(str(Path(__file__).parent.parent.parent)) + +from mllm.dataset.process_function import PlainBoxFormatter +from mllm.dataset.builder import prepare_interactive +from mllm.utils import draw_bounding_boxes +from mllm.models.builder.build_shikra import load_pretrained_shikra + +log_level = logging.DEBUG +transformers.logging.set_verbosity(log_level) +transformers.logging.enable_default_handler() +transformers.logging.enable_explicit_format() + +TEMP_FILE_DIR = Path(__file__).parent / 'temp' +TEMP_FILE_DIR.mkdir(parents=True, exist_ok=True) + +######################################### +# mllm model init +######################################### +parser = argparse.ArgumentParser("Shikra Web Demo") +parser.add_argument('--model_path', required=True) +parser.add_argument('--load_in_8bit', action='store_true') +parser.add_argument('--server_name', default=None) +parser.add_argument('--server_port', type=int, default=None) + +args = parser.parse_args() +print(args) + +model_name_or_path = args.model_path + +model_args = Config(dict( + type='shikra', + version='v1', + + # checkpoint config + cache_dir=None, + model_name_or_path=model_name_or_path, + vision_tower=r'openai/clip-vit-large-patch14', + pretrain_mm_mlp_adapter=None, + + # model config + mm_vision_select_layer=-2, + model_max_length=3072, + + # finetune config + freeze_backbone=False, + tune_mm_mlp_adapter=False, + freeze_mm_mlp_adapter=False, + + # data process config + is_multimodal=True, + sep_image_conv_front=False, + image_token_len=256, + mm_use_im_start_end=True, + + target_processor=dict( + boxes=dict(type='PlainBoxFormatter'), + ), + + process_func_args=dict( + conv=dict(type='ShikraConvProcess'), + target=dict(type='BoxFormatProcess'), + text=dict(type='ShikraTextProcess'), + image=dict(type='ShikraImageProcessor'), + ), + + conv_args=dict( + conv_template='vicuna_v1.1', + transforms=dict(type='Expand2square'), + tokenize_kwargs=dict(truncation_size=None), + ), + + gen_kwargs_set_pad_token_id=True, + gen_kwargs_set_bos_token_id=True, + gen_kwargs_set_eos_token_id=True, +)) +training_args = Config(dict( + bf16=False, + fp16=True, + device='cuda', + fsdp=None, +)) + +if args.load_in_8bit: + quantization_kwargs = dict( + quantization_config=BitsAndBytesConfig( + load_in_8bit=True, + ) + ) +else: + quantization_kwargs = dict() + +model, preprocessor = load_pretrained_shikra(model_args, training_args, **quantization_kwargs) +if not getattr(model, 'is_quantized', False): + model.to(dtype=torch.float16, device=torch.device('cuda')) +if not getattr(model.model.vision_tower[0], 'is_quantized', False): + model.model.vision_tower[0].to(dtype=torch.float16, device=torch.device('cuda')) +print(f"LLM device: {model.device}, is_quantized: {getattr(model, 'is_quantized', False)}, is_loaded_in_4bit: {getattr(model, 'is_loaded_in_4bit', False)}, is_loaded_in_8bit: {getattr(model, 'is_loaded_in_8bit', False)}") +print(f"vision device: {model.model.vision_tower[0].device}, is_quantized: {getattr(model.model.vision_tower[0], 'is_quantized', False)}, is_loaded_in_4bit: {getattr(model, 'is_loaded_in_4bit', False)}, is_loaded_in_8bit: {getattr(model, 'is_loaded_in_8bit', False)}") + +preprocessor['target'] = {'boxes': PlainBoxFormatter()} +tokenizer = preprocessor['text'] + + +######################################### +# demo utils +######################################### + +def parse_text(text): + text = text.replace("", "<image>") + return text + + +def setup_gradio_warning(level=1): + """ + level 0 1 2 3 + level IGNORE Weak Strong Error + has Warning _foo Warning Warning Error + no Warning _foo _foo Error Error + """ + + def _dummy_func(*args, **kwargs): + pass + + def _raise_error(*args, **kwargs): + raise gr.Error(*args, **kwargs) + + assert level in [0, 1, 2, 3] + if level >= 3: + return _raise_error + if level <= 0: + return _dummy_func + if hasattr(gr, 'Warning'): + return gr.Warning + if level == 1: + return _dummy_func + return _raise_error + + +grWarning = setup_gradio_warning() + + +def de_norm_box_xyxy(box, *, w, h): + x1, y1, x2, y2 = box + x1 = x1 * w + x2 = x2 * w + y1 = y1 * h + y2 = y2 * h + box = x1, y1, x2, y2 + return box + + +def expand2square(pil_img, background_color=(255, 255, 255)): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +def box_xyxy_expand2square(box, *, w, h): + if w == h: + return box + if w > h: + x1, y1, x2, y2 = box + y1 += (w - h) // 2 + y2 += (w - h) // 2 + box = x1, y1, x2, y2 + return box + assert w < h + x1, y1, x2, y2 = box + x1 += (h - w) // 2 + x2 += (h - w) // 2 + box = x1, y1, x2, y2 + return box + + +def resize_pil_img(pil_img: Image.Image, *, w, h): + old_height, old_width = pil_img.height, pil_img.width + new_height, new_width = (h, w) + if (new_height, new_width) == (old_height, old_width): + return pil_img + return pil_img.resize((new_width, new_height)) + + +def resize_box_xyxy(boxes, *, w, h, ow, oh): + old_height, old_width = (oh, ow) + new_height, new_width = (h, w) + if (new_height, new_width) == (old_height, old_width): + return boxes + w_ratio = new_width / old_width + h_ratio = new_height / old_height + out_boxes = [] + for box in boxes: + x1, y1, x2, y2 = box + x1 = x1 * w_ratio + x2 = x2 * w_ratio + y1 = y1 * h_ratio + y2 = y2 * h_ratio + nb = (x1, y1, x2, y2) + out_boxes.append(nb) + return out_boxes + + +# use mask to simulate box +# copy from https://github.com/gligen/GLIGEN/blob/master/demo/app.py +class ImageMask(gr.components.Image): + is_template = True + + def __init__(self, **kwargs): + super().__init__(source="upload", tool="sketch", interactive=True, **kwargs) + #super().__init__(tool = "sketch", interactive=True, **kwargs) + + +def binarize(x): + return (x != 0).astype('uint8') * 255 + + +class ImageBoxState: + def __init__(self, draw_size: Union[int, float, tuple, list] = 512): + if isinstance(draw_size, (float, int)): + draw_size = (draw_size, draw_size) + assert len(draw_size) == 2 + self.size = draw_size + self.height, self.width = self.size[0], self.size[1] + self.reset_state() + + # noinspection PyAttributeOutsideInit + def reset_state(self): + self.image = None + self.boxes = [] + self.masks = [] + + # noinspection PyAttributeOutsideInit + def reset_masks(self): + self.boxes = [] + self.masks = [] + + # noinspection PyAttributeOutsideInit + def update_image(self, image): + if image != self.image: + self.reset_state() + self.image = image + + def update_mask(self, mask): + if len(self.masks) == 0: + last_mask = np.zeros_like(mask) + else: + last_mask = self.masks[-1] + + if type(mask) == np.ndarray and mask.size > 1: + diff_mask = mask - last_mask + else: + diff_mask = np.zeros([]) + + if diff_mask.sum() > 0: + # noinspection PyArgumentList + x1x2 = np.where(diff_mask.max(0) != 0)[0] + # noinspection PyArgumentList + y1y2 = np.where(diff_mask.max(1) != 0)[0] + y1, y2 = y1y2.min(), y1y2.max() + x1, x2 = x1x2.min(), x1x2.max() + if (x2 - x1 > 5) and (y2 - y1 > 5): + self.masks.append(mask.copy()) + self.boxes.append(tuple(map(int, (x1, y1, x2, y2)))) + + def update_box(self, box): + x1, y1, x2, y2 = box + x1, x2 = min(x1, x2), max(x1, x2) + y1, y2 = min(y1, y2), max(y1, y2) + self.boxes.append(tuple(map(int, (x1, y1, x2, y2)))) + + def to_model(self): + if self.image is None: + return {} + image = expand2square(self.image) + boxes = [box_xyxy_expand2square(box, w=self.image.width, h=self.image.height) for box in self.boxes] + return {'image': image, 'boxes': boxes} + + def draw_boxes(self): + assert self.image is not None + grounding_texts = [f'{bid}' for bid in range(len(self.boxes))] + image = expand2square(self.image) + boxes = [box_xyxy_expand2square(box, w=self.image.width, h=self.image.height) for box in self.boxes] + + image_to_draw = resize_pil_img(image, w=self.width, h=self.height) + boxes_to_draw = resize_box_xyxy(boxes, w=self.width, h=self.height, ow=image.width, oh=image.height) + + def _draw(img, _boxes: List[Any], texts: List[str]): + assert img is not None + colors = ["red", "blue", "green", "olive", "orange", "brown", "cyan", "purple"] + _img_draw = ImageDraw.Draw(img) + font = ImageFont.truetype(os.path.join(os.path.dirname(__file__), 'assets/DejaVuSansMono.ttf'), size=18) + for bid, box in enumerate(_boxes): + _img_draw.rectangle((box[0], box[1], box[2], box[3]), outline=colors[bid % len(colors)], width=4) + anno_text = texts[bid] + _img_draw.rectangle((box[0], box[3] - int(font.size * 1.2), box[0] + int((len(anno_text) + 0.8) * font.size * 0.6), box[3]), + outline=colors[bid % len(colors)], fill=colors[bid % len(colors)], width=4) + _img_draw.text((box[0] + int(font.size * 0.2), box[3] - int(font.size * 1.2)), anno_text, font=font, fill=(255, 255, 255)) + return img + + out_draw = _draw(image_to_draw, boxes_to_draw, grounding_texts) + return out_draw + + +def add_submit_temp_image(state, temp_image_path): + if '_submit_temp_images' not in state: + state['_submit_temp_images'] = [] + state['_submit_temp_images'].append(temp_image_path) + return state + + +def clear_submit_temp_image(state): + if '_submit_temp_images' in state: + for path in state['_submit_temp_images']: + os.remove(path) + del state['_submit_temp_images'] + return state + + +if __name__ == '__main__': + with gr.Blocks() as demo: + logo_file_url = f"file={os.path.join(os.path.dirname(__file__), 'assets/logo.png')}" + gr.HTML( + f""" + +

Logo

+

Shikra: Unleashing Multimodal LLM’s Referential Dialogue Magic

+

+ [Project] + [Paper] +

+

+ Shikra, an MLLM designed to kick off referential dialogue by excelling in spatial coordinate inputs/outputs in natural language, without additional vocabularies, position encoders, pre-/post-detection, or external plug-in models. +

+

User Manual

+
    +
  • Step 1. Upload an image

    +
  • +
  • Step 2. Select Question Format in Task Template. Task template and user input (if exists) will be assembled into final inputs to the model.

    +
      +
    • SpotCap: Ask the model to generate a grounded caption.
    • +
    • GCoT: Ask the model to answer the question and provide a Grounding-CoT, which is a step-by-step reasoning with explicit grounding information.
    • +
    • Cap: Ask the model to generate a short caption.
    • +
    • VQA: Ask the model to answer the question directly.
    • +
    • REC: Referring Expression Comprehension. Ask the model to output the bounding box of <expr>.
    • +
    • REG: Referring Expression Generation. Ask the model to generate a distinguishable description for RoI.
    • +
    • Advanced: Use no predefined template. You can take full control of inputs.
    • + +
    +
  • + +
  • Step 3. Ask Question. Use <boxes> placeholder if input has bounding box.

    +
  • + +
+

The following step are needed only when input has bounding box.

+
    +
  • Step 4. Draw Bounding Box in Sketch Pad.

    +

    Each bbox has a unique index, which will show at the corner of the bbox in Parsed Sketch Pad.

    +
  • +
  • Step 5. Assign the bbox index in Boexs Seq for each <boxes> placeholder. Boexs Seq take a 2-d list as input, each sub-list will replace the <boxes> placeholder in order.

    +
  • +
+""" + ) + + with gr.Row(): + with gr.Column(): + gr.HTML( + """ +

Video example

+

a video example demonstrate how to input with boxes

+ """ + ) + video_file_url = os.path.join(os.path.dirname(__file__), f"assets/petal_20230711_153216_Compressed.mp4") + gr.Video(value=video_file_url, interactive=False, width=600) + with gr.Column(): + boxes_seq_usage_file_url = f'file={os.path.join(os.path.dirname(__file__), f"assets/boxes_seq_explanation.jpg")}' + gr.HTML( + f""" +

Boxes Seq Usage Explanation

+

the [0,2] boxes will replace the first <boxes> placeholder. the [1] boxes will replace the second <boxes> placeholder.

+

+""" + ) + + gr.HTML( + """ +

Demo

+ """ + ) + with gr.Row(): + with gr.Column(): + chatbot = gr.Chatbot() + with gr.Accordion("Parameters", open=False): + with gr.Row(): + do_sample = gr.Checkbox(value=False, label='do sampling', interactive=True) + max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="max length", interactive=True) + top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) + temperature = gr.Slider(0, 10, value=0.75, step=0.01, label="Temperature", interactive=True) + with gr.Column(): + with gr.Row(variant='compact'): + sketch_pad = ImageMask(label="Sketch Pad", elem_id="img2img_image") + out_imagebox = gr.Image(label="Parsed Sketch Pad") + with gr.Column(): + radio = gr.Radio( + ["SpotCap", "GCoT", "Cap", "VQA", "REC", "REG", "Advanced"], label="Task Template", value='SpotCap', + ) + with gr.Group(): + template = gr.Textbox(label='Template', show_label=True, lines=1, interactive=False, + value='Provide a comprehensive description of the image and specify the positions of any mentioned objects in square brackets.') + user_input = gr.Textbox(label='', show_label=True, placeholder="Input...", lines=3, + value=None, visible=False, interactive=False) + boxes_seq = gr.Textbox(label='Boxes Seq', show_label=False, placeholder="Boxes Seq...", lines=1, + value=None, visible=False, interactive=False) + with gr.Row(): + reset_all = gr.Button('Reset All') + reset_chat = gr.Button('Reset Chat') + reset_boxes = gr.Button('Reset Boxes') + submitBtn = gr.Button('Run') + + + ############################################## + # reset state + ############################################## + + def reset_state_func(): + ret = { + 'ibs': ImageBoxState(), + 'ds': prepare_interactive(model_args, preprocessor), + } + return ret + + + state = gr.State(reset_state_func) + example_image_boxes = gr.State(None) + + + ############################################## + # reset dialogue + ############################################## + + def reset_all_func(state): + # clear_submit_temp_image(state) + new_state = reset_state_func() + boxes_seq = '[[0]]' if radio in ['REG', 'GC'] else None + return [new_state, None, None, None, boxes_seq, None] + + + reset_all.click( + fn=reset_all_func, + inputs=[state], + outputs=[state, sketch_pad, out_imagebox, user_input, boxes_seq, chatbot], + ) + + + def reset_chat_func_step1(state, radio): + state['ibs'].reset_masks() + new_state = reset_state_func() + new_state['_reset_boxes_func_image'] = state['ibs'].image + boxes_seq = '[[0]]' if radio in ['REG', 'GC'] else None + return [new_state, None, None, None, boxes_seq, None] + + + def reset_chat_func_step2(state): + image = state['_reset_boxes_func_image'] + del state['_reset_boxes_func_image'] + return state, gr.update(value=image) + + + reset_chat.click( + fn=reset_chat_func_step1, + inputs=[state, radio], + outputs=[state, sketch_pad, out_imagebox, user_input, boxes_seq, chatbot], + ).then( + fn=reset_chat_func_step2, + inputs=[state], + outputs=[state, sketch_pad], + ) + + + ############################################## + # reset boxes + ############################################## + + def reset_boxes_func_step1(state): + state['_reset_boxes_func_image'] = state['ibs'].image + state['ibs'].reset_masks() + return state, None + + + def reset_boxes_func_step2(state): + image = state['_reset_boxes_func_image'] + del state['_reset_boxes_func_image'] + return state, gr.update(value=image) + + + # reset boxes + reset_boxes.click( + fn=reset_boxes_func_step1, + inputs=[state], + outputs=[state, sketch_pad], + ).then( + fn=reset_boxes_func_step2, + inputs=[state], + outputs=[state, sketch_pad], + ) + + + ############################################## + # examples + ############################################## + + def parese_example(image, boxes): + state = reset_state_func() + image = Image.open(image) + state['ibs'].update_image(image) + for box in boxes: + state['ibs'].update_box(box) + image = state['ibs'].draw_boxes() + + _, path = tempfile.mkstemp(suffix='.jpg', dir=TEMP_FILE_DIR) + image.save(path) + return path, state + + + with gr.Column(visible=True) as example_SpotCap: + _examples_cap_raw = [ + os.path.join(os.path.dirname(__file__), 'assets/proposal.jpg'), + os.path.join(os.path.dirname(__file__), 'assets/water_question.jpg'), + os.path.join(os.path.dirname(__file__), 'assets/fishing.jpg'), + os.path.join(os.path.dirname(__file__), 'assets/ball.jpg'), + + os.path.join(os.path.dirname(__file__), 'assets/banana_phone.png'), + os.path.join(os.path.dirname(__file__), "assets/airplane.jpg"), + os.path.join(os.path.dirname(__file__), 'assets/baseball.png'), + ] + _examples_cap_parsed = [[item, []] for item in _examples_cap_raw] + gr.Examples( + examples=_examples_cap_parsed, + inputs=[sketch_pad, example_image_boxes], + ) + + with gr.Column(visible=False) as example_vqabox: + _examples_vqabox_parsed = [ + [ + os.path.join(os.path.dirname(__file__), 'assets/proposal.jpg'), + 'How is the person in the picture feeling?', + '[[0]]', + [[785, 108, 1063, 844]], + ], + [ + os.path.join(os.path.dirname(__file__), 'assets/woman_door.jpg'), + "Which one is the woman's reflection in the mirror?", + '[[0,1]]', + [(770, 138, 1024, 752), (469, 146, 732, 744)], + ], + [ + os.path.join(os.path.dirname(__file__), 'assets/man.jpg'), + "What is the person scared of?", + '[[0]]', + [(148, 99, 576, 497)], + ], + [ + os.path.join(os.path.dirname(__file__), "assets/giraffes.jpg"), + "How many animals in the image?", + "", + [], + ], + [ + os.path.join(os.path.dirname(__file__), "assets/dog_selfcontrol.jpg"), + "Is this dog on a lead held by someone able to control it?", + "", + [], + ], + [ + os.path.join(os.path.dirname(__file__), 'assets/wet_paint1.jpg'), + 'What does the board say?', + '', + [], + ], + [ + os.path.join(os.path.dirname(__file__), 'assets/g2.jpg'), + "What is unusual about the image?", + '', + [], + ], + ] + + gr.Examples( + examples=_examples_vqabox_parsed, + inputs=[sketch_pad, user_input, boxes_seq, example_image_boxes], + ) + + with gr.Column(visible=False) as example_vqa: + _examples_vqa_parsed = [ + [ + os.path.join(os.path.dirname(__file__), 'assets/food-1898194_640.jpg'), + "QUESTION: Which of the following is meat?\nOPTION:\n(A) \n(B) \n(C) \n(D) ", + '[[0],[1],[2],[3]]', + [[20, 216, 70, 343], [8, 3, 187, 127], [332, 386, 424, 494], [158, 518, 330, 605]], + ], + [ + os.path.join(os.path.dirname(__file__), 'assets/potato.jpg'), + "What color is this?", + '[[0]]', + [[75, 408, 481, 802]], + ], + [ + os.path.join(os.path.dirname(__file__), 'assets/potato.jpg'), + "What color is this?", + '[[0]]', + [[147, 274, 266, 437]], + ], + [ + os.path.join(os.path.dirname(__file__), 'assets/staircase-274614_640.jpg'), + "Is this a sea snail?", + '', + [], + ], + [ + os.path.join(os.path.dirname(__file__), 'assets/staircase-274614_640.jpg'), + "Is this shape like a sea snail?", + '', + [], + ], + ] + gr.Examples( + examples=_examples_vqa_parsed, + inputs=[sketch_pad, user_input, boxes_seq, example_image_boxes], + ) + + with gr.Column(visible=False) as example_rec: + gr.Examples( + examples=[ + [ + os.path.join(os.path.dirname(__file__), "assets/rec_bear.png"), + "a brown teddy bear with a blue bow", + [], + ], + [ + os.path.join(os.path.dirname(__file__), "assets/bear-792466_1280.jpg"), + "the teddy bear lay on the sofa edge", + [], + ], + ], + inputs=[sketch_pad, user_input, example_image_boxes], + ) + + with gr.Column(visible=False) as example_reg: + gr.Examples( + examples=[ + [ + os.path.join(os.path.dirname(__file__), "assets/fruits.jpg"), + "[[0]]", + [[833, 527, 646, 315]], + ], + [ + os.path.join(os.path.dirname(__file__), "assets/bearhat.png"), + "[[0]]", + [[48, 49, 216, 152]], + ], + [ + os.path.join(os.path.dirname(__file__), "assets/oven.jpg"), + "[[0]]", + [[1267, 314, 1383, 458]], + ], + ], + inputs=[sketch_pad, boxes_seq, example_image_boxes], + ) + + with gr.Column(visible=False) as example_adv: + gr.Examples( + examples=[ + [ + + ], + ], + inputs=[sketch_pad, user_input, boxes_seq, example_image_boxes], + ) + + + ############################################## + # task template select + ############################################## + + def change_textbox(choice): + task_template = { + "SpotCap": "Please list every Reactions in this image in detail, including the category of every objects with a unique ID and coordinates[x1,y1,x2,y2]. And their Reaction role in a reaction. The category include Structure and Text. The Reaction role include Reactants, Conditions and Products. And notice that Reactants and Products are usually linked by arrows.", + "Cap": "Summarize the content of the photo .", + "GCoT": "With the help of the image , can you clarify my question ''? Also, explain the reasoning behind your answer, and don't forget to label the bounding boxes of the involved objects using square brackets.", + "VQA": "For this image , I want a simple and direct answer to my question: ", + "REC": "Can you point out in the image and provide the coordinates of its location?", + "REG": "For the given image , can you provide a unique description of the area ?", + "GC": "Can you give me a description of the region in image ?", + "Advanced": "", + } + if choice in ['Advanced']: + template_update = gr.update(value=task_template[choice], visible=False) + else: + template_update = gr.update(value=task_template[choice], visible=True) + + if choice in ['SpotCap', 'Cap']: + input_update = gr.update(value=None, visible=False, interactive=False) + boxes_seq_update = gr.update(show_label=False, value=None, visible=False, interactive=False) + elif choice in ['GCoT', 'VQA']: + input_update = gr.update(label='', value=None, visible=True, interactive=True) + boxes_seq_update = gr.update(show_label=False, value=None, visible=True, interactive=True) + elif choice in ['Advanced']: + input_update = gr.update(label='Input', value=None, visible=True, interactive=True) + boxes_seq_update = gr.update(show_label=False, value=None, visible=True, interactive=True) + elif choice in ['REC']: + input_update = gr.update(label='', value=None, visible=True, interactive=True) + boxes_seq_update = gr.update(show_label=False, value=None, visible=False, interactive=False) + elif choice in ['REG', 'GC']: + input_update = gr.update(value=None, visible=False, interactive=False) + boxes_seq_update = gr.update(show_label=True, value='[[0]]', visible=True, interactive=True) + else: + raise gr.Error("What is this?!") + + ret = [ + template_update, + input_update, + boxes_seq_update, + gr.update(visible=True) if choice in ['SpotCap', 'Cap'] else gr.update(visible=False), + gr.update(visible=True) if choice in ['GCoT'] else gr.update(visible=False), + gr.update(visible=True) if choice in ['VQA'] else gr.update(visible=False), + gr.update(visible=True) if choice in ['REC'] else gr.update(visible=False), + gr.update(visible=True) if choice in ['REG', 'GC'] else gr.update(visible=False), + gr.update(visible=True) if choice in ['Advanced'] else gr.update(visible=False), + ] + return ret + + + radio.change( + fn=change_textbox, + inputs=radio, + outputs=[template, user_input, boxes_seq, example_SpotCap, example_vqabox, example_vqa, example_rec, example_reg, example_adv], + ) + + + ############################################## + # draw + ############################################## + + def draw(sketch_pad: dict, state: dict, example_image_boxes): + if example_image_boxes is None: + image = sketch_pad['image'] + image = Image.fromarray(image) + mask = sketch_pad['mask'][..., 0] if sketch_pad['mask'].ndim == 3 else sketch_pad['mask'] + mask = binarize(mask) + ibs: ImageBoxState = state['ibs'] + ibs.update_image(image) + ibs.update_mask(mask) + out_draw = ibs.draw_boxes() + ret = [out_draw, state, None, gr.update()] + return ret + else: + image = sketch_pad['image'] + image = Image.fromarray(image) + + state = reset_state_func() + ibs: ImageBoxState = state['ibs'] + ibs.update_image(image) + for box in example_image_boxes: + ibs.update_box(box) + out_draw = ibs.draw_boxes() + ret = [out_draw, state, None, []] + return ret + + + sketch_pad.edit( + fn=draw, + inputs=[sketch_pad, state, example_image_boxes], + outputs=[out_imagebox, state, example_image_boxes, chatbot], + queue=False, + ) + + + ############################################## + # submit boxes + ############################################## + + def submit_step1(state, template, raw_user_input, boxes_seq, chatbot, do_sample, max_length, top_p, temperature): + if '' in template or '' in template: + if not bool(raw_user_input): + raise gr.Error("say sth bro.") + if '' in template: + user_input = template.replace("", raw_user_input) + elif '' in template: + user_input = template.replace("", raw_user_input) + else: + user_input = template + + def parse_boxes_seq(boxes_seq_str) -> List[List[int]]: + if not bool(boxes_seq_str): + return [] + import ast + # validate + try: + parsed = ast.literal_eval(boxes_seq_str) + assert isinstance(parsed, (tuple, list)), \ + f"boxes_seq should be a tuple/list but got {type(parsed)}" + for elem in parsed: + assert isinstance(elem, (tuple, list)), \ + f"the elem in boxes_seq should be a tuple/list but got {type(elem)} for elem: {elem}" + assert len(elem) != 0, \ + f"the elem in boxes_seq should not be empty." + for atom in elem: + assert isinstance(atom, int), \ + f"the boxes_seq atom should be a int idx but got {type(atom)} for atom: {atom}" + except (AssertionError, SyntaxError) as e: + raise gr.Error(f"error when parse boxes_seq_str: {str(e)} for input: {boxes_seq_str}") + return parsed + + boxes_seq = parse_boxes_seq(boxes_seq) + + mm_state = state['ibs'].to_model() + ds = state['ds'] + print(mm_state) + if 'image' in mm_state and bool(mm_state['image']): + # multimodal mode + if ds.image is not None and ds.image != mm_state['image']: + raise gr.Error("shikra only support single image conversation but got different images. maybe u want `Reset Dialogue`") + if ds.image != mm_state['image']: + ds.set_image(mm_state['image']) + + def validate_message_box(user_input: str, boxes_seq: list, boxes_value: list): + if boxes_value and (not boxes_seq): + grWarning("has box drawn but set no boxes_seq") + + if boxes_seq and (not boxes_value): + grWarning("ignored boxes_seq because no box drawn.") + + boxes_placeholder_num = str(user_input).count('') + if boxes_placeholder_num != len(boxes_seq): + raise gr.Error(f" and boxes_seq num not match: {boxes_placeholder_num} {len(boxes_seq)}") + + for boxes in boxes_seq: + for bidx in boxes: + if not (0 <= bidx < len(boxes_value)): + raise gr.Error(f"boxes_seq out of range: {boxes_seq} {len(boxes_value)}") + + try: + validate_message_box(user_input, boxes_seq, mm_state['boxes']) + ds.append_message(role=ds.roles[0], message=user_input, boxes=mm_state['boxes'], boxes_seq=boxes_seq) + except Exception as e: + raise gr.Error(f"error when append message: {str(e)}") + else: + # text-only mode + if bool(boxes_seq): + grWarning("ignored boxes_seq in text-only mode") + boxes_placeholder_num = str(user_input).count('') + if boxes_placeholder_num: + gr.Error("use in input but no image found.") + ds.append_message(role=ds.roles[0], message=user_input) + + model_inputs = ds.to_model_input() + model_inputs['images'] = model_inputs['images'].to(torch.float16) + print(f"model_inputs: {model_inputs}") + + if do_sample: + gen_kwargs = dict( + use_cache=True, + do_sample=do_sample, + pad_token_id=tokenizer.pad_token_id, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + max_new_tokens=max_length, + top_p=top_p, + temperature=float(temperature), + ) + else: + gen_kwargs = dict( + use_cache=True, + do_sample=do_sample, + pad_token_id=tokenizer.pad_token_id, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + max_new_tokens=max_length, + ) + print(gen_kwargs) + input_ids = model_inputs['input_ids'] + st_time = time.time() + with torch.inference_mode(): + with torch.autocast(dtype=torch.float16, device_type='cuda'): + output_ids = model.generate(**model_inputs, **gen_kwargs) + print(f"done generated in {time.time() - st_time} seconds") + input_token_len = input_ids.shape[-1] + response = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] + print(f"response: {response}") + + # update new message + + def build_boxes_image(text, image): + if image is None: + return text, None + print(text, image) + import re + + colors = ['#ed7d31', '#5b9bd5', '#70ad47', '#7030a0', '#c00000', '#ffff00', "olive", "brown", "cyan",'#003366', '#b76e79', '#008080', '#8e44ad', '#ff6b6b','#dcd0ff', '#b7410e', '#bfff00', '#87ceeb', '#f1c40f'] + pat = re.compile(r'\[\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3}(?:;\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3})*\]') + + def extract_boxes(string): + ret = [] + for bboxes_str in pat.findall(string): + bboxes = [] + bbox_strs = bboxes_str.replace("(", "").replace(")", "").replace("[", "").replace("]", "").split(";") + for bbox_str in bbox_strs: + bbox = list(map(float, bbox_str.split(','))) + bboxes.append(bbox) + ret.append(bboxes) + return ret + + extract_pred = extract_boxes(text) + boxes_to_draw = [] + color_to_draw = [] + for idx, boxes in enumerate(extract_pred): + color = colors[idx % len(colors)] + for box in boxes: + boxes_to_draw.append(de_norm_box_xyxy(box, w=image.width, h=image.height)) + color_to_draw.append(color) + if not boxes_to_draw: + return text, None + res = draw_bounding_boxes(image=image, boxes=boxes_to_draw, colors=color_to_draw, width=8) + from torchvision.transforms import ToPILImage + res = ToPILImage()(res) + _, path = tempfile.mkstemp(suffix='.jpg', dir=TEMP_FILE_DIR) + res.save(path) + add_submit_temp_image(state, path) + + # post process text color + print(text) + location_text = text + edit_text = list(text) + bboxes_str = pat.findall(text) + for idx in range(len(bboxes_str) - 1, -1, -1): + color = colors[idx % len(colors)] + boxes = bboxes_str[idx] + span = location_text.rfind(boxes), location_text.rfind(boxes) + len(boxes) + location_text = location_text[:span[0]] + edit_text[span[0]:span[1]] = f'{boxes}' + text = "".join(edit_text) + return text, path + + def convert_one_round_message(conv, image=None): + text_query = f"{conv[0][0]}: {conv[0][1]}" + text_answer = f"{conv[1][0]}: {conv[1][1]}" + text_query, image_query = build_boxes_image(text_query, image) + text_answer, image_answer = build_boxes_image(text_answer, image) + + new_chat = [] + new_chat.append([parse_text(text_query), None]) + if image_query is not None: + new_chat.append([(image_query,), None]) + + new_chat.append([None, parse_text(text_answer)]) + if image_answer is not None: + new_chat.append([None, (image_answer,)]) + return new_chat + + ds.append_message(role=ds.roles[1], message=response) + conv = ds.to_gradio_chatbot_new_messages() + new_message = convert_one_round_message(conv, image=mm_state.get('image', None)) + print(new_message) + state['_submit_new_message'] = new_message + return state, chatbot + + + def submit_step2(state, user_input, boxes_seq, chatbot): + if '_submit_new_message' in state: + chatbot.extend(state['_submit_new_message']) + del state['_submit_new_message'] + return state, None, None, chatbot + return state, user_input, boxes_seq, chatbot + + + submitBtn.click( + submit_step1, + [state, template, user_input, boxes_seq, chatbot, do_sample, max_length, top_p, temperature], + [state, chatbot], + ).then( + submit_step2, + [state, user_input, boxes_seq, chatbot], + [state, user_input, boxes_seq, chatbot], + ) + + print("launching...") + demo.queue().launch(server_name=args.server_name, server_port=args.server_port) diff --git a/mllm/demo/webdemo.pyi b/mllm/demo/webdemo.pyi new file mode 100644 index 0000000000000000000000000000000000000000..e6819981197af276ab19e80f56893575efceada1 --- /dev/null +++ b/mllm/demo/webdemo.pyi @@ -0,0 +1,1013 @@ +import os +import sys +import logging +import time +import argparse +import tempfile +from pathlib import Path +from typing import List, Any, Union + +import torch +import numpy as np +import gradio as gr +from PIL import Image +from PIL import ImageDraw, ImageFont +from mmengine import Config +import transformers +from transformers import BitsAndBytesConfig + +sys.path.append(str(Path(__file__).parent.parent.parent)) + +from mllm.dataset.process_function import PlainBoxFormatter +from mllm.dataset.builder import prepare_interactive +from mllm.utils import draw_bounding_boxes +from mllm.models.builder.build_shikra import load_pretrained_shikra + +log_level = logging.DEBUG +transformers.logging.set_verbosity(log_level) +transformers.logging.enable_default_handler() +transformers.logging.enable_explicit_format() + +TEMP_FILE_DIR = Path(__file__).parent / 'temp' +TEMP_FILE_DIR.mkdir(parents=True, exist_ok=True) + +######################################### +# mllm model init +######################################### +parser = argparse.ArgumentParser("Shikra Web Demo") +parser.add_argument('--model_path', required=True) +parser.add_argument('--load_in_8bit', action='store_true') +parser.add_argument('--server_name', default=None) +parser.add_argument('--server_port', type=int, default=None) + +args = parser.parse_args() +print(args) + +model_name_or_path = args.model_path + +model_args = Config(dict( + type='shikra', + version='v1', + + # checkpoint config + cache_dir=None, + model_name_or_path=model_name_or_path, + vision_tower=r'openai/clip-vit-large-patch14', + pretrain_mm_mlp_adapter=None, + + # model config + mm_vision_select_layer=-2, + model_max_length=2048, + + # finetune config + freeze_backbone=False, + tune_mm_mlp_adapter=False, + freeze_mm_mlp_adapter=False, + + # data process config + is_multimodal=True, + sep_image_conv_front=False, + image_token_len=256, + mm_use_im_start_end=True, + + target_processor=dict( + boxes=dict(type='PlainBoxFormatter'), + ), + + process_func_args=dict( + conv=dict(type='ShikraConvProcess'), + target=dict(type='BoxFormatProcess'), + text=dict(type='ShikraTextProcess'), + image=dict(type='ShikraImageProcessor'), + ), + + conv_args=dict( + conv_template='vicuna_v1.1', + transforms=dict(type='Expand2square'), + tokenize_kwargs=dict(truncation_size=None), + ), + + gen_kwargs_set_pad_token_id=True, + gen_kwargs_set_bos_token_id=True, + gen_kwargs_set_eos_token_id=True, +)) +training_args = Config(dict( + bf16=False, + fp16=True, + device='cuda', + fsdp=None, +)) + +if args.load_in_8bit: + quantization_kwargs = dict( + quantization_config=BitsAndBytesConfig( + load_in_8bit=True, + ) + ) +else: + quantization_kwargs = dict() + +model, preprocessor = load_pretrained_shikra(model_args, training_args, **quantization_kwargs) +if not getattr(model, 'is_quantized', False): + model.to(dtype=torch.float16, device=torch.device('cuda')) +if not getattr(model.model.vision_tower[0], 'is_quantized', False): + model.model.vision_tower[0].to(dtype=torch.float16, device=torch.device('cuda')) +print(f"LLM device: {model.device}, is_quantized: {getattr(model, 'is_quantized', False)}, is_loaded_in_4bit: {getattr(model, 'is_loaded_in_4bit', False)}, is_loaded_in_8bit: {getattr(model, 'is_loaded_in_8bit', False)}") +print(f"vision device: {model.model.vision_tower[0].device}, is_quantized: {getattr(model.model.vision_tower[0], 'is_quantized', False)}, is_loaded_in_4bit: {getattr(model, 'is_loaded_in_4bit', False)}, is_loaded_in_8bit: {getattr(model, 'is_loaded_in_8bit', False)}") + +preprocessor['target'] = {'boxes': PlainBoxFormatter()} +tokenizer = preprocessor['text'] + + +######################################### +# demo utils +######################################### + +def parse_text(text): + text = text.replace("", "<image>") + return text + + +def setup_gradio_warning(level=1): + """ + level 0 1 2 3 + level IGNORE Weak Strong Error + has Warning _foo Warning Warning Error + no Warning _foo _foo Error Error + """ + + def _dummy_func(*args, **kwargs): + pass + + def _raise_error(*args, **kwargs): + raise gr.Error(*args, **kwargs) + + assert level in [0, 1, 2, 3] + if level >= 3: + return _raise_error + if level <= 0: + return _dummy_func + if hasattr(gr, 'Warning'): + return gr.Warning + if level == 1: + return _dummy_func + return _raise_error + + +grWarning = setup_gradio_warning() + + +def de_norm_box_xyxy(box, *, w, h): + x1, y1, x2, y2 = box + x1 = x1 * w + x2 = x2 * w + y1 = y1 * h + y2 = y2 * h + box = x1, y1, x2, y2 + return box + + +def expand2square(pil_img, background_color=(255, 255, 255)): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +def box_xyxy_expand2square(box, *, w, h): + if w == h: + return box + if w > h: + x1, y1, x2, y2 = box + y1 += (w - h) // 2 + y2 += (w - h) // 2 + box = x1, y1, x2, y2 + return box + assert w < h + x1, y1, x2, y2 = box + x1 += (h - w) // 2 + x2 += (h - w) // 2 + box = x1, y1, x2, y2 + return box + + +def resize_pil_img(pil_img: Image.Image, *, w, h): + old_height, old_width = pil_img.height, pil_img.width + new_height, new_width = (h, w) + if (new_height, new_width) == (old_height, old_width): + return pil_img + return pil_img.resize((new_width, new_height)) + + +def resize_box_xyxy(boxes, *, w, h, ow, oh): + old_height, old_width = (oh, ow) + new_height, new_width = (h, w) + if (new_height, new_width) == (old_height, old_width): + return boxes + w_ratio = new_width / old_width + h_ratio = new_height / old_height + out_boxes = [] + for box in boxes: + x1, y1, x2, y2 = box + x1 = x1 * w_ratio + x2 = x2 * w_ratio + y1 = y1 * h_ratio + y2 = y2 * h_ratio + nb = (x1, y1, x2, y2) + out_boxes.append(nb) + return out_boxes + +from gradio.events import Dependency + +# use mask to simulate box +# copy from https://github.com/gligen/GLIGEN/blob/master/demo/app.py +class ImageMask(gr.components.Image): + is_template = True + + def __init__(self, **kwargs): + super().__init__(source="upload", tool="sketch", interactive=True, **kwargs) + + +def binarize(x): + return (x != 0).astype('uint8') * 255 + + +class ImageBoxState: + def __init__(self, draw_size: Union[int, float, tuple, list] = 512): + if isinstance(draw_size, (float, int)): + draw_size = (draw_size, draw_size) + assert len(draw_size) == 2 + self.size = draw_size + self.height, self.width = self.size[0], self.size[1] + self.reset_state() + + # noinspection PyAttributeOutsideInit + def reset_state(self): + self.image = None + self.boxes = [] + self.masks = [] + + # noinspection PyAttributeOutsideInit + def reset_masks(self): + self.boxes = [] + self.masks = [] + + # noinspection PyAttributeOutsideInit + def update_image(self, image): + if image != self.image: + self.reset_state() + self.image = image + + def update_mask(self, mask): + if len(self.masks) == 0: + last_mask = np.zeros_like(mask) + else: + last_mask = self.masks[-1] + + if type(mask) == np.ndarray and mask.size > 1: + diff_mask = mask - last_mask + else: + diff_mask = np.zeros([]) + + if diff_mask.sum() > 0: + # noinspection PyArgumentList + x1x2 = np.where(diff_mask.max(0) != 0)[0] + # noinspection PyArgumentList + y1y2 = np.where(diff_mask.max(1) != 0)[0] + y1, y2 = y1y2.min(), y1y2.max() + x1, x2 = x1x2.min(), x1x2.max() + if (x2 - x1 > 5) and (y2 - y1 > 5): + self.masks.append(mask.copy()) + self.boxes.append(tuple(map(int, (x1, y1, x2, y2)))) + + def update_box(self, box): + x1, y1, x2, y2 = box + x1, x2 = min(x1, x2), max(x1, x2) + y1, y2 = min(y1, y2), max(y1, y2) + self.boxes.append(tuple(map(int, (x1, y1, x2, y2)))) + + def to_model(self): + if self.image is None: + return {} + image = expand2square(self.image) + boxes = [box_xyxy_expand2square(box, w=self.image.width, h=self.image.height) for box in self.boxes] + return {'image': image, 'boxes': boxes} + + def draw_boxes(self): + assert self.image is not None + grounding_texts = [f'{bid}' for bid in range(len(self.boxes))] + image = expand2square(self.image) + boxes = [box_xyxy_expand2square(box, w=self.image.width, h=self.image.height) for box in self.boxes] + + image_to_draw = resize_pil_img(image, w=self.width, h=self.height) + boxes_to_draw = resize_box_xyxy(boxes, w=self.width, h=self.height, ow=image.width, oh=image.height) + + def _draw(img, _boxes: List[Any], texts: List[str]): + assert img is not None + colors = ["red", "blue", "green", "olive", "orange", "brown", "cyan", "purple"] + _img_draw = ImageDraw.Draw(img) + font = ImageFont.truetype(os.path.join(os.path.dirname(__file__), 'assets/DejaVuSansMono.ttf'), size=18) + for bid, box in enumerate(_boxes): + _img_draw.rectangle((box[0], box[1], box[2], box[3]), outline=colors[bid % len(colors)], width=4) + anno_text = texts[bid] + _img_draw.rectangle((box[0], box[3] - int(font.size * 1.2), box[0] + int((len(anno_text) + 0.8) * font.size * 0.6), box[3]), + outline=colors[bid % len(colors)], fill=colors[bid % len(colors)], width=4) + _img_draw.text((box[0] + int(font.size * 0.2), box[3] - int(font.size * 1.2)), anno_text, font=font, fill=(255, 255, 255)) + return img + + out_draw = _draw(image_to_draw, boxes_to_draw, grounding_texts) + return out_draw + + +def add_submit_temp_image(state, temp_image_path): + if '_submit_temp_images' not in state: + state['_submit_temp_images'] = [] + state['_submit_temp_images'].append(temp_image_path) + return state + + +def clear_submit_temp_image(state): + if '_submit_temp_images' in state: + for path in state['_submit_temp_images']: + os.remove(path) + del state['_submit_temp_images'] + return state + + +if __name__ == '__main__': + with gr.Blocks() as demo: + logo_file_url = f"file={os.path.join(os.path.dirname(__file__), 'assets/logo.png')}" + gr.HTML( + f""" + +

Logo

+

Shikra: Unleashing Multimodal LLM’s Referential Dialogue Magic

+

+ [Project] + [Paper] +

+

+ Shikra, an MLLM designed to kick off referential dialogue by excelling in spatial coordinate inputs/outputs in natural language, without additional vocabularies, position encoders, pre-/post-detection, or external plug-in models. +

+

User Manual

+
    +
  • Step 1. Upload an image

    +
  • +
  • Step 2. Select Question Format in Task Template. Task template and user input (if exists) will be assembled into final inputs to the model.

    +
      +
    • SpotCap: Ask the model to generate a grounded caption.
    • +
    • GCoT: Ask the model to answer the question and provide a Grounding-CoT, which is a step-by-step reasoning with explicit grounding information.
    • +
    • Cap: Ask the model to generate a short caption.
    • +
    • VQA: Ask the model to answer the question directly.
    • +
    • REC: Referring Expression Comprehension. Ask the model to output the bounding box of <expr>.
    • +
    • REG: Referring Expression Generation. Ask the model to generate a distinguishable description for RoI.
    • +
    • Advanced: Use no predefined template. You can take full control of inputs.
    • + +
    +
  • + +
  • Step 3. Ask Question. Use <boxes> placeholder if input has bounding box.

    +
  • + +
+

The following step are needed only when input has bounding box.

+
    +
  • Step 4. Draw Bounding Box in Sketch Pad.

    +

    Each bbox has a unique index, which will show at the corner of the bbox in Parsed Sketch Pad.

    +
  • +
  • Step 5. Assign the bbox index in Boexs Seq for each <boxes> placeholder. Boexs Seq take a 2-d list as input, each sub-list will replace the <boxes> placeholder in order.

    +
  • +
+""" + ) + + with gr.Row(): + with gr.Column(): + gr.HTML( + """ +

Video example

+

a video example demonstrate how to input with boxes

+ """ + ) + video_file_url = os.path.join(os.path.dirname(__file__), f"assets/petal_20230711_153216_Compressed.mp4") + gr.Video(value=video_file_url, interactive=False, width=600) + with gr.Column(): + boxes_seq_usage_file_url = f'file={os.path.join(os.path.dirname(__file__), f"assets/boxes_seq_explanation.jpg")}' + gr.HTML( + f""" +

Boxes Seq Usage Explanation

+

the [0,2] boxes will replace the first <boxes> placeholder. the [1] boxes will replace the second <boxes> placeholder.

+

+""" + ) + + gr.HTML( + """ +

Demo

+ """ + ) + with gr.Row(): + with gr.Column(): + chatbot = gr.Chatbot() + with gr.Accordion("Parameters", open=False): + with gr.Row(): + do_sample = gr.Checkbox(value=False, label='do sampling', interactive=True) + max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="max length", interactive=True) + top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) + temperature = gr.Slider(0, 10, value=0.75, step=0.01, label="Temperature", interactive=True) + with gr.Column(): + with gr.Row(variant='compact'): + sketch_pad = ImageMask(label="Sketch Pad", elem_id="img2img_image") + out_imagebox = gr.Image(label="Parsed Sketch Pad") + with gr.Column(): + radio = gr.Radio( + ["SpotCap", "GCoT", "Cap", "VQA", "REC", "REG", "Advanced"], label="Task Template", value='SpotCap', + ) + with gr.Group(): + template = gr.Textbox(label='Template', show_label=True, lines=1, interactive=False, + value='Provide a comprehensive description of the image and specify the positions of any mentioned objects in square brackets.') + user_input = gr.Textbox(label='', show_label=True, placeholder="Input...", lines=3, + value=None, visible=False, interactive=False) + boxes_seq = gr.Textbox(label='Boxes Seq', show_label=False, placeholder="Boxes Seq...", lines=1, + value=None, visible=False, interactive=False) + with gr.Row(): + reset_all = gr.Button('Reset All') + reset_chat = gr.Button('Reset Chat') + reset_boxes = gr.Button('Reset Boxes') + submitBtn = gr.Button('Run') + + + ############################################## + # reset state + ############################################## + + def reset_state_func(): + ret = { + 'ibs': ImageBoxState(), + 'ds': prepare_interactive(model_args, preprocessor), + } + return ret + + + state = gr.State(reset_state_func) + example_image_boxes = gr.State(None) + + + ############################################## + # reset dialogue + ############################################## + + def reset_all_func(state): + # clear_submit_temp_image(state) + new_state = reset_state_func() + boxes_seq = '[[0]]' if radio in ['REG', 'GC'] else None + return [new_state, None, None, None, boxes_seq, None] + + + reset_all.click( + fn=reset_all_func, + inputs=[state], + outputs=[state, sketch_pad, out_imagebox, user_input, boxes_seq, chatbot], + ) + + + def reset_chat_func_step1(state, radio): + state['ibs'].reset_masks() + new_state = reset_state_func() + new_state['_reset_boxes_func_image'] = state['ibs'].image + boxes_seq = '[[0]]' if radio in ['REG', 'GC'] else None + return [new_state, None, None, None, boxes_seq, None] + + + def reset_chat_func_step2(state): + image = state['_reset_boxes_func_image'] + del state['_reset_boxes_func_image'] + return state, gr.update(value=image) + + + reset_chat.click( + fn=reset_chat_func_step1, + inputs=[state, radio], + outputs=[state, sketch_pad, out_imagebox, user_input, boxes_seq, chatbot], + ).then( + fn=reset_chat_func_step2, + inputs=[state], + outputs=[state, sketch_pad], + ) + + + ############################################## + # reset boxes + ############################################## + + def reset_boxes_func_step1(state): + state['_reset_boxes_func_image'] = state['ibs'].image + state['ibs'].reset_masks() + return state, None + + + def reset_boxes_func_step2(state): + image = state['_reset_boxes_func_image'] + del state['_reset_boxes_func_image'] + return state, gr.update(value=image) + + + # reset boxes + reset_boxes.click( + fn=reset_boxes_func_step1, + inputs=[state], + outputs=[state, sketch_pad], + ).then( + fn=reset_boxes_func_step2, + inputs=[state], + outputs=[state, sketch_pad], + ) + + + ############################################## + # examples + ############################################## + + def parese_example(image, boxes): + state = reset_state_func() + image = Image.open(image) + state['ibs'].update_image(image) + for box in boxes: + state['ibs'].update_box(box) + image = state['ibs'].draw_boxes() + + _, path = tempfile.mkstemp(suffix='.jpg', dir=TEMP_FILE_DIR) + image.save(path) + return path, state + + + with gr.Column(visible=True) as example_SpotCap: + _examples_cap_raw = [ + os.path.join(os.path.dirname(__file__), 'assets/proposal.jpg'), + os.path.join(os.path.dirname(__file__), 'assets/water_question.jpg'), + os.path.join(os.path.dirname(__file__), 'assets/fishing.jpg'), + os.path.join(os.path.dirname(__file__), 'assets/ball.jpg'), + + os.path.join(os.path.dirname(__file__), 'assets/banana_phone.png'), + os.path.join(os.path.dirname(__file__), "assets/airplane.jpg"), + os.path.join(os.path.dirname(__file__), 'assets/baseball.png'), + ] + _examples_cap_parsed = [[item, []] for item in _examples_cap_raw] + gr.Examples( + examples=_examples_cap_parsed, + inputs=[sketch_pad, example_image_boxes], + ) + + with gr.Column(visible=False) as example_vqabox: + _examples_vqabox_parsed = [ + [ + os.path.join(os.path.dirname(__file__), 'assets/proposal.jpg'), + 'How is the person in the picture feeling?', + '[[0]]', + [[785, 108, 1063, 844]], + ], + [ + os.path.join(os.path.dirname(__file__), 'assets/woman_door.jpg'), + "Which one is the woman's reflection in the mirror?", + '[[0,1]]', + [(770, 138, 1024, 752), (469, 146, 732, 744)], + ], + [ + os.path.join(os.path.dirname(__file__), 'assets/man.jpg'), + "What is the person scared of?", + '[[0]]', + [(148, 99, 576, 497)], + ], + [ + os.path.join(os.path.dirname(__file__), "assets/giraffes.jpg"), + "How many animals in the image?", + "", + [], + ], + [ + os.path.join(os.path.dirname(__file__), "assets/dog_selfcontrol.jpg"), + "Is this dog on a lead held by someone able to control it?", + "", + [], + ], + [ + os.path.join(os.path.dirname(__file__), 'assets/wet_paint1.jpg'), + 'What does the board say?', + '', + [], + ], + [ + os.path.join(os.path.dirname(__file__), 'assets/g2.jpg'), + "What is unusual about the image?", + '', + [], + ], + ] + + gr.Examples( + examples=_examples_vqabox_parsed, + inputs=[sketch_pad, user_input, boxes_seq, example_image_boxes], + ) + + with gr.Column(visible=False) as example_vqa: + _examples_vqa_parsed = [ + [ + os.path.join(os.path.dirname(__file__), 'assets/food-1898194_640.jpg'), + "QUESTION: Which of the following is meat?\nOPTION:\n(A) \n(B) \n(C) \n(D) ", + '[[0],[1],[2],[3]]', + [[20, 216, 70, 343], [8, 3, 187, 127], [332, 386, 424, 494], [158, 518, 330, 605]], + ], + [ + os.path.join(os.path.dirname(__file__), 'assets/potato.jpg'), + "What color is this?", + '[[0]]', + [[75, 408, 481, 802]], + ], + [ + os.path.join(os.path.dirname(__file__), 'assets/potato.jpg'), + "What color is this?", + '[[0]]', + [[147, 274, 266, 437]], + ], + [ + os.path.join(os.path.dirname(__file__), 'assets/staircase-274614_640.jpg'), + "Is this a sea snail?", + '', + [], + ], + [ + os.path.join(os.path.dirname(__file__), 'assets/staircase-274614_640.jpg'), + "Is this shape like a sea snail?", + '', + [], + ], + ] + gr.Examples( + examples=_examples_vqa_parsed, + inputs=[sketch_pad, user_input, boxes_seq, example_image_boxes], + ) + + with gr.Column(visible=False) as example_rec: + gr.Examples( + examples=[ + [ + os.path.join(os.path.dirname(__file__), "assets/rec_bear.png"), + "a brown teddy bear with a blue bow", + [], + ], + [ + os.path.join(os.path.dirname(__file__), "assets/bear-792466_1280.jpg"), + "the teddy bear lay on the sofa edge", + [], + ], + ], + inputs=[sketch_pad, user_input, example_image_boxes], + ) + + with gr.Column(visible=False) as example_reg: + gr.Examples( + examples=[ + [ + os.path.join(os.path.dirname(__file__), "assets/fruits.jpg"), + "[[0]]", + [[833, 527, 646, 315]], + ], + [ + os.path.join(os.path.dirname(__file__), "assets/bearhat.png"), + "[[0]]", + [[48, 49, 216, 152]], + ], + [ + os.path.join(os.path.dirname(__file__), "assets/oven.jpg"), + "[[0]]", + [[1267, 314, 1383, 458]], + ], + ], + inputs=[sketch_pad, boxes_seq, example_image_boxes], + ) + + with gr.Column(visible=False) as example_adv: + gr.Examples( + examples=[ + [ + + ], + ], + inputs=[sketch_pad, user_input, boxes_seq, example_image_boxes], + ) + + + ############################################## + # task template select + ############################################## + + def change_textbox(choice): + task_template = { + "SpotCap": "Provide a comprehensive description of the image and specify the positions of any mentioned objects in square brackets.", + "Cap": "Summarize the content of the photo .", + "GCoT": "With the help of the image , can you clarify my question ''? Also, explain the reasoning behind your answer, and don't forget to label the bounding boxes of the involved objects using square brackets.", + "VQA": "For this image , I want a simple and direct answer to my question: ", + "REC": "Can you point out in the image and provide the coordinates of its location?", + "REG": "For the given image , can you provide a unique description of the area ?", + "GC": "Can you give me a description of the region in image ?", + "Advanced": "", + } + if choice in ['Advanced']: + template_update = gr.update(value=task_template[choice], visible=False) + else: + template_update = gr.update(value=task_template[choice], visible=True) + + if choice in ['SpotCap', 'Cap']: + input_update = gr.update(value=None, visible=False, interactive=False) + boxes_seq_update = gr.update(show_label=False, value=None, visible=False, interactive=False) + elif choice in ['GCoT', 'VQA']: + input_update = gr.update(label='', value=None, visible=True, interactive=True) + boxes_seq_update = gr.update(show_label=False, value=None, visible=True, interactive=True) + elif choice in ['Advanced']: + input_update = gr.update(label='Input', value=None, visible=True, interactive=True) + boxes_seq_update = gr.update(show_label=False, value=None, visible=True, interactive=True) + elif choice in ['REC']: + input_update = gr.update(label='', value=None, visible=True, interactive=True) + boxes_seq_update = gr.update(show_label=False, value=None, visible=False, interactive=False) + elif choice in ['REG', 'GC']: + input_update = gr.update(value=None, visible=False, interactive=False) + boxes_seq_update = gr.update(show_label=True, value='[[0]]', visible=True, interactive=True) + else: + raise gr.Error("What is this?!") + + ret = [ + template_update, + input_update, + boxes_seq_update, + gr.update(visible=True) if choice in ['SpotCap', 'Cap'] else gr.update(visible=False), + gr.update(visible=True) if choice in ['GCoT'] else gr.update(visible=False), + gr.update(visible=True) if choice in ['VQA'] else gr.update(visible=False), + gr.update(visible=True) if choice in ['REC'] else gr.update(visible=False), + gr.update(visible=True) if choice in ['REG', 'GC'] else gr.update(visible=False), + gr.update(visible=True) if choice in ['Advanced'] else gr.update(visible=False), + ] + return ret + + + radio.change( + fn=change_textbox, + inputs=radio, + outputs=[template, user_input, boxes_seq, example_SpotCap, example_vqabox, example_vqa, example_rec, example_reg, example_adv], + ) + + + ############################################## + # draw + ############################################## + + def draw(sketch_pad: dict, state: dict, example_image_boxes): + if example_image_boxes is None: + image = sketch_pad['image'] + image = Image.fromarray(image) + mask = sketch_pad['mask'][..., 0] if sketch_pad['mask'].ndim == 3 else sketch_pad['mask'] + mask = binarize(mask) + ibs: ImageBoxState = state['ibs'] + ibs.update_image(image) + ibs.update_mask(mask) + out_draw = ibs.draw_boxes() + ret = [out_draw, state, None, gr.update()] + return ret + else: + image = sketch_pad['image'] + image = Image.fromarray(image) + + state = reset_state_func() + ibs: ImageBoxState = state['ibs'] + ibs.update_image(image) + for box in example_image_boxes: + ibs.update_box(box) + out_draw = ibs.draw_boxes() + ret = [out_draw, state, None, []] + return ret + + + sketch_pad.edit( + fn=draw, + inputs=[sketch_pad, state, example_image_boxes], + outputs=[out_imagebox, state, example_image_boxes, chatbot], + queue=False, + ) + + + ############################################## + # submit boxes + ############################################## + + def submit_step1(state, template, raw_user_input, boxes_seq, chatbot, do_sample, max_length, top_p, temperature): + if '' in template or '' in template: + if not bool(raw_user_input): + raise gr.Error("say sth bro.") + if '' in template: + user_input = template.replace("", raw_user_input) + elif '' in template: + user_input = template.replace("", raw_user_input) + else: + user_input = template + + def parse_boxes_seq(boxes_seq_str) -> List[List[int]]: + if not bool(boxes_seq_str): + return [] + import ast + # validate + try: + parsed = ast.literal_eval(boxes_seq_str) + assert isinstance(parsed, (tuple, list)), \ + f"boxes_seq should be a tuple/list but got {type(parsed)}" + for elem in parsed: + assert isinstance(elem, (tuple, list)), \ + f"the elem in boxes_seq should be a tuple/list but got {type(elem)} for elem: {elem}" + assert len(elem) != 0, \ + f"the elem in boxes_seq should not be empty." + for atom in elem: + assert isinstance(atom, int), \ + f"the boxes_seq atom should be a int idx but got {type(atom)} for atom: {atom}" + except (AssertionError, SyntaxError) as e: + raise gr.Error(f"error when parse boxes_seq_str: {str(e)} for input: {boxes_seq_str}") + return parsed + + boxes_seq = parse_boxes_seq(boxes_seq) + + mm_state = state['ibs'].to_model() + ds = state['ds'] + print(mm_state) + if 'image' in mm_state and bool(mm_state['image']): + # multimodal mode + if ds.image is not None and ds.image != mm_state['image']: + raise gr.Error("shikra only support single image conversation but got different images. maybe u want `Reset Dialogue`") + if ds.image != mm_state['image']: + ds.set_image(mm_state['image']) + + def validate_message_box(user_input: str, boxes_seq: list, boxes_value: list): + if boxes_value and (not boxes_seq): + grWarning("has box drawn but set no boxes_seq") + + if boxes_seq and (not boxes_value): + grWarning("ignored boxes_seq because no box drawn.") + + boxes_placeholder_num = str(user_input).count('') + if boxes_placeholder_num != len(boxes_seq): + raise gr.Error(f" and boxes_seq num not match: {boxes_placeholder_num} {len(boxes_seq)}") + + for boxes in boxes_seq: + for bidx in boxes: + if not (0 <= bidx < len(boxes_value)): + raise gr.Error(f"boxes_seq out of range: {boxes_seq} {len(boxes_value)}") + + try: + validate_message_box(user_input, boxes_seq, mm_state['boxes']) + ds.append_message(role=ds.roles[0], message=user_input, boxes=mm_state['boxes'], boxes_seq=boxes_seq) + except Exception as e: + raise gr.Error(f"error when append message: {str(e)}") + else: + # text-only mode + if bool(boxes_seq): + grWarning("ignored boxes_seq in text-only mode") + boxes_placeholder_num = str(user_input).count('') + if boxes_placeholder_num: + gr.Error("use in input but no image found.") + ds.append_message(role=ds.roles[0], message=user_input) + + model_inputs = ds.to_model_input() + model_inputs['images'] = model_inputs['images'].to(torch.float16) + print(f"model_inputs: {model_inputs}") + + if do_sample: + gen_kwargs = dict( + use_cache=True, + do_sample=do_sample, + pad_token_id=tokenizer.pad_token_id, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + max_new_tokens=max_length, + top_p=top_p, + temperature=float(temperature), + ) + else: + gen_kwargs = dict( + use_cache=True, + do_sample=do_sample, + pad_token_id=tokenizer.pad_token_id, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + max_new_tokens=max_length, + ) + print(gen_kwargs) + input_ids = model_inputs['input_ids'] + st_time = time.time() + with torch.inference_mode(): + with torch.autocast(dtype=torch.float16, device_type='cuda'): + output_ids = model.generate(**model_inputs, **gen_kwargs) + print(f"done generated in {time.time() - st_time} seconds") + input_token_len = input_ids.shape[-1] + response = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] + print(f"response: {response}") + + # update new message + + def build_boxes_image(text, image): + if image is None: + return text, None + print(text, image) + import re + + colors = ['#ed7d31', '#5b9bd5', '#70ad47', '#7030a0', '#c00000', '#ffff00', "olive", "brown", "cyan"] + pat = re.compile(r'\[\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3}(?:;\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3})*\]') + + def extract_boxes(string): + ret = [] + for bboxes_str in pat.findall(string): + bboxes = [] + bbox_strs = bboxes_str.replace("(", "").replace(")", "").replace("[", "").replace("]", "").split(";") + for bbox_str in bbox_strs: + bbox = list(map(float, bbox_str.split(','))) + bboxes.append(bbox) + ret.append(bboxes) + return ret + + extract_pred = extract_boxes(text) + boxes_to_draw = [] + color_to_draw = [] + for idx, boxes in enumerate(extract_pred): + color = colors[idx % len(colors)] + for box in boxes: + boxes_to_draw.append(de_norm_box_xyxy(box, w=image.width, h=image.height)) + color_to_draw.append(color) + if not boxes_to_draw: + return text, None + res = draw_bounding_boxes(image=image, boxes=boxes_to_draw, colors=color_to_draw, width=8) + from torchvision.transforms import ToPILImage + res = ToPILImage()(res) + _, path = tempfile.mkstemp(suffix='.jpg', dir=TEMP_FILE_DIR) + res.save(path) + add_submit_temp_image(state, path) + + # post process text color + print(text) + location_text = text + edit_text = list(text) + bboxes_str = pat.findall(text) + for idx in range(len(bboxes_str) - 1, -1, -1): + color = colors[idx % len(colors)] + boxes = bboxes_str[idx] + span = location_text.rfind(boxes), location_text.rfind(boxes) + len(boxes) + location_text = location_text[:span[0]] + edit_text[span[0]:span[1]] = f'{boxes}' + text = "".join(edit_text) + return text, path + + def convert_one_round_message(conv, image=None): + text_query = f"{conv[0][0]}: {conv[0][1]}" + text_answer = f"{conv[1][0]}: {conv[1][1]}" + text_query, image_query = build_boxes_image(text_query, image) + text_answer, image_answer = build_boxes_image(text_answer, image) + + new_chat = [] + new_chat.append([parse_text(text_query), None]) + if image_query is not None: + new_chat.append([(image_query,), None]) + + new_chat.append([None, parse_text(text_answer)]) + if image_answer is not None: + new_chat.append([None, (image_answer,)]) + return new_chat + + ds.append_message(role=ds.roles[1], message=response) + conv = ds.to_gradio_chatbot_new_messages() + new_message = convert_one_round_message(conv, image=mm_state.get('image', None)) + print(new_message) + state['_submit_new_message'] = new_message + return state, chatbot + + + def submit_step2(state, user_input, boxes_seq, chatbot): + if '_submit_new_message' in state: + chatbot.extend(state['_submit_new_message']) + del state['_submit_new_message'] + return state, None, None, chatbot + return state, user_input, boxes_seq, chatbot + + + submitBtn.click( + submit_step1, + [state, template, user_input, boxes_seq, chatbot, do_sample, max_length, top_p, temperature], + [state, chatbot], + ).then( + submit_step2, + [state, user_input, boxes_seq, chatbot], + [state, user_input, boxes_seq, chatbot], + ) + + print("launching...") + demo.queue().launch(server_name=args.server_name, server_port=args.server_port) \ No newline at end of file diff --git a/mllm/demo/webdemo_re.py b/mllm/demo/webdemo_re.py new file mode 100644 index 0000000000000000000000000000000000000000..f5c672288fe8f6cf9f04628f6b6a5152cd074cd9 --- /dev/null +++ b/mllm/demo/webdemo_re.py @@ -0,0 +1,865 @@ +import os +import sys +import logging +import time +import argparse +import tempfile +from pathlib import Path +from typing import List, Any, Union + +import torch +import numpy as np +import gradio as gr +from PIL import Image +from PIL import ImageDraw, ImageFont +from mmengine import Config +import transformers +from transformers import BitsAndBytesConfig + +sys.path.append(str(Path(__file__).parent.parent.parent)) + +from mllm.dataset.process_function import PlainBoxFormatter +from mllm.dataset.builder import prepare_interactive +from mllm.utils import draw_bounding_boxes +from mllm.models.builder.build_shikra import load_pretrained_shikra + +log_level = logging.DEBUG +transformers.logging.set_verbosity(log_level) +transformers.logging.enable_default_handler() +transformers.logging.enable_explicit_format() + +TEMP_FILE_DIR = Path(__file__).parent / 'temp' +TEMP_FILE_DIR.mkdir(parents=True, exist_ok=True) + +######################################### +# mllm model init +######################################### +parser = argparse.ArgumentParser("Shikra Web Demo") +parser.add_argument('--model_path', required=True) +parser.add_argument('--load_in_8bit', action='store_true') +parser.add_argument('--server_name', default=None) +parser.add_argument('--server_port', type=int, default=None) + +args = parser.parse_args() +print(args) + +model_name_or_path = args.model_path + +model_args = Config(dict( + type='shikra', + version='v1', + + # checkpoint config + cache_dir=None, + model_name_or_path=model_name_or_path, + #vision_tower=r'openai/clip-vit-large-patch14', + vision_tower=r'SenseTime/deformable-detr', + pretrain_mm_mlp_adapter=None, + + # model config + mm_vision_select_layer=-2, + model_max_length=2048, + + # finetune config + freeze_backbone=False, + tune_mm_mlp_adapter=False, + freeze_mm_mlp_adapter=False, + + # data process config + is_multimodal=True, + sep_image_conv_front=False, + image_token_len=300, + mm_use_im_start_end=True, + + target_processor=dict( + boxes=dict(type='PlainBoxFormatter'), + ), + + process_func_args=dict( + conv=dict(type='ShikraConvProcess'), + target=dict(type='BoxFormatProcess'), + text=dict(type='ShikraTextProcess'), + image=dict(type='ShikraImageProcessor'), + ), + + conv_args=dict( + conv_template='vicuna_v1.1', + transforms=dict(type='Expand2square'), + tokenize_kwargs=dict(truncation_size=None), + ), + + gen_kwargs_set_pad_token_id=True, + gen_kwargs_set_bos_token_id=True, + gen_kwargs_set_eos_token_id=True, +)) +training_args = Config(dict( + bf16=False, + fp16=True, + device='cuda', + fsdp=None, +)) + +if args.load_in_8bit: + quantization_kwargs = dict( + quantization_config=BitsAndBytesConfig( + load_in_8bit=True, + ) + ) +else: + quantization_kwargs = dict() + +model, preprocessor = load_pretrained_shikra(model_args, training_args, **quantization_kwargs) +if not getattr(model, 'is_quantized', False): + model.to(dtype=torch.float16, device=torch.device('cuda')) +if not getattr(model.model.vision_tower[0], 'is_quantized', False): + model.model.vision_tower[0].to(dtype=torch.float16, device=torch.device('cuda')) +print(f"LLM device: {model.device}, is_quantized: {getattr(model, 'is_quantized', False)}, is_loaded_in_4bit: {getattr(model, 'is_loaded_in_4bit', False)}, is_loaded_in_8bit: {getattr(model, 'is_loaded_in_8bit', False)}") +print(f"vision device: {model.model.vision_tower[0].device}, is_quantized: {getattr(model.model.vision_tower[0], 'is_quantized', False)}, is_loaded_in_4bit: {getattr(model, 'is_loaded_in_4bit', False)}, is_loaded_in_8bit: {getattr(model, 'is_loaded_in_8bit', False)}") + +preprocessor['target'] = {'boxes': PlainBoxFormatter()} +tokenizer = preprocessor['text'] + + +######################################### +# demo utils +######################################### + +def parse_text(text): + text = text.replace("", "<image>") + return text + + +def setup_gradio_warning(level=1): + """ + level 0 1 2 3 + level IGNORE Weak Strong Error + has Warning _foo Warning Warning Error + no Warning _foo _foo Error Error + """ + + def _dummy_func(*args, **kwargs): + pass + + def _raise_error(*args, **kwargs): + raise gr.Error(*args, **kwargs) + + assert level in [0, 1, 2, 3] + if level >= 3: + return _raise_error + if level <= 0: + return _dummy_func + if hasattr(gr, 'Warning'): + return gr.Warning + if level == 1: + return _dummy_func + return _raise_error + + +grWarning = setup_gradio_warning() + + +def de_norm_box_xyxy(box, *, w, h): + x1, y1, x2, y2 = box + x1 = x1 * w + x2 = x2 * w + y1 = y1 * h + y2 = y2 * h + box = x1, y1, x2, y2 + return box + + +def expand2square(pil_img, background_color=(255, 255, 255)): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +def box_xyxy_expand2square(box, *, w, h): + if w == h: + return box + if w > h: + x1, y1, x2, y2 = box + y1 += (w - h) // 2 + y2 += (w - h) // 2 + box = x1, y1, x2, y2 + return box + assert w < h + x1, y1, x2, y2 = box + x1 += (h - w) // 2 + x2 += (h - w) // 2 + box = x1, y1, x2, y2 + return box + + +def resize_pil_img(pil_img: Image.Image, *, w, h): + old_height, old_width = pil_img.height, pil_img.width + new_height, new_width = (h, w) + if (new_height, new_width) == (old_height, old_width): + return pil_img + return pil_img.resize((new_width, new_height)) + + +def resize_box_xyxy(boxes, *, w, h, ow, oh): + old_height, old_width = (oh, ow) + new_height, new_width = (h, w) + if (new_height, new_width) == (old_height, old_width): + return boxes + w_ratio = new_width / old_width + h_ratio = new_height / old_height + out_boxes = [] + for box in boxes: + x1, y1, x2, y2 = box + x1 = x1 * w_ratio + x2 = x2 * w_ratio + y1 = y1 * h_ratio + y2 = y2 * h_ratio + nb = (x1, y1, x2, y2) + out_boxes.append(nb) + return out_boxes + + +# use mask to simulate box +# copy from https://github.com/gligen/GLIGEN/blob/master/demo/app.py +class ImageMask(gr.components.Image): + is_template = True + + def __init__(self, **kwargs): + super().__init__(source="upload", tool="sketch", interactive=True, **kwargs) + #super().__init__(tool = "sketch", interactive=True, **kwargs) + + +def binarize(x): + return (x != 0).astype('uint8') * 255 + + +class ImageBoxState: + def __init__(self, draw_size: Union[int, float, tuple, list] = 512): + if isinstance(draw_size, (float, int)): + draw_size = (draw_size, draw_size) + assert len(draw_size) == 2 + self.size = draw_size + self.height, self.width = self.size[0], self.size[1] + self.reset_state() + + # noinspection PyAttributeOutsideInit + def reset_state(self): + self.image = None + self.boxes = [] + self.masks = [] + + # noinspection PyAttributeOutsideInit + def reset_masks(self): + self.boxes = [] + self.masks = [] + + # noinspection PyAttributeOutsideInit + def update_image(self, image): + if image != self.image: + self.reset_state() + self.image = image + + def update_mask(self, mask): + if len(self.masks) == 0: + last_mask = np.zeros_like(mask) + else: + last_mask = self.masks[-1] + + if type(mask) == np.ndarray and mask.size > 1: + diff_mask = mask - last_mask + else: + diff_mask = np.zeros([]) + + if diff_mask.sum() > 0: + # noinspection PyArgumentList + x1x2 = np.where(diff_mask.max(0) != 0)[0] + # noinspection PyArgumentList + y1y2 = np.where(diff_mask.max(1) != 0)[0] + y1, y2 = y1y2.min(), y1y2.max() + x1, x2 = x1x2.min(), x1x2.max() + if (x2 - x1 > 5) and (y2 - y1 > 5): + self.masks.append(mask.copy()) + self.boxes.append(tuple(map(int, (x1, y1, x2, y2)))) + + def update_box(self, box): + x1, y1, x2, y2 = box + x1, x2 = min(x1, x2), max(x1, x2) + y1, y2 = min(y1, y2), max(y1, y2) + self.boxes.append(tuple(map(int, (x1, y1, x2, y2)))) + + def to_model(self): + if self.image is None: + return {} + image = expand2square(self.image) + boxes = [box_xyxy_expand2square(box, w=self.image.width, h=self.image.height) for box in self.boxes] + return {'image': image, 'boxes': boxes} + + def draw_boxes(self): + assert self.image is not None + grounding_texts = [f'{bid}' for bid in range(len(self.boxes))] + image = expand2square(self.image) + boxes = [box_xyxy_expand2square(box, w=self.image.width, h=self.image.height) for box in self.boxes] + + image_to_draw = resize_pil_img(image, w=self.width, h=self.height) + boxes_to_draw = resize_box_xyxy(boxes, w=self.width, h=self.height, ow=image.width, oh=image.height) + + def _draw(img, _boxes: List[Any], texts: List[str]): + assert img is not None + colors = ["red", "blue", "green", "olive", "orange", "brown", "cyan", "purple"] + _img_draw = ImageDraw.Draw(img) + font = ImageFont.truetype(os.path.join(os.path.dirname(__file__), 'assets/DejaVuSansMono.ttf'), size=18) + for bid, box in enumerate(_boxes): + _img_draw.rectangle((box[0], box[1], box[2], box[3]), outline=colors[bid % len(colors)], width=4) + anno_text = texts[bid] + _img_draw.rectangle((box[0], box[3] - int(font.size * 1.2), box[0] + int((len(anno_text) + 0.8) * font.size * 0.6), box[3]), + outline=colors[bid % len(colors)], fill=colors[bid % len(colors)], width=4) + _img_draw.text((box[0] + int(font.size * 0.2), box[3] - int(font.size * 1.2)), anno_text, font=font, fill=(255, 255, 255)) + return img + + out_draw = _draw(image_to_draw, boxes_to_draw, grounding_texts) + return out_draw + + +def add_submit_temp_image(state, temp_image_path): + if '_submit_temp_images' not in state: + state['_submit_temp_images'] = [] + state['_submit_temp_images'].append(temp_image_path) + return state + + +def clear_submit_temp_image(state): + if '_submit_temp_images' in state: + for path in state['_submit_temp_images']: + os.remove(path) + del state['_submit_temp_images'] + return state + + +if __name__ == '__main__': + with gr.Blocks() as demo: + logo_file_url = f"file={os.path.join(os.path.dirname(__file__), 'assets/relogo.png')}" + gr.HTML( + f""" + +

Logo

+

ChemRxnGPT: A Multimodal LLM for Chemical Reaction Image Analysis

+

+ [Project] + [Paper] +

+

+ ChemRxnGPT, an MLLM designed for Chemical Reaction Image Analysis by excelling in chemical reaction pattern and spatial coordinate reaction object in natural language, without additional vocabularies, position encoders, pre-/post-detection, or external plug-in models. +

+

User Manual

+
    +
  • Step 1. Upload an chemical reaction image

    +
  • +
  • Step 2. Select Question Format in Task Template. Task template and user input (if exists) will be assembled into final inputs to the model.

    +
      +
    • Task 1, Chemical Reaction Extraction Task: Ask the model to generate a complete and detailed reaction list.
    • +
    • Task 2, Detailed Condition VQA and OCR Task: Ask the model to provide a detailed condition information in a condition area.
    • +
    +
  • + +
  • Step 3. Ask Question. Use <boxes> placeholder if input has bounding box.

    +
  • + +
+

The following step are needed only for Detailed condition VQA and OCR task which input has bounding box.

+
    +
  • Step 4. Draw Bounding Box in Sketch Pad.

    +

    Each bbox has a unique index, which will show at the corner of the bbox in Parsed Sketch Pad.

    +
  • +
  • Step 5. Assign the bbox index in Boexs Seq for each <boxes> placeholder. Boexs Seq take a 2-d list as input, each sub-list will replace the <boxes> placeholder in order.

    +
  • +
+""" + ) + + with gr.Row(): + with gr.Column(): + gr.HTML( + """ +

Video example 1

+

a video example demonstrate how to use the demo for Task 1.

+ """ + ) + video_file_url = os.path.join(os.path.dirname(__file__), f"assets/petal_20230711_153216_Compressed.mp4") + gr.Video(value=video_file_url, interactive=False, width=550) + with gr.Column(): + gr.HTML( + + """ +

Video example 2

+

a video example demonstrate how to use the demo for Task 2.

+ """ + ) + video_file_url = os.path.join(os.path.dirname(__file__), f"assets/petal_20230711_153216_Compressed.mp4") + gr.Video(value=video_file_url, interactive=False, width=550) + + + gr.HTML( + """ +

Demo

+ """ + ) + with gr.Row(): + with gr.Column(): + chatbot = gr.Chatbot() + with gr.Accordion("Parameters", open=False): + with gr.Row(): + do_sample = gr.Checkbox(value=False, label='do sampling', interactive=True) + max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="max length", interactive=True) + top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) + temperature = gr.Slider(0, 10, value=0.75, step=0.01, label="Temperature", interactive=True) + with gr.Column(): + with gr.Row(variant='compact'): + sketch_pad = ImageMask(label="Sketch Pad", elem_id="img2img_image") + out_imagebox = gr.Image(label="Parsed Sketch Pad") + with gr.Column(): + radio = gr.Radio( + ["Task_1", "Task_2"], label="Task Template", value='Task_1', + ) + with gr.Group(): + template = gr.Textbox(label='Template', show_label=True, lines=1, interactive=False, + value= 'Please list every Reactions in this image in detail, including the category of every objects with a unique ID and coordinates[x1,y1,x2,y2]. And their Reaction role in a reaction. The category include Structure and Text. The Reaction role include Reactants, Conditions and Products. And notice that Reactants and Products are usually linked by arrows.') + user_input = gr.Textbox(label='', show_label=True, placeholder="Input...", lines=3, + value=None, visible=False, interactive=False) + boxes_seq = gr.Textbox(label='Boxes Seq', show_label=False, placeholder="Boxes Seq...", lines=1, + value=None, visible=False, interactive=False) + with gr.Row(): + reset_all = gr.Button('Reset All') + reset_chat = gr.Button('Reset Chat') + reset_boxes = gr.Button('Reset Boxes') + submitBtn = gr.Button('Run') + + + ############################################## + # reset state + ############################################## + + def reset_state_func(): + ret = { + 'ibs': ImageBoxState(), + 'ds': prepare_interactive(model_args, preprocessor), + } + return ret + + + state = gr.State(reset_state_func) + example_image_boxes = gr.State(None) + + + ############################################## + # reset dialogue + ############################################## + + def reset_all_func(state): + # clear_submit_temp_image(state) + new_state = reset_state_func() + boxes_seq = '[[0]]' if radio in ['Task_2', 'GC'] else None + return [new_state, None, None, None, boxes_seq, None] + + + reset_all.click( + fn=reset_all_func, + inputs=[state], + outputs=[state, sketch_pad, out_imagebox, user_input, boxes_seq, chatbot], + ) + + + def reset_chat_func_step1(state, radio): + state['ibs'].reset_masks() + new_state = reset_state_func() + new_state['_reset_boxes_func_image'] = state['ibs'].image + boxes_seq = '[[0]]' if radio in ['Task_2', 'GC'] else None + return [new_state, None, None, None, boxes_seq, None] + + + def reset_chat_func_step2(state): + image = state['_reset_boxes_func_image'] + del state['_reset_boxes_func_image'] + return state, gr.update(value=image) + + + reset_chat.click( + fn=reset_chat_func_step1, + inputs=[state, radio], + outputs=[state, sketch_pad, out_imagebox, user_input, boxes_seq, chatbot], + ).then( + fn=reset_chat_func_step2, + inputs=[state], + outputs=[state, sketch_pad], + ) + + + ############################################## + # reset boxes + ############################################## + + def reset_boxes_func_step1(state): + state['_reset_boxes_func_image'] = state['ibs'].image + state['ibs'].reset_masks() + return state, None + + + def reset_boxes_func_step2(state): + image = state['_reset_boxes_func_image'] + del state['_reset_boxes_func_image'] + return state, gr.update(value=image) + + + # reset boxes + reset_boxes.click( + fn=reset_boxes_func_step1, + inputs=[state], + outputs=[state, sketch_pad], + ).then( + fn=reset_boxes_func_step2, + inputs=[state], + outputs=[state, sketch_pad], + ) + + + ############################################## + # examples + ############################################## + + def parese_example(image, boxes): + state = reset_state_func() + image = Image.open(image) + state['ibs'].update_image(image) + for box in boxes: + state['ibs'].update_box(box) + image = state['ibs'].draw_boxes() + + _, path = tempfile.mkstemp(suffix='.jpg', dir=TEMP_FILE_DIR) + image.save(path) + return path, state + + + with gr.Column(visible=True) as example_Task_1: + _examples_cap_raw = [ + os.path.join(os.path.dirname(__file__), 'assets/reaction1.png'), + os.path.join(os.path.dirname(__file__), 'assets/reaction2.png'), + + ] + _examples_cap_parsed = [[item, []] for item in _examples_cap_raw] + gr.Examples( + examples=_examples_cap_parsed, + inputs=[sketch_pad, example_image_boxes], + ) + + + + with gr.Column(visible=True) as example_Task_2: + gr.Examples( + examples=[ + [ + os.path.join(os.path.dirname(__file__), "assets/reaction3.png"), + "[[0]]", + [[654.0, 239.0, 871.0, 285.0]], + ] + ], + inputs=[sketch_pad, boxes_seq, example_image_boxes], + ) + + + + ############################################## + # task template select + ############################################## + + def change_textbox(choice): + task_template = { + #"Task_1": "Please list every Reactions in this image in detail, including the category of every objects with a unique ID and coordinates[x1,y1,x2,y2]. And their Reaction role in a reaction. The category include Structure and Text. The Reaction role include Reactants, Conditions and Products. And notice that Reactants and Products are usually linked by arrows.", + "Task_1": "Please list every reaction in this image in detail. For each reaction, include the category and unique ID of each object, along with their coordinates [x1, y1, x2, y2]. Categories include Structure () and Text (). Describe their roles in each reaction( to ), including Reactants ( to ), Conditions ( to ), and Products ( to ). Note that Reactants and Products must include at least one object, while Conditions can be specified without any objects. Each reaction should be listed in the following structured output format: (object 1)...(object 2)...(object 3)...,.... Only the Conditions section can be empty( without anything between).", + + "Task_2": "what is written in this Text, And please indicate their roles in solvent, temperature, time, agent and yield ", + } + if choice in ['Advanced']: + template_update = gr.update(value=task_template[choice], visible=False) + else: + template_update = gr.update(value=task_template[choice], visible=True) + + if choice in ['Task_1']: + input_update = gr.update(value=None, visible=False, interactive=False) + boxes_seq_update = gr.update(show_label=False, value=None, visible=False, interactive=False) + elif choice in ['Task_2']: + input_update = gr.update(value=None, visible=False, interactive=False) + boxes_seq_update = gr.update(show_label=True, value='[[0]]', visible=True, interactive=True) + else: + raise gr.Error("What is this?!") + + ret = [ + template_update, + input_update, + boxes_seq_update, + gr.update(visible=True) if choice in ['Task_1'] else gr.update(visible=False), + gr.update(visible=True) if choice in ['Task_2'] else gr.update(visible=False), + ] + return ret + + + radio.change( + fn=change_textbox, + inputs=radio, + outputs=[template, user_input, boxes_seq, example_Task_1, example_Task_2], + ) + + + ############################################## + # draw + ############################################## + + def draw(sketch_pad: dict, state: dict, example_image_boxes): + if example_image_boxes is None: + image = sketch_pad['image'] + image = Image.fromarray(image) + mask = sketch_pad['mask'][..., 0] if sketch_pad['mask'].ndim == 3 else sketch_pad['mask'] + mask = binarize(mask) + ibs: ImageBoxState = state['ibs'] + ibs.update_image(image) + ibs.update_mask(mask) + out_draw = ibs.draw_boxes() + ret = [out_draw, state, None, gr.update()] + return ret + else: + image = sketch_pad['image'] + image = Image.fromarray(image) + + state = reset_state_func() + ibs: ImageBoxState = state['ibs'] + ibs.update_image(image) + for box in example_image_boxes: + ibs.update_box(box) + out_draw = ibs.draw_boxes() + ret = [out_draw, state, None, []] + return ret + + + sketch_pad.edit( + fn=draw, + inputs=[sketch_pad, state, example_image_boxes], + outputs=[out_imagebox, state, example_image_boxes, chatbot], + queue=False, + ) + + + ############################################## + # submit boxes + ############################################## + + def submit_step1(state, template, raw_user_input, boxes_seq, chatbot, do_sample, max_length, top_p, temperature): + if '' in template or '' in template: + if not bool(raw_user_input): + raise gr.Error("say sth bro.") + if '' in template: + user_input = template.replace("", raw_user_input) + elif '' in template: + user_input = template.replace("", raw_user_input) + else: + user_input = template + + def parse_boxes_seq(boxes_seq_str) -> List[List[int]]: + if not bool(boxes_seq_str): + return [] + import ast + # validate + try: + parsed = ast.literal_eval(boxes_seq_str) + assert isinstance(parsed, (tuple, list)), \ + f"boxes_seq should be a tuple/list but got {type(parsed)}" + for elem in parsed: + assert isinstance(elem, (tuple, list)), \ + f"the elem in boxes_seq should be a tuple/list but got {type(elem)} for elem: {elem}" + assert len(elem) != 0, \ + f"the elem in boxes_seq should not be empty." + for atom in elem: + assert isinstance(atom, int), \ + f"the boxes_seq atom should be a int idx but got {type(atom)} for atom: {atom}" + except (AssertionError, SyntaxError) as e: + raise gr.Error(f"error when parse boxes_seq_str: {str(e)} for input: {boxes_seq_str}") + return parsed + + boxes_seq = parse_boxes_seq(boxes_seq) + + mm_state = state['ibs'].to_model() + ds = state['ds'] + print(mm_state) + if 'image' in mm_state and bool(mm_state['image']): + # multimodal mode + if ds.image is not None and ds.image != mm_state['image']: + raise gr.Error("ChemRxnGPT only support single image conversation but got different images. maybe u want `Reset All`") + if ds.image != mm_state['image']: + ds.set_image(mm_state['image']) + + def validate_message_box(user_input: str, boxes_seq: list, boxes_value: list): + if boxes_value and (not boxes_seq): + grWarning("has box drawn but set no boxes_seq") + + if boxes_seq and (not boxes_value): + grWarning("ignored boxes_seq because no box drawn.") + + boxes_placeholder_num = str(user_input).count('') + if boxes_placeholder_num != len(boxes_seq): + raise gr.Error(f" and boxes_seq num not match: {boxes_placeholder_num} {len(boxes_seq)}") + + for boxes in boxes_seq: + for bidx in boxes: + if not (0 <= bidx < len(boxes_value)): + raise gr.Error(f"boxes_seq out of range: {boxes_seq} {len(boxes_value)}") + + try: + validate_message_box(user_input, boxes_seq, mm_state['boxes']) + ds.append_message(role=ds.roles[0], message=user_input, boxes=mm_state['boxes'], boxes_seq=boxes_seq) + except Exception as e: + raise gr.Error(f"error when append message: {str(e)}") + else: + # text-only mode + if bool(boxes_seq): + grWarning("ignored boxes_seq in text-only mode") + boxes_placeholder_num = str(user_input).count('') + if boxes_placeholder_num: + gr.Error("use in input but no image found.") + ds.append_message(role=ds.roles[0], message=user_input) + + model_inputs = ds.to_model_input() + model_inputs['images'] = model_inputs['images'].to(torch.float16) + print(f"model_inputs: {model_inputs}") + + if do_sample: + gen_kwargs = dict( + use_cache=True, + do_sample=do_sample, + pad_token_id=tokenizer.pad_token_id, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + max_new_tokens=max_length, + top_p=top_p, + temperature=float(temperature), + ) + else: + gen_kwargs = dict( + use_cache=True, + do_sample=do_sample, + pad_token_id=tokenizer.pad_token_id, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + max_new_tokens=max_length, + ) + print(gen_kwargs) + input_ids = model_inputs['input_ids'] + st_time = time.time() + with torch.inference_mode(): + with torch.autocast(dtype=torch.float16, device_type='cuda'): + output_ids = model.generate(**model_inputs, **gen_kwargs) + print(f"done generated in {time.time() - st_time} seconds") + input_token_len = input_ids.shape[-1] + response = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] + print(f"response: {response}") + + # update new message + + def build_boxes_image(text, image): + if image is None: + return text, None + print(text, image) + import re + + colors = ['#ed7d31', '#5b9bd5', '#70ad47', '#7030a0', '#c00000', '#ffff00', "olive", "brown", "cyan",'#003366', '#b76e79', '#008080', '#8e44ad', '#ff6b6b','#dcd0ff', '#b7410e', '#bfff00', '#87ceeb', '#f1c40f'] + pat = re.compile(r'\[\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3}(?:;\d(?:\.\d*)?(?:,\d(?:\.\d*)?){3})*\]') + + def extract_boxes(string): + ret = [] + for bboxes_str in pat.findall(string): + bboxes = [] + bbox_strs = bboxes_str.replace("(", "").replace(")", "").replace("[", "").replace("]", "").split(";") + for bbox_str in bbox_strs: + bbox = list(map(float, bbox_str.split(','))) + bboxes.append(bbox) + ret.append(bboxes) + return ret + + extract_pred = extract_boxes(text) + boxes_to_draw = [] + color_to_draw = [] + for idx, boxes in enumerate(extract_pred): + color = colors[idx % len(colors)] + for box in boxes: + boxes_to_draw.append(de_norm_box_xyxy(box, w=image.width, h=image.height)) + color_to_draw.append(color) + if not boxes_to_draw: + return text, None + res = draw_bounding_boxes(image=image, boxes=boxes_to_draw, colors=color_to_draw, width=8) + from torchvision.transforms import ToPILImage + res = ToPILImage()(res) + _, path = tempfile.mkstemp(suffix='.jpg', dir=TEMP_FILE_DIR) + res.save(path) + add_submit_temp_image(state, path) + + # post process text color + print(text) + location_text = text + edit_text = list(text) + bboxes_str = pat.findall(text) + for idx in range(len(bboxes_str) - 1, -1, -1): + color = colors[idx % len(colors)] + boxes = bboxes_str[idx] + span = location_text.rfind(boxes), location_text.rfind(boxes) + len(boxes) + location_text = location_text[:span[0]] + edit_text[span[0]:span[1]] = f'{boxes}' + text = "".join(edit_text) + return text, path + + def convert_one_round_message(conv, image=None): + text_query = f"{conv[0][0]}: {conv[0][1]}" + text_answer = f"{conv[1][0]}: {conv[1][1]}" + text_query, image_query = build_boxes_image(text_query, image) + text_answer, image_answer = build_boxes_image(text_answer, image) + + new_chat = [] + new_chat.append([parse_text(text_query), None]) + if image_query is not None: + new_chat.append([(image_query,), None]) + + new_chat.append([None, parse_text(text_answer)]) + if image_answer is not None: + new_chat.append([None, (image_answer,)]) + return new_chat + + ds.append_message(role=ds.roles[1], message=response) + conv = ds.to_gradio_chatbot_new_messages() + new_message = convert_one_round_message(conv, image=mm_state.get('image', None)) + print(new_message) + state['_submit_new_message'] = new_message + return state, chatbot + + + def submit_step2(state, user_input, boxes_seq, chatbot): + if '_submit_new_message' in state: + chatbot.extend(state['_submit_new_message']) + del state['_submit_new_message'] + return state, None, None, chatbot + return state, user_input, boxes_seq, chatbot + + + submitBtn.click( + submit_step1, + [state, template, user_input, boxes_seq, chatbot, do_sample, max_length, top_p, temperature], + [state, chatbot], + ).then( + submit_step2, + [state, user_input, boxes_seq, chatbot], + [state, user_input, boxes_seq, chatbot], + ) + + print("launching...") + demo.queue().launch(server_name=args.server_name, server_port=args.server_port) diff --git a/mllm/engine/__init__.py b/mllm/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2c350d5a160f4741c866f01798e20b23617de81 --- /dev/null +++ b/mllm/engine/__init__.py @@ -0,0 +1,3 @@ +from .base_engine import TrainerForMMLLM, TrainerDifferentCollatorMixin +from .shikra import ShikraTrainer +from .builder import prepare_trainer_collator diff --git a/mllm/engine/__pycache__/__init__.cpython-310.pyc b/mllm/engine/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..edd2bc8d6a925efdc8be9e986373a8a45fc3d339 Binary files /dev/null and b/mllm/engine/__pycache__/__init__.cpython-310.pyc differ diff --git a/mllm/engine/__pycache__/base_engine.cpython-310.pyc b/mllm/engine/__pycache__/base_engine.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b05eac3634e718c2a96069f35f8ea5120329a72a Binary files /dev/null and b/mllm/engine/__pycache__/base_engine.cpython-310.pyc differ diff --git a/mllm/engine/__pycache__/builder.cpython-310.pyc b/mllm/engine/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3cd27932c1e25b99ead9472df05d2d3c76e3125 Binary files /dev/null and b/mllm/engine/__pycache__/builder.cpython-310.pyc differ diff --git a/mllm/engine/__pycache__/shikra.cpython-310.pyc b/mllm/engine/__pycache__/shikra.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73114cb3cd8a3f7f6b6a799af2ba2646c58d6953 Binary files /dev/null and b/mllm/engine/__pycache__/shikra.cpython-310.pyc differ diff --git a/mllm/engine/base_engine.py b/mllm/engine/base_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..b84bbffed2fd5f8d403e4d61e69b394369b1ef75 --- /dev/null +++ b/mllm/engine/base_engine.py @@ -0,0 +1,275 @@ +import os +import sys +import json +import logging +import warnings +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple, Union, Sequence, Mapping + +import torch +from torch import nn +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm +from transformers import Seq2SeqTrainer, DataCollator, DataCollatorForSeq2Seq +from transformers.deepspeed import is_deepspeed_zero3_enabled +from transformers.trainer import TRAINER_STATE_NAME + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout), ], +) + + +class TrainerDifferentCollatorMixin: + def __init__(self, + *args, + train_collator: Optional[DataCollator] = None, + eval_collator: Optional[DataCollator] = None, + test_collator: Optional[DataCollator] = None, + **kwargs): + if train_collator is None and eval_collator is None and test_collator is None: + raise ValueError("use different collator for trainer but get no collator function.") + if eval_collator is not None and test_collator is not None and eval_collator != test_collator: + warnings.warn('[WARNING!!!] use different collator for eval and test. but maybe do_eval and ' + 'do_predict both use trainer.predict (i.e. only test_collator is used.) u should' + 'check your code and know exactly what u are doing.') + self._train_collator = train_collator + self._eval_collator = eval_collator if eval_collator is not None else self._train_collator + self._test_collator = test_collator if test_collator is not None else self._eval_collator + if "data_collator" in kwargs and kwargs["data_collator"] is not None: + warnings.warn("use different collator for trainer but get 'data_collator' argument. It will take no effect and be ignored.") + super().__init__(*args, **kwargs) + + # noinspection PyAttributeOutsideInit,PyUnresolvedReferences + def get_train_dataloader(self) -> DataLoader: + old_collator = self.data_collator + self.data_collator = self._train_collator + dataloader = super().get_train_dataloader() + self.data_collator = old_collator + return dataloader + + # noinspection PyAttributeOutsideInit,PyUnresolvedReferences + def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: + old_collator = self.data_collator + self.data_collator = self._eval_collator + dataloader = super().get_eval_dataloader(eval_dataset) + self.data_collator = old_collator + return dataloader + + # noinspection PyAttributeOutsideInit,PyUnresolvedReferences + def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: + old_collator = self.data_collator + self.data_collator = self._test_collator + dataloader = super().get_test_dataloader(test_dataset) + self.data_collator = old_collator + return dataloader + + +# noinspection DuplicatedCode +class TrainerForMMLLM(TrainerDifferentCollatorMixin, Seq2SeqTrainer): + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + # Override to inject custom behavior. + + # noinspection PyUnresolvedReferences + if not self.args.predict_with_generate or prediction_loss_only: + return super().prediction_step( + model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys + ) + + has_labels = "labels" in inputs + inputs = self._prepare_inputs(inputs) + + gen_kwargs = self._gen_kwargs.copy() + if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: + gen_kwargs["max_length"] = self.model.config.max_length + gen_kwargs["num_beams"] = ( + gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams + ) + default_synced_gpus = True if is_deepspeed_zero3_enabled() else False + gen_kwargs["synced_gpus"] = ( + gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus + ) + + # filter keys + filter_keys = ["labels"] + for k in inputs: + if not (k in filter_keys): + gen_kwargs[k] = inputs[k] + self._logging_generate_kwargs(gen_kwargs.keys()) + with torch.inference_mode(): + with self.compute_loss_context_manager(): + generated_tokens = self.model.generate(**gen_kwargs) + + # TODO: rewrite official seq2seq_trainer to suppress generation_config warning + if self.model.generation_config._from_model_config: + self.model.generation_config._from_model_config = False + + # important for Decoder-Only LLM: only extract generated_tokens and discard origin inputs + generation_inputs = inputs['input_ids'] + generated_tokens = generated_tokens[:, generation_inputs.size()[-1]:] + + if self.model.generation_config._from_model_config: + self.model.generation_config._from_model_config = False + + # Retrieves GenerationConfig from model.generation_config + gen_config = self.model.generation_config + # in case the batch is shorter than max length, the output should be padded + if generated_tokens.shape[-1] < gen_config.max_length: + generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length) + elif gen_config.max_new_tokens is not None and generated_tokens.shape[-1] < gen_config.max_new_tokens + 1: + generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_new_tokens + 1) + + loss = None + + if self.args.prediction_loss_only: + return loss, None, None + + if has_labels: + labels = inputs["labels"] + if labels.shape[-1] < gen_config.max_length: + labels = self._pad_tensors_to_max_len(labels, gen_config.max_length) + elif gen_config.max_new_tokens is not None and labels.shape[-1] < gen_config.max_new_tokens + 1: + labels = self._pad_tensors_to_max_len(labels, gen_config.max_new_tokens + 1) + else: + labels = None + + return loss, generated_tokens, labels + + def _logging_generate_kwargs(self, keys): + if not hasattr(self, '_generate_kwargs'): + self._generate_kwargs = None + if self._generate_kwargs != keys: + self._generate_kwargs = keys + logger.warning(f"generate use kwargs: {keys}") + + def save_prediction(self, predict_results, file_key_prefix='predict'): + if not self.is_world_process_zero(): + return + + import numpy as np + os.makedirs(self.args.output_dir, exist_ok=True) + np.save(os.path.join(self.args.output_dir, f"{file_key_prefix}_predictions.npy"), predict_results.predictions) + np.save(os.path.join(self.args.output_dir, f"{file_key_prefix}_label_ids.npy"), predict_results.label_ids) + + preds, targets = predict_results.predictions, predict_results.label_ids + origin_preds, origin_targets = preds, targets + preds, targets = deepcopy(preds), deepcopy(targets) + logger.warning(f"preds shape: {preds.shape}. targets shape: {targets.shape}") + + # decode text and save to json takes forever for big test set + os.makedirs(self.args.output_dir, exist_ok=True) + with open(os.path.join(self.args.output_dir, f'{file_key_prefix}_extra_prediction.jsonl'), 'a', encoding="utf-8") as g: + for p, t, pi, ti in tqdm( + zip(preds, targets, origin_preds, origin_targets), + total=len(preds), desc=f"saving prediction for {file_key_prefix}", + ): + p[p < 0] = self.tokenizer.pad_token_id + t[t < 0] = self.tokenizer.pad_token_id + p = self.tokenizer.decode(p, skip_special_tokens=True, clean_up_tokenization_spaces=True) + t = self.tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True) + obj = dict( + pred=p, + target=t, + # pred_id=pi.tolist(), + # target_id=ti.tolist(), + ) + g.write(json.dumps(obj) + '\n') + g.flush() + + # transformers + FSDP + saving model -> cuda OOM for small memory gpu + # refer: https://github.com/tatsu-lab/stanford_alpaca/issues/65 + def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): + if self.fsdp is not None: + if output_dir is None: + output_dir = self.args.output_dir + from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, + FullStateDictConfig, + StateDictType, + ) + save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, save_policy): + cpu_state_dict = self.model.state_dict() + if self.args.should_save: + self._save(output_dir, state_dict=cpu_state_dict) # noqa + # Push to the Hub when `save_model` is called by the user. + if self.args.push_to_hub and not _internal_call: + self.push_to_hub(commit_message="Model save") + else: + super().save_model(output_dir, _internal_call) + + def plot_loss(self) -> None: + if not self.is_world_process_zero(): + return + + training_args = self.args + FIGURE_NAME = "trainer_state.png" + import matplotlib.pyplot as plt + data = json.load(open(os.path.join(training_args.output_dir, TRAINER_STATE_NAME), "r")) + train_steps, train_losses = [], [] + for i in range(len(data["log_history"]) - 1): + train_steps.append(data["log_history"][i]["step"]) + train_losses.append(data["log_history"][i]["loss"]) + plt.figure() + plt.plot(train_steps, train_losses) + plt.title("training loss of {}".format(training_args.output_dir)) + plt.xlabel("step") + plt.ylabel("training loss") + plt.savefig(os.path.join(training_args.output_dir, FIGURE_NAME), format="png", transparent=True, dpi=300) + print("Figure saved: {}".format(os.path.join(training_args.output_dir, FIGURE_NAME))) + + +class Seq2SeqDataCollator(DataCollatorForSeq2Seq): + def __init__( + self, + inference_mode: bool = False, + **kwargs, + ): + self.inference_mode = inference_mode + self.text_keys = ['input_ids', 'labels', 'attention_mask'] + super().__init__(**kwargs) + + def __call__(self, features: Sequence[Dict[str, Sequence]], return_tensors=None) -> Dict[str, torch.Tensor]: + # evaluation/inference adopts left-padding while training adopts right-padding + text_features = [{k: feature[k] for k in self.text_keys if k in feature} for feature in features] + + if self.inference_mode: + old_padding_side = self.tokenizer.padding_side + self.tokenizer.padding_side = 'left' + text_features = super().__call__(text_features) + self.tokenizer.padding_side = old_padding_side + else: + old_padding_side = self.tokenizer.padding_side + self.tokenizer.padding_side = 'right' + text_features = super().__call__(text_features) + self.tokenizer.padding_side = old_padding_side + + return text_features + + +class Seq2Seq2DataCollatorWithImage(Seq2SeqDataCollator): + def __init__(self, preprocessor, **kwargs): + super().__init__(tokenizer=preprocessor['text'], **kwargs) + + # noinspection PyMethodMayBeStatic + def _image_process(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + images = [feature['image'] for feature in features] + images = torch.stack(images, dim=0) + ret = dict(images=images) + return ret + + def __call__(self, features: List[Dict[str, Any]], return_tensors=None) -> Dict[str, torch.Tensor]: + ret = super().__call__(features, return_tensors) + image_outputs = self._image_process(features) + ret.update(image_outputs) + return ret diff --git a/mllm/engine/builder.py b/mllm/engine/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..c448b793be205dceaf6016ac623606f5ef813a40 --- /dev/null +++ b/mllm/engine/builder.py @@ -0,0 +1,30 @@ +from functools import partial +from typing import Tuple, Dict, Any, Type + +from transformers.trainer import DataCollator + +from .shikra import ShikraTrainer +from .base_engine import TrainerForMMLLM, Seq2Seq2DataCollatorWithImage + +TYPE2TRAINER = { + 'shikra': ShikraTrainer, +} + + +def prepare_trainer_collator( + model_args, + preprocessor: Dict[str, Any], + collator_kwargs: Dict[str, Any] +) -> Tuple[Type[TrainerForMMLLM], Dict[str, DataCollator]]: + type_ = model_args.type + trainer_cls = TYPE2TRAINER[type_] + data_collator_func = partial( + Seq2Seq2DataCollatorWithImage, + preprocessor=preprocessor, + **collator_kwargs, + ) + data_collator_dict = { + "train_collator": data_collator_func(inference_mode=False), + "eval_collator": data_collator_func(inference_mode=True), + } + return trainer_cls, data_collator_dict diff --git a/mllm/engine/shikra.py b/mllm/engine/shikra.py new file mode 100644 index 0000000000000000000000000000000000000000..cf4a29f2a73647b3aac826577e6512a78782c1aa --- /dev/null +++ b/mllm/engine/shikra.py @@ -0,0 +1,34 @@ +import os +from typing import Optional + +import torch +from transformers.trainer import unwrap_model + +from .base_engine import TrainerForMMLLM + + +class ShikraTrainer(TrainerForMMLLM): + def _save(self, output_dir: Optional[str] = None, state_dict=None): + if getattr(self.args, 'tune_mm_mlp_adapter', False): + # Save the model + _state_dict = state_dict + if _state_dict is None: + # Only save the model itself if we are using distributed training + model_to_save = unwrap_model(self.model) + _state_dict = model_to_save.state_dict() + + weight_to_save = {} + keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in'] + for k, v in _state_dict.items(): + if any(key_match in k for key_match in keys_to_match): + weight_to_save[k] = v + + current_folder = output_dir.split('/')[-1] + parent_folder = os.path.dirname(output_dir) + if current_folder.startswith('checkpoint-'): + mm_projector_folder = os.path.join(parent_folder, "mm_projector") + os.makedirs(mm_projector_folder, exist_ok=True) + torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin')) + else: + torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) + super(ShikraTrainer, self)._save(output_dir, state_dict) diff --git a/mllm/models/__init__.py b/mllm/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dc78cf93dd2a17e2d8ea9f8da2ab8224316c4b30 --- /dev/null +++ b/mllm/models/__init__.py @@ -0,0 +1,3 @@ +from . import shikra + +from .builder import load_pretrained diff --git a/mllm/models/__pycache__/__init__.cpython-310.pyc b/mllm/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3639c3c672e0f2e96b9650a94d49d87f1989a4aa Binary files /dev/null and b/mllm/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/mllm/models/builder/__init__.py b/mllm/models/builder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..89f355cdb5238085f61e4fbff960e615e19d1185 --- /dev/null +++ b/mllm/models/builder/__init__.py @@ -0,0 +1 @@ +from .builder import load_pretrained diff --git a/mllm/models/builder/__pycache__/__init__.cpython-310.pyc b/mllm/models/builder/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47c93cd99abccd0460c0483312e5030495caf61b Binary files /dev/null and b/mllm/models/builder/__pycache__/__init__.cpython-310.pyc differ diff --git a/mllm/models/builder/__pycache__/build_shikra.cpython-310.pyc b/mllm/models/builder/__pycache__/build_shikra.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1a94eb4dcbf23c65e40a571a855bf29239c9874 Binary files /dev/null and b/mllm/models/builder/__pycache__/build_shikra.cpython-310.pyc differ diff --git a/mllm/models/builder/__pycache__/builder.cpython-310.pyc b/mllm/models/builder/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27ea34a3c09020db2ad8c2479fa65da06b8e0c22 Binary files /dev/null and b/mllm/models/builder/__pycache__/builder.cpython-310.pyc differ diff --git a/mllm/models/builder/build_shikra.py b/mllm/models/builder/build_shikra.py new file mode 100644 index 0000000000000000000000000000000000000000..8ec51bafdf0c26c5efee3c272642089f0d9d112f --- /dev/null +++ b/mllm/models/builder/build_shikra.py @@ -0,0 +1,146 @@ +from typing import Dict, Any, Tuple + +import torch +import transformers +from torch import nn + +from ..shikra import ShikraLlamaForCausalLM + +PREPROCESSOR = Dict[str, Any] + +DEFAULT_PAD_TOKEN = "[PAD]" +DEFAULT_EOS_TOKEN = "
" +DEFAULT_BOS_TOKEN = "" +DEFAULT_UNK_TOKEN = "" + + +def load_pretrained_shikra(model_args, training_args, **kwargs) -> Tuple[nn.Module, PREPROCESSOR]: + model = ShikraLlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + **kwargs + ) + model.config.use_cache = False + if model_args.freeze_backbone: + model.model.requires_grad_(False) + + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + model_max_length=model_args.model_max_length, + padding_side="right", + use_fast=False, + ) + + assert model_args.version == 'v1' + if model_args.version == "v0": + if tokenizer.pad_token is None: + smart_tokenizer_and_embedding_resize( + special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), + tokenizer=tokenizer, + model=model, + ) + if "llama" in model_args.model_name_or_path: + tokenizer.add_special_tokens({ + "eos_token": DEFAULT_EOS_TOKEN, + "bos_token": DEFAULT_BOS_TOKEN, + "unk_token": DEFAULT_UNK_TOKEN, + }) + else: + tokenizer.pad_token = tokenizer.unk_token + + model_vision_dict = model.model.initialize_vision_modules( + vision_tower=model_args.vision_tower, + mm_vision_select_layer=model_args.mm_vision_select_layer, + pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter + ) + dtype = torch.float32 + if training_args.fp16: + dtype = torch.float16 + if training_args.bf16: + dtype = torch.bfloat16 + # HACK for quantization + if model.model.vision_tower[0].device != torch.device('meta'): + model.model.vision_tower[0].to(dtype=dtype, device=training_args.device) + else: + from transformers import CLIPVisionModel + model.model.vision_tower[0] = CLIPVisionModel.from_pretrained(model_args.vision_tower) # not quantize clip + # model.model.vision_tower[0] = CLIPVisionModel.from_pretrained(model_args.vision_tower, **kwargs) # quantize clip、 + vision_config = model_vision_dict['vision_config'] + + model.config.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter + if model_args.tune_mm_mlp_adapter: + model.requires_grad_(False) + for p in model.model.mm_projector.parameters(): + p.requires_grad = True + + model.config.freeze_mm_mlp_adapter = model_args.freeze_mm_mlp_adapter + if model_args.freeze_mm_mlp_adapter: + for p in model.model.mm_projector.parameters(): + p.requires_grad = False + + model.config.mm_use_im_start_end = model_args.mm_use_im_start_end + vision_config.use_im_start_end = model_args.mm_use_im_start_end + model.initialize_vision_tokenizer(mm_use_im_start_end=model_args.mm_use_im_start_end, + tokenizer=tokenizer, + device=training_args.device, + tune_mm_mlp_adapter=model_args.tune_mm_mlp_adapter, + pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter) + + params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad] + if len(params_no_grad) > 0: + if training_args.fsdp is not None and len(training_args.fsdp) > 0: + if len(params_no_grad) < 10: + print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'.format(len(params_no_grad), + params_no_grad)) + else: + print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'.format( + len(params_no_grad), ', '.join(params_no_grad[:10]))) + print("[WARNING] Attempting to use FSDP with partially frozen parameters, this is experimental.") + print( + "[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining") + + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP + + def patch_FSDP_use_orig_params(func): + def wrap_func(*args, **kwargs): + use_orig_params = kwargs.pop('use_orig_params', True) + return func(*args, **kwargs, use_orig_params=use_orig_params) + + return wrap_func + + FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__) + + preprocessor = dict( + image=model_vision_dict['image_processor'], + text=tokenizer, + conv=dict( + image_token_len=model_args.image_token_len, + sep_image_conv_front=model_args.sep_image_conv_front, + use_im_start_end=model_args.mm_use_im_start_end, + ) + ) + return model, preprocessor + + +def smart_tokenizer_and_embedding_resize( + special_tokens_dict: Dict, + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg diff --git a/mllm/models/builder/builder.py b/mllm/models/builder/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..7e3473f22fe25a1c5d1b0c897fa1e18d1c550638 --- /dev/null +++ b/mllm/models/builder/builder.py @@ -0,0 +1,16 @@ +from typing import Dict, Any, Tuple + +from torch import nn + +from .build_shikra import load_pretrained_shikra + +PREPROCESSOR = Dict[str, Any] + + +# TODO: Registry +def load_pretrained(model_args, training_args) -> Tuple[nn.Module, PREPROCESSOR]: + type_ = model_args.type + if type_ == 'shikra': + return load_pretrained_shikra(model_args, training_args) + else: + assert False diff --git a/mllm/models/shikra/__init__.py b/mllm/models/shikra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..62cdbc7d01a41fed337dcc332541e7d1d353a01e --- /dev/null +++ b/mllm/models/shikra/__init__.py @@ -0,0 +1 @@ +from .shikra import ShikraLlamaForCausalLM, ShikraConfig diff --git a/mllm/models/shikra/__pycache__/__init__.cpython-310.pyc b/mllm/models/shikra/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94fb89bb00cea3c7d35f72a4964b0da97da46397 Binary files /dev/null and b/mllm/models/shikra/__pycache__/__init__.cpython-310.pyc differ diff --git a/mllm/models/shikra/__pycache__/shikra.cpython-310.pyc b/mllm/models/shikra/__pycache__/shikra.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f27a5629186afc0827d76b2bf90a157d1f89d187 Binary files /dev/null and b/mllm/models/shikra/__pycache__/shikra.cpython-310.pyc differ diff --git a/mllm/models/shikra/apply_delta.py b/mllm/models/shikra/apply_delta.py new file mode 100644 index 0000000000000000000000000000000000000000..06cffa67f1d3cefe66d05e369180c1868f283ba2 --- /dev/null +++ b/mllm/models/shikra/apply_delta.py @@ -0,0 +1,47 @@ +""" +Usage: +python3 apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/shikra-7b --delta lmsys/shikra-7b-delta +""" +import argparse + +import torch +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM +from shikra import ShikraLlamaForCausalLM + + +def apply_delta(base_model_path, target_model_path, delta_path): + print("Loading base model") + base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + + print("Loading delta") + delta = ShikraLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) + + print("Applying delta") + for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): + if name not in base.state_dict(): + assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' + continue + if param.data.shape == base.state_dict()[name].shape: + param.data += base.state_dict()[name] + else: + assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ + f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' + bparam = base.state_dict()[name] + param.data[:bparam.shape[0], :bparam.shape[1]] += bparam + + print("Saving target model") + delta.save_pretrained(target_model_path) + delta_tokenizer.save_pretrained(target_model_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base", type=str, required=True) + parser.add_argument("--target", type=str, required=True) + parser.add_argument("--delta", type=str, required=True) + + args = parser.parse_args() + + apply_delta(args.base, args.target, args.delta) diff --git a/mllm/models/shikra/make_delta.py b/mllm/models/shikra/make_delta.py new file mode 100644 index 0000000000000000000000000000000000000000..5b80a48563bd3b037d1b5065990cf723933b9559 --- /dev/null +++ b/mllm/models/shikra/make_delta.py @@ -0,0 +1,50 @@ +""" +Usage: +python3 make_delta --base ~/model_weights/llama-7b --target ~/model_weights/shikra-7b --delta ~/model_weights/shikra-7b-delta +""" +import argparse + +import torch +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM +from shikra import ShikraLlamaForCausalLM + + +def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): + print("Loading base model") + base = AutoModelForCausalLM.from_pretrained(base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + + print("Loading target model") + target = ShikraLlamaForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + + print("Calculating delta") + for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): + if name not in base.state_dict(): + assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' + continue + if param.data.shape == base.state_dict()[name].shape: + param.data -= base.state_dict()[name] + else: + assert name in ['model.embed_tokens.weight', + 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' + bparam = base.state_dict()[name] + param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam + + print("Saving delta") + if hub_repo_id: + kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} + else: + kwargs = {} + target.save_pretrained(delta_path, **kwargs) + target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) + target_tokenizer.save_pretrained(delta_path, **kwargs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base", type=str, required=True) + parser.add_argument("--target", type=str, required=True) + parser.add_argument("--delta", type=str, required=True) + args = parser.parse_args() + + make_delta(args.base, args.target, args.delta, None) diff --git a/mllm/models/shikra/shikra.py b/mllm/models/shikra/shikra.py new file mode 100644 index 0000000000000000000000000000000000000000..b776cabd0ba15183a549831f889578b12edab1f8 --- /dev/null +++ b/mllm/models/shikra/shikra.py @@ -0,0 +1,412 @@ +from typing import List, Optional, Tuple, Union +from PIL import Image +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss +import torchvision.transforms.functional as TF +from transformers import LlamaConfig, LlamaModel, LlamaForCausalLM, CLIPVisionModel, CLIPImageProcessor,AutoImageProcessor, DeformableDetrModel + +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast + +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_IMAGE_PATCH_TOKEN = "" +DEFAULT_IM_START_TOKEN = "" +DEFAULT_IM_END_TOKEN = "" + + +Rxn_st = "" +Rxn_ed = "" # Reaction +Rct_st = "" +Rct_ed = "" # Reactant +Prd_st = " " +Prd_ed = "" # Product +Cnd_st = "" +Cnd_ed = "" + +Mol = "[Str]" # Molecule +Txt = "[Txt]" # Text +Sol = "[Sol]" +Age = "[Age]" +Tem = "[Tem]" +Yld = "[Yld]" +Obj = "[Obj]" + +rxn_tokens = [Rxn_st, Rxn_ed,Rct_st, Rct_ed, Prd_st, Prd_ed, Cnd_st, Cnd_ed, Mol, Txt,Sol,Age,Tem, Yld, Obj] +number_tokens = [f"{i:03}" for i in range(1, 1000)] +ID_tokens = [f"" for i in range(1, 51)] + +def resize_batch(images, size): + """ + Resize a batch of images to the given size. + + Args: + - images (torch.Tensor): Input tensor of shape (B, C, H, W) + - size (tuple): Desired output size (new_h, new_w) + + Returns: + - torch.Tensor: Resized images of shape (B, C, new_h, new_w) + """ + resized_images = [] + for image in images: + # Resize image and add it to the list + resized = TF.resize(image, size, interpolation=Image.BICUBIC) + resized_images.append(resized) + + # Stack all resized images along the batch dimension + return torch.stack(resized_images) + +class VisionLanguageAdapter(nn.Module): + def __init__(self, feature_dim=1280, num_queries=256, num_heads=16): + super(VisionLanguageAdapter, self).__init__() + self.num_queries = num_queries + self.query_embeds = nn.Parameter(torch.randn(num_queries, feature_dim)) + self.cross_attention = nn.MultiheadAttention(embed_dim=feature_dim, num_heads=num_heads, batch_first=True) + self.positional_encoding = nn.Parameter(torch.randn(num_queries, feature_dim)) + self.layer_norm = nn.LayerNorm(feature_dim) + self.linear = nn.Linear(feature_dim, 5120) + def forward(self, image_features): + # Add positional encoding to query embeddings + query_embeds = self.query_embeds + self.positional_encoding + + # Flag to check if input was unbatched + was_unbatched = image_features.dim() == 2 + + # Adjust dimensions based on whether input is batched or unbatched + if was_unbatched: + # For unbatched input, add a batch dimension for compatibility + image_features = image_features.unsqueeze(0) + query_embeds = query_embeds.unsqueeze(0) + else: + # For batched input, adjust the query embeddings to match the batch size + batch_size = image_features.size(0) + query_embeds = query_embeds.unsqueeze(0).expand(batch_size, -1, -1) + + # Apply cross attention + attn_output, _ = self.cross_attention(query=query_embeds, key=image_features, value=image_features) + + attn_output = self.layer_norm(attn_output) + attn_output = self.linear(attn_output) + + # If the input was unbatched, remove the batch dimension from the output + if was_unbatched: + attn_output = attn_output.squeeze(0) + + return attn_output + + +class ShikraConfig(LlamaConfig): + model_type = "shikra" + + +class ShikraLlamaModel(LlamaModel): + config_class = ShikraConfig + + def __init__(self, config: LlamaConfig, mm_vision_tower=None, mm_hidden_size=None): + super(ShikraLlamaModel, self).__init__(config) + + if hasattr(config, "mm_vision_tower"): + # HACK: for FSDP + self.vision_tower = nn.ModuleList([DeformableDetrModel.from_pretrained(config.mm_vision_tower)]) + #self.vision_tower = nn.ModuleList([CLIPVisionModel.from_pretrained(config.mm_vision_tower)]) + + if hasattr(config, "use_mm_proj"): + self.mm_projector = nn.Linear(256, config.hidden_size) + + def initialize_vision_modules(self, vision_tower, mm_vision_select_layer, + pretrain_mm_mlp_adapter=None, tune_mm_mlp_adapter=False): + self.config.mm_vision_tower = vision_tower + + image_processor = AutoImageProcessor.from_pretrained(vision_tower) + #image_processor = CLIPImageProcessor.from_pretrained(vision_tower) + + if not hasattr(self, 'vision_tower'): + vision_tower = DeformableDetrModel.from_pretrained(vision_tower) + #vision_tower = CLIPVisionModel.from_pretrained(vision_tower) + self.vision_tower = nn.ModuleList([vision_tower]) # 使用 ModuleList 包装模型 + else: + self.vision_tower[0] = DeformableDetrModel.from_pretrained(vision_tower) + #self.vision_tower[0] = CLIPVisionModel.from_pretrained(vision_tower)# 直接赋值到 ModuleList 中的相应位置 + + # 设置模型为训练模式 + self.vision_tower[0].requires_grad_(True) + self.vision_tower[0] = self.vision_tower[0].to(torch.float16) + + vision_config = self.vision_tower[0].config + num_patches = 300 + self.config.use_mm_proj = True + self.config.mm_hidden_size = 256 + self.config.mm_vision_select_layer = mm_vision_select_layer + + if not hasattr(self, 'mm_projector'): + self.mm_projector = nn.Linear(vision_config.hidden_size, self.config.hidden_size) + + if pretrain_mm_mlp_adapter is not None: + mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') + self.mm_projector.load_state_dict({k.split('.')[-1]: v for k, v in mm_projector_weights.items()}) + + return dict( + image_processor=image_processor, + image_token_len=num_patches, + vision_config=vision_config + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + + orig_embeds_params = getattr(self, 'orig_embeds_params', None) + # if orig_embeds_params is not None: + # orig_embeds_params = orig_embeds_params[0] + # with torch.no_grad(): + # self.get_input_embeddings().weight.data[:-2] = orig_embeds_params[:-2].data + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + vision_tower = getattr(self, 'vision_tower', None) + if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None: + # TODO: this is a modified multimodal LLM -- Haotian Liu + vision_tower = vision_tower[0] # HACK: for FSDP + new_size = (1333, 1333) + images = resize_batch(images, new_size) + with torch.no_grad(): + if type(images) is list: + # variable length images + image_features = [] + for image in images: + image_forward_out = vision_tower(image.unsqueeze(0)) + select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1) + image_feature = image_forward_out.last_hidden_state + # image_feature = select_hidden_state[:, 1:] + image_features.append(image_feature) + #print(image_features.shape) + else: + #print(images.shape) + image_forward_outs = vision_tower(images) + select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1) + image_features = image_forward_outs.last_hidden_state + # print(image_features.shape) + + # image_forward_outs = vision_tower(images, output_hidden_states=True) + # select_hidden_state_layer = getattr(self.config, "mm_vision_select_layer", -1) + # select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer] + # image_features = select_hidden_state[:, 1:] + + if type(images) is list: + image_features = [self.mm_projector(image_feature)[0] for image_feature in image_features] + else: + image_features = self.mm_projector(image_features) + + dummy_image_features = torch.zeros(300, 256, device=inputs_embeds.device, dtype=inputs_embeds.dtype) + dummy_image_features = self.mm_projector(dummy_image_features) + + new_input_embeds = [] + cur_image_idx = 0 + for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds): + if (cur_input_ids == vision_tower.config.im_patch_token).sum() == 0: + # multimodal LLM, but the current sample is not multimodal + cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum() + new_input_embeds.append(cur_input_embeds) + continue + if vision_tower.config.use_im_start_end: + cur_image_features = image_features[cur_image_idx] + num_patches = cur_image_features.shape[0] + if (cur_input_ids == vision_tower.config.im_start_token).sum() != ( + cur_input_ids == vision_tower.config.im_end_token).sum(): + raise ValueError("The number of image start tokens and image end tokens should be the same.") + image_start_tokens = torch.where(cur_input_ids == vision_tower.config.im_start_token)[0] + for image_start_token_pos in image_start_tokens: + cur_image_features = image_features[cur_image_idx].to(device=cur_input_embeds.device) + num_patches = cur_image_features.shape[0] + if cur_input_ids[image_start_token_pos + num_patches + 1] != vision_tower.config.im_end_token: + raise ValueError("The image end token should follow the image start token.") + if orig_embeds_params is not None: + cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), + cur_input_embeds[image_start_token_pos:image_start_token_pos + 1], + cur_image_features, cur_input_embeds[ + image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], + cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0) + else: + cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos + 1], cur_image_features, + cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0) + cur_image_idx += 1 + new_input_embeds.append(cur_new_input_embeds) + else: + cur_image_features = image_features[cur_image_idx] + num_patches = cur_image_features.shape[0] + if (cur_input_ids == vision_tower.config.im_patch_token).sum() != num_patches: + raise ValueError("The number of image patch tokens should be the same as the number of image patches.") + masked_indices = torch.where(cur_input_ids == vision_tower.config.im_patch_token)[0] + mask_index_start = masked_indices[0] + if (masked_indices != torch.arange(mask_index_start, mask_index_start + num_patches, device=masked_indices.device, + dtype=masked_indices.dtype)).any(): + raise ValueError("The image patch tokens should be consecutive.") + if orig_embeds_params is not None: + cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(), cur_image_features, + cur_input_embeds[mask_index_start + num_patches:].detach()), dim=0) + else: + cur_new_input_embeds = torch.cat( + (cur_input_embeds[:mask_index_start], cur_image_features, cur_input_embeds[mask_index_start + num_patches:]), + dim=0) + new_input_embeds.append(cur_new_input_embeds) + inputs_embeds = torch.stack(new_input_embeds, dim=0) + + return super(ShikraLlamaModel, self).forward( + input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, + inputs_embeds=inputs_embeds, use_cache=use_cache, + output_attentions=output_attentions, output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + +class ShikraLlamaForCausalLM(LlamaForCausalLM): + config_class = ShikraConfig + + def __init__(self, config: ShikraConfig): + super(LlamaForCausalLM, self).__init__(config) + self.model = ShikraLlamaModel(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + images: Optional[torch.FloatTensor] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + images=images + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model/pipeline parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "images": kwargs.get("images", None), + } + ) + return model_inputs + + def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, device, + tune_mm_mlp_adapter=False, pretrain_mm_mlp_adapter=None): + vision_config = self.model.vision_tower[0].config + vision_config.use_im_start_end = mm_use_im_start_end + tokenizer.add_tokens(rxn_tokens) + tokenizer.add_tokens(ID_tokens) + #tokenizer.add_tokens(number_tokens) + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + + if mm_use_im_start_end: + num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) + self.resize_token_embeddings(len(tokenizer)) + vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) + + if num_new_tokens > 0: + input_embeddings = self.get_input_embeddings().weight.data + output_embeddings = self.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + if tune_mm_mlp_adapter: + self.model.orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)] + for p in self.get_input_embeddings().parameters(): + p.requires_grad = True + for p in self.get_output_embeddings().parameters(): + p.requires_grad = False + + if pretrain_mm_mlp_adapter: + mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') + embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] + assert num_new_tokens == 2 + if input_embeddings.shape == embed_tokens_weight.shape: + input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] + elif embed_tokens_weight.shape[0] == num_new_tokens: + input_embeddings[-num_new_tokens:] = embed_tokens_weight + else: + raise ValueError( + f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") + + vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] \ No newline at end of file diff --git a/mllm/pipeline/__init__.py b/mllm/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mllm/pipeline/finetune.py b/mllm/pipeline/finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..59eab434f8416b0800a18e9ae78be15540205f4f --- /dev/null +++ b/mllm/pipeline/finetune.py @@ -0,0 +1,141 @@ +import os +import sys +import logging +import pathlib +import typing +import warnings + +SLURM_ENV = {k: v for k, v in os.environ.items() if 'SLURM' in k} +if SLURM_ENV: + print(f"SLURM_ENV: {SLURM_ENV}") +project_path = pathlib.Path(__file__).parent.parent.parent +sys.path.append(str(project_path)) + +import torch +import torch.cuda + +from mllm.config import prepare_args +from mllm.models import load_pretrained +from mllm.utils import print_trainable_params +from mllm.engine import prepare_trainer_collator +from mllm.dataset import prepare_data, prepare_target_processor + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout), ], +) + + +def main(): + cfg, training_args = prepare_args() + model, preprocessor = load_pretrained(cfg.model_args, training_args) + # Some ugly codes to inject target_processor into preprocessor. + # maybe effect model. (e.g. add special token; resize embedding) + model, preprocessor = prepare_target_processor(model, preprocessor, cfg.model_args, training_args) + print_trainable_params(model) + + # Prepare data_collator + collator_kwargs = cfg.data_args.collator_kwargs + trainer_cls, data_collator_dict = prepare_trainer_collator(cfg.model_args, preprocessor, collator_kwargs) + dataset, compute_metrics = prepare_data(cfg.data_args, cfg.model_args, training_args, preprocessor) + + # Initialize Trainer + trainer = trainer_cls( + model=model, + args=training_args, + tokenizer=preprocessor['text'], + train_dataset=dataset['train'] if training_args.do_train else None, + eval_dataset=dataset['validation'] if training_args.do_eval else None, + compute_metrics=compute_metrics if training_args.predict_with_generate else None, + **data_collator_dict, + ) + + # Training + if training_args.do_train: + try: + if (not training_args.overwrite_output_dir) and list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): + train_result = trainer.train(resume_from_checkpoint=True) + else: + train_result = trainer.train() + trainer.log_metrics("train", train_result.metrics) # noqa + trainer.save_metrics("train", train_result.metrics) # noqa + trainer.save_model() + except RuntimeError as e: + print(f"got RuntimeError: {e.args}") + try: + print(f"#### device {training_args.local_rank} summary ####\n{torch.cuda.memory_summary(training_args.local_rank)}") + except Exception as inner_e: + print(f"get Exception when show cuda summary: {inner_e.args}") + raise e + finally: + trainer.save_state() # noqa + trainer.plot_loss() + + # save cfg to output_dir + try: + output_dir = training_args.output_dir + pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True) + cfg.dump(os.path.join(output_dir, "cfg.py")) + except Exception as e: + warnings.warn(f'try to save cfg to output_dir, but get exception {e.args}') + + # Keyword arguments for `model.generate` + gen_kwargs = dict(cfg.data_args.gen_kwargs) + gen_kwargs.setdefault('use_cache', True) + # important for use model.generate in batch mode. some model config with wrong special_token_id + # (e.g. shikra generationConfig set pad_token_id to -1) + if hasattr(cfg.model_args, 'gen_kwargs_set_pad_token_id') and cfg.model_args.gen_kwargs_set_pad_token_id: + gen_kwargs['pad_token_id'] = preprocessor['text'].pad_token_id + if hasattr(cfg.model_args, 'gen_kwargs_set_bos_token_id') and cfg.model_args.gen_kwargs_set_bos_token_id: + gen_kwargs['bos_token_id'] = preprocessor['text'].bos_token_id + if hasattr(cfg.model_args, 'gen_kwargs_set_eos_token_id') and cfg.model_args.gen_kwargs_set_eos_token_id: + gen_kwargs['eos_token_id'] = preprocessor['text'].eos_token_id + + # Evaluation + if training_args.do_eval: + if hasattr(trainer, '_test_collator') and hasattr(trainer, '_eval_collator') \ + and trainer._test_collator != trainer._eval_collator: # noqa + warnings.warn('[WARNING!!!] use different collator for eval and test. but do_eval and ' + 'do_predict both use trainer.predict (i.e. only test_collator is used.)') + eval_results = trainer.predict(dataset['validation'], metric_key_prefix="eval", **gen_kwargs) + trainer.log_metrics("eval", eval_results.metrics) # noqa + trainer.save_metrics("eval", eval_results.metrics) # noqa + trainer.save_prediction(eval_results, file_key_prefix='eval') + + # Predict + if training_args.do_predict: + predict_results = trainer.predict(dataset['test'], metric_key_prefix="test", **gen_kwargs) + trainer.log_metrics("test", predict_results.metrics) # noqa + trainer.save_metrics("test", predict_results.metrics) # noqa + trainer.save_prediction(predict_results, file_key_prefix='test') + + # Multi Predict + if training_args.do_multi_predict: + old_compute_metrics = trainer.compute_metrics + multitest = dataset['multitest'] + multitest = typing.cast(dict, multitest) + for _idx, (k, item) in enumerate(multitest.items()): + print(f'processing multitest set {_idx}/{len(multitest)}: {k}') + _ds = item['dataset'] + _compute_metrics = item['compute_metric'] + _prefix = f"multitest_{k}" + + trainer.compute_metrics = _compute_metrics + _pred_results = trainer.predict(_ds, metric_key_prefix=_prefix, **gen_kwargs) + trainer.log_metrics(_prefix, _pred_results.metrics) # noqa + trainer.save_metrics(_prefix, _pred_results.metrics) # noqa + trainer.save_prediction(_pred_results, file_key_prefix=_prefix) + trainer.compute_metrics = old_compute_metrics + + +# noinspection PyUnusedLocal +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/mllm/pipeline/finetune_mem.py b/mllm/pipeline/finetune_mem.py new file mode 100644 index 0000000000000000000000000000000000000000..80151e06e0c64ce2c56a58d9c59fae86f26e7af8 --- /dev/null +++ b/mllm/pipeline/finetune_mem.py @@ -0,0 +1,25 @@ +# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: +# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: +# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. + +import sys +import pathlib +project_path = pathlib.Path(__file__).parent.parent.parent +sys.path.append(str(project_path)) + +# Need to call this before importing transformers. +from mllm.utils.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn + +replace_llama_attn_with_flash_attn() + +from mllm.pipeline.finetune import main + + +# noinspection PyUnusedLocal +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/mllm/utils/__init__.py b/mllm/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c79560824e77821a1285520795aefffef1d1698c --- /dev/null +++ b/mllm/utils/__init__.py @@ -0,0 +1,8 @@ +from .common import ( + print_trainable_params, + show, + draw_bounding_boxes, + post_process_generate_ids, + decode_generate_ids, + smart_tokenizer_and_embedding_resize, +) diff --git a/mllm/utils/__pycache__/__init__.cpython-310.pyc b/mllm/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b4b791d4d2ea6ebe49d2ee3a431ac14e577b2ca Binary files /dev/null and b/mllm/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/mllm/utils/__pycache__/common.cpython-310.pyc b/mllm/utils/__pycache__/common.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c57f1c3c5b8cd48224673d2b7764323b3bfc17f7 Binary files /dev/null and b/mllm/utils/__pycache__/common.cpython-310.pyc differ diff --git a/mllm/utils/common.py b/mllm/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..287c5d17305348bfe8f02a1876a0a2a2347c6470 --- /dev/null +++ b/mllm/utils/common.py @@ -0,0 +1,97 @@ +import copy +from typing import List, Union, Dict + +import PIL.Image +import torch +import numpy as np +import torchvision.transforms.functional as F +import transformers +from matplotlib import pyplot as plt + +from transformers import PreTrainedTokenizer + + +def print_trainable_params(model: torch.nn.Module) -> None: + trainable_params, all_param = 0, 0 + for param in model.parameters(): + num_params = param.numel() + # if using DS Zero 3 and the weights are initialized empty + if num_params == 0 and hasattr(param, "ds_numel"): + num_params = param.ds_numel + all_param += num_params + if param.requires_grad: + trainable_params += num_params + print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( + trainable_params, all_param, 100 * trainable_params / all_param)) + + +def post_process_generate_ids(tokenizer: PreTrainedTokenizer, ids: torch.Tensor): + ids = copy.deepcopy(ids) # do not modify origin preds and targets + ids[ids < 0] = tokenizer.pad_token_id + return ids + + +def decode_generate_ids(tokenizer: PreTrainedTokenizer, ids: torch.Tensor) -> Union[List[str], str]: + assert ids.ndim in [1, 2] + only_one_sentence = ids.ndim == 1 + if only_one_sentence: + ids = ids.unsqueeze(0) + ids = post_process_generate_ids(tokenizer, ids) + res = tokenizer.batch_decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) + if only_one_sentence: + return res[0] + return res + + +def show(imgs: Union[torch.Tensor, List[Union[torch.Tensor, PIL.Image.Image]]]): + if not isinstance(imgs, list): + imgs = [imgs] + fig, axs = plt.subplots(ncols=len(imgs), squeeze=False) + for i, img in enumerate(imgs): + if isinstance(img, torch.Tensor): + img = img.detach() + img = F.to_pil_image(img) + axs[0, i].imshow(np.asarray(img)) + axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) + + +def draw_bounding_boxes( + image: Union[torch.Tensor, PIL.Image.Image], + boxes: Union[torch.Tensor, List, np.ndarray], + **kwargs, +): + if isinstance(image, PIL.Image.Image): + from torchvision.transforms import PILToTensor + image = PILToTensor()(image) + assert isinstance(image, torch.Tensor), "" + + if not isinstance(boxes, torch.Tensor): + boxes = torch.as_tensor(boxes) + assert isinstance(boxes, torch.Tensor) + + from torchvision.utils import draw_bounding_boxes as _draw_bounding_boxes + return _draw_bounding_boxes(image, boxes, **kwargs) + + +# https://github.com/huggingface/tokenizers/issues/247#issuecomment-675458087 +def smart_tokenizer_and_embedding_resize( + special_tokens_dict: Dict, + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg diff --git a/mllm/utils/llama_flash_attn_monkey_patch.py b/mllm/utils/llama_flash_attn_monkey_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..3daef763fd3dec2ab6899a011f3cceba135ec3d0 --- /dev/null +++ b/mllm/utils/llama_flash_attn_monkey_patch.py @@ -0,0 +1,102 @@ +# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: +from typing import List, Optional, Tuple + +import torch +from torch import nn + +import transformers +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb + +from einops import rearrange + +from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func +from flash_attn.bert_padding import unpad_input, pad_input + + +def forward( + self, + hidden_states: torch.Tensor, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], +Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel + + attention_mask: [bsz, q_len] + """ + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view( + bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view( + bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view( + bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + # [bsz, q_len, nh, hd] + # [bsz, nh, q_len, hd] + + kv_seq_len = key_states.shape[-2] + offset = 0 + if past_key_value is not None: + offset = past_key_value[0].shape[-2] + kv_seq_len += offset + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, + key_states, + cos, + sin, + offset=offset) + # [bsz, nh, t, hd] + assert not output_attentions, "output_attentions is not supported" + assert not use_cache, "use_cache is not supported" + assert past_key_value is None, "past_key_value is not supported" + + # Flash attention codes from + # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py + + # transform the data into the format required by flash attention + qkv = torch.stack([query_states, key_states, value_states], dim=2) # [bsz, nh, 3, q_len, hd] + qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] + # We have disabled _prepare_decoder_attention_mask in LlamaModel + # the attention_mask should be the same as the key_padding_mask + key_padding_mask = attention_mask + + if key_padding_mask is None: + qkv = rearrange(qkv, 'b s ... -> (b s) ...') + max_s = q_len + cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, + device=qkv.device) + output = flash_attn_unpadded_qkvpacked_func( + qkv, cu_q_lens, max_s, 0.0, + softmax_scale=None, causal=True + ) + output = rearrange(output, '(b s) ... -> b s ...', b=bsz) + else: + nheads = qkv.shape[-2] + x = rearrange(qkv, 'b s three h d -> b s (three h d)') + x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) + x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) + output_unpad = flash_attn_unpadded_qkvpacked_func( + x_unpad, cu_q_lens, max_s, 0.0, + softmax_scale=None, causal=True + ) + output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), + indices, bsz, q_len), + 'b s (h d) -> b s h d', h=nheads) + return self.o_proj(rearrange(output, + 'b s h d -> b s (h d)')), None, None + + +# Disable the transformation of the attention mask in LlamaModel as the flash attention +# requires the attention mask to be the same as the key_padding_mask +def _prepare_decoder_attention_mask(self, attention_mask, input_shape, + inputs_embeds, past_key_values_length): + # [bsz, seq_len] + return attention_mask + + +def replace_llama_attn_with_flash_attn(): + transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + transformers.models.llama.modeling_llama.LlamaAttention.forward = forward