From 036fb0d561d7e0f2a3637c901beb0782f4e8c8e7 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Thu, 13 Feb 2025 00:17:18 +0800 Subject: [PATCH] [misc] fix grad ckpt func (#6916) Former-commit-id: e34c3c06da706f80c74c20800f19110e9ad6b82a --- scripts/stat_utils/cal_mfu.py | 22 ++++++++++--------- src/llamafactory/hparams/parser.py | 7 +++--- .../model/model_utils/checkpointing.py | 1 + 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/scripts/stat_utils/cal_mfu.py b/scripts/stat_utils/cal_mfu.py index ef5672d2..b1aea710 100644 --- a/scripts/stat_utils/cal_mfu.py +++ b/scripts/stat_utils/cal_mfu.py @@ -142,21 +142,23 @@ def calculate_mfu( args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json" run_exp(args) - with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f: - result = json.load(f) - if dist.is_initialized(): + dist.barrier() world_size = dist.get_world_size() else: world_size = 1 - total_batch_size = batch_size * world_size - mfu_value = ( - result["train_steps_per_second"] - * compute_model_flops(model_name_or_path, total_batch_size, seq_length) - / compute_device_flops(world_size) - ) - print(f"MFU: {mfu_value * 100:.2f}%") + if int(os.getenv("LOCAL_RANK", "0")) == 0: + with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f: + result = json.load(f) + + total_batch_size = batch_size * world_size + mfu_value = ( + result["train_steps_per_second"] + * compute_model_flops(model_name_or_path, total_batch_size, seq_length) + / compute_device_flops(world_size) + ) + print(f"MFU: {mfu_value * 100:.2f}%") if __name__ == "__main__": diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 19708156..911b78ec 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -87,9 +87,10 @@ def _parse_args( def _set_transformers_logging() -> None: - transformers.utils.logging.set_verbosity_info() - transformers.utils.logging.enable_default_handler() - transformers.utils.logging.enable_explicit_format() + if os.getenv("LLAMAFACTORY_VERBOSITY", "INFO") in ["DEBUG", "INFO"]: + transformers.utils.logging.set_verbosity_info() + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() def _verify_model_args( diff --git a/src/llamafactory/model/model_utils/checkpointing.py b/src/llamafactory/model/model_utils/checkpointing.py index 80a50f3e..bbd44ba4 100644 --- a/src/llamafactory/model/model_utils/checkpointing.py +++ b/src/llamafactory/model/model_utils/checkpointing.py @@ -89,6 +89,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable for arg in args: if torch.is_tensor(arg) and torch.is_floating_point(arg): arg.requires_grad_(True) + break # assume the first tensor is always the hidden states return gradient_checkpointing_func(func, *args, **kwargs)