add docstrings, refactor logger

This commit is contained in:
hiyouga
2024-09-08 00:56:56 +08:00
parent 8eac1b929f
commit 54c6905937
30 changed files with 334 additions and 57 deletions

View File

@@ -32,6 +32,7 @@ from transformers.utils import (
WEIGHTS_NAME,
is_safetensors_available,
)
from typing_extensions import override
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import LoggerHandler, get_logger
@@ -95,6 +96,7 @@ def fix_valuehead_checkpoint(
class FixValueHeadModelCallback(TrainerCallback):
@override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a checkpoint save.
@@ -114,6 +116,7 @@ class SaveProcessorCallback(TrainerCallback):
"""
self.processor = processor
@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
@@ -127,6 +130,7 @@ class PissaConvertCallback(TrainerCallback):
Initializes a callback for converting the PiSSA adapter to a normal one.
"""
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of training.
@@ -141,6 +145,7 @@ class PissaConvertCallback(TrainerCallback):
model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors)
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
@@ -226,6 +231,7 @@ class LogCallback(TrainerCallback):
self.thread_pool.shutdown(wait=True)
self.thread_pool = None
@override
def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of the initialization of the `Trainer`.
@@ -238,6 +244,7 @@ class LogCallback(TrainerCallback):
logger.warning("Previous trainer log in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the beginning of training.
@@ -247,12 +254,14 @@ class LogCallback(TrainerCallback):
self._reset(max_steps=state.max_steps)
self._create_thread_pool(output_dir=args.output_dir)
@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
self._close_thread_pool()
@override
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of an substep during gradient accumulation.
@@ -261,6 +270,7 @@ class LogCallback(TrainerCallback):
control.should_epoch_stop = True
control.should_training_stop = True
@override
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of a training step.
@@ -269,6 +279,7 @@ class LogCallback(TrainerCallback):
control.should_epoch_stop = True
control.should_training_stop = True
@override
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after an evaluation phase.
@@ -276,6 +287,7 @@ class LogCallback(TrainerCallback):
if not self.do_train:
self._close_thread_pool()
@override
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a successful prediction.
@@ -283,6 +295,7 @@ class LogCallback(TrainerCallback):
if not self.do_train:
self._close_thread_pool()
@override
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after logging the last logs.
@@ -325,6 +338,7 @@ class LogCallback(TrainerCallback):
if self.thread_pool is not None:
self.thread_pool.submit(self._write_log, args.output_dir, logs)
@override
def on_prediction_step(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
):

View File

@@ -26,6 +26,7 @@ import torch.nn.functional as F
from transformers import Trainer
from trl import DPOTrainer
from trl.trainer import disable_dropout_in_model
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
@@ -104,11 +105,13 @@ class CustomDPOTrainer(DPOTrainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
@override
def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
@@ -164,6 +167,7 @@ class CustomDPOTrainer(DPOTrainer):
return losses, chosen_rewards, rejected_rewards
@override
def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
@@ -186,6 +190,7 @@ class CustomDPOTrainer(DPOTrainer):
chosen_length, _ = valid_length.split(batch_size, dim=0)
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
@override
def compute_reference_log_probs(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
@@ -207,6 +212,7 @@ class CustomDPOTrainer(DPOTrainer):
return reference_chosen_logps, reference_rejected_logps
@override
def get_batch_loss_metrics(
self,
model: "PreTrainedModel",

View File

@@ -25,6 +25,7 @@ import torch
from transformers import Trainer
from trl import KTOTrainer
from trl.trainer import disable_dropout_in_model
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
from ..callbacks import SaveProcessorCallback
@@ -99,23 +100,27 @@ class CustomKTOTrainer(KTOTrainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
@override
def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
@override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
r"""
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
"""
return Trainer._get_train_sampler(self)
@override
def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> Tuple["torch.Tensor", "torch.Tensor"]:
@@ -140,6 +145,7 @@ class CustomKTOTrainer(KTOTrainer):
logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
return logps, logps / valid_length
@override
def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
@@ -155,6 +161,7 @@ class CustomKTOTrainer(KTOTrainer):
chosen_logps_avg = target_logps_avg[batch["kto_tags"]]
return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg
@override
def compute_reference_log_probs(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
@@ -175,6 +182,7 @@ class CustomKTOTrainer(KTOTrainer):
return reference_chosen_logps, reference_rejected_logps, reference_kl_logps
@override
def get_batch_loss_metrics(
self,
model: "PreTrainedModel",

View File

@@ -35,6 +35,7 @@ from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from trl import PPOConfig, PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits
from trl.models.utils import unwrap_model_for_generation
from typing_extensions import override
from ...extras.logging import get_logger
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
@@ -298,6 +299,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.callback_handler.on_train_end(self.args, self.state, self.control)
@override
def create_optimizer(
self,
model: "AutoModelForCausalLMWithValueHead",
@@ -324,6 +326,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
return optimizer
@override
def create_scheduler(
self, training_args: "Seq2SeqTrainingArguments", num_training_steps: int, optimizer: "torch.optim.Optimizer"
) -> "torch.optim.lr_scheduler.LRScheduler":
@@ -410,6 +413,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
return rewards.float().detach() # use fp32 type
@override
@PPODecorators.empty_device_cache()
def batched_forward_pass(
self,
@@ -478,6 +482,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
torch.cat(all_masks)[:, :-1],
)
@override
def save_model(self, output_dir: Optional[str] = None) -> None:
r"""
Saves model checkpoint.

View File

@@ -16,6 +16,7 @@ from types import MethodType
from typing import TYPE_CHECKING, Optional
from transformers import Trainer
from typing_extensions import override
from ...extras.logging import get_logger
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
@@ -55,11 +56,13 @@ class CustomTrainer(Trainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
@override
def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":

View File

@@ -26,6 +26,10 @@ if TYPE_CHECKING:
@dataclass
class ComputeAccuracy:
r"""
Computes reward accuracy and supports `batch_eval_metrics`.
"""
def _dump(self) -> Optional[Dict[str, float]]:
result = None
if hasattr(self, "score_dict"):

View File

@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from transformers import Trainer
from typing_extensions import override
from ...extras.logging import get_logger
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
@@ -63,17 +64,20 @@ class PairwiseTrainer(Trainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
@override
def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:

View File

@@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import Seq2SeqTrainer
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
@@ -64,17 +65,20 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer()
@override
def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
@override
def prediction_step(
self,
model: "torch.nn.Module",

View File

@@ -26,6 +26,7 @@ from transformers.modeling_utils import is_fsdp_enabled
from transformers.optimization import get_scheduler
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
from typing_extensions import override
from ..extras.constants import IGNORE_INDEX
from ..extras.logging import get_logger
@@ -60,9 +61,11 @@ class DummyOptimizer(torch.optim.Optimizer):
self.optimizer_dict = optimizer_dict
super().__init__([dummy_tensor], {"lr": lr})
@override
def zero_grad(self, set_to_none: bool = True) -> None:
pass
@override
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
pass