flash attn pip install (#426)
Browse files* flash attn pip
* add packaging
* add packaging to apt get
* install flash attn in dockerfile
* remove unused whls
* add wheel
* clean up pr
fix packaging requirement for ci
upgrade pip for ci
skip build isolation for requiremnents to get flash-attn working
install flash-attn seperately
* install wheel for ci
* no flash-attn for basic cicd
* install flash-attn as pip extras
---------
Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: mhenrichsen <[email protected]>
Co-authored-by: Mads Henrichsen <[email protected]>
Co-authored-by: Wing Lian <[email protected]>
- .github/workflows/main.yml +6 -5
- README.md +1 -1
- docker/Dockerfile +2 -2
- docker/Dockerfile-base +1 -26
- requirements.txt +1 -0
- setup.py +4 -1
    	
        .github/workflows/main.yml
    CHANGED
    
    | @@ -13,17 +13,17 @@ jobs: | |
| 13 | 
             
                  fail-fast: false
         | 
| 14 | 
             
                  matrix:
         | 
| 15 | 
             
                    include:
         | 
| 16 | 
            -
                      - cuda:  | 
| 17 | 
             
                        cuda_version: 11.8.0
         | 
| 18 | 
             
                        python_version: "3.9"
         | 
| 19 | 
             
                        pytorch: 2.0.1
         | 
| 20 | 
             
                        axolotl_extras:
         | 
| 21 | 
            -
                      - cuda:  | 
| 22 | 
             
                        cuda_version: 11.8.0
         | 
| 23 | 
             
                        python_version: "3.10"
         | 
| 24 | 
             
                        pytorch: 2.0.1
         | 
| 25 | 
             
                        axolotl_extras:
         | 
| 26 | 
            -
                      - cuda:  | 
| 27 | 
             
                        cuda_version: 11.8.0
         | 
| 28 | 
             
                        python_version: "3.9"
         | 
| 29 | 
             
                        pytorch: 2.0.1
         | 
| @@ -49,10 +49,11 @@ jobs: | |
| 49 | 
             
                    with:
         | 
| 50 | 
             
                      context: .
         | 
| 51 | 
             
                      build-args: |
         | 
| 52 | 
            -
                        BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }} | 
|  | |
| 53 | 
             
                      file: ./docker/Dockerfile
         | 
| 54 | 
             
                      push: ${{ github.event_name != 'pull_request' }}
         | 
| 55 | 
            -
                      tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }} | 
| 56 | 
             
                      labels: ${{ steps.metadata.outputs.labels }}
         | 
| 57 | 
             
              build-axolotl-runpod:
         | 
| 58 | 
             
                needs: build-axolotl
         | 
|  | |
| 13 | 
             
                  fail-fast: false
         | 
| 14 | 
             
                  matrix:
         | 
| 15 | 
             
                    include:
         | 
| 16 | 
            +
                      - cuda: 118
         | 
| 17 | 
             
                        cuda_version: 11.8.0
         | 
| 18 | 
             
                        python_version: "3.9"
         | 
| 19 | 
             
                        pytorch: 2.0.1
         | 
| 20 | 
             
                        axolotl_extras:
         | 
| 21 | 
            +
                      - cuda: 118
         | 
| 22 | 
             
                        cuda_version: 11.8.0
         | 
| 23 | 
             
                        python_version: "3.10"
         | 
| 24 | 
             
                        pytorch: 2.0.1
         | 
| 25 | 
             
                        axolotl_extras:
         | 
| 26 | 
            +
                      - cuda: 118
         | 
| 27 | 
             
                        cuda_version: 11.8.0
         | 
| 28 | 
             
                        python_version: "3.9"
         | 
| 29 | 
             
                        pytorch: 2.0.1
         | 
|  | |
| 49 | 
             
                    with:
         | 
| 50 | 
             
                      context: .
         | 
| 51 | 
             
                      build-args: |
         | 
| 52 | 
            +
                        BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
         | 
| 53 | 
            +
                        CUDA=${{ matrix.cuda }}
         | 
| 54 | 
             
                      file: ./docker/Dockerfile
         | 
| 55 | 
             
                      push: ${{ github.event_name != 'pull_request' }}
         | 
| 56 | 
            +
                      tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
         | 
| 57 | 
             
                      labels: ${{ steps.metadata.outputs.labels }}
         | 
| 58 | 
             
              build-axolotl-runpod:
         | 
