support rank0 logger

This commit is contained in:
hiyouga
2024-11-02 18:31:04 +08:00
parent bd08b8c441
commit c38aa29336
42 changed files with 316 additions and 252 deletions

View File

@@ -25,8 +25,8 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
import torch
from ...extras import logging
from ...extras.constants import LAYERNORM_NAMES
from ...extras.logging import get_logger
if TYPE_CHECKING:
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments
logger = get_logger(__name__)
logger = logging.get_logger(__name__)
def get_unsloth_gradient_checkpointing_func() -> Callable:
@@ -122,7 +122,7 @@ def _gradient_checkpointing_enable(
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
self.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads()
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
logger.warning_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradients
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
@@ -141,14 +141,14 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
(3) add the upcasting of the lm_head in fp32
"""
if model_args.upcast_layernorm:
logger.info("Upcasting layernorm weights in float32.")
logger.info_rank0("Upcasting layernorm weights in float32.")
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
param.data = param.data.to(torch.float32)
if not model_args.disable_gradient_checkpointing:
if not getattr(model, "supports_gradient_checkpointing", False):
logger.warning("Current model does not support gradient checkpointing.")
logger.warning_rank0("Current model does not support gradient checkpointing.")
else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
# According to: https://github.com/huggingface/transformers/issues/28339
@@ -158,10 +158,10 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
logger.info("Gradient checkpointing enabled.")
logger.info_rank0("Gradient checkpointing enabled.")
if model_args.upcast_lmhead_output:
output_layer = model.get_output_embeddings()
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
logger.info("Upcasting lm_head outputs in float32.")
logger.info_rank0("Upcasting lm_head outputs in float32.")
output_layer.register_forward_hook(_fp32_forward_post_hook)