mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[assets] fix npu docker (#8298)
This commit is contained in:
		
							parent
							
								
									1a33d65a56
								
							
						
					
					
						commit
						ed70f8d5a2
					
				@ -27,7 +27,7 @@ WORKDIR /app
 | 
			
		||||
# Change pip source
 | 
			
		||||
RUN pip config set global.index-url "${PIP_INDEX}" && \
 | 
			
		||||
    pip config set global.extra-index-url "${PIP_INDEX}" && \
 | 
			
		||||
    python -m pip install --upgrade pip
 | 
			
		||||
    pip install --no-cache-dir --upgrade pip packaging wheel setuptools
 | 
			
		||||
 | 
			
		||||
# Install the requirements
 | 
			
		||||
COPY requirements.txt /app
 | 
			
		||||
 | 
			
		||||
@ -40,7 +40,7 @@ RUN apt-get update && \
 | 
			
		||||
# Change pip source
 | 
			
		||||
RUN pip config set global.index-url "${PIP_INDEX}" && \
 | 
			
		||||
    pip config set global.extra-index-url "${PIP_INDEX}" && \
 | 
			
		||||
    python -m pip install --upgrade pip
 | 
			
		||||
    pip install --no-cache-dir --upgrade pip packaging wheel setuptools
 | 
			
		||||
 | 
			
		||||
# Install flash-attn-2.7.4.post1 (cxx11abi=False)
 | 
			
		||||
RUN wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl && \
 | 
			
		||||
 | 
			
		||||
@ -26,7 +26,7 @@ WORKDIR /app
 | 
			
		||||
# Change pip source
 | 
			
		||||
RUN pip config set global.index-url "${PIP_INDEX}" && \
 | 
			
		||||
    pip config set global.extra-index-url "${PIP_INDEX}" && \
 | 
			
		||||
    python -m pip install --upgrade pip
 | 
			
		||||
    pip install --no-cache-dir --upgrade pip packaging wheel setuptools
 | 
			
		||||
 | 
			
		||||
# Install the requirements
 | 
			
		||||
COPY requirements.txt /app
 | 
			
		||||
 | 
			
		||||
@ -28,11 +28,11 @@ WORKDIR /app
 | 
			
		||||
# Change pip source
 | 
			
		||||
RUN pip config set global.index-url "${PIP_INDEX}" && \
 | 
			
		||||
    pip config set global.extra-index-url "${PIP_INDEX}" && \
 | 
			
		||||
    python -m pip install --upgrade pip
 | 
			
		||||
    pip install --no-cache-dir --upgrade pip packaging wheel setuptools
 | 
			
		||||
 | 
			
		||||
# Reinstall pytorch rocm
 | 
			
		||||
RUN pip uninstall -y torch torchvision torchaudio && \
 | 
			
		||||
    pip install --pre torch torchvision torchaudio --index-url "${PYTORCH_INDEX}"
 | 
			
		||||
    pip install --no-cache-dir --pre torch torchvision torchaudio --index-url "${PYTORCH_INDEX}"
 | 
			
		||||
 | 
			
		||||
# Install the requirements
 | 
			
		||||
COPY requirements.txt /app
 | 
			
		||||
 | 
			
		||||
@ -18,8 +18,6 @@ import sys
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from functools import partial
 | 
			
		||||
 | 
			
		||||
from .hparams import get_train_args
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
USAGE = (
 | 
			
		||||
    "-" * 70
 | 
			
		||||
@ -78,18 +76,20 @@ def main():
 | 
			
		||||
    command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
 | 
			
		||||
    if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
 | 
			
		||||
        # launch distributed training
 | 
			
		||||
        max_restarts = os.getenv("MAX_RESTARTS", "0")
 | 
			
		||||
        rdzv_id = os.getenv("RDZV_ID")
 | 
			
		||||
        nnodes = os.getenv("NNODES", "1")
 | 
			
		||||
        min_nnodes = os.getenv("MIN_NNODES")
 | 
			
		||||
        max_nnodes = os.getenv("MAX_NNODES")
 | 
			
		||||
        node_rank = os.getenv("NODE_RANK", "0")
 | 
			
		||||
        nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
 | 
			
		||||
        master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
 | 
			
		||||
        master_port = os.getenv("MASTER_PORT", str(find_available_port()))
 | 
			
		||||
        logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
 | 
			
		||||
        if int(nnodes) > 1:
 | 
			
		||||
            print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
 | 
			
		||||
            logger.info_rank0(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
 | 
			
		||||
 | 
			
		||||
        # elastic launch support
 | 
			
		||||
        max_restarts = os.getenv("MAX_RESTARTS", "0")
 | 
			
		||||
        rdzv_id = os.getenv("RDZV_ID")
 | 
			
		||||
        min_nnodes = os.getenv("MIN_NNODES")
 | 
			
		||||
        max_nnodes = os.getenv("MAX_NNODES")
 | 
			
		||||
 | 
			
		||||
        env = deepcopy(os.environ)
 | 
			
		||||
        if is_env_enabled("OPTIM_TORCH", "1"):
 | 
			
		||||
@ -104,24 +104,27 @@ def main():
 | 
			
		||||
            # elastic number of nodes if MIN_NNODES and MAX_NNODES are set
 | 
			
		||||
            if min_nnodes is not None and max_nnodes is not None:
 | 
			
		||||
                rdzv_nnodes = f"{min_nnodes}:{max_nnodes}"
 | 
			
		||||
            cmd = [
 | 
			
		||||
                "torchrun",
 | 
			
		||||
                "--nnodes",
 | 
			
		||||
                rdzv_nnodes,
 | 
			
		||||
                "--nproc-per-node",
 | 
			
		||||
                nproc_per_node,
 | 
			
		||||
                "--rdzv-id",
 | 
			
		||||
                rdzv_id,
 | 
			
		||||
                "--rdzv-backend",
 | 
			
		||||
                "c10d",
 | 
			
		||||
                "--rdzv-endpoint",
 | 
			
		||||
                f"{master_addr}:{master_port}",
 | 
			
		||||
                "--max-restarts",
 | 
			
		||||
                max_restarts,
 | 
			
		||||
                launcher.__file__,
 | 
			
		||||
                *sys.argv[1:],
 | 
			
		||||
            ]
 | 
			
		||||
            process = subprocess.run(cmd, env=env, check=True)
 | 
			
		||||
 | 
			
		||||
            process = subprocess.run(
 | 
			
		||||
                (
 | 
			
		||||
                    "torchrun --nnodes {rdzv_nnodes} --nproc-per-node {nproc_per_node} "
 | 
			
		||||
                    "--rdzv-id {rdzv_id} --rdzv-backend c10d --rdzv-endpoint {master_addr}:{master_port} "
 | 
			
		||||
                    "--max-restarts {max_restarts} {file_name} {args}"
 | 
			
		||||
                )
 | 
			
		||||
                .format(
 | 
			
		||||
                    rdzv_nnodes=rdzv_nnodes,
 | 
			
		||||
                    nproc_per_node=nproc_per_node,
 | 
			
		||||
                    rdzv_id=rdzv_id,
 | 
			
		||||
                    master_addr=master_addr,
 | 
			
		||||
                    master_port=master_port,
 | 
			
		||||
                    max_restarts=max_restarts,
 | 
			
		||||
                    file_name=launcher.__file__,
 | 
			
		||||
                    args=" ".join(sys.argv[1:]),
 | 
			
		||||
                )
 | 
			
		||||
                .split(),
 | 
			
		||||
                env=env,
 | 
			
		||||
                check=True,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            # NOTE: DO NOT USE shell=True to avoid security risk
 | 
			
		||||
            process = subprocess.run(
 | 
			
		||||
@ -142,6 +145,7 @@ def main():
 | 
			
		||||
                env=env,
 | 
			
		||||
                check=True,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        sys.exit(process.returncode)
 | 
			
		||||
    elif command in COMMAND_MAP:
 | 
			
		||||
        COMMAND_MAP[command]()
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user