[misc] fix grad ckpt func (#6916)

Former-commit-id: e34c3c06da706f80c74c20800f19110e9ad6b82a
This commit is contained in:
hoshi-hiyouga 2025-02-13 00:17:18 +08:00 committed by GitHub
parent bae934dea3
commit 036fb0d561
3 changed files with 17 additions and 13 deletions

View File

@ -142,21 +142,23 @@ def calculate_mfu(
args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json" args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json"
run_exp(args) run_exp(args)
with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
result = json.load(f)
if dist.is_initialized(): if dist.is_initialized():
dist.barrier()
world_size = dist.get_world_size() world_size = dist.get_world_size()
else: else:
world_size = 1 world_size = 1
total_batch_size = batch_size * world_size if int(os.getenv("LOCAL_RANK", "0")) == 0:
mfu_value = ( with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
result["train_steps_per_second"] result = json.load(f)
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
/ compute_device_flops(world_size) total_batch_size = batch_size * world_size
) mfu_value = (
print(f"MFU: {mfu_value * 100:.2f}%") result["train_steps_per_second"]
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
/ compute_device_flops(world_size)
)
print(f"MFU: {mfu_value * 100:.2f}%")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -87,9 +87,10 @@ def _parse_args(
def _set_transformers_logging() -> None: def _set_transformers_logging() -> None:
transformers.utils.logging.set_verbosity_info() if os.getenv("LLAMAFACTORY_VERBOSITY", "INFO") in ["DEBUG", "INFO"]:
transformers.utils.logging.enable_default_handler() transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_explicit_format() transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
def _verify_model_args( def _verify_model_args(

View File

@ -89,6 +89,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
for arg in args: for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg): if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True) arg.requires_grad_(True)
break # assume the first tensor is always the hidden states
return gradient_checkpointing_func(func, *args, **kwargs) return gradient_checkpointing_func(func, *args, **kwargs)