mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	Update cal_mfu.py
Former-commit-id: 0c391b2e59943b0ca9dd4e8561398e7c856a4b29
This commit is contained in:
		
							parent
							
								
									9f2427907e
								
							
						
					
					
						commit
						3021b31cf3
					
				@ -11,116 +11,139 @@
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import fire
 | 
			
		||||
import torch
 | 
			
		||||
from transformers import AutoConfig
 | 
			
		||||
import fire
 | 
			
		||||
def model_flops_counter(
 | 
			
		||||
 | 
			
		||||
from llamafactory.train.tuner import run_exp
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
BASE = 2  # gemm (add + mul)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def compute_model_flops(
 | 
			
		||||
    model_name_or_path: str,
 | 
			
		||||
    batch_size: int,
 | 
			
		||||
    seqlen: int,
 | 
			
		||||
    model_config: dict,
 | 
			
		||||
    is_backward: bool = True,
 | 
			
		||||
    is_recompute: bool = False,
 | 
			
		||||
    is_flashattn: bool = False,
 | 
			
		||||
) -> float:
 | 
			
		||||
    seq_length: int,
 | 
			
		||||
    include_backward: bool = True,
 | 
			
		||||
    include_recompute: bool = False,
 | 
			
		||||
    include_flashattn: bool = False,
 | 
			
		||||
) -> int:
 | 
			
		||||
    r"""
 | 
			
		||||
    Calculates the FLOPs of model per forward/backward pass.
 | 
			
		||||
    """
 | 
			
		||||
    calculate the FLOPs of model per iteration
 | 
			
		||||
    """
 | 
			
		||||
    hidden_size = model_config.hidden_size
 | 
			
		||||
    num_attention_heads = model_config.num_attention_heads
 | 
			
		||||
    num_key_value_heads = model_config.num_key_value_heads
 | 
			
		||||
    vocab_size = model_config.vocab_size
 | 
			
		||||
    intermediate_size = model_config.intermediate_size
 | 
			
		||||
    num_hidden_layers = model_config.num_hidden_layers
 | 
			
		||||
    """
 | 
			
		||||
    B: batch_size
 | 
			
		||||
    S: seqlen
 | 
			
		||||
    L: num_hidden_layers
 | 
			
		||||
    H: hidden_size
 | 
			
		||||
    V: vocab_size
 | 
			
		||||
    I: intermediate_size
 | 
			
		||||
    """
 | 
			
		||||
    ### MLP calculation
 | 
			
		||||
    per_mlp_calculation = 2 * hidden_size * intermediate_size
 | 
			
		||||
    mlp_calculation_per_layer = per_mlp_calculation * 3
 | 
			
		||||
    mlp_calculation = batch_size * seqlen * mlp_calculation_per_layer * num_hidden_layers
 | 
			
		||||
    config = AutoConfig.from_pretrained(model_name_or_path)
 | 
			
		||||
    hidden_size = getattr(config, "hidden_size", None)
 | 
			
		||||
    vocab_size = getattr(config, "vocab_size", None)
 | 
			
		||||
    intermediate_size = getattr(config, "intermediate_size", None)
 | 
			
		||||
    num_attention_heads = getattr(config, "num_attention_heads", None)
 | 
			
		||||
    num_key_value_heads = getattr(config, "num_key_value_heads", None)
 | 
			
		||||
    num_hidden_layers = getattr(config, "num_hidden_layers", None)
 | 
			
		||||
    tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
 | 
			
		||||
 | 
			
		||||
    ### Attention calculation
 | 
			
		||||
    Q_calculation = 2 * hidden_size * hidden_size
 | 
			
		||||
    O_calculation = 2 * hidden_size * hidden_size
 | 
			
		||||
    K_calculation = 2 * hidden_size * hidden_size * num_key_value_heads / num_attention_heads
 | 
			
		||||
    V_calculation = 2 * hidden_size * hidden_size * num_key_value_heads / num_attention_heads
 | 
			
		||||
    
 | 
			
		||||
    QKVO_calculation = Q_calculation + O_calculation + K_calculation + V_calculation # 8H^2 / coe
 | 
			
		||||
    self_attn_calculation = seqlen * hidden_size * 2 * 2  # (4 * S * H)
 | 
			
		||||
    attention_calculation = batch_size * seqlen * num_hidden_layers * (QKVO_calculation + self_attn_calculation) # BSL(8H^2/coe + 4S * H)
 | 
			
		||||
    
 | 
			
		||||
    #Embedding and LMhead calculation
 | 
			
		||||
    embedding_calculation = hidden_size * vocab_size
 | 
			
		||||
    lmhead_calculation = hidden_size * vocab_size    
 | 
			
		||||
    IO_calculation = 3 * batch_size * seqlen * (embedding_calculation + lmhead_calculation) # 2 *(1+2)BSHV    
 | 
			
		||||
    E = attention_calculation + mlp_calculation
 | 
			
		||||
    coefficient = 3
 | 
			
		||||
    fix_term = 0
 | 
			
		||||
    if(is_recompute):
 | 
			
		||||
        coefficient = 4
 | 
			
		||||
    if(is_flashattn):
 | 
			
		||||
        fix_term = batch_size *seqlen * self_attn_calculation
 | 
			
		||||
    
 | 
			
		||||
    total_calculation = coefficient * E + IO_calculation + fix_term
 | 
			
		||||
    
 | 
			
		||||
    return total_calculation
 | 
			
		||||
    # mlp module
 | 
			
		||||
    mlp_flops_per_token = 3 * BASE * hidden_size * intermediate_size  # up, gate, down
 | 
			
		||||
    mlp_flops = batch_size * seq_length * num_hidden_layers * mlp_flops_per_token
 | 
			
		||||
 | 
			
		||||
    # attn projector module
 | 
			
		||||
    q_flops_per_token = BASE * hidden_size * hidden_size
 | 
			
		||||
    o_flops_per_token = BASE * hidden_size * hidden_size
 | 
			
		||||
    k_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
 | 
			
		||||
    v_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
 | 
			
		||||
    attn_proj_flops_per_token = q_flops_per_token + o_flops_per_token + k_flops_per_token + v_flops_per_token
 | 
			
		||||
    attn_proj_flops = batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token
 | 
			
		||||
 | 
			
		||||
    # attn sdpa module
 | 
			
		||||
    sdpa_flops_per_layer = 2 * BASE * hidden_size * seq_length * seq_length  # (q * k^T) * v
 | 
			
		||||
    sdpa_flops = batch_size * num_hidden_layers * sdpa_flops_per_layer
 | 
			
		||||
 | 
			
		||||
    # embedding module
 | 
			
		||||
    embedding_flops_per_token = hidden_size * vocab_size
 | 
			
		||||
    embedding_flops = batch_size * seq_length * embedding_flops_per_token
 | 
			
		||||
    if tie_word_embeddings is False:
 | 
			
		||||
        embedding_flops *= 2
 | 
			
		||||
 | 
			
		||||
    non_embedding_flops = mlp_flops + attn_proj_flops + sdpa_flops
 | 
			
		||||
    non_embedding_coeff, embedding_coeff = 1, 1
 | 
			
		||||
    if include_backward:
 | 
			
		||||
        non_embedding_coeff += 2
 | 
			
		||||
        embedding_coeff += 2
 | 
			
		||||
 | 
			
		||||
    if include_recompute:
 | 
			
		||||
        non_embedding_coeff += 1
 | 
			
		||||
 | 
			
		||||
    total_flops = non_embedding_coeff * non_embedding_flops + embedding_coeff * embedding_flops
 | 
			
		||||
 | 
			
		||||
    if include_flashattn:
 | 
			
		||||
        total_flops += sdpa_flops
 | 
			
		||||
 | 
			
		||||
    return total_flops
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def hardware_flops_counter(
 | 
			
		||||
    seconds: float, # seconds used in given iterations
 | 
			
		||||
    num_gpus: int = 1,
 | 
			
		||||
) -> float:
 | 
			
		||||
    if "A100" in torch.cuda.get_device_name():
 | 
			
		||||
        return 312 * 1e12 * seconds * num_gpus
 | 
			
		||||
    elif "V100" in torch.cuda.get_device_name():
 | 
			
		||||
        return 125 * 1e12 * seconds * num_gpus
 | 
			
		||||
def compute_device_flops() -> float:
 | 
			
		||||
    device_name = torch.cuda.get_device_name()
 | 
			
		||||
    device_count = torch.cuda.device_count()
 | 
			
		||||
    if "H100" in device_name or "H800" in device_name:
 | 
			
		||||
        return 989 * 1e12 * device_count
 | 
			
		||||
    elif "A100" in device_name or "A800" in device_name:
 | 
			
		||||
        return 312 * 1e12 * device_count
 | 
			
		||||
    elif "V100" in device_name:
 | 
			
		||||
        return 125 * 1e12 * device_count
 | 
			
		||||
    elif "4090" in device_name:
 | 
			
		||||
        return 98 * 1e12 * device_count
 | 
			
		||||
    else:
 | 
			
		||||
        raise NotImplementedError("Device not supported: {}.".format(device_name))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def compute_mfu(
 | 
			
		||||
    model_name_or_path: str,
 | 
			
		||||
    batch_size: int,
 | 
			
		||||
    seqlen: int,
 | 
			
		||||
    model_config: dict,
 | 
			
		||||
    num_iter: int,
 | 
			
		||||
    seconds: float,
 | 
			
		||||
    num_gpus: int = 1,
 | 
			
		||||
    seq_length: int,
 | 
			
		||||
    finetuning_type: str = "lora",
 | 
			
		||||
    flash_attn: str = "auto",
 | 
			
		||||
    deepspeed_stage: int = 0,
 | 
			
		||||
    disable_gc: bool = False,
 | 
			
		||||
    liger_kernel: bool = False,
 | 
			
		||||
) -> float:
 | 
			
		||||
    r"""
 | 
			
		||||
    Computes MFU for given model and hyper-params.
 | 
			
		||||
    Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
 | 
			
		||||
    """
 | 
			
		||||
    compute MFU given model configuration, training config and training information
 | 
			
		||||
    """
 | 
			
		||||
    percentage = (num_iter * model_flops_counter(batch_size,seqlen,model_config)) / hardware_flops_counter(seconds, num_gpus)
 | 
			
		||||
    
 | 
			
		||||
    print(f"MFU : {percentage* 100:.2f}%")
 | 
			
		||||
    return percentage
 | 
			
		||||
    
 | 
			
		||||
