mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 04:02:49 +08:00
Merge pull request #6547 from hiyouga/hiyouga/fix_pixtral_dpo
[trainer] fix pixtral dpo Former-commit-id: c973f32849b979a3ebb80caa01029b43fbb620ac
This commit is contained in:
commit
a0bcac80c0
@ -31,7 +31,7 @@ from typing_extensions import override
|
|||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
|
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
|
||||||
from ..callbacks import SaveProcessorCallback
|
from ..callbacks import SaveProcessorCallback
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -193,7 +193,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
Otherwise the average log probabilities.
|
Otherwise the average log probabilities.
|
||||||
"""
|
"""
|
||||||
if self.finetuning_args.use_ref_model:
|
if self.finetuning_args.use_ref_model:
|
||||||
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
|
batch = nested_detach(batch, clone=True) # avoid error
|
||||||
|
|
||||||
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
|
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||||
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
|
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
|
||||||
|
@ -30,7 +30,7 @@ from typing_extensions import override
|
|||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
|
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
|
||||||
from ..callbacks import SaveProcessorCallback
|
from ..callbacks import SaveProcessorCallback
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -142,7 +142,7 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
r"""
|
r"""
|
||||||
Runs forward pass and computes the log probabilities.
|
Runs forward pass and computes the log probabilities.
|
||||||
"""
|
"""
|
||||||
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
|
batch = nested_detach(batch, clone=True) # avoid error
|
||||||
model_inputs = {
|
model_inputs = {
|
||||||
"input_ids": batch[f"{prefix}input_ids"],
|
"input_ids": batch[f"{prefix}input_ids"],
|
||||||
"attention_mask": batch[f"{prefix}attention_mask"],
|
"attention_mask": batch[f"{prefix}attention_mask"],
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from collections.abc import Mapping
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -36,7 +37,7 @@ from ..model import find_all_linear_modules, load_model, load_tokenizer, load_va
|
|||||||
|
|
||||||
|
|
||||||
if is_galore_available():
|
if is_galore_available():
|
||||||
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
|
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -330,7 +331,7 @@ def _create_badam_optimizer(
|
|||||||
]
|
]
|
||||||
|
|
||||||
if finetuning_args.badam_mode == "layer":
|
if finetuning_args.badam_mode == "layer":
|
||||||
from badam import BlockOptimizer
|
from badam import BlockOptimizer # type: ignore
|
||||||
|
|
||||||
base_optimizer = optim_class(param_groups, **optim_kwargs)
|
base_optimizer = optim_class(param_groups, **optim_kwargs)
|
||||||
optimizer = BlockOptimizer(
|
optimizer = BlockOptimizer(
|
||||||
@ -350,7 +351,7 @@ def _create_badam_optimizer(
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif finetuning_args.badam_mode == "ratio":
|
elif finetuning_args.badam_mode == "ratio":
|
||||||
from badam import BlockOptimizerRatio
|
from badam import BlockOptimizerRatio # type: ignore
|
||||||
|
|
||||||
assert finetuning_args.badam_update_ratio > 1e-6
|
assert finetuning_args.badam_update_ratio > 1e-6
|
||||||
optimizer = BlockOptimizerRatio(
|
optimizer = BlockOptimizerRatio(
|
||||||
@ -374,7 +375,7 @@ def _create_adam_mini_optimizer(
|
|||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
) -> "torch.optim.Optimizer":
|
) -> "torch.optim.Optimizer":
|
||||||
from adam_mini import Adam_mini
|
from adam_mini import Adam_mini # type: ignore
|
||||||
|
|
||||||
hidden_size = getattr(model.config, "hidden_size", None)
|
hidden_size = getattr(model.config, "hidden_size", None)
|
||||||
num_q_head = getattr(model.config, "num_attention_heads", None)
|
num_q_head = getattr(model.config, "num_attention_heads", None)
|
||||||
@ -459,12 +460,33 @@ def get_batch_logps(
|
|||||||
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
|
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def nested_detach(
|
||||||
|
tensors: Union["torch.Tensor", List["torch.Tensor"], Tuple["torch.Tensor"], Dict[str, "torch.Tensor"]],
|
||||||
|
clone: bool = False,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Detach `tensors` (even if it's a nested list/tuple/dict of tensors).
|
||||||
|
"""
|
||||||
|
if isinstance(tensors, (list, tuple)):
|
||||||
|
return type(tensors)(nested_detach(t, clone=clone) for t in tensors)
|
||||||
|
elif isinstance(tensors, Mapping):
|
||||||
|
return type(tensors)({k: nested_detach(t, clone=clone) for k, t in tensors.items()})
|
||||||
|
|
||||||
|
if isinstance(tensors, torch.Tensor):
|
||||||
|
if clone:
|
||||||
|
return tensors.detach().clone()
|
||||||
|
else:
|
||||||
|
return tensors.detach()
|
||||||
|
else:
|
||||||
|
return tensors
|
||||||
|
|
||||||
|
|
||||||
def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback":
|
def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback":
|
||||||
r"""
|
r"""
|
||||||
Gets the callback for logging to SwanLab.
|
Gets the callback for logging to SwanLab.
|
||||||
"""
|
"""
|
||||||
import swanlab
|
import swanlab # type: ignore
|
||||||
from swanlab.integration.transformers import SwanLabCallback
|
from swanlab.integration.transformers import SwanLabCallback # type: ignore
|
||||||
|
|
||||||
if finetuning_args.swanlab_api_key is not None:
|
if finetuning_args.swanlab_api_key is not None:
|
||||||
swanlab.login(api_key=finetuning_args.swanlab_api_key)
|
swanlab.login(api_key=finetuning_args.swanlab_api_key)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user