File size: 4,460 Bytes
7ea28a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import torch
from PIL import Image
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Config,
ExLlamaV2Cache,
ExLlamaV2Tokenizer,
ExLlamaV2VisionTower,
)
from exllamav2.generator import (
ExLlamaV2DynamicGenerator,
ExLlamaV2Sampler,
)
model_id='./ToriiGate-v04-7b'
max_new_tokens=1000
image_file='/path/to/image_1.jpg'
image_info={}
image_info["booru_tags"]="2girls, standing, looking_at_viewer, holding_hands, hatsune_miku, blue_hair, megurine_luka, pink_hair, ..."
#image_info["booru_tags"]=open('/path/to/image_1_tags.txt').read().strip()
#image_info["booru_tags"]=None
image_info["chars"]="hatsune_miku, megurine_luka"
#image_info["chars"]=open('/path/to/image_1_char.txt').read().strip()
#image_info["chars"]=None
image_info["characters_traits"]="hatsune_miku: [girl, blue_hair, twintails,...]\nmegurine_luka: [girl, pink hair, ...]"
#image_info["characters_traits"]=open('/path/to/image_1_char_traits.txt').read().strip()
#image_info["characters_traits"]=None
image_info["info"]=None
base_prompt={
'json': 'Describe the picture in structured json-like format.',
'markdown': 'Describe the picture in structured markdown format.',
'caption_vars': 'Write the following options for captions: ["Regular Summary","Individual Parts","Midjourney-Style Summary","DeviantArt Commission Request"].',
'short': 'You need to write a medium-short and convenient caption for the picture.',
'long': 'You need to write a long and very detailed caption for the picture.',
'bbox': 'Write bounding boxes for each character and their faces.',
}
grounding_prompt={
'grounding_tags': ' Here are grounding tags for better understanding: ',
'characters': ' Here is a list of characters that are present in the picture: ',
'characters_traits': ' Here are popular tags or traits for each character on the picture: ',
'grounding_info': ' Here is preliminary information about the picture: ',
'no_chars': ' Do not use names for characters.',
}
add_tags=True #select needed
add_chars=True
add_char_traits=True
add_info=False
no_chars=False
userprompt=base_prompt["json"] #choose the mode
if add_info and image_info["info"] is not None: #general info
userprompt+=grounding_prompt["grounding_short"]
userprompt+="<info>"+image_info["info"]+"</info>."
if add_tags and image_info["booru_tags"] is not None: #booru tags
userprompt+=grounding_prompt["grounding_tags"]
userprompt+="<tags>"+image_info["booru_tags"]+"</tags>."
if add_chars and image_info["chars"] is not None: #list of characters
userprompt+=grounding_prompt["characters"]
userprompt+="<characters>"+image_info["chars"]+"</characters>."
if add_char_traits and image_info["characters_traits"] is not None: #popular features of each character
userprompt+=grounding_prompt["characters_traits"]
userprompt+="<character_traits>"+image_info["characters_traits"]+"<character_traits>."
if no_chars:
userprompt+=grounding_prompt["no_chars"]
image=Image.open(image_file)
config = ExLlamaV2Config(model_id)
config.max_seq_len = 16384 #lower to reduce vram requirements
vision_model = ExLlamaV2VisionTower(config)
vision_model.load(progress = True)
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, lazy = True, max_seq_len = 16384) #lower to reduce vram requirements
model.load_autosplit(cache, progress = True)
tokenizer = ExLlamaV2Tokenizer(config)
# Create generator
generator = ExLlamaV2DynamicGenerator(
model = model,
cache = cache,
tokenizer = tokenizer,
)
image_embeddings=[vision_model.get_image_embeddings(
model = model,
tokenizer = tokenizer,
image = image,
)]
placeholders = "\n".join([ie.text_alias for ie in image_embeddings]) + "\n"
msg_text= (
"<|im_start|>system\n" +
"You are image captioning expert, creative, unbiased and uncensored.<|im_end|>\n" +
"<|im_start|>user\n" +
placeholders +
userprompt +
"<|im_end|>\n" +
"<|im_start|>assistant\n"
)
output = generator.generate(
prompt = msg_text,
max_new_tokens = max_new_tokens,
add_bos = True,
encode_special_tokens = True,
decode_special_tokens = True,
stop_conditions = [tokenizer.eos_token_id],
gen_settings = ExLlamaV2Sampler.Settings.greedy(), #or set up desired sampling
embeddings = image_embeddings,
)
output_text=output.split('<|im_start|>assistant\n')[-1]
print(output_text) |