mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[optim] add support to APOLLO (#6617)
Former-commit-id: 5a252e5a458457adbd19da3b68a3897ad2962824
This commit is contained in:
		
							parent
							
								
									66184762e8
								
							
						
					
					
						commit
						c2120432db
					
				
							
								
								
									
										45
									
								
								examples/extras/apollo/llama3_full_sft.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								examples/extras/apollo/llama3_full_sft.yaml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,45 @@
 | 
			
		||||
### model
 | 
			
		||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
 | 
			
		||||
trust_remote_code: true
 | 
			
		||||
 | 
			
		||||
### method
 | 
			
		||||
stage: sft
 | 
			
		||||
do_train: true
 | 
			
		||||
finetuning_type: full
 | 
			
		||||
use_apollo: true
 | 
			
		||||
apollo_layerwise: true
 | 
			
		||||
apollo_target: mlp,self_attn
 | 
			
		||||
apollo_rank: 128
 | 
			
		||||
apollo_scale: 32.0
 | 
			
		||||
apollo_scale_type: channel
 | 
			
		||||
 | 
			
		||||
### dataset
 | 
			
		||||
dataset: identity,alpaca_en_demo
 | 
			
		||||
template: llama3
 | 
			
		||||
cutoff_len: 2048
 | 
			
		||||
max_samples: 1000
 | 
			
		||||
overwrite_cache: true
 | 
			
		||||
preprocessing_num_workers: 16
 | 
			
		||||
 | 
			
		||||
### output
 | 
			
		||||
output_dir: saves/llama3-8b/apollo_full-scale32/sft
 | 
			
		||||
logging_steps: 10
 | 
			
		||||
save_steps: 500
 | 
			
		||||
plot_loss: true
 | 
			
		||||
overwrite_output_dir: true
 | 
			
		||||
 | 
			
		||||
### train
 | 
			
		||||
per_device_train_batch_size: 1
 | 
			
		||||
gradient_accumulation_steps: 1
 | 
			
		||||
learning_rate: 1.0e-5
 | 
			
		||||
num_train_epochs: 3.0
 | 
			
		||||
lr_scheduler_type: cosine
 | 
			
		||||
warmup_ratio: 0.1
 | 
			
		||||
pure_bf16: true
 | 
			
		||||
ddp_timeout: 180000000
 | 
			
		||||
 | 
			
		||||
### eval
 | 
			
		||||
val_size: 0.1
 | 
			
		||||
per_device_eval_batch_size: 1
 | 
			
		||||
eval_strategy: steps
 | 
			
		||||
eval_steps: 500
 | 
			
		||||
							
								
								
									
										1
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								setup.py
									
									
									
									
									
								
							@ -56,6 +56,7 @@ extra_require = {
 | 
			
		||||
    "aqlm": ["aqlm[gpu]>=1.1.0"],
 | 
			
		||||
    "vllm": ["vllm>=0.4.3,<=0.6.5"],
 | 
			
		||||
    "galore": ["galore-torch"],
 | 
			
		||||
    "apollo": ["apollo-torch"],
 | 
			
		||||
    "badam": ["badam>=1.2.1"],
 | 
			
		||||
    "adam-mini": ["adam-mini"],
 | 
			
		||||
    "qwen": ["transformers_stream_generator"],
 | 
			
		||||
 | 
			
		||||
@ -49,6 +49,8 @@ def is_fastapi_available():
 | 
			
		||||
def is_galore_available():
 | 
			
		||||
    return _is_package_available("galore_torch")
 | 
			
		||||
 | 
			
		||||
def is_apollo_available():
 | 
			
		||||
    return _is_package_available("apollo_torch")
 | 
			
		||||
 | 
			
		||||
def is_gradio_available():
 | 
			
		||||
    return _is_package_available("gradio")
 | 
			
		||||
 | 
			
		||||
@ -251,6 +251,59 @@ class GaloreArguments:
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class ApolloArguments:
 | 
			
		||||
    r"""
 | 
			
		||||
    Arguments pertaining to the APOLLO algorithm.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    use_apollo: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not to use the APOLLO optimizer."},
 | 
			
		||||
    )
 | 
			
		||||
    apollo_target: str = field(
 | 
			
		||||
        default="all",
 | 
			
		||||
        metadata={
 | 
			
		||||
            "help": (
 | 
			
		||||
                "Name(s) of modules to apply APOLLO. Use commas to separate multiple modules. "
 | 
			
		||||
                "Use `all` to specify all the linear modules."
 | 
			
		||||
            )
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
    apollo_rank: int = field(
 | 
			
		||||
        default=16,
 | 
			
		||||
        metadata={"help": "The rank of APOLLO gradients."},
 | 
			
		||||
    )
 | 
			
		||||
    apollo_update_interval: int = field(
 | 
			
		||||
        default=200,
 | 
			
		||||
        metadata={"help": "Number of steps to update the APOLLO projection."},
 | 
			
		||||
    )
 | 
			
		||||
    apollo_scale: float = field(
 | 
			
		||||
        default=1.0,
 | 
			
		||||
        metadata={"help": "APOLLO scaling coefficient."},
 | 
			
		||||
    )
 | 
			
		||||
    apollo_proj: Literal["svd", "random"] = field(
 | 
			
		||||
        default="random",
 | 
			
		||||
        metadata={"help": "Type of APOLLO low-rank projection algorithm (svd or random)."},
 | 
			
		||||
    )
 | 
			
		||||
    apollo_proj_type: Literal["std", "right", "left",] = field(
 | 
			
		||||
        default="std",
 | 
			
		||||
        metadata={"help": "Type of APOLLO projection."},
 | 
			
		||||
    )
 | 
			
		||||
    apollo_scale_type: Literal["channel", "tensor"] = field(
 | 
			
		||||
        default="channel",
 | 
			
		||||
        metadata={"help": "Type of APOLLO scaling (channel or tensor)."},
 | 
			
		||||
    )
 | 
			
		||||
    apollo_layerwise: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not to enable layer-wise update to further save memory."},
 | 
			
		||||
    )
 | 
			
		||||
    apollo_scale_front: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not to use the norm-growth limiter in front of gradient scaling."},
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class BAdamArgument:
 | 
			
		||||
    r"""
 | 
			
		||||
@ -334,7 +387,7 @@ class SwanLabArguments:
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class FinetuningArguments(
 | 
			
		||||
    FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument, SwanLabArguments
 | 
			
		||||
    FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, ApolloArguments, BAdamArgument, SwanLabArguments
 | 
			
		||||
):
 | 
			
		||||
    r"""
 | 
			
		||||
    Arguments pertaining to which techniques we are going to fine-tuning with.
 | 
			
		||||
@ -401,6 +454,7 @@ class FinetuningArguments(
 | 
			
		||||
        self.lora_target: List[str] = split_arg(self.lora_target)
 | 
			
		||||
        self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
 | 
			
		||||
        self.galore_target: List[str] = split_arg(self.galore_target)
 | 
			
		||||
        self.apollo_target: List[str] = split_arg(self.apollo_target)
 | 
			
		||||
        self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
 | 
			
		||||
        self.freeze_multi_modal_projector = self.freeze_multi_modal_projector and not self.train_mm_proj_only
 | 
			
		||||
        self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
 | 
			
		||||
@ -421,12 +475,18 @@ class FinetuningArguments(
 | 
			
		||||
        if self.use_llama_pro and self.finetuning_type == "full":
 | 
			
		||||
            raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.")
 | 
			
		||||
 | 
			
		||||
        if self.finetuning_type == "lora" and (self.use_galore or self.use_badam):
 | 
			
		||||
        if self.finetuning_type == "lora" and (self.use_galore or self.use_badam or self.use_apollo):
 | 
			
		||||
            raise ValueError("Cannot use LoRA with GaLore or BAdam together.")
 | 
			
		||||
 | 
			
		||||
        if self.use_galore and self.use_badam:
 | 
			
		||||
            raise ValueError("Cannot use GaLore with BAdam together.")
 | 
			
		||||
 | 
			
		||||
        if self.use_galore and self.use_apollo:
 | 
			
		||||
            raise ValueError("Cannot use GaLore with APOLLO together.")
 | 
			
		||||
 | 
			
		||||
        if self.use_badam and self.use_apollo:
 | 
			
		||||
            raise ValueError("Cannot use BAdam with APOLLO together.")
 | 
			
		||||
 | 
			
		||||
        if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
 | 
			
		||||
            raise ValueError("Cannot use PiSSA for current training stage.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -139,6 +139,9 @@ def _check_extra_dependencies(
 | 
			
		||||
    if finetuning_args.use_galore:
 | 
			
		||||
        check_version("galore_torch", mandatory=True)
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.use_apollo:
 | 
			
		||||
        check_version("apollo_torch", mandatory=True)
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.use_badam:
 | 
			
		||||
        check_version("badam>=1.2.1", mandatory=True)
 | 
			
		||||
 | 
			
		||||
@ -262,6 +265,13 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
 | 
			
		||||
    ):
 | 
			
		||||
        raise ValueError("Distributed training does not support layer-wise GaLore.")
 | 
			
		||||
 | 
			
		||||
    if (
 | 
			
		||||
        finetuning_args.use_apollo
 | 
			
		||||
        and finetuning_args.apollo_layerwise
 | 
			
		||||
        and training_args.parallel_mode == ParallelMode.DISTRIBUTED
 | 
			
		||||
    ):
 | 
			
		||||
        raise ValueError("Distributed training does not support layer-wise APOLLO.")
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.use_badam and training_args.parallel_mode == ParallelMode.DISTRIBUTED:
 | 
			
		||||
        if finetuning_args.badam_mode == "ratio":
 | 
			
		||||
            raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.")
 | 
			
		||||
@ -271,6 +281,9 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
 | 
			
		||||
    if finetuning_args.use_galore and training_args.deepspeed is not None:
 | 
			
		||||
        raise ValueError("GaLore is incompatible with DeepSpeed yet.")
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.use_apollo and training_args.deepspeed is not None:
 | 
			
		||||
        raise ValueError("APOLLO is incompatible with DeepSpeed yet.")
 | 
			
		||||
 | 
			
		||||
    if model_args.infer_backend == "vllm":
 | 
			
		||||
        raise ValueError("vLLM backend is only available for API, CLI and Web.")
 | 
			
		||||
 | 
			
		||||
@ -306,6 +319,11 @@ def get_train_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _
 | 
			
		||||
            "Using GaLore with mixed precision training may significantly increases GPU memory usage."
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if training_args.do_train and finetuning_args.use_apollo and not finetuning_args.pure_bf16:
 | 
			
		||||
        logger.warning_rank0(
 | 
			
		||||
            "Using APOLLO with mixed precision training may significantly increases GPU memory usage."
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if (not training_args.do_train) and model_args.quantization_bit is not None:
 | 
			
		||||
        logger.warning_rank0("Evaluating model in 4/8-bit mode may cause lower scores.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -27,7 +27,7 @@ logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Finds all available modules to apply lora or galore.
 | 
			
		||||
    Finds all available modules to apply lora or galore or apollo.
 | 
			
		||||
    """
 | 
			
		||||
    model_type = getattr(model.config, "model_type", None)
 | 
			
		||||
    forbidden_modules = {"lm_head"}
 | 
			
		||||
 | 
			
		||||
@ -32,7 +32,7 @@ from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ..extras import logging
 | 
			
		||||
from ..extras.constants import IGNORE_INDEX
 | 
			
		||||
from ..extras.packages import is_galore_available, is_ray_available
 | 
			
		||||
from ..extras.packages import is_galore_available, is_ray_available, is_apollo_available
 | 
			
		||||
from ..hparams import FinetuningArguments, ModelArguments
 | 
			
		||||
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
 | 
			
		||||
 | 
			
		||||
@ -40,6 +40,8 @@ from ..model import find_all_linear_modules, load_model, load_tokenizer, load_va
 | 
			
		||||
if is_galore_available():
 | 
			
		||||
    from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit  # type: ignore
 | 
			
		||||
 | 
			
		||||
if is_apollo_available():
 | 
			
		||||
    from apollo_torch import APOLLOAdamW  # type: ignore
 | 
			
		||||
 | 
			
		||||
if is_ray_available():
 | 
			
		||||
    from ray.train import RunConfig, ScalingConfig
 | 
			
		||||
@ -58,7 +60,7 @@ logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
class DummyOptimizer(torch.optim.Optimizer):
 | 
			
		||||
    r"""
 | 
			
		||||
    A dummy optimizer used for the GaLore algorithm.
 | 
			
		||||
    A dummy optimizer used for the GaLore or APOLLO algorithm.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
@ -275,6 +277,90 @@ def _create_galore_optimizer(
 | 
			
		||||
    logger.info_rank0("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
 | 
			
		||||
    return optimizer
 | 
			
		||||
 | 
			
		||||
def _create_apollo_optimizer(
 | 
			
		||||
    model: "PreTrainedModel",
 | 
			
		||||
    training_args: "TrainingArguments",
 | 
			
		||||
    finetuning_args: "FinetuningArguments",
 | 
			
		||||
) -> "torch.optim.Optimizer":
 | 
			
		||||
    if len(finetuning_args.apollo_target) == 1 and finetuning_args.apollo_target[0] == "all":
 | 
			
		||||
        apollo_targets = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
 | 
			
		||||
    else:
 | 
			
		||||
        apollo_targets = finetuning_args.apollo_target
 | 
			
		||||
 | 
			
		||||
    apollo_params: List["torch.nn.Parameter"] = []
 | 
			
		||||
    for name, module in model.named_modules():
 | 
			
		||||
        if isinstance(module, torch.nn.Linear) and any(target in name for target in apollo_targets):
 | 
			
		||||
            for param in module.parameters():
 | 
			
		||||
                if param.requires_grad and len(param.shape) > 1:
 | 
			
		||||
                    apollo_params.append(param)
 | 
			
		||||
 | 
			
		||||
    apollo_kwargs = {
 | 
			
		||||
        "rank": finetuning_args.apollo_rank,
 | 
			
		||||
        "proj": finetuning_args.apollo_proj,
 | 
			
		||||
        "proj_type": finetuning_args.apollo_proj_type,
 | 
			
		||||
        "update_proj_gap": finetuning_args.apollo_update_interval,
 | 
			
		||||
        "scale": finetuning_args.apollo_scale,
 | 
			
		||||
        "scale_type": finetuning_args.apollo_scale_type,
 | 
			
		||||
        "scale_front": finetuning_args.apollo_scale_front,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    print(apollo_kwargs)
 | 
			
		||||
 | 
			
		||||
    id_apollo_params = {id(param) for param in apollo_params}
 | 
			
		||||
    decay_params, nodecay_params = [], []  # they are non-galore parameters
 | 
			
		||||
    trainable_params: List["torch.nn.Parameter"] = []  # galore_params + decay_params + nodecay_params
 | 
			
		||||
    decay_param_names = _get_decay_parameter_names(model)
 | 
			
		||||
    for name, param in model.named_parameters():
 | 
			
		||||
        if param.requires_grad:
 | 
			
		||||
            trainable_params.append(param)
 | 
			
		||||
            if id(param) not in id_apollo_params:
 | 
			
		||||
                if name in decay_param_names:
 | 
			
		||||
                    decay_params.append(param)
 | 
			
		||||
                else:
 | 
			
		||||
                    nodecay_params.append(param)
 | 
			
		||||
 | 
			
		||||
    _, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
 | 
			
		||||
 | 
			
		||||
    if training_args.optim == "adamw_torch":
 | 
			
		||||
        optim_class = APOLLOAdamW
 | 
			
		||||
    else:
 | 
			
		||||
        raise NotImplementedError(f"Unknow optim: {training_args.optim}")
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.apollo_layerwise:
 | 
			
		||||
        if training_args.gradient_accumulation_steps != 1:
 | 
			
		||||
            raise ValueError("Per-layer APOLLO does not support gradient accumulation.")
 | 
			
		||||
 | 
			
		||||
        optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {}
 | 
			
		||||
        for param in nodecay_params:
 | 
			
		||||
            param_groups = [dict(params=[param], weight_decay=0.0)]
 | 
			
		||||
            optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
 | 
			
		||||
        for param in decay_params:
 | 
			
		||||
            param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)]
 | 
			
		||||
            optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
 | 
			
		||||
        for param in apollo_params:  # galore params have weight decay
 | 
			
		||||
            param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **apollo_kwargs)]
 | 
			
		||||
            optimizer_dict[param] = optim_class(param_groups, **optim_kwargs)
 | 
			
		||||
 | 
			
		||||
        def optimizer_hook(param: "torch.nn.Parameter"):
 | 
			
		||||
            if param.grad is not None:
 | 
			
		||||
                optimizer_dict[param].step()
 | 
			
		||||
                optimizer_dict[param].zero_grad()
 | 
			
		||||
 | 
			
		||||
        for param in trainable_params:
 | 
			
		||||
            param.register_post_accumulate_grad_hook(optimizer_hook)
 | 
			
		||||
 | 
			
		||||
        optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict)
 | 
			
		||||
    else:
 | 
			
		||||
        param_groups = [
 | 
			
		||||
            dict(params=nodecay_params, weight_decay=0.0),
 | 
			
		||||
            dict(params=decay_params, weight_decay=training_args.weight_decay),
 | 
			
		||||
            dict(params=apollo_params, weight_decay=training_args.weight_decay, **apollo_kwargs),
 | 
			
		||||
        ]
 | 
			
		||||
        optimizer = optim_class(param_groups, **optim_kwargs)
 | 
			
		||||
 | 
			
		||||
    logger.info_rank0("Using APOLLO optimizer.")
 | 
			
		||||
    return optimizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _create_loraplus_optimizer(
 | 
			
		||||
    model: "PreTrainedModel",
 | 
			
		||||
@ -410,6 +496,9 @@ def create_custom_optimizer(
 | 
			
		||||
    if finetuning_args.use_galore:
 | 
			
		||||
        return _create_galore_optimizer(model, training_args, finetuning_args)
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.use_apollo:
 | 
			
		||||
        return _create_apollo_optimizer(model, training_args, finetuning_args)
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.loraplus_lr_ratio is not None:
 | 
			
		||||
        return _create_loraplus_optimizer(model, training_args, finetuning_args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -250,6 +250,25 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    with gr.Accordion(open=False) as apollo_tab:
 | 
			
		||||
        with gr.Row():
 | 
			
		||||
            use_apollo = gr.Checkbox()
 | 
			
		||||
            apollo_rank = gr.Slider(minimum=1, maximum=1024, value=16, step=1)
 | 
			
		||||
            apollo_update_interval = gr.Slider(minimum=1, maximum=1024, value=200, step=1)
 | 
			
		||||
            apollo_scale = gr.Slider(minimum=0, maximum=1, value=0.25, step=0.01)
 | 
			
		||||
            apollo_target = gr.Textbox(value="all")
 | 
			
		||||
    input_elems.update({use_apollo, apollo_rank, apollo_update_interval, apollo_scale, apollo_target})
 | 
			
		||||
    elem_dict.update(
 | 
			
		||||
        dict(
 | 
			
		||||
            apollo_tab=apollo_tab,
 | 
			
		||||
            use_apollo=use_apollo,
 | 
			
		||||
            apollo_rank=apollo_rank,
 | 
			
		||||
            apollo_update_interval=apollo_update_interval,
 | 
			
		||||
            apollo_scale=apollo_scale,
 | 
			
		||||
            apollo_target=apollo_target,
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    with gr.Accordion(open=False) as badam_tab:
 | 
			
		||||
        with gr.Row():
 | 
			
		||||
            use_badam = gr.Checkbox()
 | 
			
		||||
 | 
			
		||||
@ -1249,6 +1249,110 @@ LOCALES = {
 | 
			
		||||
            "info": "GaLore를 적용할 모듈의 이름. 모듈 간에는 쉼표(,)로 구분하십시오.",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "apollo_tab": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "APOLLO configurations",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Конфигурации APOLLO",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "APOLLO 参数设置",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "APOLLO 구성",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "use_apollo": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Use APOLLO",
 | 
			
		||||
            "info": "Enable gradient low-Rank projection.",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Использовать APOLLO",
 | 
			
		||||
            "info": "Включить проекцию градиента на низкоранговое пространство.",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "使用 APOLLO",
 | 
			
		||||
            "info": "使用梯度低秩投影。",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "APOLLO 사용",
 | 
			
		||||
            "info": "그레디언트 로우 랭크 프로젝션을 활성화합니다.",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "apollo_rank": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "APOLLO rank",
 | 
			
		||||
            "info": "The rank of APOLLO gradients.",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Ранг APOLLO",
 | 
			
		||||
            "info": "Ранг градиентов APOLLO.",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "APOLLO 秩",
 | 
			
		||||
            "info": "APOLLO 梯度的秩大小。",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "APOLLO 랭크",
 | 
			
		||||
            "info": "APOLLO 그레디언트의 랭크.",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "apollo_update_interval": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Update interval",
 | 
			
		||||
            "info": "Number of steps to update the APOLLO projection.",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Интервал обновления",
 | 
			
		||||
            "info": "Количество шагов для обновления проекции APOLLO.",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "更新间隔",
 | 
			
		||||
            "info": "相邻两次投影更新的步数。",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "업데이트 간격",
 | 
			
		||||
            "info": "APOLLO 프로젝션을 업데이트할 간격의 스텝 수.",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "apollo_scale": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "APOLLO scale",
 | 
			
		||||
            "info": "APOLLO scaling coefficient.",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "LoRA Alpha",
 | 
			
		||||
            "info": "Коэффициент масштабирования APOLLO.",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "APOLLO 缩放系数",
 | 
			
		||||
            "info": "APOLLO 缩放系数大小。",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "APOLLO 스케일",
 | 
			
		||||
            "info": "APOLLO 스케일링 계수.",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "apollo_target": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "APOLLO modules",
 | 
			
		||||
            "info": "Name(s) of modules to apply APOLLO. Use commas to separate multiple modules.",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Модули APOLLO",
 | 
			
		||||
            "info": "Имена модулей для применения APOLLO. Используйте запятые для разделения нескольких модулей.",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "APOLLO 作用模块",
 | 
			
		||||
            "info": "应用 APOLLO 的模块名称。使用英文逗号分隔多个名称。",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "APOLLO 모듈",
 | 
			
		||||
            "info": "APOLLO를 적용할 모듈의 이름. 모듈 간에는 쉼표(,)로 구분하십시오.",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "badam_tab": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "BAdam configurations",
 | 
			
		||||
 | 
			
		||||
@ -147,6 +147,7 @@ class Runner:
 | 
			
		||||
            shift_attn=get("train.shift_attn"),
 | 
			
		||||
            report_to="all" if get("train.report_to") else "none",
 | 
			
		||||
            use_galore=get("train.use_galore"),
 | 
			
		||||
            use_apollo=get("train.use_apollo"),
 | 
			
		||||
            use_badam=get("train.use_badam"),
 | 
			
		||||
            use_swanlab=get("train.use_swanlab"),
 | 
			
		||||
            output_dir=get_save_dir(model_name, finetuning_type, get("train.output_dir")),
 | 
			
		||||
@ -223,6 +224,13 @@ class Runner:
 | 
			
		||||
            args["galore_update_interval"] = get("train.galore_update_interval")
 | 
			
		||||
            args["galore_scale"] = get("train.galore_scale")
 | 
			
		||||
            args["galore_target"] = get("train.galore_target")
 | 
			
		||||
        
 | 
			
		||||
        # apollo config
 | 
			
		||||
        if args["use_apollo"]:
 | 
			
		||||
            args["apollo_rank"] = get("train.apollo_rank")
 | 
			
		||||
            args["apollo_update_interval"] = get("train.apollo_update_interval")
 | 
			
		||||
            args["apollo_scale"] = get("train.apollo_scale")
 | 
			
		||||
            args["apollo_target"] = get("train.apollo_target")
 | 
			
		||||
 | 
			
		||||
        # badam config
 | 
			
		||||
        if args["use_badam"]:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user