mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
Merge pull request #6547 from hiyouga/hiyouga/fix_pixtral_dpo
[trainer] fix pixtral dpo Former-commit-id: c973f32849b979a3ebb80caa01029b43fbb620ac
This commit is contained in:
commit
a0bcac80c0
@ -31,7 +31,7 @@ from typing_extensions import override
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -193,7 +193,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
Otherwise the average log probabilities.
|
||||
"""
|
||||
if self.finetuning_args.use_ref_model:
|
||||
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
|
||||
batch = nested_detach(batch, clone=True) # avoid error
|
||||
|
||||
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
|
||||
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])
|
||||
|
@ -30,7 +30,7 @@ from typing_extensions import override
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -142,7 +142,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
r"""
|
||||
Runs forward pass and computes the log probabilities.
|
||||
"""
|
||||
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
|
||||
batch = nested_detach(batch, clone=True) # avoid error
|
||||
model_inputs = {
|
||||
"input_ids": batch[f"{prefix}input_ids"],
|
||||
"attention_mask": batch[f"{prefix}attention_mask"],
|
||||
|
@ -17,6 +17,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -36,7 +37,7 @@ from ..model import find_all_linear_modules, load_model, load_tokenizer, load_va
|
||||
|
||||
|
||||
if is_galore_available():
|
||||
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
|
||||
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -330,7 +331,7 @@ def _create_badam_optimizer(
|
||||
]
|
||||
|
||||
if finetuning_args.badam_mode == "layer":
|
||||
from badam import BlockOptimizer
|
||||
from badam import BlockOptimizer # type: ignore
|
||||
|
||||
base_optimizer = optim_class(param_groups, **optim_kwargs)
|
||||
optimizer = BlockOptimizer(
|
||||
@ -350,7 +351,7 @@ def _create_badam_optimizer(
|
||||
)
|
||||
|
||||
elif finetuning_args.badam_mode == "ratio":
|
||||
from badam import BlockOptimizerRatio
|
||||
from badam import BlockOptimizerRatio # type: ignore
|
||||
|
||||
assert finetuning_args.badam_update_ratio > 1e-6
|
||||
optimizer = BlockOptimizerRatio(
|
||||
@ -374,7 +375,7 @@ def _create_adam_mini_optimizer(
|
||||
model: "PreTrainedModel",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
) -> "torch.optim.Optimizer":
|
||||
from adam_mini import Adam_mini
|
||||
from adam_mini import Adam_mini # type: ignore
|
||||
|
||||
hidden_size = getattr(model.config, "hidden_size", None)
|
||||
num_q_head = getattr(model.config, "num_attention_heads", None)
|
||||
@ -459,12 +460,33 @@ def get_batch_logps(
|
||||
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
|
||||
|
||||
|
||||
def nested_detach(
|
||||
tensors: Union["torch.Tensor", List["torch.Tensor"], Tuple["torch.Tensor"], Dict[str, "torch.Tensor"]],
|
||||
clone: bool = False,
|
||||
):
|
||||
r"""
|
||||
Detach `tensors` (even if it's a nested list/tuple/dict of tensors).
|
||||
"""
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_detach(t, clone=clone) for t in tensors)
|
||||
elif isinstance(tensors, Mapping):
|
||||
return type(tensors)({k: nested_detach(t, clone=clone) for k, t in tensors.items()})
|
||||
|
||||
if isinstance(tensors, torch.Tensor):
|
||||
if clone:
|
||||
return tensors.detach().clone()
|
||||
else:
|
||||
return tensors.detach()
|
||||
else:
|
||||
return tensors
|
||||
|
||||
|
||||
def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback":
|
||||
r"""
|
||||
Gets the callback for logging to SwanLab.
|
||||
"""
|
||||
import swanlab
|
||||
from swanlab.integration.transformers import SwanLabCallback
|
||||
import swanlab # type: ignore
|
||||
from swanlab.integration.transformers import SwanLabCallback # type: ignore
|
||||
|
||||
if finetuning_args.swanlab_api_key is not None:
|
||||
swanlab.login(api_key=finetuning_args.swanlab_api_key)
|
||||
|
Loading…
x
Reference in New Issue
Block a user