diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index f9c32d43..6eb61702 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -73,7 +73,7 @@ def main(): "help": partial(print, USAGE), } - command = sys.argv.pop(1) if len(sys.argv) >= 1 else "help" + command = sys.argv.pop(1) if len(sys.argv) > 1 else "help" if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())): # launch distributed training nnodes = os.getenv("NNODES", "1")