p3nguknight commited on
Commit
bc504d0
·
1 Parent(s): 593852b

Use pixtral with hf

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +29 -35
  3. requirements.txt +1 -3
README.md CHANGED
@@ -16,5 +16,5 @@ models:
16
  preload_from_hub:
17
  - vidore/colpaligemma-3b-pt-448-base config.json,model-00001-of-00002.safetensors,model-00002-of-00002.safetensors,model.safetensors.index.json,preprocessor_config.json,special_tokens_map.json,tokenizer.json,tokenizer_config.json 30ab955d073de4a91dc5a288e8c97226647e3e5a
18
  - vidore/colpali-v1.3 adapter_config.json,adapter_model.safetensors,preprocessor_config.json,special_tokens_map.json,tokenizer.json,tokenizer_config.json 1b5c8929330df1a66de441a9b5409a878f0de5b0
19
- - mistral-community/pixtral-12b-240910 params.json,tekken.json,consolidated.safetensors 59794e97cb4f322f6223bb0d57b4d7523f0e27c6
20
  ---
 
16
  preload_from_hub:
17
  - vidore/colpaligemma-3b-pt-448-base config.json,model-00001-of-00002.safetensors,model-00002-of-00002.safetensors,model.safetensors.index.json,preprocessor_config.json,special_tokens_map.json,tokenizer.json,tokenizer_config.json 30ab955d073de4a91dc5a288e8c97226647e3e5a
18
  - vidore/colpali-v1.3 adapter_config.json,adapter_model.safetensors,preprocessor_config.json,special_tokens_map.json,tokenizer.json,tokenizer_config.json 1b5c8929330df1a66de441a9b5409a878f0de5b0
19
+ - mistral-community/pixtral-12b chat_template.json,config.json,generation_config.json,model-00001-of-00006.safetensors,model-00002-of-00006.safetensors,model-00003-of-00006.safetensors,model-00004-of-00006.safetensors,model-00005-of-00006.safetensors,model-00006-of-00006.safetensors,model.safetensors.index.json,preprocessor_config.json,processor_config.json,special_tokens_map.json,tokenizer.json,tokenizer_config.json c2756cbbb9422eba9f6c5c439a214b0392dfc998
20
  ---
app.py CHANGED
@@ -4,21 +4,14 @@ import gradio as gr
4
  import spaces
5
  import torch
6
  from colpali_engine.models import ColPali, ColPaliProcessor
7
- from mistral_common.protocol.instruct.messages import (
8
- ImageURLChunk,
9
- TextChunk,
10
- UserMessage,
11
- )
12
- from mistral_common.protocol.instruct.request import ChatCompletionRequest
13
- from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
14
- from mistral_inference.generate import generate
15
- from mistral_inference.transformer import Transformer
16
  from pdf2image import convert_from_path
17
  from torch.utils.data import DataLoader
18
  from tqdm import tqdm
 
 
19
 
20
- PIXTAL_MODEL_ID = "mistral-community--pixtral-12b-240910"
21
- PIXTRAL_MODEL_SNAPSHOT = "59794e97cb4f322f6223bb0d57b4d7523f0e27c6"
22
  PIXTRAL_MODEL_PATH = (
23
  pathlib.Path().home()
24
  / f".cache/huggingface/hub/models--{PIXTAL_MODEL_ID}/snapshots/{PIXTRAL_MODEL_SNAPSHOT}"
@@ -54,32 +47,33 @@ def pixtral_inference(
54
  raise gr.Error("No images for generation")
55
  if text == "":
56
  raise gr.Error("No query for generation")
57
- tokenizer = MistralTokenizer.from_file(f"{PIXTRAL_MODEL_PATH}/tekken.json")
58
- model = Transformer.from_folder(PIXTRAL_MODEL_PATH, dtype=torch.bfloat16)
59
-
60
- messages = [
61
- UserMessage(
62
- content=[ImageURLChunk(image_url=image_to_base64(i[0])) for i in images]
63
- + [TextChunk(text=text)]
64
- )
 
 
 
 
65
  ]
66
 
67
- completion_request = ChatCompletionRequest(messages=messages)
68
-
69
- encoded = tokenizer.encode_chat_completion(completion_request)
70
-
71
- images = encoded.images
72
- tokens = encoded.tokens
73
-
74
- out_tokens, _ = generate(
75
- [tokens],
76
- model,
77
- images=[images],
78
- max_tokens=512,
79
- temperature=0.45,
80
- eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id,
81
- )
82
- result = tokenizer.decode(out_tokens[0])
83
  return result
84
 
85
 
 
4
  import spaces
5
  import torch
6
  from colpali_engine.models import ColPali, ColPaliProcessor
 
 
 
 
 
 
 
 
 
7
  from pdf2image import convert_from_path
8
  from torch.utils.data import DataLoader
9
  from tqdm import tqdm
10
+ from transformers import AutoProcessor, LlavaForConditionalGeneration
11
+
12
 
13
+ PIXTAL_MODEL_ID = "mistral-community--pixtral-12b"
14
+ PIXTRAL_MODEL_SNAPSHOT = "c2756cbbb9422eba9f6c5c439a214b0392dfc998"
15
  PIXTRAL_MODEL_PATH = (
16
  pathlib.Path().home()
17
  / f".cache/huggingface/hub/models--{PIXTAL_MODEL_ID}/snapshots/{PIXTRAL_MODEL_SNAPSHOT}"
 
47
  raise gr.Error("No images for generation")
48
  if text == "":
49
  raise gr.Error("No query for generation")
50
+
51
+ model = LlavaForConditionalGeneration.from_pretrained(PIXTRAL_MODEL_PATH)
52
+ processor = AutoProcessor.from_pretrained(PIXTRAL_MODEL_PATH, use_fast=True)
53
+
54
+ chat = [
55
+ {
56
+ "role": "user",
57
+ "content": [
58
+ {"type": "text", "content": text},
59
+ ]
60
+ + [{"type": "image", "url": image_to_base64(i[0])} for i in images],
61
+ }
62
  ]
63
 
64
+ inputs = processor.apply_chat_template(
65
+ chat,
66
+ add_generation_prompt=True,
67
+ tokenize=True,
68
+ return_dict=True,
69
+ return_tensors="pt",
70
+ ).to(model.device)
71
+ generate_ids = model.generate(**inputs, max_new_tokens=500)
72
+ output = processor.batch_decode(
73
+ generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
74
+ )[0]
75
+
76
+ result = output.text
 
 
 
77
  return result
78
 
79
 
requirements.txt CHANGED
@@ -3,6 +3,4 @@ transformers==4.53.2
3
  huggingface_hub==0.33.4
4
  pdf2image==1.17.0
5
  spaces==0.37.1
6
- colpali_engine==0.3.11
7
- mistral_inference==1.6.0
8
- mistral_common[opencv]==1.7.0
 
3
  huggingface_hub==0.33.4
4
  pdf2image==1.17.0
5
  spaces==0.37.1
6
+ colpali_engine==0.3.11