From b5d667cebf8d0e2479fae8ad2ca8cda112900a6b Mon Sep 17 00:00:00 2001 From: Joe Schoonover <11430768+fluidnumerics-joe@users.noreply.github.com> Date: Tue, 15 Apr 2025 01:36:39 -0400 Subject: [PATCH] [docker] patch docker-rocm (#7725) * Update Dockerfile * Fix typo * Fix syntax for /bin/sh conditional * Add build args to docker-compose * Change shell to /bin/bash This is required for "==" syntax in conditional string comparison --- docker/docker-rocm/Dockerfile | 12 ++++++++++++ docker/docker-rocm/docker-compose.yml | 2 ++ 2 files changed, 14 insertions(+) diff --git a/docker/docker-rocm/Dockerfile b/docker/docker-rocm/Dockerfile index 61eb68e5..9595bafa 100644 --- a/docker/docker-rocm/Dockerfile +++ b/docker/docker-rocm/Dockerfile @@ -12,8 +12,13 @@ ARG INSTALL_DEEPSPEED=false ARG INSTALL_FLASHATTN=false ARG INSTALL_LIGER_KERNEL=false ARG INSTALL_HQQ=false +ARG INSTALL_PYTORCH=true ARG PIP_INDEX=https://pypi.org/simple ARG HTTP_PROXY= +ARG PYTORCH_INDEX=https://download.pytorch.org/whl/nightly/rocm6.3 + +# Use Bash instead of default /bin/sh +SHELL ["/bin/bash", "-c"] # Set the working directory WORKDIR /app @@ -62,6 +67,13 @@ RUN EXTRA_PACKAGES="metrics"; \ pip install -e ".[$EXTRA_PACKAGES]"; \ fi +# Reinstall pytorch +# This is necessary to ensure that the correct version of PyTorch is installed +RUN if [ "$INSTALL_PYTORCH" == "true" ]; then \ + pip uninstall -y torch torchvision torchaudio && \ + pip install --pre torch torchvision torchaudio --index-url "$PYTORCH_INDEX"; \ + fi + # Rebuild flash attention RUN pip uninstall -y transformer-engine flash-attn && \ if [ "$INSTALL_FLASHATTN" == "true" ]; then \ diff --git a/docker/docker-rocm/docker-compose.yml b/docker/docker-rocm/docker-compose.yml index 4233dbff..caaf4e16 100644 --- a/docker/docker-rocm/docker-compose.yml +++ b/docker/docker-rocm/docker-compose.yml @@ -9,8 +9,10 @@ services: INSTALL_DEEPSPEED: "false" INSTALL_FLASHATTN: "false" INSTALL_LIGER_KERNEL: "false" + INSTALL_PYTORCH: "true" INSTALL_HQQ: "false" PIP_INDEX: https://pypi.org/simple + PYTORCH_INDEX: https://download.pytorch.org/whl/nightly/rocm6.3 container_name: llamafactory volumes: - ../../hf_cache:/root/.cache/huggingface