Former-commit-id: 6f5bb3b8e5b6eb7fdfd7b0ca8eba789ab741a7b6
This commit is contained in:
hiyouga 2024-12-30 05:55:15 +00:00
parent 951d845af2
commit 813f5919a3
7 changed files with 26 additions and 11 deletions

View File

@ -54,7 +54,7 @@ extra_require = {
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
"awq": ["autoawq"],
"aqlm": ["aqlm[gpu]>=1.1.0"],
"vllm": ["vllm>=0.4.3,<0.6.5"],
"vllm": ["vllm>=0.4.3,<0.6.7"],
"galore": ["galore-torch"],
"badam": ["badam>=1.2.1"],
"adam-mini": ["adam-mini"],

View File

@ -171,7 +171,7 @@ class HuggingfaceEngine(BaseEngine):
elif not isinstance(value, torch.Tensor):
value = torch.tensor(value)
if torch.is_floating_point(value):
if torch.is_floating_point(value): # cast data dtype for paligemma
value = value.to(model.dtype)
gen_kwargs[key] = value.to(model.device)

View File

@ -168,6 +168,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
for key, value in features.items(): # cast data dtype for paligemma
if torch.is_tensor(value) and torch.is_floating_point(value):
features[key] = value.to(self.compute_dtype)
return features

View File

@ -79,12 +79,13 @@ def check_dependencies() -> None:
"""
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
else:
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
require_version("datasets>=2.16.0,<=3.1.0", "To fix: pip install datasets>=2.16.0,<=3.1.0")
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
return
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
require_version("datasets>=2.16.0,<=3.1.0", "To fix: pip install datasets>=2.16.0,<=3.1.0")
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:

View File

@ -112,6 +112,10 @@ def _check_extra_dependencies(
finetuning_args: "FinetuningArguments",
training_args: Optional["Seq2SeqTrainingArguments"] = None,
) -> None:
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
return
if model_args.use_unsloth:
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
@ -122,7 +126,7 @@ def _check_extra_dependencies(
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
if model_args.infer_backend == "vllm":
require_version("vllm>=0.4.3,<0.6.5", "To fix: pip install vllm>=0.4.3,<0.6.5")
require_version("vllm>=0.4.3,<0.6.7", "To fix: pip install vllm>=0.4.3,<0.6.7")
if finetuning_args.use_galore:
require_version("galore_torch", "To fix: pip install galore_torch")

View File

@ -29,7 +29,7 @@ from trl.trainer import disable_dropout_in_model
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
@ -50,6 +50,9 @@ class CustomDPOTrainer(DPOTrainer):
disable_dropout: bool = True,
**kwargs,
):
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
if disable_dropout:
disable_dropout_in_model(model)
if ref_model is not None:

View File

@ -28,7 +28,7 @@ from trl.trainer import disable_dropout_in_model
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
@ -50,6 +50,9 @@ class CustomKTOTrainer(KTOTrainer):
disable_dropout: bool = True,
**kwargs,
):
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
if disable_dropout:
disable_dropout_in_model(model)
if ref_model is not None: