mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 22:02:51 +08:00
fix packing for eager/sdpa attn
Former-commit-id: 6fd6aa4530f81a2ed306eeb2a5167607288b62c6
This commit is contained in:
parent
a38ff842d0
commit
7b3c1f29ff
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
|
from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding, SFTDataCollatorWith4DAttentionMask
|
||||||
from .data_utils import Role, split_dataset
|
from .data_utils import Role, split_dataset
|
||||||
from .loader import get_dataset
|
from .loader import get_dataset
|
||||||
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
|
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
|
||||||
@ -21,6 +21,7 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"KTODataCollatorWithPadding",
|
"KTODataCollatorWithPadding",
|
||||||
"PairwiseDataCollatorWithPadding",
|
"PairwiseDataCollatorWithPadding",
|
||||||
|
"SFTDataCollatorWith4DAttentionMask",
|
||||||
"Role",
|
"Role",
|
||||||
"split_dataset",
|
"split_dataset",
|
||||||
"get_dataset",
|
"get_dataset",
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, Sequence
|
from typing import Any, Dict, Literal, Sequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import DataCollatorForSeq2Seq
|
from transformers import DataCollatorForSeq2Seq
|
||||||
@ -62,13 +62,31 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
|
|||||||
return attention_mask_4d
|
return attention_mask_4d
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
|
||||||
|
r"""
|
||||||
|
Data collator for 4d attention mask.
|
||||||
|
"""
|
||||||
|
|
||||||
|
block_diag_attn: bool = False
|
||||||
|
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
|
||||||
|
compute_dtype: "torch.dtype" = torch.float32
|
||||||
|
|
||||||
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||||
|
features = super().__call__(features)
|
||||||
|
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
||||||
|
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||||
r"""
|
r"""
|
||||||
Data collator for pairwise data.
|
Data collator for pairwise data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||||
r"""
|
r"""
|
||||||
Pads batched data to the longest sequence in the batch.
|
Pads batched data to the longest sequence in the batch.
|
||||||
|
|
||||||
@ -100,7 +118,7 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
|||||||
Data collator for KTO data.
|
Data collator for KTO data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||||
target_features = []
|
target_features = []
|
||||||
kl_features = []
|
kl_features = []
|
||||||
kto_tags = []
|
kto_tags = []
|
||||||
|
@ -79,6 +79,7 @@ TRAINING_STAGES = {
|
|||||||
STAGES_USE_PAIR_DATA = {"rm", "dpo"}
|
STAGES_USE_PAIR_DATA = {"rm", "dpo"}
|
||||||
|
|
||||||
SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
|
SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN = {
|
||||||
|
"cohere",
|
||||||
"falcon",
|
"falcon",
|
||||||
"gemma",
|
"gemma",
|
||||||
"gemma2",
|
"gemma2",
|
||||||
|
@ -112,6 +112,3 @@ class DataArguments:
|
|||||||
|
|
||||||
if self.streaming and self.max_samples is not None:
|
if self.streaming and self.max_samples is not None:
|
||||||
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
||||||
|
|
||||||
if self.neat_packing and not self.packing:
|
|
||||||
raise ValueError("`neat_packing` requires `packing` is True.")
|
|
||||||
|
@ -376,14 +376,21 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
|||||||
if self.use_galore and self.use_badam:
|
if self.use_galore and self.use_badam:
|
||||||
raise ValueError("Cannot use GaLore with BAdam together.")
|
raise ValueError("Cannot use GaLore with BAdam together.")
|
||||||
|
|
||||||
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
|
|
||||||
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
|
|
||||||
|
|
||||||
if self.pissa_init and self.finetuning_type != "lora":
|
|
||||||
raise ValueError("`pissa_init` is only valid for LoRA training.")
|
|
||||||
|
|
||||||
if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
|
if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
|
||||||
raise ValueError("Cannot use PiSSA for current training stage.")
|
raise ValueError("Cannot use PiSSA for current training stage.")
|
||||||
|
|
||||||
if self.train_mm_proj_only and self.finetuning_type != "full":
|
if self.train_mm_proj_only and self.finetuning_type != "full":
|
||||||
raise ValueError("`train_mm_proj_only` is only valid for full training.")
|
raise ValueError("`train_mm_proj_only` is only valid for full training.")
|
||||||
|
|
||||||
|
if self.finetuning_type != "lora":
|
||||||
|
if self.loraplus_lr_ratio is not None:
|
||||||
|
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
|
||||||
|
|
||||||
|
if self.use_rslora:
|
||||||
|
raise ValueError("`use_rslora` is only valid for LoRA training.")
|
||||||
|
|
||||||
|
if self.use_dora:
|
||||||
|
raise ValueError("`use_dora` is only valid for LoRA training.")
|
||||||
|
|
||||||
|
if self.pissa_init:
|
||||||
|
raise ValueError("`pissa_init` is only valid for LoRA training.")
|
||||||
|
@ -233,6 +233,10 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
|
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
|
||||||
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
||||||
|
|
||||||
|
if data_args.neat_packing and not data_args.packing:
|
||||||
|
logger.warning("`neat_packing` requires `packing` is True. Change it to True.")
|
||||||
|
data_args.packing = True
|
||||||
|
|
||||||
_verify_model_args(model_args, finetuning_args)
|
_verify_model_args(model_args, finetuning_args)
|
||||||
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||||
|
|
||||||
|
@ -115,7 +115,9 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
|
|||||||
|
|
||||||
|
|
||||||
def patch_for_block_diag_attn(model_type: str) -> None:
|
def patch_for_block_diag_attn(model_type: str) -> None:
|
||||||
if model_type == "falcon":
|
if model_type == "cohere":
|
||||||
|
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
|
||||||
|
elif model_type == "falcon":
|
||||||
transformers.models.falcon.modeling_falcon._get_unpad_data = get_unpad_data
|
transformers.models.falcon.modeling_falcon._get_unpad_data = get_unpad_data
|
||||||
elif model_type == "gemma":
|
elif model_type == "gemma":
|
||||||
transformers.models.gemma.modeling_gemma._get_unpad_data = get_unpad_data
|
transformers.models.gemma.modeling_gemma._get_unpad_data = get_unpad_data
|
||||||
|
@ -79,9 +79,8 @@ def fix_valuehead_checkpoint(
|
|||||||
if name.startswith("v_head."):
|
if name.startswith("v_head."):
|
||||||
v_head_state_dict[name] = param
|
v_head_state_dict[name] = param
|
||||||
else:
|
else:
|
||||||
decoder_state_dict[name.replace("pretrained_model.", "")] = param
|
decoder_state_dict[name.replace("pretrained_model.", "", count=1)] = param
|
||||||
|
|
||||||
os.remove(path_to_checkpoint)
|
|
||||||
model.pretrained_model.save_pretrained(
|
model.pretrained_model.save_pretrained(
|
||||||
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
|
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
|
||||||
)
|
)
|
||||||
@ -91,6 +90,7 @@ def fix_valuehead_checkpoint(
|
|||||||
else:
|
else:
|
||||||
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
||||||
|
|
||||||
|
os.remove(path_to_checkpoint)
|
||||||
logger.info("Value head model saved at: {}".format(output_dir))
|
logger.info("Value head model saved at: {}".format(output_dir))
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,9 +17,7 @@
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from transformers import DataCollatorForSeq2Seq
|
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, split_dataset
|
||||||
|
|
||||||
from ...data import get_dataset, split_dataset
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.misc import get_logits_processor
|
from ...extras.misc import get_logits_processor
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
@ -54,10 +52,13 @@ def run_sft(
|
|||||||
if getattr(model, "is_quantized", False) and not training_args.do_train:
|
if getattr(model, "is_quantized", False) and not training_args.do_train:
|
||||||
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
|
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
|
||||||
|
|
||||||
data_collator = DataCollatorForSeq2Seq(
|
data_collator = SFTDataCollatorWith4DAttentionMask(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
|
pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
|
||||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||||
|
block_diag_attn=model_args.block_diag_attn,
|
||||||
|
attn_implementation=getattr(model.config, "_attn_implementation", None),
|
||||||
|
compute_dtype=model_args.compute_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Override the decoding parameters of Seq2SeqTrainer
|
# Override the decoding parameters of Seq2SeqTrainer
|
||||||
|
Loading…
x
Reference in New Issue
Block a user