Former-commit-id: ae045c884f8ac2aa0ea27592e0757b7bca2dba13
This commit is contained in:
hiyouga 2024-10-29 10:47:04 +00:00
parent 0d8aa6e6ef
commit 825ea1c72d
8 changed files with 58 additions and 6 deletions

View File

@ -10,6 +10,7 @@ DISABLE_VERSION_CHECK=
FORCE_CHECK_IMPORTS=
LLAMAFACTORY_VERBOSITY=
USE_MODELSCOPE_HUB=
USE_OPENMIND_HUB=
RECORD_VRAM=
# torchrun
FORCE_TORCHRUN=

1
.gitignore vendored
View File

@ -162,6 +162,7 @@ cython_debug/
# custom .gitignore
ms_cache/
hf_cache/
om_cache/
cache/
config/
saves/

View File

@ -18,4 +18,4 @@ style:
ruff format $(check_dirs)
test:
CUDA_VISIBLE_DEVICES= pytest tests/
CUDA_VISIBLE_DEVICES= WANDB_DISABLED=true pytest tests/

View File

@ -79,6 +79,11 @@ def is_transformers_version_greater_than_4_43():
return _get_package_version("transformers") >= version.parse("4.43.0")
@lru_cache
def is_transformers_version_equal_to_4_46():
return _get_package_version("transformers") == version.parse("4.46.0")
def is_uvicorn_available():
return _is_package_available("uvicorn")

View File

@ -29,6 +29,7 @@ from trl.trainer import disable_dropout_in_model
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
@ -118,6 +119,13 @@ class CustomDPOTrainer(DPOTrainer):
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
@override
def get_batch_samples(self, epoch_iterator, num_batches):
r"""
Replaces the method of KTO Trainer with the one of the standard Trainer.
"""
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
r"""
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
@ -258,3 +266,15 @@ class CustomDPOTrainer(DPOTrainer):
metrics[f"{prefix}odds_ratio_loss"] = ((losses - sft_loss) / self.beta).detach().mean().cpu()
return losses.mean(), metrics
@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
loss = super().compute_loss(model, inputs, return_outputs)
if kwargs.pop("num_items_in_batch", False) and is_transformers_version_equal_to_4_46():
loss /= self.args.gradient_accumulation_steps
return loss

View File

@ -28,6 +28,7 @@ from trl.trainer import disable_dropout_in_model
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
@ -120,6 +121,13 @@ class CustomKTOTrainer(KTOTrainer):
"""
return Trainer._get_train_sampler(self)
@override
def get_batch_samples(self, epoch_iterator, num_batches):
r"""
Replaces the method of KTO Trainer with the one of the standard Trainer.
"""
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
@override
def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
@ -231,3 +239,15 @@ class CustomKTOTrainer(KTOTrainer):
metrics["kl"] = kl.item()
return losses, metrics
@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
loss = super().compute_loss(model, inputs, return_outputs)
if kwargs.pop("num_items_in_batch", False) and is_transformers_version_equal_to_4_46():
loss /= self.args.gradient_accumulation_steps
return loss

View File

@ -25,6 +25,7 @@ from transformers import Trainer
from typing_extensions import override
from ...extras.logging import get_logger
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
@ -79,7 +80,7 @@ class PairwiseTrainer(Trainer):
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
@ -98,6 +99,10 @@ class PairwiseTrainer(Trainer):
chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
if kwargs.pop("num_items_in_batch", False) and is_transformers_version_equal_to_4_46():
loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0
if return_outputs:
return loss, (loss, chosen_scores, rejected_scores)
else:

View File

@ -44,7 +44,6 @@ INFER_ARGS = {
"finetuning_type": "lora",
"template": "llama3",
"infer_dtype": "float16",
"export_dir": "llama3_export",
}
OS_NAME = os.environ.get("OS_NAME", "")
@ -61,11 +60,12 @@ OS_NAME = os.environ.get("OS_NAME", "")
],
)
def test_run_exp(stage: str, dataset: str):
output_dir = f"train_{stage}"
output_dir = os.path.join("output", f"train_{stage}")
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
assert os.path.exists(output_dir)
def test_export():
export_model(INFER_ARGS)
assert os.path.exists("llama3_export")
export_dir = os.path.join("output", "llama3_export")
export_model({"export_dir": export_dir, **INFER_ARGS})
assert os.path.exists(export_dir)