# User input
 | 
			
		||||
    args = {
 | 
			
		||||
        "model_name_or_path": model_name_or_path,
 | 
			
		||||
        "flash_attn": flash_attn,
 | 
			
		||||
        "disable_gradient_checkpointing": disable_gc,
 | 
			
		||||
        "enable_liger_kernel": liger_kernel,
 | 
			
		||||
        "stage": "pt",
 | 
			
		||||
        "do_train": True,
 | 
			
		||||
        "finetuning_type": finetuning_type,
 | 
			
		||||
        "dataset": "c4_demo",
 | 
			
		||||
        "cutoff_len": seq_length,
 | 
			
		||||
        "output_dir": os.path.join("saves", "test_mfu"),
 | 
			
		||||
        "overwrite_output_dir": True,
 | 
			
		||||
        "per_device_train_batch_size": batch_size,
 | 
			
		||||
        "max_steps": 100,
 | 
			
		||||
        "bf16": True,
 | 
			
		||||
    }
 | 
			
		||||
    if deepspeed_stage in [2, 3]:
 | 
			
		||||
        args["deepspeed"] = "examples/deepspeed/ds_z{}_config.json".format(deepspeed_stage)
 | 
			
		||||
 | 
			
		||||
### model_name
 | 
			
		||||
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
 | 
			
		||||
    run_exp(args)
 | 
			
		||||
    with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f:
 | 
			
		||||
        result = json.load(f)
 | 
			
		||||
 | 
			
		||||
