mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[misc] fix grad ckpt func (#6916)
Former-commit-id: e34c3c06da706f80c74c20800f19110e9ad6b82a
This commit is contained in:
parent
bae934dea3
commit
036fb0d561
@ -142,21 +142,23 @@ def calculate_mfu(
|
|||||||
args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json"
|
args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json"
|
||||||
|
|
||||||
run_exp(args)
|
run_exp(args)
|
||||||
with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
|
|
||||||
result = json.load(f)
|
|
||||||
|
|
||||||
if dist.is_initialized():
|
if dist.is_initialized():
|
||||||
|
dist.barrier()
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
else:
|
else:
|
||||||
world_size = 1
|
world_size = 1
|
||||||
|
|
||||||
total_batch_size = batch_size * world_size
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
mfu_value = (
|
with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
|
||||||
result["train_steps_per_second"]
|
result = json.load(f)
|
||||||
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
|
|
||||||
/ compute_device_flops(world_size)
|
total_batch_size = batch_size * world_size
|
||||||
)
|
mfu_value = (
|
||||||
print(f"MFU: {mfu_value * 100:.2f}%")
|
result["train_steps_per_second"]
|
||||||
|
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
|
||||||
|
/ compute_device_flops(world_size)
|
||||||
|
)
|
||||||
|
print(f"MFU: {mfu_value * 100:.2f}%")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -87,9 +87,10 @@ def _parse_args(
|
|||||||
|
|
||||||
|
|
||||||
def _set_transformers_logging() -> None:
|
def _set_transformers_logging() -> None:
|
||||||
transformers.utils.logging.set_verbosity_info()
|
if os.getenv("LLAMAFACTORY_VERBOSITY", "INFO") in ["DEBUG", "INFO"]:
|
||||||
transformers.utils.logging.enable_default_handler()
|
transformers.utils.logging.set_verbosity_info()
|
||||||
transformers.utils.logging.enable_explicit_format()
|
transformers.utils.logging.enable_default_handler()
|
||||||
|
transformers.utils.logging.enable_explicit_format()
|
||||||
|
|
||||||
|
|
||||||
def _verify_model_args(
|
def _verify_model_args(
|
||||||
|
@ -89,6 +89,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
|
|||||||
for arg in args:
|
for arg in args:
|
||||||
if torch.is_tensor(arg) and torch.is_floating_point(arg):
|
if torch.is_tensor(arg) and torch.is_floating_point(arg):
|
||||||
arg.requires_grad_(True)
|
arg.requires_grad_(True)
|
||||||
|
break # assume the first tensor is always the hidden states
|
||||||
|
|
||||||
return gradient_checkpointing_func(func, *args, **kwargs)
|
return gradient_checkpointing_func(func, *args, **kwargs)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user