Julian Bilcke commited on
Commit
d178f60
·
1 Parent(s): a06f683

testing a runtime build for apex

Browse files
Files changed (2) hide show
  1. Dockerfile +1 -8
  2. run_hf_space.py +49 -0
Dockerfile CHANGED
@@ -39,14 +39,7 @@ ENV PYTHONPATH=$HOME/app \
39
  RUN echo "Installing requirements.txt"
40
  RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
41
 
42
- # Install NVIDIA Apex with CUDA and C++ extensions
43
- RUN cd $HOME && \
44
- git clone https://github.com/NVIDIA/apex && \
45
- cd apex && \
46
- # sadly this command works at compile time and can be installed, but at runtime it doesn't work (No module named 'fused_layer_norm_cuda')
47
- #NVCC_APPEND_FLAGS="--threads 4" pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--parallel" --global-option="8" ./
48
- # so let's try a python only build:
49
- pip install -v --disable-pip-version-check --no-build-isolation --no-cache-dir ./
50
 
51
  WORKDIR $HOME/app
52
 
 
39
  RUN echo "Installing requirements.txt"
40
  RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
41
 
42
+ # Apex will be installed at runtime in run_hf_space.py to ensure CUDA compatibility
 
 
 
 
 
 
 
43
 
44
  WORKDIR $HOME/app
45
 
run_hf_space.py CHANGED
@@ -21,6 +21,52 @@ logging.basicConfig(
21
  )
22
  logger = logging.getLogger(__name__)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  async def run_async():
25
  """Run the server using the async API directly for better stability"""
26
  # Get the port from environment variable in Hugging Face Space
@@ -89,6 +135,9 @@ def main():
89
  subprocess.run(cmd, shell=True)
90
 
91
  if __name__ == "__main__":
 
 
 
92
  # First try to run using the async API
93
  try:
94
  logger.info("Starting server using async API")
 
21
  )
22
  logger = logging.getLogger(__name__)
23
 
24
+ def install_apex():
25
+ """Install NVIDIA Apex at runtime with CUDA support"""
26
+ try:
27
+ logger.info("Installing NVIDIA Apex...")
28
+
29
+ # Clone the Apex repository
30
+ subprocess.check_call([
31
+ "git", "clone", "https://github.com/NVIDIA/apex"
32
+ ])
33
+
34
+ # Change to apex directory and install
35
+ os.chdir("apex")
36
+
37
+ # Try to install with CUDA extensions first
38
+ try:
39
+ logger.info("Attempting to install Apex with CUDA extensions...")
40
+ subprocess.check_call([
41
+ sys.executable, "-m", "pip", "install", "-v",
42
+ "--disable-pip-version-check", "--no-cache-dir",
43
+ "--no-build-isolation", "--global-option=--cpp_ext",
44
+ "--global-option=--cuda_ext", "./"
45
+ ])
46
+ logger.info("Apex installed successfully with CUDA extensions!")
47
+ except subprocess.CalledProcessError as e:
48
+ logger.warning(f"Failed to install Apex with CUDA extensions: {e}")
49
+ logger.info("Falling back to Python-only build...")
50
+
51
+ # Fall back to Python-only build
52
+ subprocess.check_call([
53
+ sys.executable, "-m", "pip", "install", "-v",
54
+ "--disable-pip-version-check", "--no-build-isolation",
55
+ "--no-cache-dir", "./"
56
+ ])
57
+ logger.info("Apex installed successfully (Python-only build)!")
58
+
59
+ except subprocess.CalledProcessError as e:
60
+ logger.error(f"Failed to install Apex. Error: {e}")
61
+ # Don't fail the entire startup if Apex installation fails
62
+ logger.warning("Continuing without Apex...")
63
+ except Exception as e:
64
+ logger.error(f"Unexpected error during Apex installation: {e}")
65
+ logger.warning("Continuing without Apex...")
66
+ finally:
67
+ # Change back to original directory
68
+ os.chdir("..")
69
+
70
  async def run_async():
71
  """Run the server using the async API directly for better stability"""
72
  # Get the port from environment variable in Hugging Face Space
 
135
  subprocess.run(cmd, shell=True)
136
 
137
  if __name__ == "__main__":
138
+ # Install Apex at runtime before starting the server
139
+ install_apex()
140
+
141
  # First try to run using the async API
142
  try:
143
  logger.info("Starting server using async API")