Dibiddo commited on
Commit
ef6d057
·
verified ·
1 Parent(s): 72b520a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -11
app.py CHANGED
@@ -1,23 +1,50 @@
1
- from transformers import QwenVLForConditionalGeneration, AutoProcessor
 
 
2
  import gradio as gr
3
 
4
  # 加载模型和处理器
5
- model = QwenVLForConditionalGeneration.from_pretrained(
6
- "Qwen/Qwen2.5-VL-7B-Instruct",
7
- torch_dtype="auto",
8
  device_map="auto"
9
  )
10
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
11
 
12
- # 定义识别和分析函数
13
  def recognize_and_analyze(image, text_prompt):
14
- # 处理输入图像和文本提示
15
- inputs = processor(images=image, text=text_prompt, return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
16
 
17
- # 生成输出结果
18
- outputs = model.generate(**inputs)
19
- result = processor.batch_decode(outputs, skip_special_tokens=True)
20
- return result[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # 设置Gradio界面
23
  interface = gr.Interface(
 
1
+ import torch
2
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
3
+ from qwen_vl_utils import process_vision_info
4
  import gradio as gr
5
 
6
  # 加载模型和处理器
7
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
8
+ "Qwen/Qwen2.5-VL-7B-Instruct",
9
+ torch_dtype="auto",
10
  device_map="auto"
11
  )
12
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
13
 
14
+ # 定义处理函数
15
  def recognize_and_analyze(image, text_prompt):
16
+ messages = [
17
+ {
18
+ "role": "user",
19
+ "content": [
20
+ {"type": "image", "image": image},
21
+ {"type": "text", "text": text_prompt},
22
+ ],
23
+ }
24
+ ]
25
 
26
+ # 准备推理输入数据
27
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
28
+ image_inputs, video_inputs = process_vision_info(messages)
29
+ inputs = processor(
30
+ text=[text],
31
+ images=image_inputs,
32
+ videos=video_inputs,
33
+ padding=True,
34
+ return_tensors="pt",
35
+ )
36
+ inputs = inputs.to(model.device)
37
+
38
+ # 推理:生成输出文本
39
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
40
+ generated_ids_trimmed = [
41
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
42
+ ]
43
+ output_text = processor.batch_decode(
44
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
45
+ )
46
+
47
+ return output_text[0]
48
 
49
  # 设置Gradio界面
50
  interface = gr.Interface(