### training config
 | 
			
		||||
batch_size = 8
 | 
			
		||||
seqlen = 1*1024
 | 
			
		||||
num_gpus = 1
 | 
			
		||||
 | 
			
		||||
### training information
 | 
			
		||||
num_iter = 225
 | 
			
		||||
seconds = 605 # time used in {num_iter} iterations
 | 
			
		||||
 | 
			
		||||
model_config = AutoConfig.from_pretrained(model_name)
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    fire.Fire( 
 | 
			
		||||
        compute_mfu(
 | 
			
		||||
            batch_size=batch_size,
 | 
			
		||||
            seqlen=seqlen,
 | 
			
		||||
            model_config=model_config,
 | 
			
		||||
            num_iter=num_iter,
 | 
			
		||||
            seconds=seconds,
 | 
			
		||||
            num_gpus=num_gpus
 | 
			
		||||
        )
 | 
			
		||||
    mfu_value = (
 | 
			
		||||
        result["train_samples_per_second"]
 | 
			
		||||
        * compute_model_flops(model_name_or_path, batch_size, seq_length)
 | 
			
		||||
        / compute_device_flops()
 | 
			
		||||
    )
 | 
			
		||||
    print("MFU: {:.2f}%".format(mfu_value * 100))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    fire.Fire(compute_mfu)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user