mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-07-31 10:42:50 +08:00
[assets] fix npu docker (#8298)
This commit is contained in:
parent
83688b0b4d
commit
cecba57b3e
@ -27,7 +27,7 @@ WORKDIR /app
|
||||
# Change pip source
|
||||
RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||
pip config set global.extra-index-url "${PIP_INDEX}" && \
|
||||
python -m pip install --upgrade pip
|
||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools
|
||||
|
||||
# Install the requirements
|
||||
COPY requirements.txt /app
|
||||
|
@ -40,7 +40,7 @@ RUN apt-get update && \
|
||||
# Change pip source
|
||||
RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||
pip config set global.extra-index-url "${PIP_INDEX}" && \
|
||||
python -m pip install --upgrade pip
|
||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools
|
||||
|
||||
# Install flash-attn-2.7.4.post1 (cxx11abi=False)
|
||||
RUN wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl && \
|
||||
|
@ -26,7 +26,7 @@ WORKDIR /app
|
||||
# Change pip source
|
||||
RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||
pip config set global.extra-index-url "${PIP_INDEX}" && \
|
||||
python -m pip install --upgrade pip
|
||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools
|
||||
|
||||
# Install the requirements
|
||||
COPY requirements.txt /app
|
||||
|
@ -28,11 +28,11 @@ WORKDIR /app
|
||||
# Change pip source
|
||||
RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||
pip config set global.extra-index-url "${PIP_INDEX}" && \
|
||||
python -m pip install --upgrade pip
|
||||
pip install --no-cache-dir --upgrade pip packaging wheel setuptools
|
||||
|
||||
# Reinstall pytorch rocm
|
||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||
pip install --pre torch torchvision torchaudio --index-url "${PYTORCH_INDEX}"
|
||||
pip install --no-cache-dir --pre torch torchvision torchaudio --index-url "${PYTORCH_INDEX}"
|
||||
|
||||
# Install the requirements
|
||||
COPY requirements.txt /app
|
||||
|
@ -18,8 +18,6 @@ import sys
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
from .hparams import get_train_args
|
||||
|
||||
|
||||
USAGE = (
|
||||
"-" * 70
|
||||
@ -78,18 +76,20 @@ def main():
|
||||
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
|
||||
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
|
||||
# launch distributed training
|
||||
max_restarts = os.getenv("MAX_RESTARTS", "0")
|
||||
rdzv_id = os.getenv("RDZV_ID")
|
||||
nnodes = os.getenv("NNODES", "1")
|
||||
min_nnodes = os.getenv("MIN_NNODES")
|
||||
max_nnodes = os.getenv("MAX_NNODES")
|
||||
node_rank = os.getenv("NODE_RANK", "0")
|
||||
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
|
||||
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
|
||||
master_port = os.getenv("MASTER_PORT", str(find_available_port()))
|
||||
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
|
||||
if int(nnodes) > 1:
|
||||
print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
|
||||
logger.info_rank0(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
|
||||
|
||||
# elastic launch support
|
||||
max_restarts = os.getenv("MAX_RESTARTS", "0")
|
||||
rdzv_id = os.getenv("RDZV_ID")
|
||||
min_nnodes = os.getenv("MIN_NNODES")
|
||||
max_nnodes = os.getenv("MAX_NNODES")
|
||||
|
||||
env = deepcopy(os.environ)
|
||||
if is_env_enabled("OPTIM_TORCH", "1"):
|
||||
@ -104,24 +104,27 @@ def main():
|
||||
# elastic number of nodes if MIN_NNODES and MAX_NNODES are set
|
||||
if min_nnodes is not None and max_nnodes is not None:
|
||||
rdzv_nnodes = f"{min_nnodes}:{max_nnodes}"
|
||||
cmd = [
|
||||
"torchrun",
|
||||
"--nnodes",
|
||||
rdzv_nnodes,
|
||||
"--nproc-per-node",
|
||||
nproc_per_node,
|
||||
"--rdzv-id",
|
||||
rdzv_id,
|
||||
"--rdzv-backend",
|
||||
"c10d",
|
||||
"--rdzv-endpoint",
|
||||
f"{master_addr}:{master_port}",
|
||||
"--max-restarts",
|
||||
max_restarts,
|
||||
launcher.__file__,
|
||||
*sys.argv[1:],
|
||||
]
|
||||
process = subprocess.run(cmd, env=env, check=True)
|
||||
|
||||
process = subprocess.run(
|
||||
(
|
||||
"torchrun --nnodes {rdzv_nnodes} --nproc-per-node {nproc_per_node} "
|
||||
"--rdzv-id {rdzv_id} --rdzv-backend c10d --rdzv-endpoint {master_addr}:{master_port} "
|
||||
"--max-restarts {max_restarts} {file_name} {args}"
|
||||
)
|
||||
.format(
|
||||
rdzv_nnodes=rdzv_nnodes,
|
||||
nproc_per_node=nproc_per_node,
|
||||
rdzv_id=rdzv_id,
|
||||
master_addr=master_addr,
|
||||
master_port=master_port,
|
||||
max_restarts=max_restarts,
|
||||
file_name=launcher.__file__,
|
||||
args=" ".join(sys.argv[1:]),
|
||||
)
|
||||
.split(),
|
||||
env=env,
|
||||
check=True,
|
||||
)
|
||||
else:
|
||||
# NOTE: DO NOT USE shell=True to avoid security risk
|
||||
process = subprocess.run(
|
||||
@ -142,6 +145,7 @@ def main():
|
||||
env=env,
|
||||
check=True,
|
||||
)
|
||||
|
||||
sys.exit(process.returncode)
|
||||
elif command in COMMAND_MAP:
|
||||
COMMAND_MAP[command]()
|
||||
|
Loading…
x
Reference in New Issue
Block a user