NandiniLokeshReddy commited on
Commit
761cff5
·
verified ·
1 Parent(s): 8c0556e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -20
app.py CHANGED
@@ -4,28 +4,31 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from PIL import Image
5
  import warnings
6
 
7
- # Disable warnings
8
  warnings.filterwarnings('ignore')
9
 
10
- # Set default device to GPU
11
  torch.set_default_device('cuda')
12
 
13
  # Load the model and tokenizer
14
  model_name = 'qnguyen3/nanoLLaVA-1.5'
15
- model = AutoModelForCausalLM.from_pretrained(
16
- model_name,
17
- torch_dtype=torch.float16,
18
- device_map='auto',
19
- trust_remote_code=True
20
- )
21
- tokenizer = AutoTokenizer.from_pretrained(
22
- model_name,
23
- trust_remote_code=True
24
- )
25
-
26
- # Function to generate a description of the uploaded image
 
 
 
 
27
  def describe_image(image, prompt="Describe this image in detail"):
28
- # Prepare input prompt
29
  messages = [{"role": "user", "content": f'<image>\n{prompt}'}]
30
  text = tokenizer.apply_chat_template(
31
  messages,
@@ -33,14 +36,14 @@ def describe_image(image, prompt="Describe this image in detail"):
33
  add_generation_prompt=True
34
  )
35
 
36
- # Tokenize input text
37
  text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
38
  input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
39
 
40
  # Process the image
41
  image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
42
 
43
- # Generate response
44
  output_ids = model.generate(
45
  input_ids,
46
  images=image_tensor,
@@ -48,15 +51,15 @@ def describe_image(image, prompt="Describe this image in detail"):
48
  use_cache=True
49
  )[0]
50
 
51
- # Decode the generated text
52
  description = tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
53
  return description
54
 
55
- # Create Gradio interface
56
  gr.Interface(
57
  fn=describe_image,
58
  inputs=[gr.inputs.Image(type="pil"), gr.inputs.Textbox(default="Describe this image in detail")],
59
  outputs="text",
60
  title="Image Description Model",
61
- description="Upload an image and get a detailed description generated by the model."
62
  ).launch()
 
4
  from PIL import Image
5
  import warnings
6
 
7
+ # Suppress warnings
8
  warnings.filterwarnings('ignore')
9
 
10
+ # Ensure CUDA device is used
11
  torch.set_default_device('cuda')
12
 
13
  # Load the model and tokenizer
14
  model_name = 'qnguyen3/nanoLLaVA-1.5'
15
+ try:
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ model_name,
18
+ torch_dtype=torch.float16,
19
+ device_map='auto',
20
+ trust_remote_code=True
21
+ )
22
+ tokenizer = AutoTokenizer.from_pretrained(
23
+ model_name,
24
+ trust_remote_code=True
25
+ )
26
+ except ImportError as e:
27
+ print("Error: Missing required dependencies. Make sure flash_attn is installed.")
28
+ raise e
29
+
30
+ # Function to describe the uploaded image
31
  def describe_image(image, prompt="Describe this image in detail"):
 
32
  messages = [{"role": "user", "content": f'<image>\n{prompt}'}]
33
  text = tokenizer.apply_chat_template(
34
  messages,
 
36
  add_generation_prompt=True
37
  )
38
 
39
+ # Tokenize the text
40
  text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
41
  input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
42
 
43
  # Process the image
44
  image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
45
 
46
+ # Generate a response
47
  output_ids = model.generate(
48
  input_ids,
49
  images=image_tensor,
 
51
  use_cache=True
52
  )[0]
53
 
54
+ # Decode and return the response
55
  description = tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
56
  return description
57
 
58
+ # Set up the Gradio interface
59
  gr.Interface(
60
  fn=describe_image,
61
  inputs=[gr.inputs.Image(type="pil"), gr.inputs.Textbox(default="Describe this image in detail")],
62
  outputs="text",
63
  title="Image Description Model",
64
+ description="Upload an image and receive a detailed description."
65
  ).launch()