Spaces:
Running
on
Zero
Running
on
Zero
PySpaces
Browse files- .gitignore +1 -0
- spaces/__init__.py +30 -0
- spaces/config.py +37 -0
- spaces/gradio.py +55 -0
- spaces/utils.py +85 -0
- spaces/zero/__init__.py +21 -0
- spaces/zero/api.py +156 -0
- spaces/zero/client.py +239 -0
- spaces/zero/decorator.py +113 -0
- spaces/zero/gradio.py +150 -0
- spaces/zero/torch/__init__.py +42 -0
- spaces/zero/torch/bitsandbytes.py +162 -0
- spaces/zero/torch/packing.py +209 -0
- spaces/zero/torch/patching.py +386 -0
- spaces/zero/torch/patching_legacy.py +266 -0
- spaces/zero/torch/types.py +23 -0
- spaces/zero/tqdm.py +24 -0
- spaces/zero/types.py +49 -0
- spaces/zero/wrappers.py +430 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.pyc
|
spaces/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
"""
|
| 3 |
+
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
if sys.version_info.minor < 8: # pragma: no cover
|
| 8 |
+
raise RuntimeError("Importing PySpaces requires Python 3.8+")
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# Prevent gradio from importing spaces
|
| 12 |
+
if (gr := sys.modules.get('gradio')) is not None: # pragma: no cover
|
| 13 |
+
try:
|
| 14 |
+
gr.Blocks
|
| 15 |
+
except AttributeError:
|
| 16 |
+
raise ImportError
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
from .zero.decorator import GPU
|
| 20 |
+
from .gradio import gradio_auto_wrap
|
| 21 |
+
from .gradio import disable_gradio_auto_wrap
|
| 22 |
+
from .gradio import enable_gradio_auto_wrap
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
__all__ = [
|
| 26 |
+
'GPU',
|
| 27 |
+
'gradio_auto_wrap',
|
| 28 |
+
'disable_gradio_auto_wrap',
|
| 29 |
+
'enable_gradio_auto_wrap',
|
| 30 |
+
]
|
spaces/config.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
"""
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from .utils import boolean
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
ZEROGPU_OFFLOAD_DIR_DEFAULT = str(Path.home() / '.zerogpu' / 'tensors')
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Settings:
|
| 15 |
+
def __init__(self):
|
| 16 |
+
self.zero_gpu = boolean(
|
| 17 |
+
os.getenv('SPACES_ZERO_GPU'))
|
| 18 |
+
self.zero_device_api_url = (
|
| 19 |
+
os.getenv('SPACES_ZERO_DEVICE_API_URL'))
|
| 20 |
+
self.gradio_auto_wrap = boolean(
|
| 21 |
+
os.getenv('SPACES_GRADIO_AUTO_WRAP'))
|
| 22 |
+
self.zero_patch_torch_device = boolean(
|
| 23 |
+
os.getenv('ZERO_GPU_PATCH_TORCH_DEVICE'))
|
| 24 |
+
self.zero_gpu_v2 = boolean(
|
| 25 |
+
os.getenv('ZEROGPU_V2'))
|
| 26 |
+
self.zerogpu_offload_dir = (
|
| 27 |
+
os.getenv('ZEROGPU_OFFLOAD_DIR', ZEROGPU_OFFLOAD_DIR_DEFAULT))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
Config = Settings()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
if Config.zero_gpu:
|
| 34 |
+
assert Config.zero_device_api_url is not None, (
|
| 35 |
+
'SPACES_ZERO_DEVICE_API_URL env must be set '
|
| 36 |
+
'on ZeroGPU Spaces (identified by SPACES_ZERO_GPU=true)'
|
| 37 |
+
)
|
spaces/gradio.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
"""
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Callable
|
| 6 |
+
from typing import Generator
|
| 7 |
+
from typing import TypeVar
|
| 8 |
+
from typing import overload
|
| 9 |
+
from typing_extensions import ParamSpec
|
| 10 |
+
|
| 11 |
+
from .config import Config
|
| 12 |
+
from .zero.decorator import GPU
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
Param = ParamSpec('Param')
|
| 16 |
+
Res = TypeVar('Res')
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
gradio_auto_wrap_enabled = Config.gradio_auto_wrap
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def disable_gradio_auto_wrap():
|
| 23 |
+
global gradio_auto_wrap_enabled
|
| 24 |
+
gradio_auto_wrap_enabled = False
|
| 25 |
+
|
| 26 |
+
def enable_gradio_auto_wrap():
|
| 27 |
+
global gradio_auto_wrap_enabled
|
| 28 |
+
gradio_auto_wrap_enabled = True
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@overload
|
| 32 |
+
def gradio_auto_wrap(
|
| 33 |
+
task:
|
| 34 |
+
Callable[Param, Res],
|
| 35 |
+
) -> Callable[Param, Res]:
|
| 36 |
+
...
|
| 37 |
+
@overload
|
| 38 |
+
def gradio_auto_wrap(
|
| 39 |
+
task:
|
| 40 |
+
None,
|
| 41 |
+
) -> None:
|
| 42 |
+
...
|
| 43 |
+
def gradio_auto_wrap(
|
| 44 |
+
task:
|
| 45 |
+
Callable[Param, Res]
|
| 46 |
+
| None,
|
| 47 |
+
) -> (Callable[Param, Res]
|
| 48 |
+
| None):
|
| 49 |
+
"""
|
| 50 |
+
"""
|
| 51 |
+
if not gradio_auto_wrap_enabled:
|
| 52 |
+
return task
|
| 53 |
+
if not callable(task):
|
| 54 |
+
return task
|
| 55 |
+
return GPU(task) # type: ignore
|
spaces/utils.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
"""
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import ctypes
|
| 6 |
+
import sys
|
| 7 |
+
from functools import lru_cache as cache
|
| 8 |
+
from functools import partial
|
| 9 |
+
|
| 10 |
+
import multiprocessing
|
| 11 |
+
from multiprocessing.queues import SimpleQueue as _SimpleQueue
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from pickle import PicklingError
|
| 14 |
+
from typing import Callable
|
| 15 |
+
from typing import TypeVar
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
GRADIO_VERSION_ERROR_MESSAGE = "Make sure Gradio version is at least 3.46"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
T = TypeVar('T')
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@cache
|
| 25 |
+
def self_cgroup_device_path() -> str:
|
| 26 |
+
cgroup_content = Path('/proc/self/cgroup').read_text()
|
| 27 |
+
for line in cgroup_content.strip().split('\n'):
|
| 28 |
+
contents = line.split(':devices:')
|
| 29 |
+
if len(contents) != 2:
|
| 30 |
+
continue # pragma: no cover
|
| 31 |
+
return contents[1]
|
| 32 |
+
raise Exception # pragma: no cover
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if sys.version_info.minor < 9: # pragma: no cover
|
| 36 |
+
_SimpleQueue.__class_getitem__ = classmethod(lambda cls, _: cls) # type: ignore
|
| 37 |
+
|
| 38 |
+
class SimpleQueue(_SimpleQueue[T]):
|
| 39 |
+
def __init__(self, *args):
|
| 40 |
+
super().__init__(*args, ctx=multiprocessing.get_context('fork'))
|
| 41 |
+
def put(self, obj: T):
|
| 42 |
+
try:
|
| 43 |
+
super().put(obj)
|
| 44 |
+
except PicklingError:
|
| 45 |
+
raise # pragma: no cover
|
| 46 |
+
# https://bugs.python.org/issue29187
|
| 47 |
+
except Exception as e:
|
| 48 |
+
message = str(e)
|
| 49 |
+
if not "pickle" in message:
|
| 50 |
+
raise # pragma: no cover
|
| 51 |
+
raise PicklingError(message)
|
| 52 |
+
def close(self): # Python 3.8 static typing trick
|
| 53 |
+
super().close() # type: ignore
|
| 54 |
+
def wlock_release(self):
|
| 55 |
+
if (lock := getattr(self, '_wlock', None)) is None:
|
| 56 |
+
return # pragma: no cover
|
| 57 |
+
try:
|
| 58 |
+
lock.release()
|
| 59 |
+
except ValueError:
|
| 60 |
+
pass
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def drop_params(fn: Callable[[], T]) -> Callable[..., T]:
|
| 64 |
+
def drop(*args):
|
| 65 |
+
return fn()
|
| 66 |
+
return drop
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def boolean(value: str | None) -> bool:
|
| 70 |
+
return value is not None and value.lower() in ("1", "t", "true")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def gradio_request_var():
|
| 74 |
+
try:
|
| 75 |
+
from gradio.context import LocalContext
|
| 76 |
+
except ImportError: # pragma: no cover
|
| 77 |
+
raise RuntimeError(GRADIO_VERSION_ERROR_MESSAGE)
|
| 78 |
+
return LocalContext.request
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def malloc_trim():
|
| 82 |
+
ctypes.CDLL("libc.so.6").malloc_trim(0)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
debug = partial(print, 'SPACES_ZERO_GPU_DEBUG')
|
spaces/zero/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
"""
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
from ..config import Config
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
if Config.zero_gpu:
|
| 10 |
+
|
| 11 |
+
from . import gradio
|
| 12 |
+
from . import torch
|
| 13 |
+
|
| 14 |
+
if torch.is_in_bad_fork():
|
| 15 |
+
raise RuntimeError(
|
| 16 |
+
"CUDA has been initialized before importing the `spaces` package"
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
torch.patch()
|
| 20 |
+
gradio.one_launch(torch.pack)
|
| 21 |
+
Path(Config.zerogpu_offload_dir).mkdir(parents=True, exist_ok=True)
|
spaces/zero/api.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Synced with huggingface/pyspaces:spaces/zero/api.py
|
| 3 |
+
"""
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from datetime import timedelta
|
| 7 |
+
from typing import Any
|
| 8 |
+
from typing import Generator
|
| 9 |
+
from typing import Literal
|
| 10 |
+
from typing import NamedTuple
|
| 11 |
+
from typing import Optional
|
| 12 |
+
from typing import overload
|
| 13 |
+
|
| 14 |
+
import httpx
|
| 15 |
+
from pydantic import BaseModel
|
| 16 |
+
from typing_extensions import assert_never
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
AllowToken = str
|
| 20 |
+
NvidiaIndex = int # TODO: Migrate to GpuIndex (less confusing for MIG)
|
| 21 |
+
NvidiaUUID = str
|
| 22 |
+
CGroupPath = str
|
| 23 |
+
VisitorId = str
|
| 24 |
+
Score = float
|
| 25 |
+
|
| 26 |
+
AuthLevel = Literal['regular', 'pro']
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
AUTHENTICATED_HEADER = 'X-Authenticated'
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ScheduleResponse(BaseModel):
|
| 33 |
+
idle: bool
|
| 34 |
+
nvidiaIndex: int
|
| 35 |
+
nvidiaUUID: str
|
| 36 |
+
allowToken: str
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class QuotaInfos(BaseModel):
|
| 40 |
+
left: int
|
| 41 |
+
wait: timedelta
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ReportUsageMonitoringParams(NamedTuple):
|
| 45 |
+
nvidia_index: int
|
| 46 |
+
visitor_id: str
|
| 47 |
+
duration: timedelta
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class QueueEvent(BaseModel):
|
| 51 |
+
event: Literal['ping', 'failed', 'succeeded']
|
| 52 |
+
data: Optional[ScheduleResponse] = None
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def sse_parse(text: str):
|
| 56 |
+
event, *data = text.strip().splitlines()
|
| 57 |
+
assert event.startswith('event:')
|
| 58 |
+
event = event[6:].strip()
|
| 59 |
+
if event in ('ping', 'failed'):
|
| 60 |
+
return QueueEvent(event=event)
|
| 61 |
+
assert event == 'succeeded'
|
| 62 |
+
(data,) = data
|
| 63 |
+
assert data.startswith('data:')
|
| 64 |
+
data = data[5:].strip()
|
| 65 |
+
return QueueEvent(event=event, data=ScheduleResponse.parse_raw(data))
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def sse_stream(res: httpx.Response) -> Generator[QueueEvent, Any, None]:
|
| 69 |
+
for text in res.iter_text():
|
| 70 |
+
if len(text) == 0:
|
| 71 |
+
break # pragma: no cover
|
| 72 |
+
try:
|
| 73 |
+
yield sse_parse(text)
|
| 74 |
+
except GeneratorExit:
|
| 75 |
+
res.close()
|
| 76 |
+
break
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class APIClient:
|
| 80 |
+
|
| 81 |
+
def __init__(self, client: httpx.Client):
|
| 82 |
+
self.client = client
|
| 83 |
+
|
| 84 |
+
def startup_report(self) -> httpx.codes:
|
| 85 |
+
res = self.client.post('/startup-report')
|
| 86 |
+
return httpx.codes(res.status_code)
|
| 87 |
+
|
| 88 |
+
def schedule(
|
| 89 |
+
self,
|
| 90 |
+
cgroup_path: str,
|
| 91 |
+
task_id: int = 0,
|
| 92 |
+
token: str | None = None,
|
| 93 |
+
duration_seconds: int | None = None,
|
| 94 |
+
enable_queue: bool = True,
|
| 95 |
+
):
|
| 96 |
+
params: dict[str, str | int | bool] = {
|
| 97 |
+
'cgroupPath': cgroup_path,
|
| 98 |
+
'taskId': task_id,
|
| 99 |
+
'enableQueue': enable_queue,
|
| 100 |
+
}
|
| 101 |
+
if duration_seconds is not None:
|
| 102 |
+
params['durationSeconds'] = duration_seconds
|
| 103 |
+
if token is not None:
|
| 104 |
+
params['token'] = token
|
| 105 |
+
res = self.client.send(
|
| 106 |
+
request=self.client.build_request(
|
| 107 |
+
method='POST',
|
| 108 |
+
url='/schedule',
|
| 109 |
+
params=params,
|
| 110 |
+
),
|
| 111 |
+
stream=True,
|
| 112 |
+
)
|
| 113 |
+
status = httpx.codes(res.status_code)
|
| 114 |
+
auth: AuthLevel | None = res.headers.get(AUTHENTICATED_HEADER)
|
| 115 |
+
if (status is not httpx.codes.OK and
|
| 116 |
+
status is not httpx.codes.TOO_MANY_REQUESTS
|
| 117 |
+
):
|
| 118 |
+
res.close()
|
| 119 |
+
return status, auth
|
| 120 |
+
if "text/event-stream" in res.headers['content-type']:
|
| 121 |
+
return sse_stream(res), auth
|
| 122 |
+
res.read()
|
| 123 |
+
if status is httpx.codes.TOO_MANY_REQUESTS:
|
| 124 |
+
return QuotaInfos(**res.json()), auth # pragma: no cover
|
| 125 |
+
if status is httpx.codes.OK:
|
| 126 |
+
return ScheduleResponse(**res.json()), auth
|
| 127 |
+
assert_never(status)
|
| 128 |
+
|
| 129 |
+
def allow(
|
| 130 |
+
self,
|
| 131 |
+
allow_token: str,
|
| 132 |
+
pid: int,
|
| 133 |
+
):
|
| 134 |
+
res = self.client.post('/allow', params={
|
| 135 |
+
'allowToken': allow_token,
|
| 136 |
+
'pid': pid,
|
| 137 |
+
})
|
| 138 |
+
return httpx.codes(res.status_code)
|
| 139 |
+
|
| 140 |
+
def release(
|
| 141 |
+
self,
|
| 142 |
+
allow_token: str,
|
| 143 |
+
fail: bool = False,
|
| 144 |
+
) -> httpx.codes:
|
| 145 |
+
res = self.client.post('/release', params={
|
| 146 |
+
'allowToken': allow_token,
|
| 147 |
+
'fail': fail,
|
| 148 |
+
})
|
| 149 |
+
return httpx.codes(res.status_code)
|
| 150 |
+
|
| 151 |
+
def get_queue_size(self) -> int:
|
| 152 |
+
res = self.client.get('/queue-size')
|
| 153 |
+
assert res.status_code == 200, res.status_code
|
| 154 |
+
size = res.json()
|
| 155 |
+
assert isinstance(size, int)
|
| 156 |
+
return size
|
spaces/zero/client.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
"""
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
import warnings
|
| 8 |
+
from datetime import timedelta
|
| 9 |
+
|
| 10 |
+
import gradio as gr
|
| 11 |
+
import httpx
|
| 12 |
+
from packaging import version
|
| 13 |
+
from typing_extensions import assert_never
|
| 14 |
+
|
| 15 |
+
from .. import utils
|
| 16 |
+
from ..config import Config
|
| 17 |
+
from .api import APIClient
|
| 18 |
+
from .api import AuthLevel
|
| 19 |
+
from .api import QuotaInfos
|
| 20 |
+
from .api import ScheduleResponse
|
| 21 |
+
from .gradio import HTMLError
|
| 22 |
+
from .gradio import get_event
|
| 23 |
+
from .gradio import supports_auth
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
TOKEN_HEADER = 'X-IP-Token'
|
| 27 |
+
DEFAULT_SCHEDULE_DURATION = 60
|
| 28 |
+
|
| 29 |
+
QUOTA_MESSAGE = "You have exceeded your GPU quota"
|
| 30 |
+
UNUSED_MESSAGE = "GPU device not used"
|
| 31 |
+
NO_GPU_MESSAGE_REGULAR = "No GPU was available"
|
| 32 |
+
NO_GPU_MESSAGE_INQUEUE = "No GPU was available after 60s"
|
| 33 |
+
|
| 34 |
+
SIGNUP_ON_HF_TXT = "Create a free account"
|
| 35 |
+
SIGNUP_ON_HF_URL = "https://huggingface.co/join"
|
| 36 |
+
SUBSCRIBE_TO_PRO_TXT = "Subscribe to Pro"
|
| 37 |
+
SUBSCRIBE_TO_PRO_URL = "https://huggingface.co/settings/billing/subscription"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def api_client():
|
| 41 |
+
assert Config.zero_device_api_url is not None
|
| 42 |
+
httpx_client = httpx.Client(base_url=Config.zero_device_api_url, timeout=60, verify=False)
|
| 43 |
+
return APIClient(httpx_client)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def startup_report():
|
| 47 |
+
retries, max_retries = 0, 2
|
| 48 |
+
client = api_client()
|
| 49 |
+
while (status := client.startup_report()) is httpx.codes.NOT_FOUND: # pragma: no cover
|
| 50 |
+
time.sleep(1)
|
| 51 |
+
if (retries := retries + 1) > max_retries:
|
| 52 |
+
raise RuntimeError("Error while initializing ZeroGPU: NotFound")
|
| 53 |
+
if status is not httpx.codes.OK: # pragma: no cover
|
| 54 |
+
raise RuntimeError("Error while initializing ZeroGPU: Unknown")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def html_string(html_contents: str, text_contents: str): # pragma: no cover
|
| 58 |
+
class HTMLString(str):
|
| 59 |
+
def __str__(self):
|
| 60 |
+
return text_contents
|
| 61 |
+
return HTMLString(html_contents)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _toast_action(
|
| 65 |
+
auth: AuthLevel | None,
|
| 66 |
+
supports_html: bool,
|
| 67 |
+
pro_message: str,
|
| 68 |
+
unlogged_desc: str,
|
| 69 |
+
logged_desc: str,
|
| 70 |
+
ending: str,
|
| 71 |
+
) -> tuple[str, str]: # pragma: no cover
|
| 72 |
+
if not supports_auth() or auth == 'pro':
|
| 73 |
+
return pro_message, pro_message
|
| 74 |
+
html = ""
|
| 75 |
+
link = SIGNUP_ON_HF_URL if auth is None else SUBSCRIBE_TO_PRO_URL
|
| 76 |
+
text = SIGNUP_ON_HF_TXT if auth is None else SUBSCRIBE_TO_PRO_TXT
|
| 77 |
+
desc = unlogged_desc if auth is None else logged_desc
|
| 78 |
+
desc += f" {ending}."
|
| 79 |
+
style = ";".join([
|
| 80 |
+
"white-space: nowrap",
|
| 81 |
+
"text-underline-offset: 2px",
|
| 82 |
+
"color: var(--body-text-color)",
|
| 83 |
+
])
|
| 84 |
+
if supports_html:
|
| 85 |
+
html += f'<a style="{style}" href="{link}">'
|
| 86 |
+
html += text
|
| 87 |
+
if supports_html:
|
| 88 |
+
html += '</a>'
|
| 89 |
+
html += f" {desc}"
|
| 90 |
+
markdown = f'[{text}]({link}) {desc}'
|
| 91 |
+
return html, markdown
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def schedule(
|
| 95 |
+
task_id: int,
|
| 96 |
+
request: gr.Request | None = None,
|
| 97 |
+
duration: timedelta | None = None,
|
| 98 |
+
_first_attempt: bool = True,
|
| 99 |
+
) -> ScheduleResponse:
|
| 100 |
+
|
| 101 |
+
if not (gradio_version := version.parse(gr.__version__)).major >= 4: # pragma: no cover
|
| 102 |
+
raise RuntimeError("ZeroGPU is only compatible with Gradio 4+")
|
| 103 |
+
|
| 104 |
+
GRADIO_HTML_TOASTS = gradio_version >= version.Version('4.39')
|
| 105 |
+
|
| 106 |
+
res, auth = api_client().schedule(
|
| 107 |
+
cgroup_path=utils.self_cgroup_device_path(),
|
| 108 |
+
task_id=task_id,
|
| 109 |
+
token=_get_token(request),
|
| 110 |
+
duration_seconds=duration.seconds if duration is not None else None,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
if isinstance(res, ScheduleResponse):
|
| 114 |
+
return res
|
| 115 |
+
|
| 116 |
+
if isinstance(res, QuotaInfos): # pragma: no cover
|
| 117 |
+
requested = duration.seconds if duration is not None else DEFAULT_SCHEDULE_DURATION
|
| 118 |
+
if res.wait < timedelta(0):
|
| 119 |
+
raise gr.Error(
|
| 120 |
+
f"The requested GPU duration ({requested}s) "
|
| 121 |
+
f"is larger than the maximum allowed"
|
| 122 |
+
)
|
| 123 |
+
else:
|
| 124 |
+
gpu = "Pro GPU" if auth == 'pro' else ("free GPU" if auth == 'regular' else "GPU")
|
| 125 |
+
message = (
|
| 126 |
+
f"You have exceeded your {gpu} quota "
|
| 127 |
+
f"({requested}s requested vs. {res.left}s left)."
|
| 128 |
+
)
|
| 129 |
+
details_html, details_markdown = _toast_action(
|
| 130 |
+
auth=auth,
|
| 131 |
+
supports_html=GRADIO_HTML_TOASTS,
|
| 132 |
+
pro_message=f"Try again in {res.wait}",
|
| 133 |
+
unlogged_desc="to get more",
|
| 134 |
+
logged_desc="to get 5x more",
|
| 135 |
+
ending="usage quota",
|
| 136 |
+
)
|
| 137 |
+
message_html = f"{message} {details_html}"
|
| 138 |
+
message_text = f"{message} {details_markdown}"
|
| 139 |
+
raise HTMLError(html_string(message_html, message_text))
|
| 140 |
+
|
| 141 |
+
if not isinstance(res, httpx.codes): # pragma: no cover
|
| 142 |
+
gr.Info("Waiting for a GPU to become available")
|
| 143 |
+
# TODO: Sign-up message if not authenticated (after some time ?)
|
| 144 |
+
connection_event = get_event()
|
| 145 |
+
if connection_event is None and request is not None:
|
| 146 |
+
warnings.warn("ZeroGPU: Cannot get Gradio app Queue instance")
|
| 147 |
+
while True:
|
| 148 |
+
try:
|
| 149 |
+
event = next(res)
|
| 150 |
+
except StopIteration:
|
| 151 |
+
raise RuntimeError("Unexpected end of stream")
|
| 152 |
+
except httpx.RemoteProtocolError:
|
| 153 |
+
if not _first_attempt:
|
| 154 |
+
raise RuntimeError("Error while re-trying after queue disconnect")
|
| 155 |
+
return schedule(task_id, request, duration, _first_attempt=False)
|
| 156 |
+
if event.event == 'ping':
|
| 157 |
+
if connection_event is not None and not connection_event.alive:
|
| 158 |
+
res.close()
|
| 159 |
+
raise RuntimeError("Connection closed by visitor while queueing")
|
| 160 |
+
continue
|
| 161 |
+
if event.event == 'failed':
|
| 162 |
+
details_html, details_markdown = _toast_action(
|
| 163 |
+
auth=auth,
|
| 164 |
+
supports_html=GRADIO_HTML_TOASTS,
|
| 165 |
+
pro_message="Retry later",
|
| 166 |
+
unlogged_desc="to get a higher",
|
| 167 |
+
logged_desc="to get the highest",
|
| 168 |
+
ending="priority in ZeroGPU queues",
|
| 169 |
+
)
|
| 170 |
+
message_html = f"{NO_GPU_MESSAGE_INQUEUE}. {details_html}"
|
| 171 |
+
message_text = f"{NO_GPU_MESSAGE_INQUEUE} {details_markdown}"
|
| 172 |
+
raise HTMLError(html_string(message_html, message_text))
|
| 173 |
+
if event.event == 'succeeded':
|
| 174 |
+
assert event.data is not None
|
| 175 |
+
if connection_event is not None and not connection_event.alive:
|
| 176 |
+
release(event.data.allowToken)
|
| 177 |
+
raise RuntimeError("Connection closed by visitor on queue success")
|
| 178 |
+
gr.Info("Successfully acquired a GPU")
|
| 179 |
+
return event.data
|
| 180 |
+
|
| 181 |
+
if res is httpx.codes.SERVICE_UNAVAILABLE:
|
| 182 |
+
raise gr.Error(NO_GPU_MESSAGE_REGULAR)
|
| 183 |
+
|
| 184 |
+
# TODO: Find a way to log 'detail' response field
|
| 185 |
+
raise RuntimeError(f"ZeroGPU API /schedule error: {res} ({httpx.codes.get_reason_phrase(res)})") # pragma: no cover
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def allow(allow_token: str) -> None:
|
| 189 |
+
pid = os.getpid()
|
| 190 |
+
assert pid != 1, "Allowing PID 1 on ZeroGPU will end up killing your Space"
|
| 191 |
+
assert api_client().allow(allow_token=allow_token, pid=pid) is httpx.codes.OK
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def release(
|
| 195 |
+
allow_token: str, *,
|
| 196 |
+
fail: bool = False,
|
| 197 |
+
allow_404: bool = False,
|
| 198 |
+
) -> None:
|
| 199 |
+
|
| 200 |
+
res = api_client().release(
|
| 201 |
+
allow_token=allow_token,
|
| 202 |
+
fail=fail,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
if res is httpx.codes.NO_CONTENT: # pragma: no cover
|
| 206 |
+
try:
|
| 207 |
+
gr.Warning(UNUSED_MESSAGE)
|
| 208 |
+
except AttributeError:
|
| 209 |
+
pass
|
| 210 |
+
warnings.warn(UNUSED_MESSAGE, RuntimeWarning)
|
| 211 |
+
return None
|
| 212 |
+
|
| 213 |
+
if res is httpx.codes.NOT_FOUND:
|
| 214 |
+
if not allow_404:
|
| 215 |
+
warnings.warn("ZeroGPU API /release warning: 404 Not Found")
|
| 216 |
+
return None
|
| 217 |
+
|
| 218 |
+
if httpx.codes.is_success(res):
|
| 219 |
+
return None
|
| 220 |
+
|
| 221 |
+
# TODO: Find a way to log 'detail' response field
|
| 222 |
+
# TODO: Only raise in dev environment. Simply warn in production ?
|
| 223 |
+
raise RuntimeError(f"ZeroGPU API /release error: {res} ({httpx.codes.get_reason_phrase(res)})") # pragma: no cover
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def _get_token(request: gr.Request | None) -> str | None:
|
| 227 |
+
|
| 228 |
+
if request is None:
|
| 229 |
+
return None
|
| 230 |
+
|
| 231 |
+
headers = getattr(request, 'headers', None)
|
| 232 |
+
if headers is None or not hasattr(headers, '__dict__'):
|
| 233 |
+
raise gr.Error("Internal Gradio error")
|
| 234 |
+
|
| 235 |
+
# Compatibility trick
|
| 236 |
+
if not hasattr(headers, 'get'):
|
| 237 |
+
headers = headers.__dict__ # pragma: no cover
|
| 238 |
+
|
| 239 |
+
return headers.get(TOKEN_HEADER.lower())
|
spaces/zero/decorator.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
"""
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import inspect
|
| 6 |
+
import sys
|
| 7 |
+
import warnings
|
| 8 |
+
from datetime import timedelta
|
| 9 |
+
from functools import partial
|
| 10 |
+
from typing import Callable
|
| 11 |
+
from typing import TypeVar
|
| 12 |
+
from typing import overload
|
| 13 |
+
from typing_extensions import ParamSpec
|
| 14 |
+
from typing_extensions import Unpack
|
| 15 |
+
|
| 16 |
+
from ..config import Config
|
| 17 |
+
from .types import DynamicDuration
|
| 18 |
+
from .types import EmptyKwargs
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
P = ParamSpec('P')
|
| 22 |
+
R = TypeVar('R')
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
decorated_cache: dict[Callable, Callable] = {}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@overload
|
| 29 |
+
def GPU(
|
| 30 |
+
task: None = None, *,
|
| 31 |
+
duration: DynamicDuration[P] = None,
|
| 32 |
+
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
| 33 |
+
...
|
| 34 |
+
@overload
|
| 35 |
+
def GPU(
|
| 36 |
+
task: Callable[P, R], *,
|
| 37 |
+
duration: DynamicDuration[P] = None,
|
| 38 |
+
) -> Callable[P, R]:
|
| 39 |
+
...
|
| 40 |
+
def GPU(
|
| 41 |
+
task: Callable[P, R] | None = None, *,
|
| 42 |
+
duration: DynamicDuration[P] = None,
|
| 43 |
+
**kwargs: Unpack[EmptyKwargs],
|
| 44 |
+
) -> Callable[[Callable[P, R]], Callable[P, R]] | Callable[P, R]:
|
| 45 |
+
"""
|
| 46 |
+
ZeroGPU decorator
|
| 47 |
+
|
| 48 |
+
Basic usage:
|
| 49 |
+
```
|
| 50 |
+
@spaces.GPU
|
| 51 |
+
def fn(...):
|
| 52 |
+
# CUDA is available here
|
| 53 |
+
pass
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
With custom duration:
|
| 57 |
+
```
|
| 58 |
+
@spaces.GPU(duration=45) # Expressed in seconds
|
| 59 |
+
def fn(...):
|
| 60 |
+
# CUDA is available here
|
| 61 |
+
pass
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
task (`Callable | None`): Python function that requires CUDA
|
| 66 |
+
duration (`int | datetime.timedelta`): Estimated duration in seconds or `datetime.timedelta`
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
`Callable`: GPU-ready function
|
| 70 |
+
"""
|
| 71 |
+
if "enable_queue" in kwargs:
|
| 72 |
+
warnings.warn("`enable_queue` parameter is now ignored and always set to `True`")
|
| 73 |
+
if task is None:
|
| 74 |
+
return partial(_GPU, duration=duration)
|
| 75 |
+
return _GPU(task, duration)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _GPU(
|
| 79 |
+
task: Callable[P, R],
|
| 80 |
+
duration: DynamicDuration[P],
|
| 81 |
+
) -> Callable[P, R]:
|
| 82 |
+
|
| 83 |
+
if not Config.zero_gpu:
|
| 84 |
+
return task
|
| 85 |
+
|
| 86 |
+
from . import client
|
| 87 |
+
from .wrappers import regular_function_wrapper
|
| 88 |
+
from .wrappers import generator_function_wrapper
|
| 89 |
+
|
| 90 |
+
if sys.version_info.minor < 9: # pragma: no cover
|
| 91 |
+
raise RuntimeError("Actually using @spaces.GPU on a ZeroGPU Space requires Python 3.9+")
|
| 92 |
+
|
| 93 |
+
if task in decorated_cache:
|
| 94 |
+
# TODO: Assert same duration ?
|
| 95 |
+
return decorated_cache[task] # type: ignore
|
| 96 |
+
|
| 97 |
+
if inspect.iscoroutinefunction(task):
|
| 98 |
+
raise NotImplementedError
|
| 99 |
+
|
| 100 |
+
if inspect.isgeneratorfunction(task):
|
| 101 |
+
decorated = generator_function_wrapper(task, duration)
|
| 102 |
+
else:
|
| 103 |
+
decorated = regular_function_wrapper(task, duration)
|
| 104 |
+
|
| 105 |
+
setattr(decorated, 'zerogpu', None)
|
| 106 |
+
|
| 107 |
+
client.startup_report()
|
| 108 |
+
decorated_cache.update({
|
| 109 |
+
task: decorated,
|
| 110 |
+
decorated: decorated,
|
| 111 |
+
})
|
| 112 |
+
|
| 113 |
+
return decorated # type: ignore
|
spaces/zero/gradio.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
"""
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from functools import wraps
|
| 6 |
+
from packaging import version
|
| 7 |
+
from typing import Callable
|
| 8 |
+
from typing import NamedTuple
|
| 9 |
+
from typing import TYPE_CHECKING
|
| 10 |
+
import warnings
|
| 11 |
+
|
| 12 |
+
import gradio as gr
|
| 13 |
+
from gradio.context import Context
|
| 14 |
+
from gradio.context import LocalContext
|
| 15 |
+
from gradio.helpers import Progress
|
| 16 |
+
from gradio.helpers import TrackedIterable
|
| 17 |
+
from gradio.queueing import Queue
|
| 18 |
+
from typing_extensions import ParamSpec
|
| 19 |
+
|
| 20 |
+
from ..utils import SimpleQueue
|
| 21 |
+
from .types import GeneratorResQueueResult
|
| 22 |
+
from .types import GradioQueueEvent
|
| 23 |
+
from .types import RegularResQueueResult
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
QUEUE_RPC_METHODS = [
|
| 27 |
+
"set_progress",
|
| 28 |
+
"log_message",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class GradioPartialContext(NamedTuple):
|
| 33 |
+
event_id: str | None
|
| 34 |
+
in_event_listener: bool
|
| 35 |
+
progress: Progress | None
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def get():
|
| 39 |
+
TrackedIterable.__reduce__ = tracked_iterable__reduce__
|
| 40 |
+
return GradioPartialContext(
|
| 41 |
+
event_id=LocalContext.event_id.get(),
|
| 42 |
+
in_event_listener=LocalContext.in_event_listener.get(),
|
| 43 |
+
progress=LocalContext.progress.get(),
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def apply(context: 'GradioPartialContext'):
|
| 48 |
+
LocalContext.event_id.set(context.event_id)
|
| 49 |
+
LocalContext.in_event_listener.set(context.in_event_listener)
|
| 50 |
+
LocalContext.progress.set(context.progress)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_queue_instance():
|
| 54 |
+
blocks = LocalContext.blocks.get()
|
| 55 |
+
if blocks is None: # pragma: no cover
|
| 56 |
+
return None
|
| 57 |
+
return blocks._queue
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def get_event():
|
| 61 |
+
queue = get_queue_instance()
|
| 62 |
+
event_id = LocalContext.event_id.get()
|
| 63 |
+
if queue is None:
|
| 64 |
+
return None
|
| 65 |
+
if event_id is None: # pragma: no cover
|
| 66 |
+
return None
|
| 67 |
+
for job in queue.active_jobs:
|
| 68 |
+
if job is None: # pragma: no cover
|
| 69 |
+
continue
|
| 70 |
+
for event in job:
|
| 71 |
+
if event._id == event_id:
|
| 72 |
+
return event
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_server_port() -> int | None:
|
| 76 |
+
from_request_context = True
|
| 77 |
+
if (blocks := LocalContext.blocks.get()) is None: # Request
|
| 78 |
+
from_request_context = False
|
| 79 |
+
if (blocks := Context.root_block) is None: # Caching
|
| 80 |
+
return None
|
| 81 |
+
if (server := getattr(blocks, 'server', None)) is None: # pragma: no cover (Gradio 4)
|
| 82 |
+
if from_request_context:
|
| 83 |
+
warnings.warn("Gradio: No blocks.server inside a request") # pragma: no cover
|
| 84 |
+
return -1
|
| 85 |
+
if TYPE_CHECKING:
|
| 86 |
+
assert (server := blocks.server)
|
| 87 |
+
return server.config.port
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def try_process_queue_event(method_name: str, *args, **kwargs):
|
| 91 |
+
queue = get_queue_instance()
|
| 92 |
+
if queue is None: # pragma: no cover
|
| 93 |
+
warnings.warn("ZeroGPU: Cannot get Gradio app Queue instance")
|
| 94 |
+
return
|
| 95 |
+
method = getattr(queue, method_name, None)
|
| 96 |
+
assert callable(method)
|
| 97 |
+
method(*args, **kwargs)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def patch_gradio_queue(
|
| 101 |
+
res_queue: SimpleQueue[RegularResQueueResult | None] | SimpleQueue[GeneratorResQueueResult | None],
|
| 102 |
+
):
|
| 103 |
+
|
| 104 |
+
def rpc_method(method_name: str):
|
| 105 |
+
def method(*args, **kwargs):
|
| 106 |
+
if args and isinstance(args[0], Queue):
|
| 107 |
+
args = args[1:] # drop `self`
|
| 108 |
+
res_queue.put(GradioQueueEvent(method_name, args, kwargs))
|
| 109 |
+
return method
|
| 110 |
+
|
| 111 |
+
for method_name in QUEUE_RPC_METHODS:
|
| 112 |
+
if (method := getattr(Queue, method_name, None)) is None: # pragma: no cover
|
| 113 |
+
warnings.warn(f"ZeroGPU: Gradio Queue has no {method_name} attribute")
|
| 114 |
+
continue
|
| 115 |
+
if not callable(method): # pragma: no cover
|
| 116 |
+
warnings.warn(f"ZeroGPU: Gradio Queue {method_name} is not callable")
|
| 117 |
+
continue
|
| 118 |
+
setattr(Queue, method_name, rpc_method(method_name))
|
| 119 |
+
|
| 120 |
+
TrackedIterable.__reduce__ = tracked_iterable__reduce__
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def tracked_iterable__reduce__(self):
|
| 124 |
+
res: tuple = super(TrackedIterable, self).__reduce__() # type: ignore
|
| 125 |
+
cls, base, state, *_ = res
|
| 126 |
+
return cls, base,{**state, **{
|
| 127 |
+
'iterable': None,
|
| 128 |
+
'_tqdm': None,
|
| 129 |
+
}}
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def supports_auth():
|
| 133 |
+
return version.parse(gr.__version__) >= version.Version('4.27.0')
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
Param = ParamSpec('Param')
|
| 137 |
+
|
| 138 |
+
def one_launch(task: Callable[Param, None], *task_args: Param.args, **task_kwargs: Param.kwargs):
|
| 139 |
+
_launch = gr.Blocks.launch
|
| 140 |
+
@wraps(gr.Blocks.launch)
|
| 141 |
+
def launch(*args, **kwargs):
|
| 142 |
+
task(*task_args, **task_kwargs)
|
| 143 |
+
gr.Blocks.launch = _launch
|
| 144 |
+
return gr.Blocks.launch(*args, **kwargs)
|
| 145 |
+
gr.Blocks.launch = launch
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class HTMLError(gr.Error):
|
| 149 |
+
def __str__(self): # pragma: no cover
|
| 150 |
+
return self.message
|
spaces/zero/torch/__init__.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
"""
|
| 3 |
+
|
| 4 |
+
from ...config import Config
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
except ImportError:
|
| 12 |
+
|
| 13 |
+
_patch = lambda *args, **kwargs: None
|
| 14 |
+
_unpatch = lambda *args, **kwargs: None
|
| 15 |
+
_pack = lambda *args, **kwargs: None
|
| 16 |
+
_init = lambda *args, **kwargs: None
|
| 17 |
+
_size = lambda *args, **kwargs: 0
|
| 18 |
+
_move = lambda *args, **kwargs: None
|
| 19 |
+
_is_in_bad_fork = lambda *args, **kwargs: False
|
| 20 |
+
|
| 21 |
+
else:
|
| 22 |
+
|
| 23 |
+
if Config.zero_gpu_v2:
|
| 24 |
+
from . import patching as _patching
|
| 25 |
+
else: # pragma: no cover
|
| 26 |
+
from . import patching_legacy as _patching
|
| 27 |
+
|
| 28 |
+
_patch = _patching.patch
|
| 29 |
+
_unpatch = _patching.unpatch
|
| 30 |
+
_pack = _patching.pack
|
| 31 |
+
_init = _patching.init
|
| 32 |
+
_size = _patching.size
|
| 33 |
+
_move = _patching.move
|
| 34 |
+
_is_in_bad_fork = _patching.is_in_bad_fork
|
| 35 |
+
|
| 36 |
+
patch = _patch
|
| 37 |
+
unpatch = _unpatch
|
| 38 |
+
pack = _pack
|
| 39 |
+
init = _init
|
| 40 |
+
size = _size
|
| 41 |
+
move = _move
|
| 42 |
+
is_in_bad_fork = _is_in_bad_fork
|
spaces/zero/torch/bitsandbytes.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
"""
|
| 3 |
+
# pyright: reportPrivateImportUsage=false
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import importlib
|
| 8 |
+
from contextlib import contextmanager
|
| 9 |
+
from importlib import metadata
|
| 10 |
+
from types import ModuleType
|
| 11 |
+
from typing import TYPE_CHECKING
|
| 12 |
+
from typing import Tuple
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from packaging import version
|
| 16 |
+
|
| 17 |
+
if TYPE_CHECKING:
|
| 18 |
+
import torch as Torch
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@contextmanager
|
| 22 |
+
def cuda_unavailable(torch: ModuleType):
|
| 23 |
+
_is_available = torch.cuda.is_available
|
| 24 |
+
torch.cuda.is_available = lambda: False
|
| 25 |
+
yield
|
| 26 |
+
torch.cuda.is_available = _is_available
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def maybe_import_bitsandbytes():
|
| 30 |
+
try:
|
| 31 |
+
import torch
|
| 32 |
+
except ImportError: # pragma: no cover
|
| 33 |
+
return None
|
| 34 |
+
with cuda_unavailable(torch):
|
| 35 |
+
try:
|
| 36 |
+
import bitsandbytes
|
| 37 |
+
except ImportError:
|
| 38 |
+
bitsandbytes = None
|
| 39 |
+
else:
|
| 40 |
+
if (bnb_version := version.parse(metadata.version('bitsandbytes'))) < version.parse('0.40.0'):
|
| 41 |
+
raise RuntimeError(f"ZeroGPU requires bitsandbytes >= 0.40.0 (installed: {bnb_version})") # pragma: no cover
|
| 42 |
+
print("↑ Those bitsandbytes warnings are expected on ZeroGPU ↑")
|
| 43 |
+
return bitsandbytes
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
if (bnb := maybe_import_bitsandbytes()):
|
| 47 |
+
|
| 48 |
+
from torch.utils.weak import WeakTensorKeyDictionary
|
| 49 |
+
|
| 50 |
+
with cuda_unavailable(torch):
|
| 51 |
+
from bitsandbytes import cextension
|
| 52 |
+
from bitsandbytes import functional
|
| 53 |
+
try: # bitsandbytes < 0.44
|
| 54 |
+
from bitsandbytes.cuda_setup.main import CUDASetup
|
| 55 |
+
except ModuleNotFoundError: # pragma: no cover
|
| 56 |
+
CUDASetup = None
|
| 57 |
+
from bitsandbytes.nn import Int8Params
|
| 58 |
+
from bitsandbytes.nn import Params4bit
|
| 59 |
+
|
| 60 |
+
_param_to_8bit = Int8Params.to # type: ignore
|
| 61 |
+
_param_cuda_8bit = Int8Params.cuda
|
| 62 |
+
_param_to_4bit = Params4bit.to # type: ignore
|
| 63 |
+
_param_cuda_4bit = Params4bit.cuda
|
| 64 |
+
|
| 65 |
+
TensorToArgs = Tuple[torch.device, torch.dtype, bool, torch.memory_format]
|
| 66 |
+
|
| 67 |
+
to_ops_8bit: dict[Int8Params, TensorToArgs | None] = WeakTensorKeyDictionary() # type: ignore
|
| 68 |
+
to_ops_4bit: dict[Params4bit, TensorToArgs | None] = WeakTensorKeyDictionary() # type: ignore
|
| 69 |
+
|
| 70 |
+
def _to_op_register_8bit(self: Int8Params, *args, **kwargs):
|
| 71 |
+
parsed = torch._C._nn._parse_to(*args, **kwargs)
|
| 72 |
+
device, *_ = parsed
|
| 73 |
+
if not isinstance(device, torch.device): # pragma: no cover
|
| 74 |
+
return _param_to_8bit(self, *args, **kwargs)
|
| 75 |
+
if device.type != 'cuda':
|
| 76 |
+
return _param_to_8bit(self, *args, **kwargs)
|
| 77 |
+
to_ops_8bit[self] = parsed
|
| 78 |
+
return self
|
| 79 |
+
|
| 80 |
+
def _to_op_register_4bit(self: Params4bit, *args, **kwargs):
|
| 81 |
+
parsed = torch._C._nn._parse_to(*args, **kwargs)
|
| 82 |
+
device, *_ = parsed
|
| 83 |
+
if not isinstance(device, torch.device): # pragma: no cover
|
| 84 |
+
return _param_to_4bit(self, *args, **kwargs)
|
| 85 |
+
if device.type != 'cuda':
|
| 86 |
+
return _param_to_4bit(self, *args, **kwargs)
|
| 87 |
+
to_ops_4bit[self] = parsed
|
| 88 |
+
return self
|
| 89 |
+
|
| 90 |
+
def _cuda_op_arg_check(device: Torch.device | int | str | None) -> bool:
|
| 91 |
+
if device is None: # pragma: no cover
|
| 92 |
+
return True
|
| 93 |
+
if isinstance(device, int):
|
| 94 |
+
return True
|
| 95 |
+
if isinstance(device, str): # pragma: no cover
|
| 96 |
+
device = torch.device(device)
|
| 97 |
+
return device.type == 'cuda' # pragma: no cover
|
| 98 |
+
|
| 99 |
+
def _cuda_op_register_8bit(self: Int8Params, device: Torch.device | int | str | None = None, **kwargs):
|
| 100 |
+
if not _cuda_op_arg_check(device): # pragma: no cover
|
| 101 |
+
# Let PyTorch handle the fail
|
| 102 |
+
return _param_cuda_8bit(self, device, **kwargs)
|
| 103 |
+
to_ops_8bit[self] = None
|
| 104 |
+
return self
|
| 105 |
+
|
| 106 |
+
def _cuda_op_register_4bit(self: Params4bit, device: Torch.device | int | str | None = None, **kwargs):
|
| 107 |
+
if not _cuda_op_arg_check(device): # pragma: no cover
|
| 108 |
+
# Let PyTorch handle the fail
|
| 109 |
+
return _param_cuda_4bit(self, device, **kwargs)
|
| 110 |
+
to_ops_4bit[self] = None
|
| 111 |
+
return self
|
| 112 |
+
|
| 113 |
+
def _patch():
|
| 114 |
+
Int8Params.to = _to_op_register_8bit # type: ignore
|
| 115 |
+
Int8Params.cuda = _cuda_op_register_8bit # type: ignore
|
| 116 |
+
Params4bit.to = _to_op_register_4bit # type: ignore
|
| 117 |
+
Params4bit.cuda = _cuda_op_register_4bit # type: ignore
|
| 118 |
+
|
| 119 |
+
def _unpatch():
|
| 120 |
+
Int8Params.to = _param_to_8bit # type: ignore
|
| 121 |
+
Int8Params.cuda = _param_cuda_8bit
|
| 122 |
+
Params4bit.to = _param_to_4bit # type: ignore
|
| 123 |
+
Params4bit.cuda = _param_cuda_4bit
|
| 124 |
+
|
| 125 |
+
def _move():
|
| 126 |
+
if CUDASetup is not None:
|
| 127 |
+
CUDASetup._instance = None
|
| 128 |
+
importlib.reload(cextension)
|
| 129 |
+
functional.lib = cextension.lib
|
| 130 |
+
for op in to_ops_8bit.items():
|
| 131 |
+
tensor, parsed_args = op
|
| 132 |
+
if parsed_args:
|
| 133 |
+
_, dtype, _, memory_format = parsed_args
|
| 134 |
+
else:
|
| 135 |
+
dtype, memory_format = None, None
|
| 136 |
+
tensor.data = _param_to_8bit(tensor,
|
| 137 |
+
device='cuda',
|
| 138 |
+
dtype=dtype,
|
| 139 |
+
memory_format=memory_format,
|
| 140 |
+
) # type: ignore
|
| 141 |
+
for op in to_ops_4bit.items():
|
| 142 |
+
tensor, parsed_args = op
|
| 143 |
+
if parsed_args:
|
| 144 |
+
_, dtype, _, memory_format = parsed_args
|
| 145 |
+
else:
|
| 146 |
+
dtype, memory_format = None, None
|
| 147 |
+
tensor.data = _param_to_4bit(tensor,
|
| 148 |
+
device='cuda',
|
| 149 |
+
dtype=dtype,
|
| 150 |
+
memory_format=memory_format,
|
| 151 |
+
) # type: ignore
|
| 152 |
+
|
| 153 |
+
else:
|
| 154 |
+
|
| 155 |
+
_patch = lambda: None
|
| 156 |
+
_unpatch = lambda: None
|
| 157 |
+
_move = lambda: None
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
patch = _patch
|
| 161 |
+
unpatch = _unpatch
|
| 162 |
+
move = _move
|
spaces/zero/torch/packing.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
"""
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
import ctypes
|
| 8 |
+
import os
|
| 9 |
+
from concurrent.futures import as_completed
|
| 10 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 11 |
+
from contextvars import copy_context
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from queue import Queue
|
| 14 |
+
from typing import Callable
|
| 15 |
+
|
| 16 |
+
from ...utils import debug
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from typing_extensions import TypeAlias
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
PAGE_SIZE = 4096
|
| 23 |
+
TOTAL_MEMORY = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')
|
| 24 |
+
VM_MAX_SIZE = min(2**38, TOTAL_MEMORY // 2)
|
| 25 |
+
|
| 26 |
+
BUFFER_SIZE = 64 * 2**20
|
| 27 |
+
BUFFER_COUNT = 2
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
TensorWithSizes: TypeAlias = 'tuple[torch.Tensor, int, int]'
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class ZeroGPUTensorPack:
|
| 34 |
+
base_dir: str
|
| 35 |
+
batches: list[list[TensorWithSizes]]
|
| 36 |
+
big_tensors: list[TensorWithSizes]
|
| 37 |
+
fakes: dict[torch.Tensor, list[torch.Tensor]]
|
| 38 |
+
total_size: int
|
| 39 |
+
def path(self):
|
| 40 |
+
return f'{self.base_dir}/{id(self)}'
|
| 41 |
+
def __del__(self):
|
| 42 |
+
try:
|
| 43 |
+
os.remove(self.path())
|
| 44 |
+
except FileNotFoundError: # pragma: no cover
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def write(fd: int, tensor: torch.Tensor):
|
| 49 |
+
clone = torch.empty_like(tensor)
|
| 50 |
+
size = clone.untyped_storage().size() # pyright: ignore [reportAttributeAccessIssue]
|
| 51 |
+
buffer = torch.UntypedStorage(VM_MAX_SIZE)
|
| 52 |
+
buffer_ptr = buffer.data_ptr()
|
| 53 |
+
offset = -buffer_ptr % PAGE_SIZE
|
| 54 |
+
padding = -size % PAGE_SIZE
|
| 55 |
+
clone.set_(buffer[offset:offset+size], 0, clone.shape, clone.stride()) # pyright: ignore [reportArgumentType]
|
| 56 |
+
clone.copy_(tensor)
|
| 57 |
+
mv = memoryview((ctypes.c_char * (size+padding)).from_address(buffer_ptr+offset))
|
| 58 |
+
written_bytes = 0
|
| 59 |
+
while written_bytes < size:
|
| 60 |
+
written_bytes += os.write(fd, mv[written_bytes:])
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def pack_tensors(
|
| 64 |
+
tensors: set[torch.Tensor],
|
| 65 |
+
fakes: dict[torch.Tensor, list[torch.Tensor]],
|
| 66 |
+
offload_dir: str,
|
| 67 |
+
callback: Callable[[int]] | None = None,
|
| 68 |
+
):
|
| 69 |
+
|
| 70 |
+
callback = (lambda bytes: None) if callback is None else callback
|
| 71 |
+
|
| 72 |
+
batches: list[list[TensorWithSizes]] = []
|
| 73 |
+
big_tensors: list[TensorWithSizes] = []
|
| 74 |
+
|
| 75 |
+
tensors_with_sizes: list[tuple[torch.Tensor, int, int]] = []
|
| 76 |
+
for tensor in tensors:
|
| 77 |
+
size = tensor.numel() * tensor.element_size()
|
| 78 |
+
aligned_size = size + (-size % PAGE_SIZE)
|
| 79 |
+
tensors_with_sizes += [(tensor, size, aligned_size)]
|
| 80 |
+
|
| 81 |
+
current_batch, current_size = [], 0
|
| 82 |
+
for (tensor, size, aligned_size) in sorted(tensors_with_sizes, key=lambda item: item[2]):
|
| 83 |
+
if aligned_size > BUFFER_SIZE:
|
| 84 |
+
big_tensors += [(tensor, size, aligned_size)]
|
| 85 |
+
continue
|
| 86 |
+
current_size += aligned_size
|
| 87 |
+
if current_size > BUFFER_SIZE:
|
| 88 |
+
batches += [current_batch]
|
| 89 |
+
current_batch, current_size = [(tensor, size, aligned_size)], aligned_size
|
| 90 |
+
else:
|
| 91 |
+
current_batch += [(tensor, size, aligned_size)]
|
| 92 |
+
|
| 93 |
+
if current_batch:
|
| 94 |
+
batches += [current_batch]
|
| 95 |
+
|
| 96 |
+
get_meta = {tensor: torch.empty_like(tensor) for tensor in tensors}
|
| 97 |
+
batches_meta = [[(get_meta[tensor], size, asize) for tensor, size, asize in batch] for batch in batches]
|
| 98 |
+
big_tensors_meta = [(get_meta[tensor], size, asize) for tensor, size, asize in big_tensors]
|
| 99 |
+
fakes_meta = {get_meta[tensor]: fake_list for tensor, fake_list in fakes.items()}
|
| 100 |
+
|
| 101 |
+
pack = ZeroGPUTensorPack(
|
| 102 |
+
base_dir=offload_dir,
|
| 103 |
+
batches=batches_meta,
|
| 104 |
+
big_tensors=big_tensors_meta,
|
| 105 |
+
fakes=fakes_meta,
|
| 106 |
+
total_size=sum([size for _, size, _ in tensors_with_sizes]),
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
fd = os.open(pack.path(), os.O_CREAT | os.O_WRONLY | os.O_DIRECT)
|
| 110 |
+
try:
|
| 111 |
+
total_asize = sum([aligned_size for batch in batches for *_, aligned_size in batch])
|
| 112 |
+
total_asize += sum([aligned_size for *_, aligned_size in big_tensors])
|
| 113 |
+
if total_asize > 0:
|
| 114 |
+
os.posix_fallocate(fd, 0, total_asize)
|
| 115 |
+
for batch in batches:
|
| 116 |
+
for tensor, size, _ in batch:
|
| 117 |
+
write(fd, tensor)
|
| 118 |
+
callback(size)
|
| 119 |
+
for tensor, size, _ in big_tensors:
|
| 120 |
+
write(fd, tensor)
|
| 121 |
+
callback(size)
|
| 122 |
+
return pack
|
| 123 |
+
finally:
|
| 124 |
+
os.close(fd)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def pack_to_cuda(pack: ZeroGPUTensorPack, callback: Callable[[int]] | None = None):
|
| 128 |
+
|
| 129 |
+
callback = (lambda bytes: None) if callback is None else callback
|
| 130 |
+
|
| 131 |
+
free_buffers: Queue[torch.Tensor] = Queue()
|
| 132 |
+
read_buffers: Queue[torch.Tensor] = Queue()
|
| 133 |
+
|
| 134 |
+
for _ in range(BUFFER_COUNT):
|
| 135 |
+
free_buffers.put(torch.ByteTensor(BUFFER_SIZE).pin_memory())
|
| 136 |
+
|
| 137 |
+
def read(fd: int, buffer: torch.Tensor, size: int):
|
| 138 |
+
mv = memoryview((ctypes.c_char * size).from_address(buffer.data_ptr()))
|
| 139 |
+
read_bytes = 0
|
| 140 |
+
while read_bytes < size:
|
| 141 |
+
read_bytes += os.readv(fd, [mv[read_bytes:]])
|
| 142 |
+
|
| 143 |
+
def disk_to_pin(fd: int):
|
| 144 |
+
for batch in pack.batches:
|
| 145 |
+
buffer = free_buffers.get()
|
| 146 |
+
batch_size = sum([aligned_size for *_, aligned_size in batch])
|
| 147 |
+
read(fd, buffer, batch_size)
|
| 148 |
+
read_buffers.put(buffer)
|
| 149 |
+
for *_, aligned_size in pack.big_tensors:
|
| 150 |
+
read_bytes = 0
|
| 151 |
+
while read_bytes < aligned_size:
|
| 152 |
+
buffer = free_buffers.get()
|
| 153 |
+
read_size = min(BUFFER_SIZE, aligned_size - read_bytes)
|
| 154 |
+
read(fd, buffer, read_size)
|
| 155 |
+
read_buffers.put(buffer)
|
| 156 |
+
read_bytes += read_size
|
| 157 |
+
|
| 158 |
+
def pin_to_cuda():
|
| 159 |
+
total_duration_in_callback = 0
|
| 160 |
+
for batch in pack.batches:
|
| 161 |
+
buffer = read_buffers.get()
|
| 162 |
+
offset = 0
|
| 163 |
+
cuda_storages = []
|
| 164 |
+
for tensor, size, aligned_size in batch:
|
| 165 |
+
cuda_storages += [buffer[offset:offset+size].cuda(non_blocking=True)]
|
| 166 |
+
offset += aligned_size
|
| 167 |
+
torch.cuda.synchronize()
|
| 168 |
+
free_buffers.put(buffer)
|
| 169 |
+
batch_total_size = 0
|
| 170 |
+
for (tensor, size, _), cuda_storage in zip(batch, cuda_storages):
|
| 171 |
+
cuda_tensor = torch.tensor([], dtype=tensor.dtype, device='cuda')
|
| 172 |
+
cuda_tensor = cuda_tensor.set_(cuda_storage.untyped_storage(), 0, tensor.shape, tensor.stride())
|
| 173 |
+
for fake in pack.fakes[tensor]:
|
| 174 |
+
fake.data = cuda_tensor
|
| 175 |
+
batch_total_size += size
|
| 176 |
+
t0 = time.perf_counter()
|
| 177 |
+
callback(batch_total_size)
|
| 178 |
+
total_duration_in_callback += time.perf_counter() - t0
|
| 179 |
+
for tensor, size, _ in pack.big_tensors:
|
| 180 |
+
cuda_storage = torch.empty(size, dtype=torch.uint8, device='cuda')
|
| 181 |
+
offset = 0
|
| 182 |
+
while offset < size:
|
| 183 |
+
buffer = read_buffers.get()
|
| 184 |
+
read_size = min(BUFFER_SIZE, size - offset)
|
| 185 |
+
cuda_storage[offset:offset+read_size] = buffer[:read_size]
|
| 186 |
+
offset += read_size
|
| 187 |
+
torch.cuda.synchronize() # Probably not needed
|
| 188 |
+
free_buffers.put(buffer)
|
| 189 |
+
t0 = time.perf_counter()
|
| 190 |
+
callback(read_size)
|
| 191 |
+
total_duration_in_callback += time.perf_counter() - t0
|
| 192 |
+
cuda_tensor = torch.tensor([], dtype=tensor.dtype, device='cuda')
|
| 193 |
+
cuda_tensor = cuda_tensor.set_(cuda_storage.untyped_storage(), 0, tensor.shape, tensor.stride())
|
| 194 |
+
for fake in pack.fakes[tensor]:
|
| 195 |
+
fake.data = cuda_tensor
|
| 196 |
+
|
| 197 |
+
debug(f"{total_duration_in_callback=}")
|
| 198 |
+
|
| 199 |
+
with ThreadPoolExecutor(2) as e:
|
| 200 |
+
fd = os.open(pack.path(), os.O_RDONLY | os.O_DIRECT)
|
| 201 |
+
try:
|
| 202 |
+
futures = [
|
| 203 |
+
e.submit(copy_context().run, disk_to_pin, fd),
|
| 204 |
+
e.submit(copy_context().run, pin_to_cuda),
|
| 205 |
+
]
|
| 206 |
+
for future in as_completed(futures):
|
| 207 |
+
future.result()
|
| 208 |
+
finally:
|
| 209 |
+
os.close(fd)
|
spaces/zero/torch/patching.py
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
"""
|
| 3 |
+
# pyright: reportPrivateImportUsage=false
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import gc
|
| 8 |
+
import multiprocessing
|
| 9 |
+
import os
|
| 10 |
+
from collections import defaultdict
|
| 11 |
+
from concurrent.futures import ProcessPoolExecutor
|
| 12 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 13 |
+
from contextlib import nullcontext
|
| 14 |
+
from contextvars import copy_context
|
| 15 |
+
from types import SimpleNamespace
|
| 16 |
+
from typing import Any
|
| 17 |
+
from typing import Callable
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from torch.overrides import TorchFunctionMode
|
| 21 |
+
from torch.overrides import resolve_name
|
| 22 |
+
from torch.utils._python_dispatch import TorchDispatchMode
|
| 23 |
+
from torch.utils._pytree import tree_map_only
|
| 24 |
+
from torch.utils.weak import WeakTensorKeyDictionary
|
| 25 |
+
|
| 26 |
+
from ...config import Config
|
| 27 |
+
from ...utils import malloc_trim
|
| 28 |
+
from ..tqdm import tqdm
|
| 29 |
+
from . import bitsandbytes
|
| 30 |
+
from .packing import ZeroGPUTensorPack
|
| 31 |
+
from .packing import pack_tensors
|
| 32 |
+
from .packing import pack_to_cuda
|
| 33 |
+
from .types import AliasId
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# Nvidia A100.80G MIG (drivers 535) / Torch 2.2.0
|
| 37 |
+
CUDA_DEVICE_NAME = 'NVIDIA A100-SXM4-80GB MIG 3g.40gb'
|
| 38 |
+
CUDA_TOTAL_MEMORY = 42144366592
|
| 39 |
+
CUDA_MEM_GET_INFO = (41911451648, CUDA_TOTAL_MEMORY)
|
| 40 |
+
CUDA_DEVICE_CAPABILITY = (8, 0)
|
| 41 |
+
CUDA_DEVICE_PROPERTIES = SimpleNamespace(name=CUDA_DEVICE_NAME, major=8, minor=0, total_memory=CUDA_TOTAL_MEMORY, multi_processor_count=42)
|
| 42 |
+
|
| 43 |
+
OPS_INPUTS_CHECK_NO_RETURN = (
|
| 44 |
+
torch.Tensor.equal,
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
OPS_INPUT_CHECK_SELF_RETURN = (
|
| 48 |
+
torch.Tensor.set_, # probably never dispatched
|
| 49 |
+
torch.ops.aten.set_.source_Tensor, # pyright: ignore [reportAttributeAccessIssue]
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
OFFLOADED_ERROR_MESSAGE = "Cannot apply function {} on disk-offloaded Tensor {}"
|
| 53 |
+
|
| 54 |
+
_tensor_make_subclass = torch.Tensor._make_subclass
|
| 55 |
+
_asarray = torch.asarray
|
| 56 |
+
_cuda_init = torch._C._cuda_init
|
| 57 |
+
_cuda_exchange_device = torch.cuda._exchange_device
|
| 58 |
+
_cuda_available = torch.cuda.is_available
|
| 59 |
+
_cuda_device_count = torch.cuda.device_count
|
| 60 |
+
_cuda_current_device = torch.cuda.current_device
|
| 61 |
+
_cuda_mem_get_info = torch.cuda.mem_get_info
|
| 62 |
+
_cuda_get_device_capability = torch.cuda.get_device_capability
|
| 63 |
+
_cuda_get_device_properties = torch.cuda.get_device_properties
|
| 64 |
+
_cuda_get_device_name = torch.cuda.get_device_name
|
| 65 |
+
|
| 66 |
+
# PyTorch 2.3
|
| 67 |
+
_cuda_maybe_exchange_device = getattr(torch.cuda, '_maybe_exchange_device', None)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
cuda_aliases: dict[torch.Tensor, torch.Tensor | None] = WeakTensorKeyDictionary() # pyright: ignore [reportAssignmentType]
|
| 71 |
+
|
| 72 |
+
tensor_packs: list[ZeroGPUTensorPack] = []
|
| 73 |
+
|
| 74 |
+
class ZeroGPUTensor(torch.Tensor):
|
| 75 |
+
pass
|
| 76 |
+
|
| 77 |
+
def empty_fake(tensor: torch.Tensor):
|
| 78 |
+
fake = torch.empty_like(tensor, requires_grad=tensor.requires_grad)
|
| 79 |
+
if fake.__class__ != tensor.__class__:
|
| 80 |
+
fake = _tensor_make_subclass(tensor.__class__, fake, require_grad=tensor.requires_grad) # pyright: ignore [reportArgumentType]
|
| 81 |
+
return fake
|
| 82 |
+
|
| 83 |
+
class ZeroGPUFunctionMode(TorchFunctionMode):
|
| 84 |
+
|
| 85 |
+
def __torch_function__(self, func, types, args=(), kwargs: dict[str, Any] | None = None):
|
| 86 |
+
|
| 87 |
+
kwargs = {} if kwargs is None else kwargs
|
| 88 |
+
|
| 89 |
+
if func == torch._C._nn._parse_to:
|
| 90 |
+
return func(*args, **kwargs)
|
| 91 |
+
|
| 92 |
+
# Redispatch: tensor.cuda() -> tensor.to(device='cuda')
|
| 93 |
+
if func == torch.Tensor.cuda or func == torch.Tensor.cpu:
|
| 94 |
+
memory_format = kwargs.get('memory_format')
|
| 95 |
+
return self.__torch_function__(torch.Tensor.to, types, (args[0],), {
|
| 96 |
+
'device': 'cuda' if func == torch.Tensor.cuda else 'cpu',
|
| 97 |
+
**({'memory_format': memory_format} if memory_format is not None else {}),
|
| 98 |
+
})
|
| 99 |
+
|
| 100 |
+
# Redispatch: tensor.to('cuda') -> tensor.to(device='cuda')
|
| 101 |
+
if func == torch.Tensor.to and len(args) > 1:
|
| 102 |
+
device, dtype, _, memory_format = torch._C._nn._parse_to(*args[1:], **kwargs)
|
| 103 |
+
return self.__torch_function__(torch.Tensor.to, types, (args[0],), {
|
| 104 |
+
'device': device,
|
| 105 |
+
'dtype': dtype,
|
| 106 |
+
'memory_format': memory_format,
|
| 107 |
+
})
|
| 108 |
+
|
| 109 |
+
if func == torch.Tensor.data.__set__: # pyright: ignore [reportAttributeAccessIssue]
|
| 110 |
+
self, target = args
|
| 111 |
+
if target in cuda_aliases:
|
| 112 |
+
if (target_original := cuda_aliases[target]) is None:
|
| 113 |
+
raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), target))
|
| 114 |
+
original = empty_fake(self)
|
| 115 |
+
original.data = target_original
|
| 116 |
+
cuda_aliases[self] = original
|
| 117 |
+
elif self in cuda_aliases:
|
| 118 |
+
del cuda_aliases[self]
|
| 119 |
+
self.data = target
|
| 120 |
+
return
|
| 121 |
+
|
| 122 |
+
if func == torch.Tensor.device.__get__:
|
| 123 |
+
tensor, = args
|
| 124 |
+
if tensor in cuda_aliases:
|
| 125 |
+
return torch.device('cuda', index=0)
|
| 126 |
+
|
| 127 |
+
elif func == torch.Tensor.__repr__:
|
| 128 |
+
tensor, = args
|
| 129 |
+
if tensor in cuda_aliases:
|
| 130 |
+
if (original := cuda_aliases[tensor]) is None:
|
| 131 |
+
original = tensor.to('meta')
|
| 132 |
+
original_class = original.__class__
|
| 133 |
+
original.__class__ = ZeroGPUTensor
|
| 134 |
+
try:
|
| 135 |
+
return func(original, **kwargs)
|
| 136 |
+
finally:
|
| 137 |
+
original.__class__ = original_class
|
| 138 |
+
|
| 139 |
+
elif func == torch.Tensor.untyped_storage:
|
| 140 |
+
tensor, = args
|
| 141 |
+
if tensor in cuda_aliases:
|
| 142 |
+
if (original := cuda_aliases[tensor]) is None:
|
| 143 |
+
raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), tensor))
|
| 144 |
+
res = func(original, **kwargs)
|
| 145 |
+
res._zerogpu = True
|
| 146 |
+
return res
|
| 147 |
+
|
| 148 |
+
cuda: bool | None = None
|
| 149 |
+
|
| 150 |
+
# Handle device kwarg
|
| 151 |
+
if (device := kwargs.get('device')) is not None:
|
| 152 |
+
device = torch.device(device)
|
| 153 |
+
if device.type == 'cuda':
|
| 154 |
+
kwargs['device'] = torch.device('cpu')
|
| 155 |
+
cuda = True
|
| 156 |
+
else:
|
| 157 |
+
cuda = False
|
| 158 |
+
|
| 159 |
+
# Swap fake inputs with original data
|
| 160 |
+
swapped = {}
|
| 161 |
+
inputs_are_cuda = set()
|
| 162 |
+
def swap(tensor: torch.Tensor):
|
| 163 |
+
nonlocal inputs_are_cuda
|
| 164 |
+
if tensor not in cuda_aliases:
|
| 165 |
+
inputs_are_cuda |= {False}
|
| 166 |
+
return tensor
|
| 167 |
+
if (original := cuda_aliases[tensor]) is None:
|
| 168 |
+
raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), tensor))
|
| 169 |
+
swapped[original] = tensor
|
| 170 |
+
inputs_are_cuda |= {True}
|
| 171 |
+
return original
|
| 172 |
+
args_ = tree_map_only(torch.Tensor, swap, args)
|
| 173 |
+
kwargs_ = tree_map_only(torch.Tensor, swap, kwargs)
|
| 174 |
+
if inputs_are_cuda == {True}:
|
| 175 |
+
if cuda is not False:
|
| 176 |
+
cuda = True
|
| 177 |
+
|
| 178 |
+
res = func(*args_, **kwargs_)
|
| 179 |
+
|
| 180 |
+
# Re-generate swapped fakes in case of mutation
|
| 181 |
+
for original, fake in swapped.items():
|
| 182 |
+
fake.data = empty_fake(original)
|
| 183 |
+
|
| 184 |
+
# Special case for Tensor indexing where only 'self' matters
|
| 185 |
+
if func in {
|
| 186 |
+
torch.ops.aten.index.Tensor, # pyright: ignore [reportAttributeAccessIssue]
|
| 187 |
+
torch.Tensor.__getitem__, # PyTorch 2.4+
|
| 188 |
+
}:
|
| 189 |
+
self = args[0]
|
| 190 |
+
cuda = self in cuda_aliases
|
| 191 |
+
inputs_are_cuda = {cuda}
|
| 192 |
+
|
| 193 |
+
# Emulate device check
|
| 194 |
+
if isinstance(res, torch.Tensor) or func in OPS_INPUTS_CHECK_NO_RETURN:
|
| 195 |
+
self = None
|
| 196 |
+
if len(args_) >= 1 and isinstance(args_[0], torch.Tensor):
|
| 197 |
+
self = args_[0]
|
| 198 |
+
# Only raise if func does not return its first input (Tensor.copy_)
|
| 199 |
+
if res is not self or func in OPS_INPUT_CHECK_SELF_RETURN:
|
| 200 |
+
if inputs_are_cuda == {True, False}:
|
| 201 |
+
raise RuntimeError(
|
| 202 |
+
"Expected all tensors to be on the same device, "
|
| 203 |
+
"but found at least two devices, cuda:0 (ZeroGPU) and cpu!"
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Register output
|
| 207 |
+
def register(tensor: torch.Tensor):
|
| 208 |
+
if tensor in swapped and cuda is not False:
|
| 209 |
+
return swapped[tensor]
|
| 210 |
+
if cuda is not True:
|
| 211 |
+
return tensor
|
| 212 |
+
fake = empty_fake(tensor)
|
| 213 |
+
cuda_aliases[fake] = tensor
|
| 214 |
+
return fake
|
| 215 |
+
|
| 216 |
+
return tree_map_only(torch.Tensor, register, res)
|
| 217 |
+
|
| 218 |
+
# When enabling DispatchMode, some aten ops are dispatched to FunctionMode
|
| 219 |
+
# We are using it for aten.alias.default and aten.set_.source_Tensor
|
| 220 |
+
class DefaultDispatchMode(TorchDispatchMode):
|
| 221 |
+
def __torch_dispatch__(self, func, types, args=(), kwargs: dict[str, Any] | None = None):
|
| 222 |
+
return func(*args, **(kwargs or {}))
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
function_mode = ZeroGPUFunctionMode()
|
| 226 |
+
dispatch_mode = DefaultDispatchMode()
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def _untyped_storage_new_register(*args, **kwargs):
|
| 230 |
+
cuda = False
|
| 231 |
+
if (device := kwargs.get('device')) is not None and device.type == 'cuda':
|
| 232 |
+
cuda = True
|
| 233 |
+
del kwargs['device']
|
| 234 |
+
storage = torch._C.StorageBase.__new__(*args, **kwargs)
|
| 235 |
+
if cuda:
|
| 236 |
+
storage._zerogpu = True
|
| 237 |
+
return storage
|
| 238 |
+
|
| 239 |
+
@property
|
| 240 |
+
def _untyped_storage_device(self):
|
| 241 |
+
if hasattr(self, '_zerogpu'):
|
| 242 |
+
return torch.device('cuda', index=0)
|
| 243 |
+
return torch._C.StorageBase.device.__get__(self) # pyright: ignore [reportAttributeAccessIssue]
|
| 244 |
+
|
| 245 |
+
# Force dispatch
|
| 246 |
+
def _tensor_make_subclass_function_mode(*args, **kwargs):
|
| 247 |
+
with torch._C.DisableTorchFunction():
|
| 248 |
+
return function_mode.__torch_function__(_tensor_make_subclass, (), args=args, kwargs=kwargs)
|
| 249 |
+
def _asarray_function_mode(*args, **kwargs):
|
| 250 |
+
with torch._C.DisableTorchFunction():
|
| 251 |
+
return function_mode.__torch_function__(_asarray, (), args=args, kwargs=kwargs)
|
| 252 |
+
|
| 253 |
+
def _cuda_init_raise():
|
| 254 |
+
raise RuntimeError(
|
| 255 |
+
"CUDA must not be initialized in the main process "
|
| 256 |
+
"on Spaces with Stateless GPU environment.\n"
|
| 257 |
+
"You can look at this Stacktrace to find out "
|
| 258 |
+
"which part of your code triggered a CUDA init"
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
def _cuda_dummy_exchange_device(device):
|
| 262 |
+
assert device in {-1, 0}
|
| 263 |
+
return device
|
| 264 |
+
|
| 265 |
+
def patch():
|
| 266 |
+
function_mode.__enter__()
|
| 267 |
+
dispatch_mode.__enter__()
|
| 268 |
+
# TODO: only patch bellow methods on current Thread to be consistent with TorchModes
|
| 269 |
+
# (or hijack threading.Thread.__init__ to force Modes on all threads)
|
| 270 |
+
torch.Tensor._make_subclass = _tensor_make_subclass_function_mode # pyright: ignore [reportAttributeAccessIssue]
|
| 271 |
+
torch.UntypedStorage.__new__ = _untyped_storage_new_register
|
| 272 |
+
torch.UntypedStorage.device = _untyped_storage_device # pyright: ignore [reportAttributeAccessIssue]
|
| 273 |
+
torch.asarray = _asarray_function_mode
|
| 274 |
+
torch._C._cuda_init = _cuda_init_raise
|
| 275 |
+
torch.cuda._exchange_device = _cuda_dummy_exchange_device
|
| 276 |
+
torch.cuda.is_available = lambda: True
|
| 277 |
+
torch.cuda.device_count = lambda: 1
|
| 278 |
+
torch.cuda.current_device = lambda: 0
|
| 279 |
+
torch.cuda.mem_get_info = lambda *args, **kwargs: CUDA_MEM_GET_INFO
|
| 280 |
+
torch.cuda.get_device_capability = lambda *args, **kwargs: CUDA_DEVICE_CAPABILITY
|
| 281 |
+
torch.cuda.get_device_properties = lambda *args, **kwargs: CUDA_DEVICE_PROPERTIES
|
| 282 |
+
torch.cuda.get_device_name = lambda *args, **kwargs: CUDA_DEVICE_NAME
|
| 283 |
+
# PyTorch 2.3
|
| 284 |
+
if _cuda_maybe_exchange_device is not None: # pragma: no cover
|
| 285 |
+
setattr(torch.cuda, '_maybe_exchange_device', _cuda_dummy_exchange_device)
|
| 286 |
+
bitsandbytes.patch()
|
| 287 |
+
|
| 288 |
+
def unpatch():
|
| 289 |
+
try:
|
| 290 |
+
dispatch_mode.__exit__(None, None, None)
|
| 291 |
+
function_mode.__exit__(None, None, None)
|
| 292 |
+
except RuntimeError:
|
| 293 |
+
pass # patch() and unpatch() called from != threads
|
| 294 |
+
torch.Tensor._make_subclass = _tensor_make_subclass
|
| 295 |
+
torch.UntypedStorage.__new__ = torch._C.StorageBase.__new__
|
| 296 |
+
torch.UntypedStorage.device = torch._C.StorageBase.device # pyright: ignore [reportAttributeAccessIssue]
|
| 297 |
+
torch.asarray = _asarray
|
| 298 |
+
torch._C._cuda_init = _cuda_init
|
| 299 |
+
torch.cuda._exchange_device = _cuda_exchange_device
|
| 300 |
+
torch.cuda.is_available = _cuda_available
|
| 301 |
+
torch.cuda.device_count = _cuda_device_count
|
| 302 |
+
torch.cuda.current_device = _cuda_current_device
|
| 303 |
+
torch.cuda.mem_get_info = _cuda_mem_get_info
|
| 304 |
+
torch.cuda.get_device_capability = _cuda_get_device_capability
|
| 305 |
+
torch.cuda.get_device_properties = _cuda_get_device_properties
|
| 306 |
+
torch.cuda.get_device_name = _cuda_get_device_name
|
| 307 |
+
# PyTorch 2.3
|
| 308 |
+
if _cuda_maybe_exchange_device is not None: # pragma: no cover
|
| 309 |
+
setattr(torch.cuda, '_maybe_exchange_device', _cuda_exchange_device)
|
| 310 |
+
bitsandbytes.unpatch()
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def _total_unpacked_size():
|
| 314 |
+
tensors = [tensor for tensor in cuda_aliases.values() if tensor is not None]
|
| 315 |
+
deduped = {AliasId.from_tensor(tensor): tensor for tensor in tensors}
|
| 316 |
+
return sum([tensor.numel() * tensor.element_size() for tensor in deduped.values()])
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def _pack(offload_dir: str):
|
| 320 |
+
# Pack to disk
|
| 321 |
+
originals: set[torch.Tensor] = set()
|
| 322 |
+
originals_dedup: dict[AliasId, torch.Tensor] = {}
|
| 323 |
+
fakes: dict[torch.Tensor, list[torch.Tensor]] = defaultdict(list)
|
| 324 |
+
for fake, original in cuda_aliases.items():
|
| 325 |
+
# TODO filter-out sparse Tensors
|
| 326 |
+
if original is not None:
|
| 327 |
+
original_id = AliasId.from_tensor(original)
|
| 328 |
+
if original_id not in originals_dedup:
|
| 329 |
+
originals_dedup[original_id] = original
|
| 330 |
+
originals |= {original}
|
| 331 |
+
fakes[originals_dedup[original_id]] += [fake]
|
| 332 |
+
progress = tqdm(
|
| 333 |
+
total=_total_unpacked_size(),
|
| 334 |
+
unit='B',
|
| 335 |
+
unit_scale=True,
|
| 336 |
+
desc="ZeroGPU tensors packing",
|
| 337 |
+
) if tqdm is not None else nullcontext()
|
| 338 |
+
with progress as progress:
|
| 339 |
+
update = progress.update if progress is not None else lambda _: None
|
| 340 |
+
pack = pack_tensors(originals, fakes, offload_dir, callback=update)
|
| 341 |
+
tensor_packs.append(pack)
|
| 342 |
+
# Free memory
|
| 343 |
+
for fake_list in fakes.values():
|
| 344 |
+
for fake in fake_list:
|
| 345 |
+
cuda_aliases[fake] = None
|
| 346 |
+
|
| 347 |
+
def pack():
|
| 348 |
+
_pack(Config.zerogpu_offload_dir)
|
| 349 |
+
gc.collect()
|
| 350 |
+
malloc_trim()
|
| 351 |
+
|
| 352 |
+
def init(nvidia_uuid: str):
|
| 353 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid
|
| 354 |
+
torch.Tensor([0]).cuda()
|
| 355 |
+
|
| 356 |
+
def size():
|
| 357 |
+
return _total_unpacked_size() + sum([pack.total_size for pack in tensor_packs])
|
| 358 |
+
|
| 359 |
+
def _move(callback: Callable[[int]] | None = None):
|
| 360 |
+
callback = callback if callback is not None else lambda _: None
|
| 361 |
+
# CPU -> CUDA
|
| 362 |
+
moved: dict[AliasId, torch.Tensor] = {}
|
| 363 |
+
for fake, original in cuda_aliases.items():
|
| 364 |
+
if original is not None:
|
| 365 |
+
original_id = AliasId.from_tensor(original)
|
| 366 |
+
if original_id not in moved:
|
| 367 |
+
moved[original_id] = original.cuda()
|
| 368 |
+
callback(fake.numel() * fake.element_size())
|
| 369 |
+
for fake, original in cuda_aliases.items():
|
| 370 |
+
if original is not None:
|
| 371 |
+
fake.data = moved[AliasId.from_tensor(original)]
|
| 372 |
+
# Disk -> CUDA
|
| 373 |
+
for tensor_pack in tensor_packs:
|
| 374 |
+
pack_to_cuda(tensor_pack, callback=callback)
|
| 375 |
+
bitsandbytes.move()
|
| 376 |
+
|
| 377 |
+
def move(callback: Callable[[int]] | None = None):
|
| 378 |
+
callback = callback if callback is not None else lambda _: None
|
| 379 |
+
with ThreadPoolExecutor(1) as e:
|
| 380 |
+
e.submit(copy_context().run, _move, callback=callback).result()
|
| 381 |
+
torch.cuda.synchronize()
|
| 382 |
+
|
| 383 |
+
def is_in_bad_fork():
|
| 384 |
+
with ProcessPoolExecutor(mp_context=multiprocessing.get_context('fork')) as e:
|
| 385 |
+
f = e.submit(torch.cuda._is_in_bad_fork)
|
| 386 |
+
return f.result()
|
spaces/zero/torch/patching_legacy.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
"""
|
| 3 |
+
# pyright: reportPrivateImportUsage=false
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import multiprocessing
|
| 8 |
+
import os
|
| 9 |
+
from concurrent.futures import ProcessPoolExecutor
|
| 10 |
+
from contextlib import suppress
|
| 11 |
+
from functools import partial
|
| 12 |
+
from types import SimpleNamespace
|
| 13 |
+
from typing import Any
|
| 14 |
+
from typing import Callable
|
| 15 |
+
from typing import Optional
|
| 16 |
+
from typing import Tuple
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from torch.utils.weak import WeakTensorKeyDictionary
|
| 20 |
+
|
| 21 |
+
from ...config import Config
|
| 22 |
+
from . import bitsandbytes
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# Nvidia A100.80G MIG (drivers 535) / Torch 2.2.0
|
| 26 |
+
CUDA_DEVICE_NAME = 'NVIDIA A100-SXM4-80GB MIG 3g.40gb'
|
| 27 |
+
CUDA_TOTAL_MEMORY = 42144366592
|
| 28 |
+
CUDA_MEM_GET_INFO = (41911451648, CUDA_TOTAL_MEMORY)
|
| 29 |
+
CUDA_DEVICE_CAPABILITY = (8, 0)
|
| 30 |
+
CUDA_DEVICE_PROPERTIES = SimpleNamespace(name=CUDA_DEVICE_NAME, major=8, minor=0, total_memory=CUDA_TOTAL_MEMORY, multi_processor_count=42)
|
| 31 |
+
|
| 32 |
+
GENERIC_METHOD_NAMES = [
|
| 33 |
+
'arange',
|
| 34 |
+
'as_tensor',
|
| 35 |
+
'asarray',
|
| 36 |
+
'bartlett_window',
|
| 37 |
+
'blackman_window',
|
| 38 |
+
'empty',
|
| 39 |
+
'empty_like',
|
| 40 |
+
'empty_strided',
|
| 41 |
+
'eye',
|
| 42 |
+
'full',
|
| 43 |
+
'full_like',
|
| 44 |
+
'hamming_window',
|
| 45 |
+
'hann_window',
|
| 46 |
+
'kaiser_window',
|
| 47 |
+
'linspace',
|
| 48 |
+
'logspace',
|
| 49 |
+
'ones',
|
| 50 |
+
'ones_like',
|
| 51 |
+
'rand',
|
| 52 |
+
'rand_like',
|
| 53 |
+
'randint',
|
| 54 |
+
'randint_like',
|
| 55 |
+
'randn',
|
| 56 |
+
'randn_like',
|
| 57 |
+
'randperm',
|
| 58 |
+
'range',
|
| 59 |
+
'sparse_bsc_tensor',
|
| 60 |
+
'sparse_bsr_tensor',
|
| 61 |
+
'sparse_compressed_tensor',
|
| 62 |
+
'sparse_coo_tensor',
|
| 63 |
+
'sparse_csc_tensor',
|
| 64 |
+
'sparse_csr_tensor',
|
| 65 |
+
'tensor',
|
| 66 |
+
'tril_indices',
|
| 67 |
+
'triu_indices',
|
| 68 |
+
'zeros',
|
| 69 |
+
'zeros_like',
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
TO_CUDA = (torch.device('cuda'), None, False, None)
|
| 74 |
+
|
| 75 |
+
_tensor__deepcopy__ = torch.Tensor.__deepcopy__
|
| 76 |
+
_tensor_to = torch.Tensor.to
|
| 77 |
+
_tensor_cuda = torch.Tensor.cuda
|
| 78 |
+
_tensor_cpu = torch.Tensor.cpu
|
| 79 |
+
_torch_generics = {name: getattr(torch, name) for name in GENERIC_METHOD_NAMES}
|
| 80 |
+
_cuda_init = torch._C._cuda_init
|
| 81 |
+
_cuda_available = torch.cuda.is_available
|
| 82 |
+
_cuda_device_count = torch.cuda.device_count
|
| 83 |
+
_cuda_current_device = torch.cuda.current_device
|
| 84 |
+
_cuda_mem_get_info = torch.cuda.mem_get_info
|
| 85 |
+
_cuda_get_device_capability = torch.cuda.get_device_capability
|
| 86 |
+
_cuda_get_device_properties = torch.cuda.get_device_properties
|
| 87 |
+
_cuda_get_device_name = torch.cuda.get_device_name
|
| 88 |
+
|
| 89 |
+
TensorToArgs = Tuple[Optional[torch.device], Optional[torch.dtype], bool, Optional[torch.memory_format]]
|
| 90 |
+
|
| 91 |
+
to_ops: dict[torch.Tensor, TensorToArgs] = WeakTensorKeyDictionary() # type: ignore
|
| 92 |
+
|
| 93 |
+
def _tensor_new_register(*args, **kwargs):
|
| 94 |
+
new_tensor: torch.Tensor = torch._C._TensorBase.__new__(*args, **kwargs)
|
| 95 |
+
if (base_tensor := new_tensor._base) is not None:
|
| 96 |
+
if base_tensor in to_ops:
|
| 97 |
+
to_ops[new_tensor] = to_ops[base_tensor]
|
| 98 |
+
return new_tensor
|
| 99 |
+
|
| 100 |
+
def _tensor_deepcopy_register(self: torch.Tensor, memo):
|
| 101 |
+
new_tensor = _tensor__deepcopy__(self, memo)
|
| 102 |
+
if isinstance(new_tensor, torch.Tensor):
|
| 103 |
+
if self in to_ops:
|
| 104 |
+
to_ops[new_tensor] = to_ops[self]
|
| 105 |
+
return new_tensor
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def _tensor_device_property(self: torch.Tensor):
|
| 109 |
+
if self in to_ops:
|
| 110 |
+
return torch.device(type='cuda', index=0)
|
| 111 |
+
del torch.Tensor.device
|
| 112 |
+
try:
|
| 113 |
+
return self.device
|
| 114 |
+
finally:
|
| 115 |
+
torch.Tensor.device = _tensor_device_property # type: ignore
|
| 116 |
+
|
| 117 |
+
@property
|
| 118 |
+
def _tensor_dtype_property(self: torch.Tensor):
|
| 119 |
+
if self in to_ops:
|
| 120 |
+
if (to_dtype := to_ops[self][1]) is not None:
|
| 121 |
+
return to_dtype
|
| 122 |
+
del torch.Tensor.dtype
|
| 123 |
+
try:
|
| 124 |
+
return self.dtype
|
| 125 |
+
finally:
|
| 126 |
+
torch.Tensor.dtype = _tensor_dtype_property # type: ignore
|
| 127 |
+
|
| 128 |
+
def _to_op_register(self: torch.Tensor, *args, **kwargs):
|
| 129 |
+
parsed = torch._C._nn._parse_to(*args, **kwargs)
|
| 130 |
+
device, dtype, *_ = parsed
|
| 131 |
+
try:
|
| 132 |
+
to_args = to_ops.pop(self)
|
| 133 |
+
except KeyError:
|
| 134 |
+
to_args = None
|
| 135 |
+
if device is None: # pyright: ignore [reportUnnecessaryComparison]
|
| 136 |
+
if to_args is not None:
|
| 137 |
+
to_ops[self] = (to_args[0], dtype, *to_args[2:])
|
| 138 |
+
return self
|
| 139 |
+
return _tensor_to(self, *args, **kwargs)
|
| 140 |
+
if device.type != 'cuda':
|
| 141 |
+
if to_args is not None:
|
| 142 |
+
if (to_dtype := to_args[1]) is not None:
|
| 143 |
+
kwargs = {'dtype': to_dtype, **kwargs}
|
| 144 |
+
return _tensor_to(self, *args, **kwargs)
|
| 145 |
+
to_ops[self] = parsed
|
| 146 |
+
return self
|
| 147 |
+
|
| 148 |
+
def _cuda_op_arg_check(device: torch.device | int | str | None) -> bool:
|
| 149 |
+
if device is None:
|
| 150 |
+
return True
|
| 151 |
+
if isinstance(device, int):
|
| 152 |
+
return True
|
| 153 |
+
if isinstance(device, str):
|
| 154 |
+
device = torch.device(device)
|
| 155 |
+
return device.type == 'cuda'
|
| 156 |
+
|
| 157 |
+
def _cuda_op_register(self: torch.Tensor, device: torch.device | int | str | None = None, **kwargs):
|
| 158 |
+
if not _cuda_op_arg_check(device):
|
| 159 |
+
# Let PyTorch handle the fail
|
| 160 |
+
return _tensor_cuda(self, device, **kwargs)
|
| 161 |
+
to_ops[self] = TO_CUDA
|
| 162 |
+
return self
|
| 163 |
+
|
| 164 |
+
def _cpu_op_remove(self: torch.Tensor, **kwargs):
|
| 165 |
+
try:
|
| 166 |
+
to_args = to_ops.pop(self)
|
| 167 |
+
except KeyError:
|
| 168 |
+
to_args = None
|
| 169 |
+
if to_args is not None:
|
| 170 |
+
if (to_dtype := to_args[1]) is not None:
|
| 171 |
+
return _tensor_to(self, 'cpu', **{'dtype': to_dtype, **kwargs})
|
| 172 |
+
return _tensor_cpu(self, **kwargs)
|
| 173 |
+
|
| 174 |
+
def _cuda_init_raise():
|
| 175 |
+
raise RuntimeError(
|
| 176 |
+
"CUDA must not be initialized in the main process "
|
| 177 |
+
"on Spaces with Stateless GPU environment.\n"
|
| 178 |
+
"You can look at this Stacktrace to find out "
|
| 179 |
+
"which part of your code triggered a CUDA init"
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
def _generic_method_register(name: str, *args: Any, **kwargs: Any):
|
| 183 |
+
try:
|
| 184 |
+
device = torch.device(kwargs.get('device', "cpu"))
|
| 185 |
+
except Exception:
|
| 186 |
+
return _torch_generics[name](*args, **kwargs)
|
| 187 |
+
if device.type != 'cuda':
|
| 188 |
+
return _torch_generics[name](*args, **kwargs)
|
| 189 |
+
tensor = _torch_generics[name](*args, **{**kwargs, 'device': "cpu"})
|
| 190 |
+
to_ops[tensor] = TO_CUDA
|
| 191 |
+
return tensor
|
| 192 |
+
|
| 193 |
+
def patch():
|
| 194 |
+
torch.Tensor.__deepcopy__ = _tensor_deepcopy_register
|
| 195 |
+
torch.Tensor.__new__ = _tensor_new_register # pyright: ignore [reportAttributeAccessIssue]
|
| 196 |
+
torch.Tensor.to = _to_op_register # type: ignore
|
| 197 |
+
torch.Tensor.cuda = _cuda_op_register # type: ignore
|
| 198 |
+
torch.Tensor.cpu = _cpu_op_remove # type: ignore
|
| 199 |
+
if Config.zero_patch_torch_device:
|
| 200 |
+
torch.Tensor.device = _tensor_device_property # type: ignore
|
| 201 |
+
torch.Tensor.dtype = _tensor_dtype_property # pyright: ignore [reportAttributeAccessIssue]
|
| 202 |
+
for name in GENERIC_METHOD_NAMES:
|
| 203 |
+
setattr(torch, name, partial(_generic_method_register, name))
|
| 204 |
+
torch._C._cuda_init = _cuda_init_raise
|
| 205 |
+
torch.cuda.is_available = lambda: True
|
| 206 |
+
torch.cuda.device_count = lambda: 1
|
| 207 |
+
torch.cuda.current_device = lambda: 0
|
| 208 |
+
torch.cuda.mem_get_info = lambda *args, **kwargs: CUDA_MEM_GET_INFO
|
| 209 |
+
torch.cuda.get_device_capability = lambda *args, **kwargs: CUDA_DEVICE_CAPABILITY
|
| 210 |
+
torch.cuda.get_device_properties = lambda *args, **kwargs: CUDA_DEVICE_PROPERTIES
|
| 211 |
+
torch.cuda.get_device_name = lambda *args, **kwargs: CUDA_DEVICE_NAME
|
| 212 |
+
bitsandbytes.patch()
|
| 213 |
+
|
| 214 |
+
def unpatch():
|
| 215 |
+
torch.Tensor.__deepcopy__ = _tensor__deepcopy__
|
| 216 |
+
with suppress(AttributeError):
|
| 217 |
+
del torch.Tensor.__new__
|
| 218 |
+
torch.Tensor.to = _tensor_to
|
| 219 |
+
torch.Tensor.cuda = _tensor_cuda
|
| 220 |
+
torch.Tensor.cpu = _tensor_cpu
|
| 221 |
+
with suppress(AttributeError):
|
| 222 |
+
del torch.Tensor.device
|
| 223 |
+
with suppress(AttributeError):
|
| 224 |
+
del torch.Tensor.dtype
|
| 225 |
+
for name in GENERIC_METHOD_NAMES:
|
| 226 |
+
setattr(torch, name, _torch_generics[name])
|
| 227 |
+
torch._C._cuda_init = _cuda_init
|
| 228 |
+
torch.cuda.is_available = _cuda_available
|
| 229 |
+
torch.cuda.device_count = _cuda_device_count
|
| 230 |
+
torch.cuda.current_device = _cuda_current_device
|
| 231 |
+
torch.cuda.mem_get_info = _cuda_mem_get_info
|
| 232 |
+
torch.cuda.get_device_capability = _cuda_get_device_capability
|
| 233 |
+
torch.cuda.get_device_properties = _cuda_get_device_properties
|
| 234 |
+
torch.cuda.get_device_name = _cuda_get_device_name
|
| 235 |
+
bitsandbytes.unpatch()
|
| 236 |
+
|
| 237 |
+
def pack():
|
| 238 |
+
pass
|
| 239 |
+
|
| 240 |
+
def init(nvidia_uuid: str):
|
| 241 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid
|
| 242 |
+
torch.Tensor([0]).cuda() # CUDA init
|
| 243 |
+
|
| 244 |
+
def size():
|
| 245 |
+
return 0
|
| 246 |
+
|
| 247 |
+
def move(callback: Callable[[int]] | None = None):
|
| 248 |
+
for op in to_ops.items():
|
| 249 |
+
tensor, parsed_args = op
|
| 250 |
+
_, dtype, _, memory_format = parsed_args
|
| 251 |
+
tensor.data = _tensor_to(tensor,
|
| 252 |
+
device='cuda',
|
| 253 |
+
dtype=dtype,
|
| 254 |
+
memory_format=memory_format,
|
| 255 |
+
) # type: ignore
|
| 256 |
+
bitsandbytes.move()
|
| 257 |
+
torch.cuda.synchronize()
|
| 258 |
+
|
| 259 |
+
def is_in_bad_fork():
|
| 260 |
+
with ProcessPoolExecutor(mp_context=multiprocessing.get_context('fork')) as e:
|
| 261 |
+
f = e.submit(torch.cuda._is_in_bad_fork)
|
| 262 |
+
return f.result()
|
| 263 |
+
|
| 264 |
+
def disable_cuda_intercept():
|
| 265 |
+
torch.Tensor.to = _tensor_to
|
| 266 |
+
torch.Tensor.cuda = _tensor_cuda
|
spaces/zero/torch/types.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
"""
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import NamedTuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class AliasId(NamedTuple):
|
| 11 |
+
data_ptr: int
|
| 12 |
+
dtype: torch.dtype
|
| 13 |
+
shape: tuple[int, ...]
|
| 14 |
+
stride: tuple[int, ...]
|
| 15 |
+
|
| 16 |
+
@classmethod
|
| 17 |
+
def from_tensor(cls, tensor: torch.Tensor):
|
| 18 |
+
return cls(
|
| 19 |
+
tensor.data_ptr(),
|
| 20 |
+
tensor.dtype,
|
| 21 |
+
tensor.shape,
|
| 22 |
+
tensor.stride(),
|
| 23 |
+
)
|
spaces/zero/tqdm.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
"""
|
| 3 |
+
|
| 4 |
+
from multiprocessing.synchronize import RLock as MultiprocessingRLock
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
from tqdm import tqdm as _tqdm
|
| 9 |
+
except ImportError: # pragma: no cover
|
| 10 |
+
_tqdm = None
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def remove_tqdm_multiprocessing_lock():
|
| 14 |
+
if _tqdm is None: # pragma: no cover
|
| 15 |
+
return
|
| 16 |
+
tqdm_lock = _tqdm.get_lock()
|
| 17 |
+
assert tqdm_lock.__class__.__name__ == 'TqdmDefaultWriteLock'
|
| 18 |
+
tqdm_lock.locks = [
|
| 19 |
+
lock for lock in tqdm_lock.locks
|
| 20 |
+
if not isinstance(lock, MultiprocessingRLock)
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
tqdm = _tqdm
|
spaces/zero/types.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
"""
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from datetime import timedelta
|
| 8 |
+
from typing import Any
|
| 9 |
+
from typing import Dict
|
| 10 |
+
from typing import Tuple
|
| 11 |
+
from typing import TypedDict
|
| 12 |
+
from typing_extensions import Callable
|
| 13 |
+
from typing_extensions import Generic
|
| 14 |
+
from typing_extensions import ParamSpec
|
| 15 |
+
from typing_extensions import TypeAlias
|
| 16 |
+
from typing_extensions import TypeVar
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
Params = Tuple[Tuple[object, ...], Dict[str, Any]]
|
| 20 |
+
Res = TypeVar('Res')
|
| 21 |
+
Param = ParamSpec('Param')
|
| 22 |
+
|
| 23 |
+
class EmptyKwargs(TypedDict):
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class OkResult(Generic[Res]):
|
| 28 |
+
value: Res
|
| 29 |
+
@dataclass
|
| 30 |
+
class ExceptionResult:
|
| 31 |
+
value: Exception
|
| 32 |
+
@dataclass
|
| 33 |
+
class AbortedResult:
|
| 34 |
+
pass
|
| 35 |
+
@dataclass
|
| 36 |
+
class EndResult:
|
| 37 |
+
pass
|
| 38 |
+
@dataclass
|
| 39 |
+
class GradioQueueEvent:
|
| 40 |
+
method_name: str
|
| 41 |
+
args: tuple[Any, ...]
|
| 42 |
+
kwargs: dict[str, Any]
|
| 43 |
+
|
| 44 |
+
RegularResQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | GradioQueueEvent"
|
| 45 |
+
GeneratorResQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | EndResult | GradioQueueEvent"
|
| 46 |
+
YieldQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | EndResult | AbortedResult"
|
| 47 |
+
|
| 48 |
+
Duration: TypeAlias = "int | timedelta"
|
| 49 |
+
DynamicDuration: TypeAlias = "Duration | Callable[Param, Duration] | None"
|
spaces/zero/wrappers.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
"""
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import multiprocessing
|
| 6 |
+
import os
|
| 7 |
+
import signal
|
| 8 |
+
import time
|
| 9 |
+
import traceback
|
| 10 |
+
import warnings
|
| 11 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 12 |
+
from contextlib import nullcontext
|
| 13 |
+
from contextvars import copy_context
|
| 14 |
+
from datetime import timedelta
|
| 15 |
+
from functools import partial
|
| 16 |
+
from functools import wraps
|
| 17 |
+
from multiprocessing.context import ForkProcess
|
| 18 |
+
from pickle import PicklingError
|
| 19 |
+
from queue import Empty
|
| 20 |
+
from queue import Queue as ThreadQueue
|
| 21 |
+
from threading import Thread
|
| 22 |
+
from typing import TYPE_CHECKING
|
| 23 |
+
from typing import Callable
|
| 24 |
+
from typing import Generator
|
| 25 |
+
from typing import Generic
|
| 26 |
+
from typing_extensions import assert_never
|
| 27 |
+
|
| 28 |
+
import psutil
|
| 29 |
+
|
| 30 |
+
from ..config import Config
|
| 31 |
+
from ..utils import debug
|
| 32 |
+
from ..utils import drop_params
|
| 33 |
+
from ..utils import gradio_request_var
|
| 34 |
+
from ..utils import SimpleQueue as Queue
|
| 35 |
+
from . import client
|
| 36 |
+
from . import torch
|
| 37 |
+
from .api import AllowToken
|
| 38 |
+
from .api import NvidiaIndex
|
| 39 |
+
from .api import NvidiaUUID
|
| 40 |
+
from .gradio import GradioPartialContext
|
| 41 |
+
from .gradio import get_server_port
|
| 42 |
+
from .gradio import patch_gradio_queue
|
| 43 |
+
from .gradio import try_process_queue_event
|
| 44 |
+
from .tqdm import remove_tqdm_multiprocessing_lock
|
| 45 |
+
from .tqdm import tqdm
|
| 46 |
+
from .types import * # TODO: Please don't do that
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
GENERATOR_GLOBAL_TIMEOUT = 20 * 60
|
| 50 |
+
|
| 51 |
+
SPAWN_PROGRESS_CLEANUP = 0.1
|
| 52 |
+
SPAWN_PROGRESS_INIT = 0.1
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
Process = multiprocessing.get_context('fork').Process
|
| 56 |
+
forked = False
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class Worker(Generic[Res]):
|
| 60 |
+
process: ForkProcess
|
| 61 |
+
arg_queue: Queue[tuple[Params, GradioPartialContext]]
|
| 62 |
+
res_queue: Queue[Res | None]
|
| 63 |
+
_sentinel: Thread
|
| 64 |
+
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
target: Callable[[
|
| 68 |
+
Queue[tuple[Params, GradioPartialContext]],
|
| 69 |
+
Queue[Res | None],
|
| 70 |
+
AllowToken,
|
| 71 |
+
NvidiaUUID,
|
| 72 |
+
list[int],
|
| 73 |
+
], None],
|
| 74 |
+
allow_token: str,
|
| 75 |
+
nvidia_uuid: str,
|
| 76 |
+
):
|
| 77 |
+
self._sentinel = Thread(target=self._close_on_exit, daemon=True)
|
| 78 |
+
self.arg_queue = Queue()
|
| 79 |
+
self.res_queue = Queue()
|
| 80 |
+
debug(f"{self.arg_queue._writer.fileno()=}") # pyright: ignore [reportAttributeAccessIssue]
|
| 81 |
+
debug(f"{self.res_queue._writer.fileno()=}") # pyright: ignore [reportAttributeAccessIssue]
|
| 82 |
+
if (server_port := get_server_port()) is not None:
|
| 83 |
+
fds = [c.fd for c in psutil.Process().connections() if c.laddr.port == server_port]
|
| 84 |
+
debug(f"{fds=}")
|
| 85 |
+
else:
|
| 86 |
+
warnings.warn("Using a ZeroGPU function outside of Gradio caching or request might block the app")
|
| 87 |
+
fds = []
|
| 88 |
+
args = self.arg_queue, self.res_queue, allow_token, nvidia_uuid, fds
|
| 89 |
+
if TYPE_CHECKING:
|
| 90 |
+
target(*args)
|
| 91 |
+
self.process = Process(
|
| 92 |
+
target=target,
|
| 93 |
+
args=args,
|
| 94 |
+
daemon=True,
|
| 95 |
+
)
|
| 96 |
+
self.process.start()
|
| 97 |
+
self._sentinel.start()
|
| 98 |
+
|
| 99 |
+
def _close_on_exit(self):
|
| 100 |
+
self.process.join()
|
| 101 |
+
self.arg_queue.close()
|
| 102 |
+
self.res_queue.wlock_release()
|
| 103 |
+
self.res_queue.put(None)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def worker_init(
|
| 107 |
+
res_queue: Queue[RegularResQueueResult | None] | Queue[GeneratorResQueueResult | None],
|
| 108 |
+
allow_token: str,
|
| 109 |
+
nvidia_uuid: str,
|
| 110 |
+
fds: list[int],
|
| 111 |
+
) -> None | ExceptionResult:
|
| 112 |
+
# Immediately close file descriptors
|
| 113 |
+
for fd in fds:
|
| 114 |
+
try:
|
| 115 |
+
os.close(fd)
|
| 116 |
+
except Exception as e: # pragma: no cover
|
| 117 |
+
if isinstance(e, OSError) and e.errno == 9:
|
| 118 |
+
continue
|
| 119 |
+
traceback.print_exc()
|
| 120 |
+
return ExceptionResult(e)
|
| 121 |
+
progress = nullcontext()
|
| 122 |
+
if tqdm is not None and Config.zero_gpu_v2:
|
| 123 |
+
progress = tqdm(total=100, desc="ZeroGPU init", file=open(os.devnull, 'w'))
|
| 124 |
+
try: # Unrecoverable init part
|
| 125 |
+
patch_gradio_queue(res_queue)
|
| 126 |
+
with progress as progress:
|
| 127 |
+
current_progress = 0 # Gradio does not support float progress updates
|
| 128 |
+
def update(n: float):
|
| 129 |
+
nonlocal current_progress
|
| 130 |
+
current_progress += n
|
| 131 |
+
if progress is not None:
|
| 132 |
+
progress.update(round(current_progress * 100) - progress.n)
|
| 133 |
+
t0 = time.perf_counter()
|
| 134 |
+
client.allow(allow_token)
|
| 135 |
+
print("client.allow", (dt := time.perf_counter() - t0)); t0 = dt
|
| 136 |
+
update(SPAWN_PROGRESS_CLEANUP)
|
| 137 |
+
torch.unpatch()
|
| 138 |
+
print("torch.unpatch", (dt := time.perf_counter() - t0)); t0 = dt
|
| 139 |
+
torch.init(nvidia_uuid)
|
| 140 |
+
print("torch.init", (dt := time.perf_counter() - t0)); t0 = dt
|
| 141 |
+
update(SPAWN_PROGRESS_INIT)
|
| 142 |
+
callback = None
|
| 143 |
+
if (transfer_size := torch.size()) > 0:
|
| 144 |
+
remaining = 1 - (SPAWN_PROGRESS_CLEANUP + SPAWN_PROGRESS_INIT)
|
| 145 |
+
callback = lambda n: update(n * remaining / transfer_size)
|
| 146 |
+
torch.move(callback=callback)
|
| 147 |
+
print("torch.move", (dt := time.perf_counter() - t0)); t0 = dt
|
| 148 |
+
except Exception as e: # pragma: no cover
|
| 149 |
+
traceback.print_exc()
|
| 150 |
+
return ExceptionResult(e)
|
| 151 |
+
try:
|
| 152 |
+
remove_tqdm_multiprocessing_lock()
|
| 153 |
+
except Exception: # pragma: no cover
|
| 154 |
+
print("Error while trying to remove tqdm mp_lock:")
|
| 155 |
+
traceback.print_exc()
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def process_duration(duration: Duration | None):
|
| 159 |
+
if duration is None or isinstance(duration, timedelta):
|
| 160 |
+
return duration
|
| 161 |
+
return timedelta(seconds=duration)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def static_duration(duration: DynamicDuration[Param], *args: Param.args, **kwargs: Param.kwargs):
|
| 165 |
+
if not callable(duration):
|
| 166 |
+
return duration
|
| 167 |
+
return duration(*args, **kwargs)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def regular_function_wrapper(
|
| 171 |
+
task: Callable[Param, Res],
|
| 172 |
+
duration: DynamicDuration[Param],
|
| 173 |
+
) -> Callable[Param, Res]:
|
| 174 |
+
|
| 175 |
+
import gradio as gr
|
| 176 |
+
|
| 177 |
+
request_var = gradio_request_var()
|
| 178 |
+
workers: dict[NvidiaIndex, Worker[RegularResQueueResult[Res]]] = {}
|
| 179 |
+
task_id = id(task)
|
| 180 |
+
|
| 181 |
+
@wraps(task)
|
| 182 |
+
def gradio_handler(*args: Param.args, **kwargs: Param.kwargs) -> Res:
|
| 183 |
+
|
| 184 |
+
if forked:
|
| 185 |
+
return task(*args, **kwargs)
|
| 186 |
+
|
| 187 |
+
request = request_var.get()
|
| 188 |
+
duration_ = static_duration(duration, *args, **kwargs)
|
| 189 |
+
duration_ = process_duration(duration_)
|
| 190 |
+
t0 = time.perf_counter()
|
| 191 |
+
schedule_response = client.schedule(task_id=task_id, request=request, duration=duration_)
|
| 192 |
+
print("client.schedule", time.perf_counter() - t0)
|
| 193 |
+
allow_token = schedule_response.allowToken
|
| 194 |
+
nvidia_index = schedule_response.nvidiaIndex
|
| 195 |
+
nvidia_uuid = schedule_response.nvidiaUUID
|
| 196 |
+
release = partial(client.release, allow_token)
|
| 197 |
+
|
| 198 |
+
try:
|
| 199 |
+
worker = workers.pop(nvidia_index)
|
| 200 |
+
except KeyError:
|
| 201 |
+
worker = None
|
| 202 |
+
|
| 203 |
+
if worker is not None and worker.process.is_alive() and schedule_response.idle:
|
| 204 |
+
assert worker.arg_queue.empty()
|
| 205 |
+
assert worker.res_queue.empty()
|
| 206 |
+
else:
|
| 207 |
+
worker = Worker(thread_wrapper, allow_token, nvidia_uuid)
|
| 208 |
+
|
| 209 |
+
try:
|
| 210 |
+
worker.arg_queue.put(((args, kwargs), GradioPartialContext.get()))
|
| 211 |
+
except PicklingError: # TODO: detailed serialization diagnostic
|
| 212 |
+
release(fail=True)
|
| 213 |
+
raise
|
| 214 |
+
|
| 215 |
+
while True:
|
| 216 |
+
res = worker.res_queue.get()
|
| 217 |
+
if res is None:
|
| 218 |
+
release(fail=True, allow_404=True)
|
| 219 |
+
raise gr.Error("GPU task aborted")
|
| 220 |
+
if isinstance(res, ExceptionResult):
|
| 221 |
+
release(fail=True)
|
| 222 |
+
raise res.value
|
| 223 |
+
if isinstance(res, OkResult):
|
| 224 |
+
t0 = time.perf_counter()
|
| 225 |
+
release()
|
| 226 |
+
print("client.release", time.perf_counter() - t0)
|
| 227 |
+
workers[nvidia_index] = worker
|
| 228 |
+
return res.value
|
| 229 |
+
if isinstance(res, GradioQueueEvent):
|
| 230 |
+
try_process_queue_event(res.method_name, *res.args, **res.kwargs)
|
| 231 |
+
continue
|
| 232 |
+
assert_never(res)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def thread_wrapper(
|
| 236 |
+
arg_queue: Queue[tuple[Params, GradioPartialContext]],
|
| 237 |
+
res_queue: Queue[RegularResQueueResult[Res] | None],
|
| 238 |
+
allow_token: str,
|
| 239 |
+
nvidia_uuid: str,
|
| 240 |
+
fds: list[int],
|
| 241 |
+
):
|
| 242 |
+
global forked
|
| 243 |
+
forked = True
|
| 244 |
+
signal.signal(signal.SIGTERM, drop_params(arg_queue.close))
|
| 245 |
+
initialized = False
|
| 246 |
+
while True:
|
| 247 |
+
try:
|
| 248 |
+
(args, kwargs), gradio_context = arg_queue.get()
|
| 249 |
+
except OSError:
|
| 250 |
+
break
|
| 251 |
+
if not initialized:
|
| 252 |
+
t0 = time.perf_counter()
|
| 253 |
+
if (res := worker_init(
|
| 254 |
+
res_queue=res_queue,
|
| 255 |
+
allow_token=allow_token,
|
| 256 |
+
nvidia_uuid=nvidia_uuid,
|
| 257 |
+
fds=fds,
|
| 258 |
+
)) is not None:
|
| 259 |
+
res_queue.put(res)
|
| 260 |
+
return
|
| 261 |
+
print("worker_init", time.perf_counter() - t0)
|
| 262 |
+
initialized = True
|
| 263 |
+
GradioPartialContext.apply(gradio_context)
|
| 264 |
+
context = copy_context()
|
| 265 |
+
with ThreadPoolExecutor() as executor:
|
| 266 |
+
future = executor.submit(context.run, task, *args, **kwargs) # type: ignore
|
| 267 |
+
try:
|
| 268 |
+
res = future.result()
|
| 269 |
+
except Exception as e:
|
| 270 |
+
traceback.print_exc()
|
| 271 |
+
res = ExceptionResult(e)
|
| 272 |
+
else:
|
| 273 |
+
res = OkResult(res)
|
| 274 |
+
try:
|
| 275 |
+
res_queue.put(res)
|
| 276 |
+
except PicklingError as e:
|
| 277 |
+
res_queue.put(ExceptionResult(e))
|
| 278 |
+
|
| 279 |
+
# https://github.com/python/cpython/issues/91002
|
| 280 |
+
if not hasattr(task, '__annotations__'):
|
| 281 |
+
gradio_handler.__annotations__ = {}
|
| 282 |
+
|
| 283 |
+
return gradio_handler
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def generator_function_wrapper(
|
| 287 |
+
task: Callable[Param, Generator[Res, None, None]],
|
| 288 |
+
duration: DynamicDuration[Param],
|
| 289 |
+
) -> Callable[Param, Generator[Res, None, None]]:
|
| 290 |
+
|
| 291 |
+
import gradio as gr
|
| 292 |
+
|
| 293 |
+
request_var = gradio_request_var()
|
| 294 |
+
workers: dict[NvidiaIndex, Worker[GeneratorResQueueResult[Res]]] = {}
|
| 295 |
+
task_id = id(task)
|
| 296 |
+
|
| 297 |
+
@wraps(task)
|
| 298 |
+
def gradio_handler(*args: Param.args, **kwargs: Param.kwargs) -> Generator[Res, None, None]:
|
| 299 |
+
|
| 300 |
+
if forked:
|
| 301 |
+
yield from task(*args, **kwargs)
|
| 302 |
+
return
|
| 303 |
+
|
| 304 |
+
request = request_var.get()
|
| 305 |
+
duration_ = static_duration(duration, *args, **kwargs)
|
| 306 |
+
duration_ = process_duration(duration_)
|
| 307 |
+
schedule_response = client.schedule(task_id=task_id, request=request, duration=duration_)
|
| 308 |
+
allow_token = schedule_response.allowToken
|
| 309 |
+
nvidia_index = schedule_response.nvidiaIndex
|
| 310 |
+
nvidia_uuid = schedule_response.nvidiaUUID
|
| 311 |
+
release = partial(client.release, allow_token)
|
| 312 |
+
|
| 313 |
+
try:
|
| 314 |
+
worker = workers.pop(nvidia_index)
|
| 315 |
+
except KeyError:
|
| 316 |
+
worker = None
|
| 317 |
+
|
| 318 |
+
if worker is not None and worker.process.is_alive() and schedule_response.idle:
|
| 319 |
+
assert worker.arg_queue.empty()
|
| 320 |
+
assert worker.res_queue.empty()
|
| 321 |
+
else:
|
| 322 |
+
worker = Worker(thread_wrapper, allow_token, nvidia_uuid)
|
| 323 |
+
|
| 324 |
+
try:
|
| 325 |
+
worker.arg_queue.put(((args, kwargs), GradioPartialContext.get()))
|
| 326 |
+
except PicklingError: # TODO: detailed serialization diagnostic
|
| 327 |
+
release(fail=True)
|
| 328 |
+
raise
|
| 329 |
+
|
| 330 |
+
yield_queue: ThreadQueue[YieldQueueResult[Res]] = ThreadQueue()
|
| 331 |
+
def fill_yield_queue(worker: Worker[GeneratorResQueueResult[Res]]):
|
| 332 |
+
while True:
|
| 333 |
+
res = worker.res_queue.get()
|
| 334 |
+
if res is None:
|
| 335 |
+
release(fail=True, allow_404=True)
|
| 336 |
+
yield_queue.put(AbortedResult())
|
| 337 |
+
return
|
| 338 |
+
if isinstance(res, ExceptionResult):
|
| 339 |
+
release(fail=True)
|
| 340 |
+
yield_queue.put(ExceptionResult(res.value))
|
| 341 |
+
return
|
| 342 |
+
if isinstance(res, EndResult):
|
| 343 |
+
release()
|
| 344 |
+
workers[nvidia_index] = worker
|
| 345 |
+
yield_queue.put(EndResult())
|
| 346 |
+
return
|
| 347 |
+
if isinstance(res, OkResult):
|
| 348 |
+
yield_queue.put(OkResult(res.value))
|
| 349 |
+
continue
|
| 350 |
+
if isinstance(res, GradioQueueEvent): # pragma: no cover (not working properly on Gradio side)
|
| 351 |
+
try_process_queue_event(res.method_name, *res.args, **res.kwargs)
|
| 352 |
+
continue
|
| 353 |
+
debug(f"fill_yield_queue: assert_never({res=})")
|
| 354 |
+
assert_never(res)
|
| 355 |
+
from typing_extensions import assert_never
|
| 356 |
+
with ThreadPoolExecutor() as e:
|
| 357 |
+
f = e.submit(copy_context().run, fill_yield_queue, worker)
|
| 358 |
+
f.add_done_callback(lambda _: debug("fill_yield_queue DONE"))
|
| 359 |
+
while True:
|
| 360 |
+
try:
|
| 361 |
+
res = yield_queue.get(timeout=GENERATOR_GLOBAL_TIMEOUT)
|
| 362 |
+
except Empty: # pragma: no cover
|
| 363 |
+
debug(f"yield_queue TIMEOUT ({GENERATOR_GLOBAL_TIMEOUT=})")
|
| 364 |
+
raise
|
| 365 |
+
if isinstance(res, AbortedResult):
|
| 366 |
+
raise gr.Error("GPU task aborted")
|
| 367 |
+
if isinstance(res, ExceptionResult):
|
| 368 |
+
raise res.value
|
| 369 |
+
if isinstance(res, EndResult):
|
| 370 |
+
break
|
| 371 |
+
if isinstance(res, OkResult):
|
| 372 |
+
yield res.value
|
| 373 |
+
continue
|
| 374 |
+
debug(f"gradio_handler: assert_never({res=})")
|
| 375 |
+
assert_never(res)
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def thread_wrapper(
|
| 379 |
+
arg_queue: Queue[tuple[Params, GradioPartialContext]],
|
| 380 |
+
res_queue: Queue[GeneratorResQueueResult[Res] | None],
|
| 381 |
+
allow_token: str,
|
| 382 |
+
nvidia_uuid: str,
|
| 383 |
+
fds: list[int],
|
| 384 |
+
):
|
| 385 |
+
global forked
|
| 386 |
+
forked = True
|
| 387 |
+
signal.signal(signal.SIGTERM, drop_params(arg_queue.close))
|
| 388 |
+
initialized = False
|
| 389 |
+
while True:
|
| 390 |
+
try:
|
| 391 |
+
(args, kwargs), gradio_context = arg_queue.get()
|
| 392 |
+
except OSError:
|
| 393 |
+
break
|
| 394 |
+
if not initialized:
|
| 395 |
+
if (res := worker_init(
|
| 396 |
+
res_queue=res_queue,
|
| 397 |
+
allow_token=allow_token,
|
| 398 |
+
nvidia_uuid=nvidia_uuid,
|
| 399 |
+
fds=fds,
|
| 400 |
+
)) is not None:
|
| 401 |
+
res_queue.put(res)
|
| 402 |
+
return
|
| 403 |
+
initialized = True
|
| 404 |
+
def iterate():
|
| 405 |
+
gen = task(*args, **kwargs) # type: ignore
|
| 406 |
+
while True:
|
| 407 |
+
try:
|
| 408 |
+
res = next(gen)
|
| 409 |
+
except StopIteration:
|
| 410 |
+
break
|
| 411 |
+
except Exception as e:
|
| 412 |
+
res_queue.put(ExceptionResult(e))
|
| 413 |
+
break
|
| 414 |
+
try:
|
| 415 |
+
res_queue.put(OkResult(res))
|
| 416 |
+
except PicklingError as e:
|
| 417 |
+
res_queue.put(ExceptionResult(e))
|
| 418 |
+
break
|
| 419 |
+
else:
|
| 420 |
+
continue
|
| 421 |
+
GradioPartialContext.apply(gradio_context)
|
| 422 |
+
with ThreadPoolExecutor() as executor:
|
| 423 |
+
executor.submit(copy_context().run, iterate)
|
| 424 |
+
res_queue.put(EndResult())
|
| 425 |
+
|
| 426 |
+
# https://github.com/python/cpython/issues/91002
|
| 427 |
+
if not hasattr(task, '__annotations__'):
|
| 428 |
+
gradio_handler.__annotations__ = {}
|
| 429 |
+
|
| 430 |
+
return gradio_handler
|