diff --git a/scripts/cal_mfu.py b/scripts/cal_mfu.py new file mode 100644 index 00000000..0ae4dd42 --- /dev/null +++ b/scripts/cal_mfu.py @@ -0,0 +1,149 @@ +# coding=utf-8 +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +import fire +import torch +from transformers import AutoConfig + +from llamafactory.train.tuner import run_exp + + +BASE = 2 # gemm (add + mul) + + +def compute_model_flops( + model_name_or_path: str, + batch_size: int, + seq_length: int, + include_backward: bool = True, + include_recompute: bool = False, + include_flashattn: bool = False, +) -> int: + r""" + Calculates the FLOPs of model per forward/backward pass. + """ + config = AutoConfig.from_pretrained(model_name_or_path) + hidden_size = getattr(config, "hidden_size", None) + vocab_size = getattr(config, "vocab_size", None) + intermediate_size = getattr(config, "intermediate_size", None) + num_attention_heads = getattr(config, "num_attention_heads", None) + num_key_value_heads = getattr(config, "num_key_value_heads", None) + num_hidden_layers = getattr(config, "num_hidden_layers", None) + tie_word_embeddings = getattr(config, "tie_word_embeddings", False) + + # mlp module + mlp_flops_per_token = 3 * BASE * hidden_size * intermediate_size # up, gate, down + mlp_flops = batch_size * seq_length * num_hidden_layers * mlp_flops_per_token + + # attn projector module + q_flops_per_token = BASE * hidden_size * hidden_size + o_flops_per_token = BASE * hidden_size * hidden_size + k_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads + v_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads + attn_proj_flops_per_token = q_flops_per_token + o_flops_per_token + k_flops_per_token + v_flops_per_token + attn_proj_flops = batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token + + # attn sdpa module + sdpa_flops_per_layer = 2 * BASE * hidden_size * seq_length * seq_length # (q * k^T) * v + sdpa_flops = batch_size * num_hidden_layers * sdpa_flops_per_layer + + # embedding module + embedding_flops_per_token = hidden_size * vocab_size + embedding_flops = batch_size * seq_length * embedding_flops_per_token + if tie_word_embeddings is False: + embedding_flops *= 2 + + non_embedding_flops = mlp_flops + attn_proj_flops + sdpa_flops + non_embedding_coeff, embedding_coeff = 1, 1 + if include_backward: + non_embedding_coeff += 2 + embedding_coeff += 2 + + if include_recompute: + non_embedding_coeff += 1 + + total_flops = non_embedding_coeff * non_embedding_flops + embedding_coeff * embedding_flops + + if include_flashattn: + total_flops += sdpa_flops + + return total_flops + + +def compute_device_flops() -> float: + device_name = torch.cuda.get_device_name() + device_count = torch.cuda.device_count() + if "H100" in device_name or "H800" in device_name: + return 989 * 1e12 * device_count + elif "A100" in device_name or "A800" in device_name: + return 312 * 1e12 * device_count + elif "V100" in device_name: + return 125 * 1e12 * device_count + elif "4090" in device_name: + return 98 * 1e12 * device_count + else: + raise NotImplementedError("Device not supported: {}.".format(device_name)) + + +def compute_mfu( + model_name_or_path: str, + batch_size: int, + seq_length: int, + finetuning_type: str = "lora", + flash_attn: str = "auto", + deepspeed_stage: int = 0, + disable_gc: bool = False, + liger_kernel: bool = False, +) -> float: + r""" + Computes MFU for given model and hyper-params. + Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024 + """ + args = { + "model_name_or_path": model_name_or_path, + "flash_attn": flash_attn, + "disable_gradient_checkpointing": disable_gc, + "enable_liger_kernel": liger_kernel, + "stage": "pt", + "do_train": True, + "finetuning_type": finetuning_type, + "dataset": "c4_demo", + "cutoff_len": seq_length, + "output_dir": os.path.join("saves", "test_mfu"), + "overwrite_output_dir": True, + "per_device_train_batch_size": batch_size, + "max_steps": 100, + "bf16": True, + } + if deepspeed_stage in [2, 3]: + args["deepspeed"] = "examples/deepspeed/ds_z{}_config.json".format(deepspeed_stage) + + run_exp(args) + with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f: + result = json.load(f) + + mfu_value = ( + result["train_steps_per_second"] + * compute_model_flops(model_name_or_path, batch_size, seq_length) + / compute_device_flops() + ) + print("MFU: {:.2f}%".format(mfu_value * 100)) + + +if __name__ == "__main__": + fire.Fire(compute_mfu)