| 59 | 
             
                needs: build-axolotl
         | 
    	
        README.md
    CHANGED
    
    | @@ -69,7 +69,7 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo | |
| 69 | 
             
            ```bash
         | 
| 70 | 
             
            git clone https://github.com/OpenAccess-AI-Collective/axolotl
         | 
| 71 |  | 
| 72 | 
            -
            pip3 install -e .
         | 
| 73 | 
             
            pip3 install -U git+https://github.com/huggingface/peft.git
         | 
| 74 |  | 
| 75 | 
             
            # finetune lora
         | 
|  | |
| 69 | 
             
            ```bash
         | 
| 70 | 
             
            git clone https://github.com/OpenAccess-AI-Collective/axolotl
         | 
| 71 |  | 
| 72 | 
            +
            pip3 install -e .[flash-attn]
         | 
| 73 | 
             
            pip3 install -U git+https://github.com/huggingface/peft.git
         | 
| 74 |  | 
| 75 | 
             
            # finetune lora
         | 
    	
        docker/Dockerfile
    CHANGED
    
    | @@ -16,9 +16,9 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git | |
| 16 | 
             
            # If AXOLOTL_EXTRAS is set, append it in brackets
         | 
| 17 | 
             
            RUN cd axolotl && \
         | 
| 18 | 
             
                if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
         | 
| 19 | 
            -
                    pip install -e .[ | 
| 20 | 
             
                else \
         | 
| 21 | 
            -
                    pip install -e  | 
| 22 | 
             
                fi
         | 
| 23 |  | 
| 24 | 
             
            # fix so that git fetch/pull from remote works
         | 
|  | |
| 16 | 
             
            # If AXOLOTL_EXTRAS is set, append it in brackets
         | 
| 17 | 
             
            RUN cd axolotl && \
         | 
| 18 | 
             
                if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
         | 
| 19 | 
            +
                    pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
         | 
| 20 | 
             
                else \
         | 
| 21 | 
            +
                    pip install -e .[flash-attn]; \
         | 
| 22 | 
             
                fi
         | 
| 23 |  | 
| 24 | 
             
            # fix so that git fetch/pull from remote works
         | 
    	
        docker/Dockerfile-base
    CHANGED
    
    | @@ -31,26 +31,6 @@ WORKDIR /workspace | |
| 31 | 
             
            RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
         | 
| 32 | 
             
                python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
         | 
| 33 |  | 
| 34 | 
            -
             | 
| 35 | 
            -
            FROM base-builder AS flash-attn-builder
         | 
| 36 | 
            -
             | 
| 37 | 
            -
            WORKDIR /workspace
         | 
| 38 | 
            -
             | 
| 39 | 
            -
            ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
         | 
| 40 | 
            -
             | 
| 41 | 
            -
            RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
         | 
| 42 | 
            -
                cd flash-attention && \
         | 
| 43 | 
            -
                git checkout v2.0.4  && \
         | 
| 44 | 
            -
                python3 setup.py bdist_wheel && \
         | 
| 45 | 
            -
                cd csrc/fused_dense_lib && \
         | 
| 46 | 
            -
                python3 setup.py bdist_wheel && \
         | 
| 47 | 
            -
                cd ../xentropy && \
         | 
| 48 | 
            -
                python3 setup.py bdist_wheel && \
         | 
| 49 | 
            -
                cd ../rotary && \
         | 
| 50 | 
            -
                python3 setup.py bdist_wheel && \
         | 
| 51 | 
            -
                cd ../layer_norm && \
         | 
| 52 | 
            -
                python3 setup.py bdist_wheel
         | 
| 53 | 
            -
             | 
| 54 | 
             
            FROM base-builder AS deepspeed-builder
         | 
| 55 |  | 
| 56 | 
             
            ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
         | 
| @@ -90,13 +70,8 @@ RUN mkdir -p /workspace/wheels/bitsandbytes | |
| 90 | 
             
            COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
         | 
| 91 | 
             
            COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels
         | 
| 92 | 
             
            COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes
         | 
| 93 | 
            -
            COPY --from=flash-attn-builder /workspace/flash-attention/dist/flash_attn-*.whl wheels
         | 
| 94 | 
            -
            COPY --from=flash-attn-builder /workspace/flash-attention/csrc/fused_dense_lib/dist/fused_dense_lib-*.whl wheels
         | 
| 95 | 
            -
            COPY --from=flash-attn-builder /workspace/flash-attention/csrc/xentropy/dist/xentropy_cuda_lib-*.whl wheels
         | 
| 96 | 
            -
            COPY --from=flash-attn-builder /workspace/flash-attention/csrc/rotary/dist/rotary_emb-*.whl wheels
         | 
| 97 | 
            -
            COPY --from=flash-attn-builder /workspace/flash-attention/csrc/layer_norm/dist/dropout_layer_norm-*.whl wheels
         | 
| 98 |  | 
| 99 | 
            -
            RUN pip3 install wheels/deepspeed-*.whl | 
| 100 | 
             
            RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
         | 
| 101 | 
             
            RUN git lfs install --skip-repo
         | 
| 102 | 
             
            RUN pip3 install awscli && \
         | 
|  | |
| 31 | 
             
            RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
         | 
| 32 | 
             
                python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
         | 
| 33 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 34 | 
             
            FROM base-builder AS deepspeed-builder
         | 
| 35 |  | 
| 36 | 
             
            ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
         | 
|  | |
| 70 | 
             
            COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
         | 
| 71 | 
             
            COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels
         | 
| 72 | 
             
            COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 73 |  | 
| 74 | 
            +
            RUN pip3 install wheels/deepspeed-*.whl
         | 
| 75 | 
             
            RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
         | 
| 76 | 
             
            RUN git lfs install --skip-repo
         | 
| 77 | 
             
            RUN pip3 install awscli && \
         | 
    	
        requirements.txt
    CHANGED
    
    | @@ -6,6 +6,7 @@ addict | |
| 6 | 
             
            fire
         | 
| 7 | 
             
            PyYAML==6.0
         | 
| 8 | 
             
            datasets
         | 
|  | |
| 9 | 
             
            sentencepiece
         | 
| 10 | 
             
            wandb
         | 
| 11 | 
             
            einops
         | 
|  | |
| 6 | 
             
            fire
         | 
| 7 | 
             
            PyYAML==6.0
         | 
| 8 | 
             
            datasets
         | 
| 9 | 
            +
            flash-attn==2.0.8
         | 
| 10 | 
             
            sentencepiece
         | 
| 11 | 
             
            wandb
         | 
| 12 | 
             
            einops
         | 
    	
        setup.py
    CHANGED
    
    | @@ -7,6 +7,7 @@ with open("./requirements.txt", encoding="utf-8") as requirements_file: | |
| 7 | 
             
                # don't include peft yet until we check the int4
         | 
| 8 | 
             
                # need to manually install peft for now...
         | 
| 9 | 
             
                reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r]
         | 
