mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 03:02:51 +08:00
[optim] add support to APOLLO (#6617)
Former-commit-id: d9189f9f0b23ff6929044919208e0e813ca95b1c
This commit is contained in:
parent
57043fb4e6
commit
763f9b9df0
45
examples/extras/apollo/llama3_full_sft.yaml
Normal file
45
examples/extras/apollo/llama3_full_sft.yaml
Normal file
@ -0,0 +1,45 @@
|
||||
### model
|
||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||
trust_remote_code: true
|
||||
|
||||
### method
|
||||
stage: sft
|
||||
do_train: true
|
||||
finetuning_type: full
|
||||
use_apollo: true
|
||||
apollo_layerwise: true
|
||||
apollo_target: mlp,self_attn
|
||||
apollo_rank: 128
|
||||
apollo_scale: 32.0
|
||||
apollo_scale_type: channel
|
||||
|
||||
### dataset
|
||||
dataset: identity,alpaca_en_demo
|
||||
template: llama3
|
||||
cutoff_len: 2048
|
||||
max_samples: 1000
|
||||
overwrite_cache: true
|
||||
preprocessing_num_workers: 16
|
||||
|
||||
### output
|
||||
output_dir: saves/llama3-8b/apollo_full-scale32/sft
|
||||
logging_steps: 10
|
||||
save_steps: 500
|
||||
plot_loss: true
|
||||
overwrite_output_dir: true
|
||||
|
||||
### train
|
||||
per_device_train_batch_size: 1
|
||||
gradient_accumulation_steps: 1
|
||||
learning_rate: 1.0e-5
|
||||
num_train_epochs: 3.0
|
||||
lr_scheduler_type: cosine
|
||||
warmup_ratio: 0.1
|
||||
pure_bf16: true
|
||||
ddp_timeout: 180000000
|
||||
|
||||
### eval
|
||||
val_size: 0.1
|
||||
per_device_eval_batch_size: 1
|
||||
eval_strategy: steps
|
||||
eval_steps: 500
|
1
setup.py
1
setup.py
@ -56,6 +56,7 @@ extra_require = {
|
||||
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
||||
"vllm": ["vllm>=0.4.3,<=0.6.5"],
|
||||
"galore": ["galore-torch"],
|
||||
"apollo": ["apollo-torch"],
|
||||
"badam": ["badam>=1.2.1"],
|
||||
"adam-mini": ["adam-mini"],
|
||||
"qwen": ["transformers_stream_generator"],
|
||||
|
@ -49,6 +49,8 @@ def is_fastapi_available():
|
||||
def is_galore_available():
|
||||
return _is_package_available("galore_torch")
|
||||
|
||||
def is_apollo_available():
|
||||
return _is_package_available("apollo_torch")
|
||||
|
||||
def is_gradio_available():
|
||||
return _is_package_available("gradio")
|
||||
|
@ -251,6 +251,59 @@ class GaloreArguments:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApolloArguments:
|
||||
r"""
|
||||
Arguments pertaining to the APOLLO algorithm.
|
||||
"""
|
||||
|
||||
use_apollo: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the APOLLO optimizer."},
|
||||
)
|
||||
apollo_target: str = field(
|
||||
default="all",
|
||||
metadata={
|
||||
"help": (
|
||||
"Name(s) of modules to apply APOLLO. Use commas to separate multiple modules. "
|
||||
"Use `all` to specify all the linear modules."
|
||||
)
|
||||
},
|
||||
)
|
||||
apollo_rank: int = field(
|
||||
default=16,
|
||||
metadata={"help": "The rank of APOLLO gradients."},
|
||||
)
|
||||
apollo_update_interval: int = field(
|
||||
default=200,
|
||||
metadata={"help": "Number of steps to update the APOLLO projection."},
|
||||
)
|
||||
apollo_scale: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "APOLLO scaling coefficient."},
|
||||
)
|
||||
apollo_proj: Literal["svd", "random"] = field(
|
||||
default="random",
|
||||
metadata={"help": "Type of APOLLO low-rank projection algorithm (svd or random)."},
|
||||
)
|
||||
apollo_proj_type: Literal["std", "right", "left",] = field(
|
||||
default="std",
|
||||
metadata={"help": "Type of APOLLO projection."},
|
||||
)
|
||||
apollo_scale_type: Literal["channel", "tensor"] = field(
|
||||
default="channel",
|
||||
metadata={"help": "Type of APOLLO scaling (channel or tensor)."},
|
||||
)
|
||||
apollo_layerwise: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to enable layer-wise update to further save memory."},
|
||||
)
|
||||
apollo_scale_front: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the norm-growth limiter in front of gradient scaling."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BAdamArgument:
|
||||
r"""
|
||||
@ -334,7 +387,7 @@ class SwanLabArguments:
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments(
|
||||
FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument, SwanLabArguments
|
||||
FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, ApolloArguments, BAdamArgument, SwanLabArguments
|
||||
):
|
||||
r"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
@ -401,6 +454,7 @@ class FinetuningArguments(
|
||||
self.lora_target: List[str] = split_arg(self.lora_target)
|
||||
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
|
||||
self.galore_target: List[str] = split_arg(self.galore_target)
|
||||
self.apollo_target: List[str] = split_arg(self.apollo_target)
|
||||
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
|
||||
self.freeze_multi_modal_projector = self.freeze_multi_modal_projector and not self.train_mm_proj_only
|
||||
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
|
||||
@ -421,12 +475,18 @@ class FinetuningArguments(
|
||||
if self.use_llama_pro and self.finetuning_type == "full":
|
||||
raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.")
|
||||
|
||||
if self.finetuning_type == "lora" and (self.use_galore or self.use_badam):
|
||||
if self.finetuning_type == "lora" and (self.use_galore or self.use_badam or self.use_apollo):
|
||||
raise ValueError("Cannot use LoRA with GaLore or BAdam together.")
|
||||
|
||||
if self.use_galore and self.use_badam:
|
||||
raise ValueError("Cannot use GaLore with BAdam together.")
|
||||
|
||||
if self.use_galore and self.use_apollo:
|
||||
raise ValueError("Cannot use GaLore with APOLLO together.")
|
||||
|
||||
if self.use_badam and self.use_apollo:
|
||||
raise ValueError("Cannot use BAdam with APOLLO together.")
|
||||
|
||||
if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
|
||||
raise ValueError("Cannot use PiSSA for current training stage.")
|
||||
|
||||
|
@ -139,6 +139,9 @@ def _check_extra_dependencies(
|
||||
if finetuning_args.use_galore:
|
||||
check_version("galore_torch", mandatory=True)
|
||||
|
||||
if finetuning_args.use_apollo:
|
||||
check_version("apollo_torch", mandatory=True)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
check_version("badam>=1.2.1", mandatory=True)
|
||||
|
||||
@ -262,6 +265,13 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
):
|
||||
raise ValueError("Distributed training does not support layer-wise GaLore.")
|
||||
|
||||
if (
|
||||
finetuning_args.use_apollo
|
||||
and finetuning_args.apollo_layerwise
|
||||
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||
):
|
||||
raise ValueError("Distributed training does not support layer-wise APOLLO.")
|
||||
|
||||
if finetuning_args.use_badam and training_args.parallel_mode == ParallelMode.DISTRIBUTED:
|
||||
if finetuning_args.badam_mode == "ratio":
|
||||
raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
|
||||
@ -271,6 +281,9 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
if finetuning_args.use_galore and training_args.deepspeed is not None:
|
||||
raise ValueError("GaLore is incompatible with DeepSpeed yet.")
|
||||
|
||||
if finetuning_args.use_apollo and training_args.deepspeed is not None:
|
||||
raise ValueError("APOLLO is incompatible with DeepSpeed yet.")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
raise ValueError("vLLM backend is only available for API, CLI and Web.")
|
||||
|
||||
@ -306,6 +319,11 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
|
||||
"Using GaLore with mixed precision training may significantly increases GPU memory usage."
|
||||
)
|
||||
|
||||
if training_args.do_train and finetuning_args.use_apollo and not finetuning_args.pure_bf16:
|
||||
logger.warning_rank0(
|
||||
"Using APOLLO with mixed precision training may significantly increases GPU memory usage."
|
||||
)
|
||||
|
||||
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
||||
logger.warning_rank0("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
|
||||
|
@ -27,7 +27,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
|
||||
r"""
|
||||
Finds all available modules to apply lora or galore.
|
||||
Finds all available modules to apply lora or galore or apollo.
|
||||
"""
|
||||
model_type = getattr(model.config, "model_type", None)
|
||||
forbidden_modules = {"lm_head"}
|
||||
|
@ -32,7 +32,7 @@ from typing_extensions import override
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import IGNORE_INDEX
|
||||
from ..extras.packages import is_galore_available, is_ray_available
|
||||
from ..extras.packages import is_galore_available, is_ray_available, is_apollo_available
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
|
||||
|
||||
@ -40,6 +40,8 @@ from ..model import find_all_linear_modules, load_model, load_tokenizer, load_va
|
||||
if is_galore_available():
|
||||
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore
|
||||
|
||||
if is_apollo_available():
|
||||
from apollo_torch import APOLLOAdamW # type: ignore
|
||||
|
||||
if is_ray_available():
|
||||
from ray.train import RunConfig, ScalingConfig
|
||||
@ -58,7 +60,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
class DummyOptimizer(torch.optim.Optimizer):
|
||||
r"""
|
||||
A dummy optimizer used for the GaLore algorithm.
|
||||
A dummy optimizer used for the GaLore or APOLLO algorithm.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -275,6 +277,90 @@ def _create_galore_optimizer(
|
||||
logger.info_rank0("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
|
||||
return optimizer
|
||||
|
||||
def _create_apollo_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "TrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> "torch.optim.Optimizer":
|
||||
if len(finetuning_args.apollo_target) == 1 and finetuning_args.apollo_target[0] == "all":
|
||||
apollo_targets = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
|
||||
else:
|
||||
apollo_targets = finetuning_args.apollo_target
|
||||
|
||||
apollo_params: List["torch.nn.Parameter"] = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear) and any(target in name for target in apollo_targets):
|
||||
for param in module.parameters():
|
||||
if param.requires_grad and len(param.shape) > 1:
|
||||
apollo_params.append(param)
|
||||
|
||||
apollo_kwargs = {
|
||||
"rank": finetuning_args.apollo_rank,
|
||||
"proj": finetuning_args.apollo_proj,
|
||||
"proj_type": finetuning_args.apollo_proj_type,
|
||||
"update_proj_gap": finetuning_args.apollo_update_interval,
|
||||
"scale": finetuning_args.apollo_scale,
|
||||
"scale_type": finetuning_args.apollo_scale_type,
|
||||
"scale_front": finetuning_args.apollo_scale_front,
|
||||
}
|
||||
|
||||
print(apollo_kwargs)
|
||||
|
||||
id_apollo_params = {id(param) for param in apollo_params}
|
||||
decay_params, nodecay_params = [], [] # they are non-galore parameters
|
||||
trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params
|
||||
decay_param_names = _get_decay_parameter_names(model)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
trainable_params.append(param)
|
||||
if id(param) not in id_apollo_params:
|
||||
if name in decay_param_names:
|
||||
decay_params.append(param)
|
||||
else:
|
||||
nodecay_params.append(param)
|
||||
|
||||
_, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
|
||||
|
||||
if training_args.optim == "adamw_torch":
|
||||
optim_class = APOLLOAdamW
|
||||
else:
|
||||
raise NotImplementedError(f"Unknow optim: {training_args.optim}")
|
||||
|
||||
if finetuning_args.apollo_layerwise:
|
||||
if training_args.gradient_accumulation_steps != 1:
|
||||
raise ValueError("Per-layer APOLLO does not support gradient accumulation.")
|
||||
|
||||
optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
|
||||
for param in nodecay_params:
|
||||
param_groups = [dict(params=[param], weight_decay=0.0)]
|
||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||
for param in decay_params:
|
||||
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
|
||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||
for param in apollo_params: # galore params have weight decay
|
||||
param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **apollo_kwargs)]
|
||||
optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
|
||||
|
||||
def optimizer_hook(param: "torch.nn.Parameter"):
|
||||
if param.grad is not None:
|
||||
optimizer_dict[param].step()
|
||||
optimizer_dict[param].zero_grad()
|
||||
|
||||
for param in trainable_params:
|
||||
param.register_post_accumulate_grad_hook(optimizer_hook)
|
||||
|
||||
optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
|
||||
else:
|
||||
param_groups = [
|
||||
dict(params=nodecay_params, weight_decay=0.0),
|
||||
dict(params=decay_params, weight_decay=training_args.weight_decay),
|
||||
dict(params=apollo_params, weight_decay=training_args.weight_decay, **apollo_kwargs),
|
||||
]
|
||||
optimizer = optim_class(param_groups, **optim_kwargs)
|
||||
|
||||
logger.info_rank0("Using APOLLO optimizer.")
|
||||
return optimizer
|
||||
|
||||
|
||||
def _create_loraplus_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
@ -410,6 +496,9 @@ def create_custom_optimizer(
|
||||
if finetuning_args.use_galore:
|
||||
return _create_galore_optimizer(model, training_args, finetuning_args)
|
||||
|
||||
if finetuning_args.use_apollo:
|
||||
return _create_apollo_optimizer(model, training_args, finetuning_args)
|
||||
|
||||
if finetuning_args.loraplus_lr_ratio is not None:
|
||||
return _create_loraplus_optimizer(model, training_args, finetuning_args)
|
||||
|
||||
|
@ -250,6 +250,25 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(open=False) as apollo_tab:
|
||||
with gr.Row():
|
||||
use_apollo = gr.Checkbox()
|
||||
apollo_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1)
|
||||
apollo_update_interval = gr.Slider(minimum=1, maximum=1024, value=200, step=1)
|
||||
apollo_scale = gr.Slider(minimum=0, maximum=1, value=0.25, step=0.01)
|
||||
apollo_target = gr.Textbox(value="all")
|
||||
input_elems.update({use_apollo, apollo_rank, apollo_update_interval, apollo_scale, apollo_target})
|
||||
elem_dict.update(
|
||||
dict(
|
||||
apollo_tab=apollo_tab,
|
||||
use_apollo=use_apollo,
|
||||
apollo_rank=apollo_rank,
|
||||
apollo_update_interval=apollo_update_interval,
|
||||
apollo_scale=apollo_scale,
|
||||
apollo_target=apollo_target,
|
||||
)
|
||||
)
|
||||
|
||||
with gr.Accordion(open=False) as badam_tab:
|
||||
with gr.Row():
|
||||
use_badam = gr.Checkbox()
|
||||
|
@ -1249,6 +1249,110 @@ LOCALES = {
|
||||
"info": "GaLore를 적용할 모듈의 이름. 모듈 간에는 쉼표(,)로 구분하십시오.",
|
||||
},
|
||||
},
|
||||
"apollo_tab": {
|
||||
"en": {
|
||||
"label": "APOLLO configurations",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Конфигурации APOLLO",
|
||||
},
|
||||
"zh": {
|
||||
"label": "APOLLO 参数设置",
|
||||
},
|
||||
"ko": {
|
||||
"label": "APOLLO 구성",
|
||||
},
|
||||
},
|
||||
"use_apollo": {
|
||||
"en": {
|
||||
"label": "Use APOLLO",
|
||||
"info": "Enable gradient low-Rank projection.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Использовать APOLLO",
|
||||
"info": "Включить проекцию градиента на низкоранговое пространство.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "使用 APOLLO",
|
||||
"info": "使用梯度低秩投影。",
|
||||
},
|
||||
"ko": {
|
||||
"label": "APOLLO 사용",
|
||||
"info": "그레디언트 로우 랭크 프로젝션을 활성화합니다.",
|
||||
},
|
||||
},
|
||||
"apollo_rank": {
|
||||
"en": {
|
||||
"label": "APOLLO rank",
|
||||
"info": "The rank of APOLLO gradients.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Ранг APOLLO",
|
||||
"info": "Ранг градиентов APOLLO.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "APOLLO 秩",
|
||||
"info": "APOLLO 梯度的秩大小。",
|
||||
},
|
||||
"ko": {
|
||||
"label": "APOLLO 랭크",
|
||||
"info": "APOLLO 그레디언트의 랭크.",
|
||||
},
|
||||
},
|
||||
"apollo_update_interval": {
|
||||
"en": {
|
||||
"label": "Update interval",
|
||||
"info": "Number of steps to update the APOLLO projection.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Интервал обновления",
|
||||
"info": "Количество шагов для обновления проекции APOLLO.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "更新间隔",
|
||||
"info": "相邻两次投影更新的步数。",
|
||||
},
|
||||
"ko": {
|
||||
"label": "업데이트 간격",
|
||||
"info": "APOLLO 프로젝션을 업데이트할 간격의 스텝 수.",
|
||||
},
|
||||
},
|
||||
"apollo_scale": {
|
||||
"en": {
|
||||
"label": "APOLLO scale",
|
||||
"info": "APOLLO scaling coefficient.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "LoRA Alpha",
|
||||
"info": "Коэффициент масштабирования APOLLO.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "APOLLO 缩放系数",
|
||||
"info": "APOLLO 缩放系数大小。",
|
||||
},
|
||||
"ko": {
|
||||
"label": "APOLLO 스케일",
|
||||
"info": "APOLLO 스케일링 계수.",
|
||||
},
|
||||
},
|
||||
"apollo_target": {
|
||||
"en": {
|
||||
"label": "APOLLO modules",
|
||||
"info": "Name(s) of modules to apply APOLLO. Use commas to separate multiple modules.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Модули APOLLO",
|
||||
"info": "Имена модулей для применения APOLLO. Используйте запятые для разделения нескольких модулей.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "APOLLO 作用模块",
|
||||
"info": "应用 APOLLO 的模块名称。使用英文逗号分隔多个名称。",
|
||||
},
|
||||
"ko": {
|
||||
"label": "APOLLO 모듈",
|
||||
"info": "APOLLO를 적용할 모듈의 이름. 모듈 간에는 쉼표(,)로 구분하십시오.",
|
||||
},
|
||||
},
|
||||
"badam_tab": {
|
||||
"en": {
|
||||
"label": "BAdam configurations",
|
||||
|
@ -147,6 +147,7 @@ class Runner:
|
||||
shift_attn=get("train.shift_attn"),
|
||||
report_to="all" if get("train.report_to") else "none",
|
||||
use_galore=get("train.use_galore"),
|
||||
use_apollo=get("train.use_apollo"),
|
||||
use_badam=get("train.use_badam"),
|
||||
use_swanlab=get("train.use_swanlab"),
|
||||
output_dir=get_save_dir(model_name, finetuning_type, get("train.output_dir")),
|
||||
@ -223,6 +224,13 @@ class Runner:
|
||||
args["galore_update_interval"] = get("train.galore_update_interval")
|
||||
args["galore_scale"] = get("train.galore_scale")
|
||||
args["galore_target"] = get("train.galore_target")
|
||||
|
||||
# apollo config
|
||||
if args["use_apollo"]:
|
||||
args["apollo_rank"] = get("train.apollo_rank")
|
||||
args["apollo_update_interval"] = get("train.apollo_update_interval")
|
||||
args["apollo_scale"] = get("train.apollo_scale")
|
||||
args["apollo_target"] = get("train.apollo_target")
|
||||
|
||||
# badam config
|
||||
if args["use_badam"]:
|
||||
|
Loading…
x
Reference in New Issue
Block a user