[misc] fix grad ckpt func (#6916)

Former-commit-id: e34c3c06da706f80c74c20800f19110e9ad6b82a
This commit is contained in:
hoshi-hiyouga 2025-02-13 00:17:18 +08:00 committed by GitHub
parent bae934dea3
commit 036fb0d561
3 changed files with 17 additions and 13 deletions

View File

@ -142,21 +142,23 @@ def calculate_mfu(
args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json"
run_exp(args)
with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
result = json.load(f)
if dist.is_initialized():
dist.barrier()
world_size = dist.get_world_size()
else:
world_size = 1
total_batch_size = batch_size * world_size
mfu_value = (
result["train_steps_per_second"]
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
/ compute_device_flops(world_size)
)
print(f"MFU: {mfu_value * 100:.2f}%")
if int(os.getenv("LOCAL_RANK", "0")) == 0:
with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
result = json.load(f)
total_batch_size = batch_size * world_size
mfu_value = (
result["train_steps_per_second"]
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
/ compute_device_flops(world_size)
)
print(f"MFU: {mfu_value * 100:.2f}%")
if __name__ == "__main__":

View File

@ -87,9 +87,10 @@ def _parse_args(
def _set_transformers_logging() -> None:
transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
if os.getenv("LLAMAFACTORY_VERBOSITY", "INFO") in ["DEBUG", "INFO"]:
transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
def _verify_model_args(

View File

@ -89,6 +89,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
break # assume the first tensor is always the hidden states
return gradient_checkpointing_func(func, *args, **kwargs)