|  | |
| 10 | 
             
                reqs = [r for r in reqs if r and r[0] != "#"]
         | 
| 11 | 
             
                for r in reqs:
         | 
| 12 | 
             
                    install_requires.append(r)
         | 
| @@ -25,8 +26,10 @@ setup( | |
| 25 | 
             
                    "gptq_triton": [
         | 
| 26 | 
             
                        "alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
         | 
| 27 | 
             
                    ],
         | 
|  | |
|  | |
|  | |
| 28 | 
             
                    "extras": [
         | 
| 29 | 
            -
                        "flash-attn",
         | 
| 30 | 
             
                        "deepspeed",
         | 
| 31 | 
             
                    ],
         | 
| 32 | 
             
                },
         | 
|  | |
| 7 | 
             
                # don't include peft yet until we check the int4
         | 
| 8 | 
             
                # need to manually install peft for now...
         | 
| 9 | 
             
                reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r]
         | 
| 10 | 
            +
                reqs = [r for r in reqs if "flash-attn" not in r]
         | 
| 11 | 
             
                reqs = [r for r in reqs if r and r[0] != "#"]
         | 
| 12 | 
             
                for r in reqs:
         | 
| 13 | 
             
                    install_requires.append(r)
         | 
|  | |
| 26 | 
             
                    "gptq_triton": [
         | 
| 27 | 
             
                        "alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
         | 
| 28 | 
             
                    ],
         | 
| 29 | 
            +
                    "flash-attn": [
         | 
| 30 | 
            +
                        "flash-attn==2.0.8",
         | 
| 31 | 
            +
                    ],
         | 
| 32 | 
             
                    "extras": [
         | 
|  | |
| 33 | 
             
                        "deepspeed",
         | 
| 34 | 
             
                    ],
         | 
| 35 | 
             
                },
         | 
