From 5f48dd545e5ab6e6222dc3d7c0fe6e4dce0c5adb Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 16 Apr 2024 17:26:30 +0800 Subject: [PATCH] Update finetuning_args.py Former-commit-id: ec899cccf3b8710510e496a3cd8e4c302bb99a19 --- src/llmtuner/hparams/finetuning_args.py | 90 +++++++++++++------------ 1 file changed, 48 insertions(+), 42 deletions(-) diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index d64f1583..899c7284 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -163,47 +163,6 @@ class RLHFArguments: metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."}, ) -@dataclass -class BAdamArgument: - r""" - Arguments for BAdam optimizer. - """ - use_badam: bool = field( - default=False, - metadata={"help": "Whether or not to use BAdam optimizer."}, - ) - badam_mode: Literal["layer", "ratio"] = field( - default="layer", - metadata={"help": "The mode of BAdam optimizer. 'layer' for layer-wise, 'ratio' for ratio-wise."}, - ) - - # ======== Arguments for layer-wise update ======== - start_block: Optional[int] = field( - default=None, - metadata={"help": "The starting block index for block-wise fine-tuning."} - ) - switch_block_every: Optional[int] = field( - default=50, - metadata={"help": "how often to switch model's block update. Set to -1 to disable the block update."} - ) - switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field( - default="ascending", - metadata={"help": "the strategy of picking block to update."} - ) - - # ======== Arguments for ratio-wise update ======== - badam_update_ratio: float = field( - default=0., - metadata={"help": "The ratio of the update for the BAdam optimizer."} - ) - badam_mask_mode: Literal["adjacent", "scatter"] = field( - default="adjacent", - metadata={"help": "The mode of the mask for BAdam optimizer. `adjacent` means that the trainable parameters are adjacent to each other; `scatter` means that trainable parameters are randomly choosed from the weight."} - ) - badam_verbose: int = field( - default=0, - metadata={"help": "The verbosity level of BAdam optimizer. 0 for no print, 1 for print the block prefix, 2 for print trainable parameters"} - ) @dataclass class GaloreArguments: @@ -213,7 +172,7 @@ class GaloreArguments: use_galore: bool = field( default=False, - metadata={"help": "Whether or not to use gradient low-Rank projection."}, + metadata={"help": "Whether or not to use the gradient low-Rank projection (GaLore)."}, ) galore_target: str = field( default="all", @@ -244,6 +203,53 @@ class GaloreArguments: ) +@dataclass +class BAdamArgument: + r""" + Arguments pertaining to the BAdam optimizer. + """ + + use_badam: bool = field( + default=False, + metadata={"help": "Whether or not to use the BAdam optimizer."}, + ) + badam_mode: Literal["layer", "ratio"] = field( + default="layer", + metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."}, + ) + badam_start_block: Optional[int] = field( + default=None, + metadata={"help": "The starting block index for layer-wise BAdam."}, + ) + badam_switch_block_every: Optional[int] = field( + default=50, + metadata={"help": "How often to switch model's block update. Set to -1 to disable the block update."}, + ) + badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field( + default="ascending", + metadata={"help": "the strategy of picking block to update for layer-wise BAdam."}, + ) + badam_update_ratio: float = field( + default=0.0, + metadata={"help": "The ratio of the update for ratio-wise BAdam."}, + ) + badam_mask_mode: Literal["adjacent", "scatter"] = field( + default="adjacent", + metadata={ + "help": """The mode of the mask for BAdam optimizer. \ + `adjacent` means that the trainable parameters are adjacent to each other, \ + `scatter` means that trainable parameters are randomly choosed from the weight.""" + }, + ) + badam_verbose: int = field( + default=0, + metadata={ + "help": """The verbosity level of BAdam optimizer. \ + 0 for no print, 1 for print the block prefix, 2 for print trainable parameters""" + }, + ) + + @dataclass class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument): r"""