diff --git a/examples/extras/apollo/llama3_full_sft.yaml b/examples/extras/apollo/llama3_full_sft.yaml new file mode 100644 index 00000000..c90a0147 --- /dev/null +++ b/examples/extras/apollo/llama3_full_sft.yaml @@ -0,0 +1,45 @@ +### model +model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct +trust_remote_code: true + +### method +stage: sft +do_train: true +finetuning_type: full +use_apollo: true +apollo_layerwise: true +apollo_target: mlp,self_attn +apollo_rank: 128 +apollo_scale: 32.0 +apollo_scale_type: channel + +### dataset +dataset: identity,alpaca_en_demo +template: llama3 +cutoff_len: 2048 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 + +### output +output_dir: saves/llama3-8b/apollo_full-scale32/sft +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true + +### train +per_device_train_batch_size: 1 +gradient_accumulation_steps: 1 +learning_rate: 1.0e-5 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +pure_bf16: true +ddp_timeout: 180000000 + +### eval +val_size: 0.1 +per_device_eval_batch_size: 1 +eval_strategy: steps +eval_steps: 500 diff --git a/setup.py b/setup.py index dd0cf74a..6f0d09e1 100644 --- a/setup.py +++ b/setup.py @@ -56,6 +56,7 @@ extra_require = { "aqlm": ["aqlm[gpu]>=1.1.0"], "vllm": ["vllm>=0.4.3,<=0.6.5"], "galore": ["galore-torch"], + "apollo": ["apollo-torch"], "badam": ["badam>=1.2.1"], "adam-mini": ["adam-mini"], "qwen": ["transformers_stream_generator"], diff --git a/src/llamafactory/extras/packages.py b/src/llamafactory/extras/packages.py index 6b2bc3f3..3dda9560 100644 --- a/src/llamafactory/extras/packages.py +++ b/src/llamafactory/extras/packages.py @@ -49,6 +49,8 @@ def is_fastapi_available(): def is_galore_available(): return _is_package_available("galore_torch") +def is_apollo_available(): + return _is_package_available("apollo_torch") def is_gradio_available(): return _is_package_available("gradio") diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 5770e395..9996ef51 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -251,6 +251,59 @@ class GaloreArguments: ) +@dataclass +class ApolloArguments: + r""" + Arguments pertaining to the APOLLO algorithm. + """ + + use_apollo: bool = field( + default=False, + metadata={"help": "Whether or not to use the APOLLO optimizer."}, + ) + apollo_target: str = field( + default="all", + metadata={ + "help": ( + "Name(s) of modules to apply APOLLO. Use commas to separate multiple modules. " + "Use `all` to specify all the linear modules." + ) + }, + ) + apollo_rank: int = field( + default=16, + metadata={"help": "The rank of APOLLO gradients."}, + ) + apollo_update_interval: int = field( + default=200, + metadata={"help": "Number of steps to update the APOLLO projection."}, + ) + apollo_scale: float = field( + default=1.0, + metadata={"help": "APOLLO scaling coefficient."}, + ) + apollo_proj: Literal["svd", "random"] = field( + default="random", + metadata={"help": "Type of APOLLO low-rank projection algorithm (svd or random)."}, + ) + apollo_proj_type: Literal["std", "right", "left",] = field( + default="std", + metadata={"help": "Type of APOLLO projection."}, + ) + apollo_scale_type: Literal["channel", "tensor"] = field( + default="channel", + metadata={"help": "Type of APOLLO scaling (channel or tensor)."}, + ) + apollo_layerwise: bool = field( + default=False, + metadata={"help": "Whether or not to enable layer-wise update to further save memory."}, + ) + apollo_scale_front: bool = field( + default=False, + metadata={"help": "Whether or not to use the norm-growth limiter in front of gradient scaling."}, + ) + + @dataclass class BAdamArgument: r""" @@ -334,7 +387,7 @@ class SwanLabArguments: @dataclass class FinetuningArguments( - FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument, SwanLabArguments + FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, ApolloArguments, BAdamArgument, SwanLabArguments ): r""" Arguments pertaining to which techniques we are going to fine-tuning with. @@ -401,6 +454,7 @@ class FinetuningArguments( self.lora_target: List[str] = split_arg(self.lora_target) self.additional_target: Optional[List[str]] = split_arg(self.additional_target) self.galore_target: List[str] = split_arg(self.galore_target) + self.apollo_target: List[str] = split_arg(self.apollo_target) self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only self.freeze_multi_modal_projector = self.freeze_multi_modal_projector and not self.train_mm_proj_only self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"] @@ -421,12 +475,18 @@ class FinetuningArguments( if self.use_llama_pro and self.finetuning_type == "full": raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.") - if self.finetuning_type == "lora" and (self.use_galore or self.use_badam): + if self.finetuning_type == "lora" and (self.use_galore or self.use_badam or self.use_apollo): raise ValueError("Cannot use LoRA with GaLore or BAdam together.") if self.use_galore and self.use_badam: raise ValueError("Cannot use GaLore with BAdam together.") + if self.use_galore and self.use_apollo: + raise ValueError("Cannot use GaLore with APOLLO together.") + + if self.use_badam and self.use_apollo: + raise ValueError("Cannot use BAdam with APOLLO together.") + if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model): raise ValueError("Cannot use PiSSA for current training stage.") diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 36de8b50..a7c2f3ec 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -139,6 +139,9 @@ def _check_extra_dependencies( if finetuning_args.use_galore: check_version("galore_torch", mandatory=True) + if finetuning_args.use_apollo: + check_version("apollo_torch", mandatory=True) + if finetuning_args.use_badam: check_version("badam>=1.2.1", mandatory=True) @@ -262,6 +265,13 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _ ): raise ValueError("Distributed training does not support layer-wise GaLore.") + if ( + finetuning_args.use_apollo + and finetuning_args.apollo_layerwise + and training_args.parallel_mode == ParallelMode.DISTRIBUTED + ): + raise ValueError("Distributed training does not support layer-wise APOLLO.") + if finetuning_args.use_badam and training_args.parallel_mode == ParallelMode.DISTRIBUTED: if finetuning_args.badam_mode == "ratio": raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.") @@ -271,6 +281,9 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _ if finetuning_args.use_galore and training_args.deepspeed is not None: raise ValueError("GaLore is incompatible with DeepSpeed yet.") + if finetuning_args.use_apollo and training_args.deepspeed is not None: + raise ValueError("APOLLO is incompatible with DeepSpeed yet.") + if model_args.infer_backend == "vllm": raise ValueError("vLLM backend is only available for API, CLI and Web.") @@ -306,6 +319,11 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _ "Using GaLore with mixed precision training may significantly increases GPU memory usage." ) + if training_args.do_train and finetuning_args.use_apollo and not finetuning_args.pure_bf16: + logger.warning_rank0( + "Using APOLLO with mixed precision training may significantly increases GPU memory usage." + ) + if (not training_args.do_train) and model_args.quantization_bit is not None: logger.warning_rank0("Evaluating model in 4/8-bit mode may cause lower scores.") diff --git a/src/llamafactory/model/model_utils/misc.py b/src/llamafactory/model/model_utils/misc.py index 5c5178d4..6d626f33 100644 --- a/src/llamafactory/model/model_utils/misc.py +++ b/src/llamafactory/model/model_utils/misc.py @@ -27,7 +27,7 @@ logger = logging.get_logger(__name__) def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]: r""" - Finds all available modules to apply lora or galore. + Finds all available modules to apply lora or galore or apollo. """ model_type = getattr(model.config, "model_type", None) forbidden_modules = {"lm_head"} diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 6aca53cf..832b084e 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -32,7 +32,7 @@ from typing_extensions import override from ..extras import logging from ..extras.constants import IGNORE_INDEX -from ..extras.packages import is_galore_available, is_ray_available +from ..extras.packages import is_galore_available, is_ray_available, is_apollo_available from ..hparams import FinetuningArguments, ModelArguments from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params @@ -40,6 +40,8 @@ from ..model import find_all_linear_modules, load_model, load_tokenizer, load_va if is_galore_available(): from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore +if is_apollo_available(): + from apollo_torch import APOLLOAdamW # type: ignore if is_ray_available(): from ray.train import RunConfig, ScalingConfig @@ -58,7 +60,7 @@ logger = logging.get_logger(__name__) class DummyOptimizer(torch.optim.Optimizer): r""" - A dummy optimizer used for the GaLore algorithm. + A dummy optimizer used for the GaLore or APOLLO algorithm. """ def __init__( @@ -275,6 +277,90 @@ def _create_galore_optimizer( logger.info_rank0("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.") return optimizer +def _create_apollo_optimizer( + model: "PreTrainedModel", + training_args: "TrainingArguments", + finetuning_args: "FinetuningArguments", +) -> "torch.optim.Optimizer": + if len(finetuning_args.apollo_target) == 1 and finetuning_args.apollo_target[0] == "all": + apollo_targets = find_all_linear_modules(model, finetuning_args.freeze_vision_tower) + else: + apollo_targets = finetuning_args.apollo_target + + apollo_params: List["torch.nn.Parameter"] = [] + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear) and any(target in name for target in apollo_targets): + for param in module.parameters(): + if param.requires_grad and len(param.shape) > 1: + apollo_params.append(param) + + apollo_kwargs = { + "rank": finetuning_args.apollo_rank, + "proj": finetuning_args.apollo_proj, + "proj_type": finetuning_args.apollo_proj_type, + "update_proj_gap": finetuning_args.apollo_update_interval, + "scale": finetuning_args.apollo_scale, + "scale_type": finetuning_args.apollo_scale_type, + "scale_front": finetuning_args.apollo_scale_front, + } + + print(apollo_kwargs) + + id_apollo_params = {id(param) for param in apollo_params} + decay_params, nodecay_params = [], [] # they are non-galore parameters + trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params + decay_param_names = _get_decay_parameter_names(model) + for name, param in model.named_parameters(): + if param.requires_grad: + trainable_params.append(param) + if id(param) not in id_apollo_params: + if name in decay_param_names: + decay_params.append(param) + else: + nodecay_params.append(param) + + _, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) + + if training_args.optim == "adamw_torch": + optim_class = APOLLOAdamW + else: + raise NotImplementedError(f"Unknow optim: {training_args.optim}") + + if finetuning_args.apollo_layerwise: + if training_args.gradient_accumulation_steps != 1: + raise ValueError("Per-layer APOLLO does not support gradient accumulation.") + + optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {} + for param in nodecay_params: + param_groups = [dict(params=[param], weight_decay=0.0)] + optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) + for param in decay_params: + param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)] + optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) + for param in apollo_params: # galore params have weight decay + param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **apollo_kwargs)] + optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) + + def optimizer_hook(param: "torch.nn.Parameter"): + if param.grad is not None: + optimizer_dict[param].step() + optimizer_dict[param].zero_grad() + + for param in trainable_params: + param.register_post_accumulate_grad_hook(optimizer_hook) + + optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict) + else: + param_groups = [ + dict(params=nodecay_params, weight_decay=0.0), + dict(params=decay_params, weight_decay=training_args.weight_decay), + dict(params=apollo_params, weight_decay=training_args.weight_decay, **apollo_kwargs), + ] + optimizer = optim_class(param_groups, **optim_kwargs) + + logger.info_rank0("Using APOLLO optimizer.") + return optimizer + def _create_loraplus_optimizer( model: "PreTrainedModel", @@ -410,6 +496,9 @@ def create_custom_optimizer( if finetuning_args.use_galore: return _create_galore_optimizer(model, training_args, finetuning_args) + if finetuning_args.use_apollo: + return _create_apollo_optimizer(model, training_args, finetuning_args) + if finetuning_args.loraplus_lr_ratio is not None: return _create_loraplus_optimizer(model, training_args, finetuning_args) diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index 05fa810b..62bd66c5 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -250,6 +250,25 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: ) ) + with gr.Accordion(open=False) as apollo_tab: + with gr.Row(): + use_apollo = gr.Checkbox() + apollo_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1) + apollo_update_interval = gr.Slider(minimum=1, maximum=1024, value=200, step=1) + apollo_scale = gr.Slider(minimum=0, maximum=1, value=0.25, step=0.01) + apollo_target = gr.Textbox(value="all") + input_elems.update({use_apollo, apollo_rank, apollo_update_interval, apollo_scale, apollo_target}) + elem_dict.update( + dict( + apollo_tab=apollo_tab, + use_apollo=use_apollo, + apollo_rank=apollo_rank, + apollo_update_interval=apollo_update_interval, + apollo_scale=apollo_scale, + apollo_target=apollo_target, + ) + ) + with gr.Accordion(open=False) as badam_tab: with gr.Row(): use_badam = gr.Checkbox() diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index 9b78e6e9..1dd1810b 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -1249,6 +1249,110 @@ LOCALES = { "info": "GaLore를 적용할 모듈의 이름. 모듈 간에는 쉼표(,)로 구분하십시오.", }, }, + "apollo_tab": { + "en": { + "label": "APOLLO configurations", + }, + "ru": { + "label": "Конфигурации APOLLO", + }, + "zh": { + "label": "APOLLO 参数设置", + }, + "ko": { + "label": "APOLLO 구성", + }, + }, + "use_apollo": { + "en": { + "label": "Use APOLLO", + "info": "Enable gradient low-Rank projection.", + }, + "ru": { + "label": "Использовать APOLLO", + "info": "Включить проекцию градиента на низкоранговое пространство.", + }, + "zh": { + "label": "使用 APOLLO", + "info": "使用梯度低秩投影。", + }, + "ko": { + "label": "APOLLO 사용", + "info": "그레디언트 로우 랭크 프로젝션을 활성화합니다.", + }, + }, + "apollo_rank": { + "en": { + "label": "APOLLO rank", + "info": "The rank of APOLLO gradients.", + }, + "ru": { + "label": "Ранг APOLLO", + "info": "Ранг градиентов APOLLO.", + }, + "zh": { + "label": "APOLLO 秩", + "info": "APOLLO 梯度的秩大小。", + }, + "ko": { + "label": "APOLLO 랭크", + "info": "APOLLO 그레디언트의 랭크.", + }, + }, + "apollo_update_interval": { + "en": { + "label": "Update interval", + "info": "Number of steps to update the APOLLO projection.", + }, + "ru": { + "label": "Интервал обновления", + "info": "Количество шагов для обновления проекции APOLLO.", + }, + "zh": { + "label": "更新间隔", + "info": "相邻两次投影更新的步数。", + }, + "ko": { + "label": "업데이트 간격", + "info": "APOLLO 프로젝션을 업데이트할 간격의 스텝 수.", + }, + }, + "apollo_scale": { + "en": { + "label": "APOLLO scale", + "info": "APOLLO scaling coefficient.", + }, + "ru": { + "label": "LoRA Alpha", + "info": "Коэффициент масштабирования APOLLO.", + }, + "zh": { + "label": "APOLLO 缩放系数", + "info": "APOLLO 缩放系数大小。", + }, + "ko": { + "label": "APOLLO 스케일", + "info": "APOLLO 스케일링 계수.", + }, + }, + "apollo_target": { + "en": { + "label": "APOLLO modules", + "info": "Name(s) of modules to apply APOLLO. Use commas to separate multiple modules.", + }, + "ru": { + "label": "Модули APOLLO", + "info": "Имена модулей для применения APOLLO. Используйте запятые для разделения нескольких модулей.", + }, + "zh": { + "label": "APOLLO 作用模块", + "info": "应用 APOLLO 的模块名称。使用英文逗号分隔多个名称。", + }, + "ko": { + "label": "APOLLO 모듈", + "info": "APOLLO를 적용할 모듈의 이름. 모듈 간에는 쉼표(,)로 구분하십시오.", + }, + }, "badam_tab": { "en": { "label": "BAdam configurations", diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index fe1e643c..bc010992 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -147,6 +147,7 @@ class Runner: shift_attn=get("train.shift_attn"), report_to="all" if get("train.report_to") else "none", use_galore=get("train.use_galore"), + use_apollo=get("train.use_apollo"), use_badam=get("train.use_badam"), use_swanlab=get("train.use_swanlab"), output_dir=get_save_dir(model_name, finetuning_type, get("train.output_dir")), @@ -223,6 +224,13 @@ class Runner: args["galore_update_interval"] = get("train.galore_update_interval") args["galore_scale"] = get("train.galore_scale") args["galore_target"] = get("train.galore_target") + + # apollo config + if args["use_apollo"]: + args["apollo_rank"] = get("train.apollo_rank") + args["apollo_update_interval"] = get("train.apollo_update_interval") + args["apollo_scale"] = get("train.apollo_scale") + args["apollo_target"] = get("train.apollo_target") # badam config if args["use_badam"]: