remg1997 commited on
Commit
9a1c01d
·
1 Parent(s): 42aabcc
Files changed (1) hide show
  1. handler.py +1 -1
handler.py CHANGED
@@ -8,7 +8,7 @@ class EndpointHandler():
8
  self.path = path
9
  self.model = "remg1997/dynabench-sdxl10"
10
  self.pipeline = DiffusionPipeline.from_pretrained(self.model, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
11
- self.pipeline = self.pipeline.to("cuda")
12
 
13
  def __call__(self, data: Dict[str, Any])-> List[Dict[str, Any]]:
14
  print("Torch version is", torch.__version__)
 
8
  self.path = path
9
  self.model = "remg1997/dynabench-sdxl10"
10
  self.pipeline = DiffusionPipeline.from_pretrained(self.model, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
11
+ self.pipeline = self.pipeline.to("cuda", torch.float16)
12
 
13
  def __call__(self, data: Dict[str, Any])-> List[Dict[str, Any]]:
14
  print("Torch version is", torch.__version__)