mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 20:00:36 +08:00
[example] add bash usage (#7794)
This commit is contained in:
@@ -65,14 +65,16 @@ class BaseModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
|
||||
)
|
||||
new_special_tokens: Optional[str] = field(
|
||||
add_tokens: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
|
||||
},
|
||||
)
|
||||
add_special_tokens: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
|
||||
)
|
||||
new_normal_tokens: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Normal tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
@@ -180,11 +182,11 @@ class BaseModelArguments:
|
||||
if self.adapter_name_or_path is not None: # support merging multiple lora weights
|
||||
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
|
||||
|
||||
if self.new_normal_tokens is not None: # support multiple normal tokens
|
||||
self.new_normal_tokens = [token.strip() for token in self.new_normal_tokens.split(",")]
|
||||
if self.add_tokens is not None: # support multiple tokens
|
||||
self.add_tokens = [token.strip() for token in self.add_tokens.split(",")]
|
||||
|
||||
if self.new_special_tokens is not None: # support multiple special tokens
|
||||
self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")]
|
||||
if self.add_special_tokens is not None: # support multiple special tokens
|
||||
self.add_special_tokens = [token.strip() for token in self.add_special_tokens.split(",")]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -153,7 +153,7 @@ def _check_extra_dependencies(
|
||||
elif model_args.infer_backend == EngineName.SGLANG:
|
||||
check_version("sglang>=0.4.4")
|
||||
check_version("sglang", mandatory=True)
|
||||
|
||||
|
||||
if finetuning_args.use_galore:
|
||||
check_version("galore_torch", mandatory=True)
|
||||
|
||||
|
||||
@@ -124,6 +124,7 @@ def configure_quantization(
|
||||
|
||||
try:
|
||||
from optimum.gptq import utils as gq_utils
|
||||
|
||||
if "language_model.model.layers" not in gq_utils.BLOCK_PATTERNS:
|
||||
gq_utils.BLOCK_PATTERNS.insert(0, "language_model.model.layers")
|
||||
except ImportError:
|
||||
|
||||
@@ -54,26 +54,22 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument
|
||||
if model_args.model_max_length is not None and tokenizer.model_max_length < model_args.model_max_length:
|
||||
tokenizer.model_max_length = model_args.model_max_length # enlarge the tokenizer max length
|
||||
|
||||
if model_args.new_special_tokens is not None:
|
||||
num_added_special_tokens = tokenizer.add_special_tokens(
|
||||
dict(additional_special_tokens=model_args.new_special_tokens),
|
||||
replace_additional_special_tokens=False,
|
||||
if model_args.add_tokens is not None:
|
||||
num_added_tokens = tokenizer.add_tokens(new_tokens=model_args.add_tokens, special_tokens=False)
|
||||
logger.info_rank0("Add tokens {} to tokenizer's vocabulary.".format(",".join(model_args.add_tokens)))
|
||||
if num_added_tokens > 0 and not model_args.resize_vocab:
|
||||
model_args.resize_vocab = True
|
||||
logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.")
|
||||
|
||||
if model_args.add_special_tokens is not None:
|
||||
num_added_special_tokens = tokenizer.add_tokens(new_tokens=model_args.add_special_tokens, special_tokens=True)
|
||||
logger.info_rank0(
|
||||
"Add special tokens {} to tokenizer's vocabulary.".format(",".join(model_args.add_special_tokens))
|
||||
)
|
||||
logger.info_rank0("Add special tokens {} to vocab.".format(",".join(model_args.new_special_tokens)))
|
||||
if num_added_special_tokens > 0 and not model_args.resize_vocab:
|
||||
model_args.resize_vocab = True
|
||||
logger.warning_rank0("New special tokens have been added, changed `resize_vocab` to True.")
|
||||
|
||||
if model_args.new_normal_tokens is not None:
|
||||
num_added_normal_tokens = tokenizer.add_tokens(
|
||||
new_tokens=model_args.new_normal_tokens,
|
||||
special_tokens=False,
|
||||
)
|
||||
logger.info_rank0("Add normal tokens {} to vocab.".format(",".join(model_args.new_normal_tokens)))
|
||||
if num_added_normal_tokens > 0 and not model_args.resize_vocab:
|
||||
model_args.resize_vocab = True
|
||||
logger.warning_rank0("New normal tokens have been added, changed `resize_vocab` to True.")
|
||||
|
||||
|
||||
def patch_processor(
|
||||
processor: "ProcessorMixin",
|
||||
|
||||
0
src/llamafactory/third_party/__init__.py
vendored
Normal file
0
src/llamafactory/third_party/__init__.py
vendored
Normal file
30
src/llamafactory/third_party/muon/muon.py
vendored
30
src/llamafactory/third_party/muon/muon.py
vendored
@@ -2,6 +2,8 @@
|
||||
#
|
||||
# This code is based on the MoonshotAI's Moonlight library.
|
||||
# https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
|
||||
# and the Keller Jordan's Muon library.
|
||||
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -18,6 +20,7 @@
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2025 Moonshot AI
|
||||
# Copyright (c) 2024 Keller Jordan
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
@@ -36,22 +39,20 @@
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# This code snippet is a modified version adapted from the following GitHub repository:
|
||||
# https://github.com/KellerJordan/Muon/blob/master/muon.py
|
||||
@torch.compile
|
||||
def zeropower_via_newtonschulz5(G, steps):
|
||||
def zeropower_via_newtonschulz5(G: "torch.Tensor", steps: int) -> "torch.Tensor":
|
||||
"""Newton-Schulz iteration to compute the zeroth power / orthogonalization of G.
|
||||
|
||||
We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero.
|
||||
For the purpose of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
For the purpose of minimizing steps, it turns out to be empirically effective to keep increasing
|
||||
the slope at zero even beyond the point where the iteration no longer converges all the way to
|
||||
one everywhere on the interval. This iteration therefore does not produce UV^T but rather something
|
||||
like US'V^T where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||
"""
|
||||
assert len(G.shape) == 2
|
||||
@@ -133,7 +134,7 @@ class Muon(torch.optim.Optimizer):
|
||||
# Do not use Muon for parameters in adamw_params
|
||||
self.state[p]["use_muon"] = False
|
||||
|
||||
def adjust_lr_for_muon(self, lr, param_shape):
|
||||
def adjust_lr_for_muon(self, lr: float, param_shape: list[int]) -> float:
|
||||
A, B = param_shape[:2]
|
||||
# We adjust the learning rate and weight decay based on the size of the parameter matrix
|
||||
# as describted in the paper
|
||||
@@ -154,12 +155,8 @@ class Muon(torch.optim.Optimizer):
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
############################
|
||||
# Muon #
|
||||
############################
|
||||
|
||||
# Muon loop
|
||||
params = [p for p in group["params"] if self.state[p]["use_muon"]]
|
||||
# import pdb; pdb.set_trace()
|
||||
lr = group["lr"]
|
||||
wd = group["wd"]
|
||||
momentum = group["momentum"]
|
||||
@@ -195,10 +192,7 @@ class Muon(torch.optim.Optimizer):
|
||||
# apply update
|
||||
p.data.add_(u, alpha=-adjusted_lr)
|
||||
|
||||
############################
|
||||
# AdamW backup #
|
||||
############################
|
||||
|
||||
# Adam backup
|
||||
params = [p for p in group["params"] if not self.state[p]["use_muon"]]
|
||||
lr = group["lr"]
|
||||
beta1, beta2 = group["adamw_betas"]
|
||||
|
||||
@@ -489,16 +489,14 @@ def _create_adam_mini_optimizer(
|
||||
logger.info_rank0("Using Adam-mini optimizer.")
|
||||
return optimizer
|
||||
|
||||
|
||||
def _create_muon_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "TrainingArguments",
|
||||
) -> "torch.optim.Optimizer":
|
||||
from llamafactory.third_party.muon import Muon # type: ignore
|
||||
|
||||
# Separate parameters for Muon (2D parameters) and AdamW (others)
|
||||
muon_params = []
|
||||
adamw_params = []
|
||||
|
||||
from ..third_party.muon import Muon
|
||||
|
||||
muon_params, adamw_params = [], []
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
# Use Muon for 2D parameters that aren't embeddings or heads
|
||||
@@ -506,34 +504,26 @@ def _create_muon_optimizer(
|
||||
muon_params.append(param)
|
||||
else:
|
||||
adamw_params.append(param)
|
||||
|
||||
# Get optimizer settings from training_args
|
||||
ns_steps = getattr(training_args, "ns_steps", 5)
|
||||
|
||||
# Create Muon optimizer
|
||||
|
||||
optimizer = Muon(
|
||||
lr=training_args.learning_rate,
|
||||
wd=training_args.weight_decay,
|
||||
muon_params=muon_params,
|
||||
momentum=0.95, # default momentum for Muon
|
||||
nesterov=True, # default nesterov for Muon
|
||||
ns_steps=ns_steps,
|
||||
adamw_params=adamw_params,
|
||||
adamw_betas=(training_args.adam_beta1, training_args.adam_beta2),
|
||||
adamw_eps=training_args.adam_epsilon,
|
||||
)
|
||||
|
||||
logger.info_rank0(f"Using Muon optimizer with {len(muon_params)} Muon params and {len(adamw_params)} AdamW params.")
|
||||
logger.info_rank0(
|
||||
f"Using Muon optimizer with {len(muon_params)} Muon params and {len(adamw_params)} AdamW params."
|
||||
)
|
||||
return optimizer
|
||||
|
||||
|
||||
def create_custom_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "TrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
) -> Optional["torch.optim.Optimizer"]:
|
||||
if finetuning_args.use_muon:
|
||||
return _create_muon_optimizer(model, training_args)
|
||||
|
||||
if finetuning_args.use_galore:
|
||||
return _create_galore_optimizer(model, training_args, finetuning_args)
|
||||
|
||||
@@ -549,6 +539,9 @@ def create_custom_optimizer(
|
||||
if finetuning_args.use_adam_mini:
|
||||
return _create_adam_mini_optimizer(model, training_args)
|
||||
|
||||
if finetuning_args.use_muon:
|
||||
return _create_muon_optimizer(model, training_args)
|
||||
|
||||
|
||||
def create_custom_scheduler(
|
||||
training_args: "TrainingArguments",
|
||||
|
||||
Reference in New Issue
Block a user