diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index 731d99e4..a2ad43c4 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -95,7 +95,8 @@ def main(): ( "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} " "--master_addr {master_addr} --master_port {master_port} {file_name} {args}" - ).format( + ) + .format( nnodes=os.getenv("NNODES", "1"), node_rank=os.getenv("NODE_RANK", "0"), nproc_per_node=os.getenv("NPROC_PER_NODE", str(get_device_count())), @@ -103,8 +104,8 @@ def main(): master_port=master_port, file_name=launcher.__file__, args=" ".join(sys.argv[1:]), - ), - shell=True, + ) + .split() ) sys.exit(process.returncode) else: