mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[misc] fix grad ckpt func (#6916)
Former-commit-id: e34c3c06da706f80c74c20800f19110e9ad6b82a
This commit is contained in:
parent
bae934dea3
commit
036fb0d561
@ -142,14 +142,16 @@ def calculate_mfu(
|
||||
args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json"
|
||||
|
||||
run_exp(args)
|
||||
with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
|
||||
result = json.load(f)
|
||||
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
world_size = 1
|
||||
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
|
||||
result = json.load(f)
|
||||
|
||||
total_batch_size = batch_size * world_size
|
||||
mfu_value = (
|
||||
result["train_steps_per_second"]
|
||||
|
@ -87,6 +87,7 @@ def _parse_args(
|
||||
|
||||
|
||||
def _set_transformers_logging() -> None:
|
||||
if os.getenv("LLAMAFACTORY_VERBOSITY", "INFO") in ["DEBUG", "INFO"]:
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
@ -89,6 +89,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
|
||||
for arg in args:
|
||||
if torch.is_tensor(arg) and torch.is_floating_point(arg):
|
||||
arg.requires_grad_(True)
|
||||
break # assume the first tensor is always the hidden states
|
||||
|
||||
return gradient_checkpointing_func(func, *args, **kwargs)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user