mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 12:20:37 +08:00
add docstrings, refactor logger
This commit is contained in:
@@ -32,6 +32,7 @@ from transformers.utils import (
|
||||
WEIGHTS_NAME,
|
||||
is_safetensors_available,
|
||||
)
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.logging import LoggerHandler, get_logger
|
||||
@@ -95,6 +96,7 @@ def fix_valuehead_checkpoint(
|
||||
|
||||
|
||||
class FixValueHeadModelCallback(TrainerCallback):
|
||||
@override
|
||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after a checkpoint save.
|
||||
@@ -114,6 +116,7 @@ class SaveProcessorCallback(TrainerCallback):
|
||||
"""
|
||||
self.processor = processor
|
||||
|
||||
@override
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
@@ -127,6 +130,7 @@ class PissaConvertCallback(TrainerCallback):
|
||||
Initializes a callback for converting the PiSSA adapter to a normal one.
|
||||
"""
|
||||
|
||||
@override
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of training.
|
||||
@@ -141,6 +145,7 @@ class PissaConvertCallback(TrainerCallback):
|
||||
model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors)
|
||||
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
|
||||
|
||||
@override
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
@@ -226,6 +231,7 @@ class LogCallback(TrainerCallback):
|
||||
self.thread_pool.shutdown(wait=True)
|
||||
self.thread_pool = None
|
||||
|
||||
@override
|
||||
def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of the initialization of the `Trainer`.
|
||||
@@ -238,6 +244,7 @@ class LogCallback(TrainerCallback):
|
||||
logger.warning("Previous trainer log in this folder will be deleted.")
|
||||
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
|
||||
|
||||
@override
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of training.
|
||||
@@ -247,12 +254,14 @@ class LogCallback(TrainerCallback):
|
||||
self._reset(max_steps=state.max_steps)
|
||||
self._create_thread_pool(output_dir=args.output_dir)
|
||||
|
||||
@override
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
self._close_thread_pool()
|
||||
|
||||
@override
|
||||
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of an substep during gradient accumulation.
|
||||
@@ -261,6 +270,7 @@ class LogCallback(TrainerCallback):
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
@override
|
||||
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of a training step.
|
||||
@@ -269,6 +279,7 @@ class LogCallback(TrainerCallback):
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
@override
|
||||
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after an evaluation phase.
|
||||
@@ -276,6 +287,7 @@ class LogCallback(TrainerCallback):
|
||||
if not self.do_train:
|
||||
self._close_thread_pool()
|
||||
|
||||
@override
|
||||
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after a successful prediction.
|
||||
@@ -283,6 +295,7 @@ class LogCallback(TrainerCallback):
|
||||
if not self.do_train:
|
||||
self._close_thread_pool()
|
||||
|
||||
@override
|
||||
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after logging the last logs.
|
||||
@@ -325,6 +338,7 @@ class LogCallback(TrainerCallback):
|
||||
if self.thread_pool is not None:
|
||||
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
||||
|
||||
@override
|
||||
def on_prediction_step(
|
||||
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||
):
|
||||
|
||||
@@ -26,6 +26,7 @@ import torch.nn.functional as F
|
||||
from transformers import Trainer
|
||||
from trl import DPOTrainer
|
||||
from trl.trainer import disable_dropout_in_model
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||
@@ -104,11 +105,13 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||
return super().create_optimizer()
|
||||
|
||||
@override
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
|
||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||
@@ -164,6 +167,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
return losses, chosen_rewards, rejected_rewards
|
||||
|
||||
@override
|
||||
def concatenated_forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
@@ -186,6 +190,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
chosen_length, _ = valid_length.split(batch_size, dim=0)
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
|
||||
|
||||
@override
|
||||
def compute_reference_log_probs(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
|
||||
@@ -207,6 +212,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
return reference_chosen_logps, reference_rejected_logps
|
||||
|
||||
@override
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
|
||||
@@ -25,6 +25,7 @@ import torch
|
||||
from transformers import Trainer
|
||||
from trl import KTOTrainer
|
||||
from trl.trainer import disable_dropout_in_model
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
@@ -99,23 +100,27 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||
return super().create_optimizer()
|
||||
|
||||
@override
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
|
||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
@override
|
||||
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
|
||||
r"""
|
||||
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
|
||||
"""
|
||||
return Trainer._get_train_sampler(self)
|
||||
|
||||
@override
|
||||
def forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor"]:
|
||||
@@ -140,6 +145,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
|
||||
return logps, logps / valid_length
|
||||
|
||||
@override
|
||||
def concatenated_forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
@@ -155,6 +161,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
chosen_logps_avg = target_logps_avg[batch["kto_tags"]]
|
||||
return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg
|
||||
|
||||
@override
|
||||
def compute_reference_log_probs(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
@@ -175,6 +182,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
return reference_chosen_logps, reference_rejected_logps, reference_kl_logps
|
||||
|
||||
@override
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
|
||||
@@ -35,6 +35,7 @@ from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
|
||||
from trl import PPOConfig, PPOTrainer
|
||||
from trl.core import PPODecorators, logprobs_from_logits
|
||||
from trl.models.utils import unwrap_model_for_generation
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
|
||||
@@ -298,6 +299,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
self.callback_handler.on_train_end(self.args, self.state, self.control)
|
||||
|
||||
@override
|
||||
def create_optimizer(
|
||||
self,
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
@@ -324,6 +326,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
return optimizer
|
||||
|
||||
@override
|
||||
def create_scheduler(
|
||||
self, training_args: "Seq2SeqTrainingArguments", num_training_steps: int, optimizer: "torch.optim.Optimizer"
|
||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||
@@ -410,6 +413,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
|
||||
return rewards.float().detach() # use fp32 type
|
||||
|
||||
@override
|
||||
@PPODecorators.empty_device_cache()
|
||||
def batched_forward_pass(
|
||||
self,
|
||||
@@ -478,6 +482,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
torch.cat(all_masks)[:, :-1],
|
||||
)
|
||||
|
||||
@override
|
||||
def save_model(self, output_dir: Optional[str] = None) -> None:
|
||||
r"""
|
||||
Saves model checkpoint.
|
||||
|
||||
@@ -16,6 +16,7 @@ from types import MethodType
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from transformers import Trainer
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||
@@ -55,11 +56,13 @@ class CustomTrainer(Trainer):
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||
return super().create_optimizer()
|
||||
|
||||
@override
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
|
||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||
|
||||
@@ -26,6 +26,10 @@ if TYPE_CHECKING:
|
||||
|
||||
@dataclass
|
||||
class ComputeAccuracy:
|
||||
r"""
|
||||
Computes reward accuracy and supports `batch_eval_metrics`.
|
||||
"""
|
||||
|
||||
def _dump(self) -> Optional[Dict[str, float]]:
|
||||
result = None
|
||||
if hasattr(self, "score_dict"):
|
||||
|
||||
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
|
||||
@@ -63,17 +64,20 @@ class PairwiseTrainer(Trainer):
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||
return super().create_optimizer()
|
||||
|
||||
@override
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
|
||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False
|
||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||
|
||||
@@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Seq2SeqTrainer
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
@@ -64,17 +65,20 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||
return super().create_optimizer()
|
||||
|
||||
@override
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
|
||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
@override
|
||||
def prediction_step(
|
||||
self,
|
||||
model: "torch.nn.Module",
|
||||
|
||||
@@ -26,6 +26,7 @@ from transformers.modeling_utils import is_fsdp_enabled
|
||||
from transformers.optimization import get_scheduler
|
||||
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX
|
||||
from ..extras.logging import get_logger
|
||||
@@ -60,9 +61,11 @@ class DummyOptimizer(torch.optim.Optimizer):
|
||||
self.optimizer_dict = optimizer_dict
|
||||
super().__init__([dummy_tensor], {"lr": lr})
|
||||
|
||||
@override
|
||||
def zero_grad(self, set_to_none: bool = True) -> None:
|
||||
pass
|
||||
|
||||
@override
|
||||
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user