diff --git a/scripts/cal_mfu.py b/scripts/cal_mfu.py index 2f408497..c4e851d7 100644 --- a/scripts/cal_mfu.py +++ b/scripts/cal_mfu.py @@ -11,116 +11,139 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License +# limitations under the License. +import json +import os + +import fire import torch from transformers import AutoConfig -import fire -def model_flops_counter( + +from llamafactory.train.tuner import run_exp + + +BASE = 2 # gemm (add + mul) + + +def compute_model_flops( + model_name_or_path: str, batch_size: int, - seqlen: int, - model_config: dict, - is_backward: bool = True, - is_recompute: bool = False, - is_flashattn: bool = False, -) -> float: + seq_length: int, + include_backward: bool = True, + include_recompute: bool = False, + include_flashattn: bool = False, +) -> int: + r""" + Calculates the FLOPs of model per forward/backward pass. """ - calculate the FLOPs of model per iteration - """ - hidden_size = model_config.hidden_size - num_attention_heads = model_config.num_attention_heads - num_key_value_heads = model_config.num_key_value_heads - vocab_size = model_config.vocab_size - intermediate_size = model_config.intermediate_size - num_hidden_layers = model_config.num_hidden_layers - """ - B: batch_size - S: seqlen - L: num_hidden_layers - H: hidden_size - V: vocab_size - I: intermediate_size - """ - ### MLP calculation - per_mlp_calculation = 2 * hidden_size * intermediate_size - mlp_calculation_per_layer = per_mlp_calculation * 3 - mlp_calculation = batch_size * seqlen * mlp_calculation_per_layer * num_hidden_layers + config = AutoConfig.from_pretrained(model_name_or_path) + hidden_size = getattr(config, "hidden_size", None) + vocab_size = getattr(config, "vocab_size", None) + intermediate_size = getattr(config, "intermediate_size", None) + num_attention_heads = getattr(config, "num_attention_heads", None) + num_key_value_heads = getattr(config, "num_key_value_heads", None) + num_hidden_layers = getattr(config, "num_hidden_layers", None) + tie_word_embeddings = getattr(config, "tie_word_embeddings", False) - ### Attention calculation - Q_calculation = 2 * hidden_size * hidden_size - O_calculation = 2 * hidden_size * hidden_size - K_calculation = 2 * hidden_size * hidden_size * num_key_value_heads / num_attention_heads - V_calculation = 2 * hidden_size * hidden_size * num_key_value_heads / num_attention_heads - - QKVO_calculation = Q_calculation + O_calculation + K_calculation + V_calculation # 8H^2 / coe - self_attn_calculation = seqlen * hidden_size * 2 * 2 # (4 * S * H) - attention_calculation = batch_size * seqlen * num_hidden_layers * (QKVO_calculation + self_attn_calculation) # BSL(8H^2/coe + 4S * H) - - #Embedding and LMhead calculation - embedding_calculation = hidden_size * vocab_size - lmhead_calculation = hidden_size * vocab_size - IO_calculation = 3 * batch_size * seqlen * (embedding_calculation + lmhead_calculation) # 2 *(1+2)BSHV - E = attention_calculation + mlp_calculation - coefficient = 3 - fix_term = 0 - if(is_recompute): - coefficient = 4 - if(is_flashattn): - fix_term = batch_size *seqlen * self_attn_calculation - - total_calculation = coefficient * E + IO_calculation + fix_term - - return total_calculation + # mlp module + mlp_flops_per_token = 3 * BASE * hidden_size * intermediate_size # up, gate, down + mlp_flops = batch_size * seq_length * num_hidden_layers * mlp_flops_per_token + + # attn projector module + q_flops_per_token = BASE * hidden_size * hidden_size + o_flops_per_token = BASE * hidden_size * hidden_size + k_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads + v_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads + attn_proj_flops_per_token = q_flops_per_token + o_flops_per_token + k_flops_per_token + v_flops_per_token + attn_proj_flops = batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token + + # attn sdpa module + sdpa_flops_per_layer = 2 * BASE * hidden_size * seq_length * seq_length # (q * k^T) * v + sdpa_flops = batch_size * num_hidden_layers * sdpa_flops_per_layer + + # embedding module + embedding_flops_per_token = hidden_size * vocab_size + embedding_flops = batch_size * seq_length * embedding_flops_per_token + if tie_word_embeddings is False: + embedding_flops *= 2 + + non_embedding_flops = mlp_flops + attn_proj_flops + sdpa_flops + non_embedding_coeff, embedding_coeff = 1, 1 + if include_backward: + non_embedding_coeff += 2 + embedding_coeff += 2 + + if include_recompute: + non_embedding_coeff += 1 + + total_flops = non_embedding_coeff * non_embedding_flops + embedding_coeff * embedding_flops + + if include_flashattn: + total_flops += sdpa_flops + + return total_flops -def hardware_flops_counter( - seconds: float, # seconds used in given iterations - num_gpus: int = 1, -) -> float: - if "A100" in torch.cuda.get_device_name(): - return 312 * 1e12 * seconds * num_gpus - elif "V100" in torch.cuda.get_device_name(): - return 125 * 1e12 * seconds * num_gpus +def compute_device_flops() -> float: + device_name = torch.cuda.get_device_name() + device_count = torch.cuda.device_count() + if "H100" in device_name or "H800" in device_name: + return 989 * 1e12 * device_count + elif "A100" in device_name or "A800" in device_name: + return 312 * 1e12 * device_count + elif "V100" in device_name: + return 125 * 1e12 * device_count + elif "4090" in device_name: + return 98 * 1e12 * device_count + else: + raise NotImplementedError("Device not supported: {}.".format(device_name)) + def compute_mfu( + model_name_or_path: str, batch_size: int, - seqlen: int, - model_config: dict, - num_iter: int, - seconds: float, - num_gpus: int = 1, + seq_length: int, + finetuning_type: str = "lora", + flash_attn: str = "auto", + deepspeed_stage: int = 0, + disable_gc: bool = False, + liger_kernel: bool = False, ) -> float: + r""" + Computes MFU for given model and hyper-params. + Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024 """ - compute MFU given model configuration, training config and training information - """ - percentage = (num_iter * model_flops_counter(batch_size,seqlen,model_config)) / hardware_flops_counter(seconds, num_gpus) - - print(f"MFU : {percentage* 100:.2f}%") - return percentage - -# User input + args = { + "model_name_or_path": model_name_or_path, + "flash_attn": flash_attn, + "disable_gradient_checkpointing": disable_gc, + "enable_liger_kernel": liger_kernel, + "stage": "pt", + "do_train": True, + "finetuning_type": finetuning_type, + "dataset": "c4_demo", + "cutoff_len": seq_length, + "output_dir": os.path.join("saves", "test_mfu"), + "overwrite_output_dir": True, + "per_device_train_batch_size": batch_size, + "max_steps": 100, + "bf16": True, + } + if deepspeed_stage in [2, 3]: + args["deepspeed"] = "examples/deepspeed/ds_z{}_config.json".format(deepspeed_stage) -### model_name -model_name = "meta-llama/Meta-Llama-3-8B-Instruct" + run_exp(args) + with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f: + result = json.load(f) -### training config -batch_size = 8 -seqlen = 1*1024 -num_gpus = 1 - -### training information -num_iter = 225 -seconds = 605 # time used in {num_iter} iterations - -model_config = AutoConfig.from_pretrained(model_name) -if __name__ == "__main__": - fire.Fire( - compute_mfu( - batch_size=batch_size, - seqlen=seqlen, - model_config=model_config, - num_iter=num_iter, - seconds=seconds, - num_gpus=num_gpus - ) + mfu_value = ( + result["train_samples_per_second"] + * compute_model_flops(model_name_or_path, batch_size, seq_length) + / compute_device_flops() ) + print("MFU: {:.2f}%".format(mfu_value * 100)) + + +if __name__ == "__main__": + fire.Fire(compute_mfu)