mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
parent
951d845af2
commit
813f5919a3
2
setup.py
2
setup.py
@ -54,7 +54,7 @@ extra_require = {
|
||||
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
|
||||
"awq": ["autoawq"],
|
||||
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
||||
"vllm": ["vllm>=0.4.3,<0.6.5"],
|
||||
"vllm": ["vllm>=0.4.3,<0.6.7"],
|
||||
"galore": ["galore-torch"],
|
||||
"badam": ["badam>=1.2.1"],
|
||||
"adam-mini": ["adam-mini"],
|
||||
|
@ -171,7 +171,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
elif not isinstance(value, torch.Tensor):
|
||||
value = torch.tensor(value)
|
||||
|
||||
if torch.is_floating_point(value):
|
||||
if torch.is_floating_point(value): # cast data dtype for paligemma
|
||||
value = value.to(model.dtype)
|
||||
|
||||
gen_kwargs[key] = value.to(model.device)
|
||||
|
@ -168,6 +168,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
|
||||
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
||||
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
||||
|
||||
for key, value in features.items(): # cast data dtype for paligemma
|
||||
if torch.is_tensor(value) and torch.is_floating_point(value):
|
||||
features[key] = value.to(self.compute_dtype)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
|
@ -79,12 +79,13 @@ def check_dependencies() -> None:
|
||||
"""
|
||||
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
||||
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
else:
|
||||
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
|
||||
require_version("datasets>=2.16.0,<=3.1.0", "To fix: pip install datasets>=2.16.0,<=3.1.0")
|
||||
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
|
||||
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
|
||||
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
|
||||
return
|
||||
|
||||
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
|
||||
require_version("datasets>=2.16.0,<=3.1.0", "To fix: pip install datasets>=2.16.0,<=3.1.0")
|
||||
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
|
||||
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
|
||||
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
|
||||
|
||||
|
||||
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
|
||||
|
@ -112,6 +112,10 @@ def _check_extra_dependencies(
|
||||
finetuning_args: "FinetuningArguments",
|
||||
training_args: Optional["Seq2SeqTrainingArguments"] = None,
|
||||
) -> None:
|
||||
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
||||
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
return
|
||||
|
||||
if model_args.use_unsloth:
|
||||
require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth")
|
||||
|
||||
@ -122,7 +126,7 @@ def _check_extra_dependencies(
|
||||
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
require_version("vllm>=0.4.3,<0.6.5", "To fix: pip install vllm>=0.4.3,<0.6.5")
|
||||
require_version("vllm>=0.4.3,<0.6.7", "To fix: pip install vllm>=0.4.3,<0.6.7")
|
||||
|
||||
if finetuning_args.use_galore:
|
||||
require_version("galore_torch", "To fix: pip install galore_torch")
|
||||
|
@ -29,7 +29,7 @@ from trl.trainer import disable_dropout_in_model
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.packages import is_transformers_version_equal_to_4_46
|
||||
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
|
||||
|
||||
@ -50,6 +50,9 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
disable_dropout: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
if is_transformers_version_greater_than("4.46"):
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
|
||||
if disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
if ref_model is not None:
|
||||
|
@ -28,7 +28,7 @@ from trl.trainer import disable_dropout_in_model
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.packages import is_transformers_version_equal_to_4_46
|
||||
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
|
||||
|
||||
@ -50,6 +50,9 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
disable_dropout: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
if is_transformers_version_greater_than("4.46"):
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
|
||||
if disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
if ref_model is not None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user