mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 20:30:36 +08:00
support rank0 logger
This commit is contained in:
@@ -25,8 +25,8 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import LAYERNORM_NAMES
|
||||
from ...extras.logging import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_unsloth_gradient_checkpointing_func() -> Callable:
|
||||
@@ -122,7 +122,7 @@ def _gradient_checkpointing_enable(
|
||||
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||
self.enable_input_require_grads()
|
||||
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
||||
logger.warning_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
||||
else: # have already enabled input require gradients
|
||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
||||
|
||||
@@ -141,14 +141,14 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
|
||||
(3) add the upcasting of the lm_head in fp32
|
||||
"""
|
||||
if model_args.upcast_layernorm:
|
||||
logger.info("Upcasting layernorm weights in float32.")
|
||||
logger.info_rank0("Upcasting layernorm weights in float32.")
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if not model_args.disable_gradient_checkpointing:
|
||||
if not getattr(model, "supports_gradient_checkpointing", False):
|
||||
logger.warning("Current model does not support gradient checkpointing.")
|
||||
logger.warning_rank0("Current model does not support gradient checkpointing.")
|
||||
else:
|
||||
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
|
||||
# According to: https://github.com/huggingface/transformers/issues/28339
|
||||
@@ -158,10 +158,10 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
|
||||
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
||||
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
||||
logger.info("Gradient checkpointing enabled.")
|
||||
logger.info_rank0("Gradient checkpointing enabled.")
|
||||
|
||||
if model_args.upcast_lmhead_output:
|
||||
output_layer = model.get_output_embeddings()
|
||||
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
||||
logger.info("Upcasting lm_head outputs in float32.")
|
||||
logger.info_rank0("Upcasting lm_head outputs in float32.")
|
||||
output_layer.register_forward_hook(_fp32_forward_post_hook)
|
||||
|
||||
Reference in New Issue
Block a user