diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index 177618bb..33aa735b 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -88,18 +88,24 @@ def main(): elif command == Command.TRAIN: force_torchrun = is_env_enabled("FORCE_TORCHRUN") if force_torchrun or (get_device_count() > 1 and not use_ray()): + nnodes = os.getenv("NNODES", "1") + node_rank = os.getenv("NODE_RANK", "0") + nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count())) master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999))) - logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}") + logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}") + if int(nnodes) > 1: + print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}") + process = subprocess.run( ( "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} " "--master_addr {master_addr} --master_port {master_port} {file_name} {args}" ) .format( - nnodes=os.getenv("NNODES", "1"), - node_rank=os.getenv("NODE_RANK", "0"), - nproc_per_node=os.getenv("NPROC_PER_NODE", str(get_device_count())), + nnodes=nnodes, + node_rank=node_rank, + nproc_per_node=nproc_per_node, master_addr=master_addr, master_port=master_port, file_name=launcher.__file__, @@ -119,7 +125,7 @@ def main(): elif command == Command.HELP: print(USAGE) else: - raise NotImplementedError(f"Unknown command: {command}.") + print(f"Unknown command: {command}.\n{USAGE}") if __name__ == "__main__": diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 52e37a3b..933ab9e5 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -387,7 +387,7 @@ class SwanLabArguments: @dataclass class FinetuningArguments( - FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, ApolloArguments, BAdamArgument, SwanLabArguments + SwanLabArguments, BAdamArgument, ApolloArguments, GaloreArguments, RLHFArguments, LoraArguments, FreezeArguments ): r""" Arguments pertaining to which techniques we are going to fine-tuning with. diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index fba55832..3ec60b7b 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -24,6 +24,162 @@ from transformers.training_args import _convert_str_dict from typing_extensions import Self +@dataclass +class BaseModelArguments: + r""" + Arguments pertaining to the model. + """ + + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models." + }, + ) + adapter_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Path to the adapter weight or identifier from huggingface.co/models. " + "Use commas to separate multiple adapters." + ) + }, + ) + adapter_folder: Optional[str] = field( + default=None, + metadata={"help": "The folder containing the adapter weights to load."}, + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."}, + ) + resize_vocab: bool = field( + default=False, + metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."}, + ) + split_special_tokens: bool = field( + default=False, + metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}, + ) + new_special_tokens: Optional[str] = field( + default=None, + metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + low_cpu_mem_usage: bool = field( + default=True, + metadata={"help": "Whether or not to use memory-efficient model loading."}, + ) + rope_scaling: Optional[Literal["linear", "dynamic", "yarn", "llama3"]] = field( + default=None, + metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, + ) + flash_attn: Literal["auto", "disabled", "sdpa", "fa2"] = field( + default="auto", + metadata={"help": "Enable FlashAttention for faster training and inference."}, + ) + shift_attn: bool = field( + default=False, + metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}, + ) + mixture_of_depths: Optional[Literal["convert", "load"]] = field( + default=None, + metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."}, + ) + use_unsloth: bool = field( + default=False, + metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}, + ) + use_unsloth_gc: bool = field( + default=False, + metadata={"help": "Whether or not to use unsloth's gradient checkpointing (no need to install unsloth)."}, + ) + enable_liger_kernel: bool = field( + default=False, + metadata={"help": "Whether or not to enable liger kernel for faster training."}, + ) + moe_aux_loss_coef: Optional[float] = field( + default=None, + metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."}, + ) + disable_gradient_checkpointing: bool = field( + default=False, + metadata={"help": "Whether or not to disable gradient checkpointing."}, + ) + use_reentrant_gc: bool = field( + default=True, + metadata={"help": "Whether or not to use reentrant gradient checkpointing."}, + ) + upcast_layernorm: bool = field( + default=False, + metadata={"help": "Whether or not to upcast the layernorm weights in fp32."}, + ) + upcast_lmhead_output: bool = field( + default=False, + metadata={"help": "Whether or not to upcast the output of lm_head in fp32."}, + ) + train_from_scratch: bool = field( + default=False, + metadata={"help": "Whether or not to randomly initialize the model weights."}, + ) + infer_backend: Literal["huggingface", "vllm"] = field( + default="huggingface", + metadata={"help": "Backend engine used at inference."}, + ) + offload_folder: str = field( + default="offload", + metadata={"help": "Path to offload model weights."}, + ) + use_cache: bool = field( + default=True, + metadata={"help": "Whether or not to use KV cache in generation."}, + ) + infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field( + default="auto", + metadata={"help": "Data type for model weights and activations at inference."}, + ) + hf_hub_token: Optional[str] = field( + default=None, + metadata={"help": "Auth token to log in with Hugging Face Hub."}, + ) + ms_hub_token: Optional[str] = field( + default=None, + metadata={"help": "Auth token to log in with ModelScope Hub."}, + ) + om_hub_token: Optional[str] = field( + default=None, + metadata={"help": "Auth token to log in with Modelers Hub."}, + ) + print_param_status: bool = field( + default=False, + metadata={"help": "For debugging purposes, print the status of the parameters in the model."}, + ) + trust_remote_code: bool = field( + default=False, + metadata={"help": "Whether to trust the execution of code from datasets/models defined on the Hub or not."}, + ) + + def __post_init__(self): + if self.model_name_or_path is None: + raise ValueError("Please provide `model_name_or_path`.") + + if self.split_special_tokens and self.use_fast_tokenizer: + raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") + + if self.adapter_name_or_path is not None: # support merging multiple lora weights + self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")] + + if self.new_special_tokens is not None: # support multiple special tokens + self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")] + + @dataclass class QuantizationArguments: r""" @@ -127,6 +283,10 @@ class ExportArguments: metadata={"help": "The name of the repository if push the model to the Hugging Face hub."}, ) + def __post_init__(self): + if self.export_quantization_bit is not None and self.export_quantization_dataset is None: + raise ValueError("Quantization dataset is necessary for exporting.") + @dataclass class VllmArguments: @@ -155,148 +315,19 @@ class VllmArguments: metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."}, ) + def __post_init__(self): + if isinstance(self.vllm_config, str) and self.vllm_config.startswith("{"): + self.vllm_config = _convert_str_dict(json.loads(self.vllm_config)) + @dataclass -class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments, VllmArguments): +class ModelArguments(VllmArguments, ExportArguments, ProcessorArguments, QuantizationArguments, BaseModelArguments): r""" Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer. + + The class on the most right will be displayed first. """ - model_name_or_path: Optional[str] = field( - default=None, - metadata={ - "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models." - }, - ) - adapter_name_or_path: Optional[str] = field( - default=None, - metadata={ - "help": ( - "Path to the adapter weight or identifier from huggingface.co/models. " - "Use commas to separate multiple adapters." - ) - }, - ) - adapter_folder: Optional[str] = field( - default=None, - metadata={"help": "The folder containing the adapter weights to load."}, - ) - cache_dir: Optional[str] = field( - default=None, - metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."}, - ) - use_fast_tokenizer: bool = field( - default=True, - metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."}, - ) - resize_vocab: bool = field( - default=False, - metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."}, - ) - split_special_tokens: bool = field( - default=False, - metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}, - ) - new_special_tokens: Optional[str] = field( - default=None, - metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."}, - ) - model_revision: str = field( - default="main", - metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, - ) - low_cpu_mem_usage: bool = field( - default=True, - metadata={"help": "Whether or not to use memory-efficient model loading."}, - ) - rope_scaling: Optional[Literal["linear", "dynamic", "yarn", "llama3"]] = field( - default=None, - metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, - ) - flash_attn: Literal["auto", "disabled", "sdpa", "fa2"] = field( - default="auto", - metadata={"help": "Enable FlashAttention for faster training and inference."}, - ) - shift_attn: bool = field( - default=False, - metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}, - ) - mixture_of_depths: Optional[Literal["convert", "load"]] = field( - default=None, - metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."}, - ) - use_unsloth: bool = field( - default=False, - metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}, - ) - use_unsloth_gc: bool = field( - default=False, - metadata={"help": "Whether or not to use unsloth's gradient checkpointing."}, - ) - enable_liger_kernel: bool = field( - default=False, - metadata={"help": "Whether or not to enable liger kernel for faster training."}, - ) - moe_aux_loss_coef: Optional[float] = field( - default=None, - metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."}, - ) - disable_gradient_checkpointing: bool = field( - default=False, - metadata={"help": "Whether or not to disable gradient checkpointing."}, - ) - use_reentrant_gc: bool = field( - default=True, - metadata={"help": "Whether or not to use reentrant gradient checkpointing."}, - ) - upcast_layernorm: bool = field( - default=False, - metadata={"help": "Whether or not to upcast the layernorm weights in fp32."}, - ) - upcast_lmhead_output: bool = field( - default=False, - metadata={"help": "Whether or not to upcast the output of lm_head in fp32."}, - ) - train_from_scratch: bool = field( - default=False, - metadata={"help": "Whether or not to randomly initialize the model weights."}, - ) - infer_backend: Literal["huggingface", "vllm"] = field( - default="huggingface", - metadata={"help": "Backend engine used at inference."}, - ) - offload_folder: str = field( - default="offload", - metadata={"help": "Path to offload model weights."}, - ) - use_cache: bool = field( - default=True, - metadata={"help": "Whether or not to use KV cache in generation."}, - ) - infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field( - default="auto", - metadata={"help": "Data type for model weights and activations at inference."}, - ) - hf_hub_token: Optional[str] = field( - default=None, - metadata={"help": "Auth token to log in with Hugging Face Hub."}, - ) - ms_hub_token: Optional[str] = field( - default=None, - metadata={"help": "Auth token to log in with ModelScope Hub."}, - ) - om_hub_token: Optional[str] = field( - default=None, - metadata={"help": "Auth token to log in with Modelers Hub."}, - ) - print_param_status: bool = field( - default=False, - metadata={"help": "For debugging purposes, print the status of the parameters in the model."}, - ) - trust_remote_code: bool = field( - default=False, - metadata={"help": "Whether to trust the execution of code from datasets/models defined on the Hub or not."}, - ) compute_dtype: Optional[torch.dtype] = field( default=None, init=False, @@ -319,23 +350,9 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments, ) def __post_init__(self): - if self.model_name_or_path is None: - raise ValueError("Please provide `model_name_or_path`.") - - if self.split_special_tokens and self.use_fast_tokenizer: - raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") - - if self.adapter_name_or_path is not None: # support merging multiple lora weights - self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")] - - if self.new_special_tokens is not None: # support multiple special tokens - self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")] - - if self.export_quantization_bit is not None and self.export_quantization_dataset is None: - raise ValueError("Quantization dataset is necessary for exporting.") - - if isinstance(self.vllm_config, str) and self.vllm_config.startswith("{"): - self.vllm_config = _convert_str_dict(json.loads(self.vllm_config)) + BaseModelArguments.__post_init__(self) + ExportArguments.__post_init__(self) + VllmArguments.__post_init__(self) @classmethod def copyfrom(cls, source: "Self", **kwargs) -> "Self": diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index b40148fd..0e22868c 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -382,10 +382,10 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _ # Log on each process the small summary logger.info( - "Process rank: {}, device: {}, n_gpu: {}, distributed training: {}, compute dtype: {}".format( - training_args.local_rank, + "Process rank: {}, world size: {}, device: {}, distributed training: {}, compute dtype: {}".format( + training_args.process_index, + training_args.world_size, training_args.device, - training_args.n_gpu, training_args.parallel_mode == ParallelMode.DISTRIBUTED, str(model_args.compute_dtype), ) diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 767d0cda..8ea961a9 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -86,6 +86,9 @@ def _training_function(config: Dict[str, Any]) -> None: def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None) -> None: args = read_args(args) + if "-h" in args or "--help" in args: + get_train_args(args) + ray_args = get_ray_args(args) callbacks = callbacks or [] if ray_args.use_ray: