FROM nvidia/cuda:12.3.2-cudnn9-devel-ubuntu22.04 | |
# 1. System packages | |
RUN apt-get update && apt-get install -y \ | |
python3 python3-pip python3-dev build-essential ninja-build git \ | |
&& rm -rf /var/lib/apt/lists/* | |
# 2. Upgrade pip | |
RUN pip install --upgrade pip | |
# 3. Install GPU-compatible Torch | |
RUN pip install torch==2.1.0+cu121 -f https://download.pytorch.org/whl/cu121 | |
# 4. Install flash_attn | |
RUN pip install flash_attn | |
# 5. Install other Python libraries you need | |
RUN pip install transformers gradio | |
# 6. Copy your application code into the container | |
WORKDIR /app | |
COPY . /app | |
# 7. Expose port 7860 (Gradio default) | |
EXPOSE 7860 | |
# 8. Launch your app | |
CMD ["python3", "app.py"] | |