mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 23:58:11 +08:00
parent
248d5daaff
commit
e2748fa967
@ -10,6 +10,7 @@ DISABLE_VERSION_CHECK=
|
|||||||
FORCE_CHECK_IMPORTS=
|
FORCE_CHECK_IMPORTS=
|
||||||
LLAMAFACTORY_VERBOSITY=
|
LLAMAFACTORY_VERBOSITY=
|
||||||
USE_MODELSCOPE_HUB=
|
USE_MODELSCOPE_HUB=
|
||||||
|
USE_OPENMIND_HUB=
|
||||||
RECORD_VRAM=
|
RECORD_VRAM=
|
||||||
# torchrun
|
# torchrun
|
||||||
FORCE_TORCHRUN=
|
FORCE_TORCHRUN=
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -162,6 +162,7 @@ cython_debug/
|
|||||||
# custom .gitignore
|
# custom .gitignore
|
||||||
ms_cache/
|
ms_cache/
|
||||||
hf_cache/
|
hf_cache/
|
||||||
|
om_cache/
|
||||||
cache/
|
cache/
|
||||||
config/
|
config/
|
||||||
saves/
|
saves/
|
||||||
|
2
Makefile
2
Makefile
@ -18,4 +18,4 @@ style:
|
|||||||
ruff format $(check_dirs)
|
ruff format $(check_dirs)
|
||||||
|
|
||||||
test:
|
test:
|
||||||
CUDA_VISIBLE_DEVICES= pytest tests/
|
CUDA_VISIBLE_DEVICES= WANDB_DISABLED=true pytest tests/
|
||||||
|
@ -79,6 +79,11 @@ def is_transformers_version_greater_than_4_43():
|
|||||||
return _get_package_version("transformers") >= version.parse("4.43.0")
|
return _get_package_version("transformers") >= version.parse("4.43.0")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def is_transformers_version_equal_to_4_46():
|
||||||
|
return _get_package_version("transformers") == version.parse("4.46.0")
|
||||||
|
|
||||||
|
|
||||||
def is_uvicorn_available():
|
def is_uvicorn_available():
|
||||||
return _is_package_available("uvicorn")
|
return _is_package_available("uvicorn")
|
||||||
|
|
||||||
|
@ -29,6 +29,7 @@ from trl.trainer import disable_dropout_in_model
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
|
from ...extras.packages import is_transformers_version_equal_to_4_46
|
||||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
|
||||||
|
|
||||||
@ -118,6 +119,13 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||||
return super().create_scheduler(num_training_steps, optimizer)
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_batch_samples(self, epoch_iterator, num_batches):
|
||||||
|
r"""
|
||||||
|
Replaces the method of KTO Trainer with the one of the standard Trainer.
|
||||||
|
"""
|
||||||
|
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
|
||||||
|
|
||||||
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
|
def odds_ratio_loss(self, chosen_logps: "torch.Tensor", rejected_logps: "torch.Tensor") -> "torch.Tensor":
|
||||||
r"""
|
r"""
|
||||||
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
|
Computes ORPO's odds ratio (OR) loss for batched log probabilities of the policy model.
|
||||||
@ -258,3 +266,15 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
metrics[f"{prefix}odds_ratio_loss"] = ((losses - sft_loss) / self.beta).detach().mean().cpu()
|
metrics[f"{prefix}odds_ratio_loss"] = ((losses - sft_loss) / self.beta).detach().mean().cpu()
|
||||||
|
|
||||||
return losses.mean(), metrics
|
return losses.mean(), metrics
|
||||||
|
|
||||||
|
@override
|
||||||
|
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||||
|
r"""
|
||||||
|
Fixes the loss value for transformers 4.46.0.
|
||||||
|
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
|
||||||
|
"""
|
||||||
|
loss = super().compute_loss(model, inputs, return_outputs)
|
||||||
|
if kwargs.pop("num_items_in_batch", False) and is_transformers_version_equal_to_4_46():
|
||||||
|
loss /= self.args.gradient_accumulation_steps
|
||||||
|
|
||||||
|
return loss
|
||||||
|
@ -28,6 +28,7 @@ from trl.trainer import disable_dropout_in_model
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
|
from ...extras.packages import is_transformers_version_equal_to_4_46
|
||||||
from ..callbacks import SaveProcessorCallback
|
from ..callbacks import SaveProcessorCallback
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
|
||||||
|
|
||||||
@ -120,6 +121,13 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
"""
|
"""
|
||||||
return Trainer._get_train_sampler(self)
|
return Trainer._get_train_sampler(self)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_batch_samples(self, epoch_iterator, num_batches):
|
||||||
|
r"""
|
||||||
|
Replaces the method of KTO Trainer with the one of the standard Trainer.
|
||||||
|
"""
|
||||||
|
return Trainer.get_batch_samples(self, epoch_iterator, num_batches)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def forward(
|
def forward(
|
||||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
|
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
|
||||||
@ -231,3 +239,15 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
metrics["kl"] = kl.item()
|
metrics["kl"] = kl.item()
|
||||||
|
|
||||||
return losses, metrics
|
return losses, metrics
|
||||||
|
|
||||||
|
@override
|
||||||
|
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||||
|
r"""
|
||||||
|
Fixes the loss value for transformers 4.46.0.
|
||||||
|
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
|
||||||
|
"""
|
||||||
|
loss = super().compute_loss(model, inputs, return_outputs)
|
||||||
|
if kwargs.pop("num_items_in_batch", False) and is_transformers_version_equal_to_4_46():
|
||||||
|
loss /= self.args.gradient_accumulation_steps
|
||||||
|
|
||||||
|
return loss
|
||||||
|
@ -25,6 +25,7 @@ from transformers import Trainer
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
|
from ...extras.packages import is_transformers_version_equal_to_4_46
|
||||||
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
|
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||||
|
|
||||||
@ -79,7 +80,7 @@ class PairwiseTrainer(Trainer):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False
|
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||||
r"""
|
r"""
|
||||||
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
||||||
@ -98,6 +99,10 @@ class PairwiseTrainer(Trainer):
|
|||||||
chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
|
chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
|
||||||
|
|
||||||
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
|
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
|
||||||
|
|
||||||
|
if kwargs.pop("num_items_in_batch", False) and is_transformers_version_equal_to_4_46():
|
||||||
|
loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0
|
||||||
|
|
||||||
if return_outputs:
|
if return_outputs:
|
||||||
return loss, (loss, chosen_scores, rejected_scores)
|
return loss, (loss, chosen_scores, rejected_scores)
|
||||||
else:
|
else:
|
||||||
|
@ -44,7 +44,6 @@ INFER_ARGS = {
|
|||||||
"finetuning_type": "lora",
|
"finetuning_type": "lora",
|
||||||
"template": "llama3",
|
"template": "llama3",
|
||||||
"infer_dtype": "float16",
|
"infer_dtype": "float16",
|
||||||
"export_dir": "llama3_export",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
OS_NAME = os.environ.get("OS_NAME", "")
|
OS_NAME = os.environ.get("OS_NAME", "")
|
||||||
@ -61,11 +60,12 @@ OS_NAME = os.environ.get("OS_NAME", "")
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_run_exp(stage: str, dataset: str):
|
def test_run_exp(stage: str, dataset: str):
|
||||||
output_dir = f"train_{stage}"
|
output_dir = os.path.join("output", f"train_{stage}")
|
||||||
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
|
run_exp({"stage": stage, "dataset": dataset, "output_dir": output_dir, **TRAIN_ARGS})
|
||||||
assert os.path.exists(output_dir)
|
assert os.path.exists(output_dir)
|
||||||
|
|
||||||
|
|
||||||
def test_export():
|
def test_export():
|
||||||
export_model(INFER_ARGS)
|
export_dir = os.path.join("output", "llama3_export")
|
||||||
assert os.path.exists("llama3_export")
|
export_model({"export_dir": export_dir, **INFER_ARGS})
|
||||||
|
assert os.path.exists(export_dir)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user