Upload 2 files
Browse files- joycaption.py +6 -122
joycaption.py
CHANGED
|
@@ -33,8 +33,9 @@ use_inference_client = False
|
|
| 33 |
PIXTRAL_PATH = "mistral-community/pixtral-12b"
|
| 34 |
|
| 35 |
llm_models = {
|
| 36 |
-
"
|
| 37 |
#PIXTRAL_PATH: None,
|
|
|
|
| 38 |
"Sao10K/Llama-3.1-8B-Stheno-v3.4": None,
|
| 39 |
"unsloth/Meta-Llama-3.1-8B-bnb-4bit": None,
|
| 40 |
"DevQuasar/HermesNova-Llama-3.1-8B": None,
|
|
@@ -157,6 +158,8 @@ def load_text_model(model_name: str=MODEL_PATH, gguf_file: Union[str, None]=None
|
|
| 157 |
else:
|
| 158 |
text_model = LlavaForConditionalGeneration.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16).eval()
|
| 159 |
image_adapter = AutoProcessor.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16)
|
|
|
|
|
|
|
| 160 |
|
| 161 |
print("Loading tokenizer")
|
| 162 |
if gguf_file: tokenizer = AutoTokenizer.from_pretrained(model_name, gguf_file=gguf_file, use_fast=True, legacy=False)
|
|
@@ -217,88 +220,10 @@ clip_model.eval().requires_grad_(False).to(device)
|
|
| 217 |
load_text_model()
|
| 218 |
|
| 219 |
@spaces.GPU()
|
| 220 |
-
@torch.
|
| 221 |
-
def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: Union[str, int]) -> str:
|
| 222 |
-
torch.cuda.empty_cache()
|
| 223 |
-
|
| 224 |
-
# 'any' means no length specified
|
| 225 |
-
length = None if caption_length == "any" else caption_length
|
| 226 |
-
|
| 227 |
-
if isinstance(length, str):
|
| 228 |
-
try:
|
| 229 |
-
length = int(length)
|
| 230 |
-
except ValueError:
|
| 231 |
-
pass
|
| 232 |
-
|
| 233 |
-
# 'rng-tags' and 'training_prompt' don't have formal/informal tones
|
| 234 |
-
if caption_type == "rng-tags" or caption_type == "training_prompt":
|
| 235 |
-
caption_tone = "formal"
|
| 236 |
-
|
| 237 |
-
# Build prompt
|
| 238 |
-
prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
|
| 239 |
-
if prompt_key not in CAPTION_TYPE_MAP:
|
| 240 |
-
raise ValueError(f"Invalid caption type: {prompt_key}")
|
| 241 |
-
|
| 242 |
-
prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length)
|
| 243 |
-
print(f"Prompt: {prompt_str}")
|
| 244 |
-
|
| 245 |
-
# Preprocess image
|
| 246 |
-
#image = clip_processor(images=input_image, return_tensors='pt').pixel_values
|
| 247 |
-
image = input_image.resize((384, 384), Image.LANCZOS)
|
| 248 |
-
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
|
| 249 |
-
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
| 250 |
-
pixel_values = pixel_values.to('cuda')
|
| 251 |
-
|
| 252 |
-
# Tokenize the prompt
|
| 253 |
-
prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
|
| 254 |
-
|
| 255 |
-
# Embed image
|
| 256 |
-
with torch.amp.autocast_mode.autocast('cuda', enabled=True):
|
| 257 |
-
vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
|
| 258 |
-
image_features = vision_outputs.hidden_states
|
| 259 |
-
embedded_images = image_adapter(image_features)
|
| 260 |
-
embedded_images = embedded_images.to('cuda')
|
| 261 |
-
|
| 262 |
-
# Embed prompt
|
| 263 |
-
prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
|
| 264 |
-
assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
|
| 265 |
-
embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
|
| 266 |
-
eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
|
| 267 |
-
|
| 268 |
-
# Construct prompts
|
| 269 |
-
inputs_embeds = torch.cat([
|
| 270 |
-
embedded_bos.expand(embedded_images.shape[0], -1, -1),
|
| 271 |
-
embedded_images.to(dtype=embedded_bos.dtype),
|
| 272 |
-
prompt_embeds.expand(embedded_images.shape[0], -1, -1),
|
| 273 |
-
eot_embed.expand(embedded_images.shape[0], -1, -1),
|
| 274 |
-
], dim=1)
|
| 275 |
-
|
| 276 |
-
input_ids = torch.cat([
|
| 277 |
-
torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long),
|
| 278 |
-
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long),
|
| 279 |
-
prompt,
|
| 280 |
-
torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
|
| 281 |
-
], dim=1).to('cuda')
|
| 282 |
-
attention_mask = torch.ones_like(input_ids)
|
| 283 |
-
|
| 284 |
-
#generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=False, suppress_tokens=None)
|
| 285 |
-
#generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, top_k=10, temperature=0.5, suppress_tokens=None)
|
| 286 |
-
generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=300, do_sample=True, suppress_tokens=None) # Uses the default which is temp=0.6, top_p=0.9
|
| 287 |
-
|
| 288 |
-
# Trim off the prompt
|
| 289 |
-
generate_ids = generate_ids[:, input_ids.shape[1]:]
|
| 290 |
-
if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
|
| 291 |
-
generate_ids = generate_ids[:, :-1]
|
| 292 |
-
|
| 293 |
-
caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
|
| 294 |
-
|
| 295 |
-
return caption.strip()
|
| 296 |
-
|
| 297 |
-
@spaces.GPU()
|
| 298 |
-
@torch.no_grad()
|
| 299 |
def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: Union[str, int],
|
| 300 |
max_new_tokens: int=300, top_p: float=0.9, temperature: float=0.6, model_name: str=MODEL_PATH, progress=gr.Progress(track_tqdm=True)) -> str:
|
| 301 |
-
global
|
| 302 |
torch.cuda.empty_cache()
|
| 303 |
gc.collect()
|
| 304 |
|
|
@@ -476,44 +401,3 @@ def change_text_model(model_name: str=MODEL_PATH, use_client: bool=False, gguf_f
|
|
| 476 |
return gr.update(choices=get_text_model())
|
| 477 |
except Exception as e:
|
| 478 |
raise gr.Error(f"Model load error: {model_name}, {e}")
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
# original UI
|
| 482 |
-
with gr.Blocks() as demo:
|
| 483 |
-
gr.HTML(TITLE)
|
| 484 |
-
|
| 485 |
-
with gr.Row():
|
| 486 |
-
with gr.Column():
|
| 487 |
-
input_image = gr.Image(type="pil", label="Input Image")
|
| 488 |
-
|
| 489 |
-
caption_type = gr.Dropdown(
|
| 490 |
-
choices=["descriptive", "training_prompt", "rng-tags"],
|
| 491 |
-
label="Caption Type",
|
| 492 |
-
value="descriptive",
|
| 493 |
-
)
|
| 494 |
-
|
| 495 |
-
caption_tone = gr.Dropdown(
|
| 496 |
-
choices=["formal", "informal"],
|
| 497 |
-
label="Caption Tone",
|
| 498 |
-
value="formal",
|
| 499 |
-
)
|
| 500 |
-
|
| 501 |
-
caption_length = gr.Dropdown(
|
| 502 |
-
choices=["any", "very short", "short", "medium-length", "long", "very long"] +
|
| 503 |
-
[str(i) for i in range(20, 261, 10)],
|
| 504 |
-
label="Caption Length",
|
| 505 |
-
value="any",
|
| 506 |
-
)
|
| 507 |
-
|
| 508 |
-
gr.Markdown("**Note:** Caption tone doesn't affect `rng-tags` and `training_prompt`.")
|
| 509 |
-
|
| 510 |
-
run_button = gr.Button("Caption")
|
| 511 |
-
|
| 512 |
-
with gr.Column():
|
| 513 |
-
output_caption = gr.Textbox(label="Caption")
|
| 514 |
-
|
| 515 |
-
run_button.click(fn=stream_chat, inputs=[input_image, caption_type, caption_tone, caption_length], outputs=[output_caption])
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
if __name__ == "__main__":
|
| 519 |
-
demo.launch()
|
|
|
|
| 33 |
PIXTRAL_PATH = "mistral-community/pixtral-12b"
|
| 34 |
|
| 35 |
llm_models = {
|
| 36 |
+
"Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2": None,
|
| 37 |
#PIXTRAL_PATH: None,
|
| 38 |
+
"bunnycore/LLama-3.1-8B-Matrix": None,
|
| 39 |
"Sao10K/Llama-3.1-8B-Stheno-v3.4": None,
|
| 40 |
"unsloth/Meta-Llama-3.1-8B-bnb-4bit": None,
|
| 41 |
"DevQuasar/HermesNova-Llama-3.1-8B": None,
|
|
|
|
| 158 |
else:
|
| 159 |
text_model = LlavaForConditionalGeneration.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16).eval()
|
| 160 |
image_adapter = AutoProcessor.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16)
|
| 161 |
+
tokenizer = None
|
| 162 |
+
peft_config = None
|
| 163 |
|
| 164 |
print("Loading tokenizer")
|
| 165 |
if gguf_file: tokenizer = AutoTokenizer.from_pretrained(model_name, gguf_file=gguf_file, use_fast=True, legacy=False)
|
|
|
|
| 220 |
load_text_model()
|
| 221 |
|
| 222 |
@spaces.GPU()
|
| 223 |
+
@torch.inference_mode()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: Union[str, int],
|
| 225 |
max_new_tokens: int=300, top_p: float=0.9, temperature: float=0.6, model_name: str=MODEL_PATH, progress=gr.Progress(track_tqdm=True)) -> str:
|
| 226 |
+
global tokenizer, text_model, image_adapter, peft_config, text_model_client, use_inference_client
|
| 227 |
torch.cuda.empty_cache()
|
| 228 |
gc.collect()
|
| 229 |
|
|
|
|
| 401 |
return gr.update(choices=get_text_model())
|
| 402 |
except Exception as e:
|
| 403 |
raise gr.Error(f"Model load error: {model_name}, {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|