diff --git a/scripts/cal_mfu.py b/scripts/cal_mfu.py index c4e851d7..0ae4dd42 100644 --- a/scripts/cal_mfu.py +++ b/scripts/cal_mfu.py @@ -138,7 +138,7 @@ def compute_mfu( result = json.load(f) mfu_value = ( - result["train_samples_per_second"] + result["train_steps_per_second"] * compute_model_flops(model_name_or_path, batch_size, seq_length) / compute_device_flops() )