mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
update hparams
This commit is contained in:
@@ -83,9 +83,7 @@ class DataArguments:
|
||||
)
|
||||
ignore_pad_token_for_loss: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation."
|
||||
},
|
||||
metadata={"help": "Whether or not to ignore the tokens corresponding to the pad tokens in loss computation."},
|
||||
)
|
||||
val_size: float = field(
|
||||
default=0.0,
|
||||
@@ -93,9 +91,11 @@ class DataArguments:
|
||||
)
|
||||
packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training."
|
||||
},
|
||||
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
|
||||
)
|
||||
neat_packing: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable sequence packing without cross-attention."},
|
||||
)
|
||||
tool_format: Optional[str] = field(
|
||||
default=None,
|
||||
@@ -112,3 +112,6 @@ class DataArguments:
|
||||
|
||||
if self.streaming and self.max_samples is not None:
|
||||
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
||||
|
||||
if self.neat_packing and not self.packing:
|
||||
raise ValueError("`neat_packing` requires `packing` is True.")
|
||||
|
||||
@@ -109,12 +109,6 @@ class ModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
|
||||
)
|
||||
efficient_packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether or not to pack the sequences without cross-contamination attention for efficient training."
|
||||
},
|
||||
)
|
||||
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
|
||||
@@ -232,6 +226,7 @@ class ModelArguments:
|
||||
self.compute_dtype: Optional["torch.dtype"] = None
|
||||
self.device_map: Optional[Union[str, Dict[str, Any]]] = None
|
||||
self.model_max_length: Optional[int] = None
|
||||
self.block_diag_attn: bool = False
|
||||
|
||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||
@@ -259,4 +254,5 @@ class ModelArguments:
|
||||
new_arg.compute_dtype = old_arg.compute_dtype
|
||||
new_arg.device_map = old_arg.device_map
|
||||
new_arg.model_max_length = old_arg.model_max_length
|
||||
new_arg.block_diag_attn = old_arg.block_diag_attn
|
||||
return new_arg
|
||||
|
||||
@@ -158,6 +158,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
||||
|
||||
if finetuning_args.stage != "sft" and data_args.neat_packing:
|
||||
raise ValueError("`neat_packing` cannot be set as True except SFT.")
|
||||
|
||||
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||
|
||||
@@ -170,9 +173,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
||||
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
||||
|
||||
if finetuning_args.stage != "sft" and model_args.efficient_packing:
|
||||
raise ValueError("`efficient_packing` cannot be set as True except SFT.")
|
||||
|
||||
if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
||||
raise ValueError("Unsloth does not support lora reward model.")
|
||||
|
||||
@@ -314,6 +314,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
|
||||
model_args.device_map = {"": get_current_device()}
|
||||
model_args.model_max_length = data_args.cutoff_len
|
||||
model_args.block_diag_attn = data_args.neat_packing
|
||||
data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
|
||||
|
||||
# Log on each process the small summary
|
||||
|
||||
Reference in New Issue
Block a user