derektan commited on
Commit
182a98a
·
1 Parent(s): 996aee8

Integrated with hf Zero GPU

Browse files
Files changed (2) hide show
  1. app.py +3 -1
  2. test_parameter.py +1 -1
app.py CHANGED
@@ -21,6 +21,7 @@ import os, glob, threading, time
21
  import torch
22
  from PIL import Image
23
  import json
 
24
 
25
  # Import configuration & RL / TTA utilities -------------------------------------------------
26
  # NOTE: we import * so that the global names (e.g. USE_GPU, MODEL_NAME, etc.)
@@ -137,7 +138,8 @@ def process_no_tta(
137
 
138
 
139
 
140
-
 
141
  def process(
142
  sat_path: str | None,
143
  ground_path: str | None,
 
21
  import torch
22
  from PIL import Image
23
  import json
24
+ import spaces # integration with ZeroGPU on hf
25
 
26
  # Import configuration & RL / TTA utilities -------------------------------------------------
27
  # NOTE: we import * so that the global names (e.g. USE_GPU, MODEL_NAME, etc.)
 
138
 
139
 
140
 
141
+ # integration with ZeroGPU on hf
142
+ @spaces.GPU
143
  def process(
144
  sat_path: str | None,
145
  ground_path: str | None,
test_parameter.py CHANGED
@@ -162,7 +162,7 @@ INPUT_DIM = 4
162
  EMBEDDING_DIM = 128
163
  K_SIZE = 8 # 8
164
 
165
- USE_GPU = False # do you want to use GPUS?
166
  NUM_GPU = getenv("NUM_GPU", default=2, cast_type=int) # the number of GPUs
167
  NUM_META_AGENT = getenv("NUM_META_AGENT", default=4, cast_type=int) # the number of processes
168
  FOLDER_NAME = 'inference'
 
162
  EMBEDDING_DIM = 128
163
  K_SIZE = 8 # 8
164
 
165
+ USE_GPU = True # do you want to use GPUS?
166
  NUM_GPU = getenv("NUM_GPU", default=2, cast_type=int) # the number of GPUs
167
  NUM_META_AGENT = getenv("NUM_META_AGENT", default=4, cast_type=int) # the number of processes
168
  FOLDER_NAME = 'inference'