From 4a4e4b4354b2ee476988c71d5dc625a8672e5109 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 10 Mar 2024 00:24:11 +0800 Subject: [PATCH] support layerwise galore Former-commit-id: 8664262cde3919e10eaecbd66e8c5d356856362e --- README.md | 5 +- README_zh.md | 5 +- examples/extras/galore/adamw.sh | 2 +- examples/extras/galore/adamw_8bit_bf16.sh | 2 +- examples/extras/galore/galore_adamw.sh | 5 +- .../extras/galore/galore_adamw_8bit_bf16.sh | 5 +- src/llmtuner/data/loader.py | 2 +- src/llmtuner/hparams/finetuning_args.py | 4 + src/llmtuner/train/dpo/workflow.py | 2 +- src/llmtuner/train/ppo/workflow.py | 2 +- src/llmtuner/train/pt/workflow.py | 2 +- src/llmtuner/train/rm/workflow.py | 2 +- src/llmtuner/train/sft/workflow.py | 2 +- src/llmtuner/train/utils.py | 120 +++++++++++++----- 14 files changed, 109 insertions(+), 51 deletions(-) diff --git a/README.md b/README.md index a287ac55..a0f7d9cc 100644 --- a/README.md +++ b/README.md @@ -276,16 +276,13 @@ huggingface-cli login | ------ | ---- | ----- | ----- | ----- | ------ | ------ | | Full | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB | | Full | 16 | 60GB | 120GB | 300GB | 600GB | 400GB | -| GaLore | 16 | 28GB | 60GB | 150GB | 300GB | 200GB | +| GaLore | 16 | 16GB | 32GB | 64GB | 160GB | 120GB | | Freeze | 16 | 20GB | 40GB | 80GB | 200GB | 160GB | | LoRA | 16 | 16GB | 32GB | 64GB | 160GB | 120GB | | QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB | | QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB | | QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB | -> [!NOTE] -> We report the GaLore results without per-layer weight updates. - ## Getting Started ### Data Preparation (optional) diff --git a/README_zh.md b/README_zh.md index d92ab254..d034edc8 100644 --- a/README_zh.md +++ b/README_zh.md @@ -276,16 +276,13 @@ huggingface-cli login | ------- | ---- | ----- | ----- | ----- | ------ | ------ | | 全参数 | AMP | 120GB | 240GB | 600GB | 1200GB | 900GB | | 全参数 | 16 | 60GB | 120GB | 300GB | 600GB | 400GB | -| GaLore | 16 | 28GB | 60GB | 150GB | 300GB | 200GB | +| GaLore | 16 | 16GB | 32GB | 64GB | 160GB | 120GB | | 部分参数 | 16 | 20GB | 40GB | 80GB | 200GB | 160GB | | LoRA | 16 | 16GB | 32GB | 64GB | 160GB | 120GB | | QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | 60GB | | QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | 30GB | | QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | 18GB | -> [!NOTE] -> 上述 GaLore 的结果中不包含逐层权重更新。 - ## 如何使用 ### 数据准备(可跳过) diff --git a/examples/extras/galore/adamw.sh b/examples/extras/galore/adamw.sh index 1fd2aaf0..d4f5afb4 100644 --- a/examples/extras/galore/adamw.sh +++ b/examples/extras/galore/adamw.sh @@ -15,7 +15,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \ --preprocessing_num_workers 16 \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 1 \ - --gradient_accumulation_steps 2 \ + --gradient_accumulation_steps 1 \ --lr_scheduler_type cosine \ --logging_steps 10 \ --warmup_steps 20 \ diff --git a/examples/extras/galore/adamw_8bit_bf16.sh b/examples/extras/galore/adamw_8bit_bf16.sh index 01f4e8de..ecb4fa96 100644 --- a/examples/extras/galore/adamw_8bit_bf16.sh +++ b/examples/extras/galore/adamw_8bit_bf16.sh @@ -16,7 +16,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \ --preprocessing_num_workers 16 \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 1 \ - --gradient_accumulation_steps 2 \ + --gradient_accumulation_steps 1 \ --lr_scheduler_type cosine \ --logging_steps 10 \ --warmup_steps 20 \ diff --git a/examples/extras/galore/galore_adamw.sh b/examples/extras/galore/galore_adamw.sh index 83be6a51..063bb6df 100644 --- a/examples/extras/galore/galore_adamw.sh +++ b/examples/extras/galore/galore_adamw.sh @@ -9,8 +9,9 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \ --template default \ --finetuning_type full \ --use_galore \ + --galore_layerwise \ --galore_target mlp,self_attn \ - --galore_rank 32 \ + --galore_rank 128 \ --output_dir ../../../saves/LLaMA2-7B/galore/sft \ --overwrite_cache \ --overwrite_output_dir \ @@ -18,7 +19,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \ --preprocessing_num_workers 16 \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 1 \ - --gradient_accumulation_steps 2 \ + --gradient_accumulation_steps 1 \ --lr_scheduler_type cosine \ --logging_steps 10 \ --warmup_steps 20 \ diff --git a/examples/extras/galore/galore_adamw_8bit_bf16.sh b/examples/extras/galore/galore_adamw_8bit_bf16.sh index 881ab2eb..cedc8bee 100644 --- a/examples/extras/galore/galore_adamw_8bit_bf16.sh +++ b/examples/extras/galore/galore_adamw_8bit_bf16.sh @@ -10,8 +10,9 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \ --finetuning_type full \ --optim adamw_8bit \ --use_galore \ + --galore_layerwise \ --galore_target mlp,self_attn \ - --galore_rank 16 \ + --galore_rank 128 \ --output_dir ../../../saves/LLaMA2-7B/galore/sft \ --overwrite_cache \ --overwrite_output_dir \ @@ -19,7 +20,7 @@ CUDA_VISIBLE_DEVICES=0 python ../../../src/train_bash.py \ --preprocessing_num_workers 16 \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 1 \ - --gradient_accumulation_steps 2 \ + --gradient_accumulation_steps 1 \ --lr_scheduler_type cosine \ --logging_steps 10 \ --warmup_steps 20 \ diff --git a/src/llmtuner/data/loader.py b/src/llmtuner/data/loader.py index f51369bc..937fdb36 100644 --- a/src/llmtuner/data/loader.py +++ b/src/llmtuner/data/loader.py @@ -29,7 +29,7 @@ def load_single_dataset( dataset_attr: "DatasetAttr", model_args: "ModelArguments", data_args: "DataArguments", -): +) -> Union["Dataset", "IterableDataset"]: logger.info("Loading dataset {}...".format(dataset_attr)) data_path, data_name, data_dir, data_files = None, None, None, None if dataset_attr.load_from in ["hf_hub", "ms_hub"]: diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index 1fb270ab..be1fd12c 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -182,6 +182,10 @@ class GaloreArguments: default="std", metadata={"help": "Type of GaLore projection."}, ) + galore_layerwise: bool = field( + default=False, + metadata={"help": "Whether or not to enable layer-wise update to further save memory."}, + ) @dataclass diff --git a/src/llmtuner/train/dpo/workflow.py b/src/llmtuner/train/dpo/workflow.py index ba3a323f..39ea1a0e 100644 --- a/src/llmtuner/train/dpo/workflow.py +++ b/src/llmtuner/train/dpo/workflow.py @@ -44,7 +44,7 @@ def run_dpo( training_args.remove_unused_columns = False # important for pairwise dataset # Initialize our Trainer - optimizer = create_custom_optimzer(model, training_args, finetuning_args) + optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args) trainer = CustomDPOTrainer( beta=finetuning_args.dpo_beta, loss_type=finetuning_args.dpo_loss, diff --git a/src/llmtuner/train/ppo/workflow.py b/src/llmtuner/train/ppo/workflow.py index 4c164f7a..de9f2a2f 100644 --- a/src/llmtuner/train/ppo/workflow.py +++ b/src/llmtuner/train/ppo/workflow.py @@ -64,7 +64,7 @@ def run_ppo( ) # Create optimizer and scheduler - optimizer = create_custom_optimzer(model, training_args, finetuning_args) + optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args) if optimizer is None: optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate) diff --git a/src/llmtuner/train/pt/workflow.py b/src/llmtuner/train/pt/workflow.py index debf600f..5a08854a 100644 --- a/src/llmtuner/train/pt/workflow.py +++ b/src/llmtuner/train/pt/workflow.py @@ -30,7 +30,7 @@ def run_pt( data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) # Initialize our Trainer - optimizer = create_custom_optimzer(model, training_args, finetuning_args) + optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args) trainer = Trainer( model=model, args=training_args, diff --git a/src/llmtuner/train/rm/workflow.py b/src/llmtuner/train/rm/workflow.py index 5daf374c..9dfef302 100644 --- a/src/llmtuner/train/rm/workflow.py +++ b/src/llmtuner/train/rm/workflow.py @@ -35,7 +35,7 @@ def run_rm( training_args.remove_unused_columns = False # important for pairwise dataset # Initialize our Trainer - optimizer = create_custom_optimzer(model, training_args, finetuning_args) + optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args) trainer = PairwiseTrainer( model=model, args=training_args, diff --git a/src/llmtuner/train/sft/workflow.py b/src/llmtuner/train/sft/workflow.py index 10df3b3f..099edc14 100644 --- a/src/llmtuner/train/sft/workflow.py +++ b/src/llmtuner/train/sft/workflow.py @@ -50,7 +50,7 @@ def run_sft( training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams # Initialize our Trainer - optimizer = create_custom_optimzer(model, training_args, finetuning_args) + optimizer = create_custom_optimzer(model, dataset, training_args, finetuning_args) trainer = CustomSeq2SeqTrainer( model=model, args=training_args, diff --git a/src/llmtuner/train/utils.py b/src/llmtuner/train/utils.py index 9d4526a0..75006ee0 100644 --- a/src/llmtuner/train/utils.py +++ b/src/llmtuner/train/utils.py @@ -1,6 +1,8 @@ -from typing import TYPE_CHECKING, Optional, Union +import math +from typing import TYPE_CHECKING, Callable, Dict, Optional, Union import torch +from transformers.optimization import get_scheduler from transformers.utils.versions import require_version from ..extras.logging import get_logger @@ -14,6 +16,7 @@ if is_galore_available(): if TYPE_CHECKING: + from datasets import Dataset, IterableDataset from transformers import Seq2SeqTrainingArguments, Trainer from transformers.modeling_utils import PreTrainedModel from trl import AutoModelForCausalLMWithValueHead @@ -24,6 +27,18 @@ if TYPE_CHECKING: logger = get_logger(__name__) +class DummyOptimizer(torch.optim.Optimizer): + def __init__(self, *args, **kwargs): + dummy_tensor = torch.randn(1, 1) + super().__init__([dummy_tensor], {"lr": 1e-3}) + + def zero_grad(self, set_to_none: bool = True) -> None: + pass + + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + pass + + def create_modelcard_and_push( trainer: "Trainer", model_args: "ModelArguments", @@ -127,7 +142,10 @@ def create_reward_model( def create_custom_optimzer( - model: "PreTrainedModel", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments" + model: "PreTrainedModel", + dataset: Union["Dataset", "IterableDataset"], + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", ) -> Optional["torch.optim.Optimizer"]: if not finetuning_args.use_galore: return None @@ -144,40 +162,80 @@ def create_custom_optimzer( trainable_params = filter(lambda p: p.requires_grad, model.parameters()) non_galore_params = [p for p in trainable_params if id(p) not in id_galore_params] - # define param groups as galore_params and non_galore_params - param_groups = [ - {"params": non_galore_params}, - { - "params": galore_params, - "rank": finetuning_args.galore_rank, - "update_proj_gap": finetuning_args.galore_update_interval, - "scale": finetuning_args.galore_scale, - "proj_type": finetuning_args.galore_proj_type, - }, - ] if training_args.optim == "adamw_torch": - optimizer = GaLoreAdamW( - param_groups, - lr=training_args.learning_rate, - eps=training_args.adam_epsilon, - betas=(training_args.adam_beta1, training_args.adam_beta2), - ) + optim_class = GaLoreAdamW + optim_kwargs = { + "lr": training_args.learning_rate, + "eps": training_args.adam_epsilon, + "betas": (training_args.adam_beta1, training_args.adam_beta2), + } + elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]: - optimizer = GaLoreAdamW8bit( - param_groups, - lr=training_args.learning_rate, - eps=training_args.adam_epsilon, - betas=(training_args.adam_beta1, training_args.adam_beta2), - optim_bits=8, - is_paged="paged" in training_args.optim, - ) + optim_class = GaLoreAdamW8bit + optim_kwargs = { + "lr": training_args.learning_rate, + "eps": training_args.adam_epsilon, + "betas": (training_args.adam_beta1, training_args.adam_beta2), + "optim_bits": 8, + "is_paged": "paged" in training_args.optim, + } + elif training_args.optim == "adafactor": - optimizer = GaLoreAdafactor( - param_groups, - lr=training_args.learning_rate, - ) + optim_class = GaLoreAdafactor + optim_kwargs = { + "lr": training_args.learning_rate, + } + else: raise NotImplementedError("Unknow optim: {}".format(training_args.optim)) + galore_kwargs = { + "rank": finetuning_args.galore_rank, + "update_proj_gap": finetuning_args.galore_update_interval, + "scale": finetuning_args.galore_scale, + "proj_type": finetuning_args.galore_proj_type, + } + + if finetuning_args.galore_layerwise: + if training_args.gradient_accumulation_steps != 1: + raise ValueError("Per-layer GaLore does not support gradient accumulation.") + + if training_args.max_steps > 0: + num_training_steps = training_args.max_steps + else: + total_train_batch_size = training_args.per_device_train_batch_size * training_args.world_size + num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size) + + optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {} + for param in non_galore_params: + param_groups = [dict(params=[param])] + optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) + for param in galore_params: + param_groups = [dict(params=[param], **galore_kwargs)] + optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) + + scheduler_dict: Dict["torch.Tensor", "torch.optim.lr_scheduler.LRScheduler"] = {} + for param in non_galore_params + galore_params: + scheduler_dict[param] = get_scheduler( + training_args.lr_scheduler_type, + optimizer=optimizer_dict[param], + num_warmup_steps=training_args.get_warmup_steps(num_training_steps) * 2, + num_training_steps=num_training_steps * 2, + ) + + def optimizer_hook(param: "torch.Tensor"): + if param.grad is not None: + optimizer_dict[param].step() + optimizer_dict[param].zero_grad() + scheduler_dict[param].step() + + for param in non_galore_params + galore_params: + param.register_post_accumulate_grad_hook(optimizer_hook) + + optimizer = DummyOptimizer() + else: + param_groups = [dict(params=non_galore_params), dict(params=galore_params, **galore_kwargs)] + optimizer = optim_class(param_groups, **optim_kwargs) + logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.") return optimizer