hughtayloe commited on
Commit
6a67ddb
·
verified ·
1 Parent(s): eaeacdb

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +10 -9
handler.py CHANGED
@@ -3,7 +3,8 @@ from PIL import Image
3
  import requests
4
  import torch
5
  import numpy as np
6
- from transformers import AutoProcessor, LlavaForConditionalGeneration
 
7
 
8
  class EndpointHandler():
9
  def __init__(self, path=""):
@@ -12,16 +13,16 @@ class EndpointHandler():
12
  model_id,
13
  torch_dtype=torch.float16,
14
  low_cpu_mem_usage=True,
 
15
  ).to(0)
16
  self.processor = AutoProcessor.from_pretrained(model_id)
17
 
18
  def __call__(self, data: Dict[str, Any]):
19
  parameters = data.pop("inputs", data)
20
- if parameters is not None:
21
- url = "http://images.cocodataset.org/val2017/000000039769.jpg"
22
- prompt = "USER: <image>\nWhat are these?\nASSISTANT:"
23
- raw_image = Image.open(requests.get(url, stream=True).raw)
24
- inputs = self.processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
25
- output = self.model.generate(**inputs, max_new_tokens=200, do_sample=False)
26
- readable = self.processor.decode(output[0][2:], skip_special_tokens=True)
27
- return readable
 
3
  import requests
4
  import torch
5
  import numpy as np
6
+ from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig
7
+
8
 
9
  class EndpointHandler():
10
  def __init__(self, path=""):
 
13
  model_id,
14
  torch_dtype=torch.float16,
15
  low_cpu_mem_usage=True,
16
+ load_in_4bit=True
17
  ).to(0)
18
  self.processor = AutoProcessor.from_pretrained(model_id)
19
 
20
  def __call__(self, data: Dict[str, Any]):
21
  parameters = data.pop("inputs", data)
22
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
23
+ prompt = "USER: <image>\nWhat are these?\nASSISTANT:"
24
+ raw_image = Image.open(requests.get(url, stream=True).raw)
25
+ inputs = self.processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
26
+ output = self.model.generate(**inputs, max_new_tokens=200, do_sample=False)
27
+ print(self.processor.decode(output[0][2:], skip_special_tokens=True))
28
+ return output