From 0229263fbe2eb7661e67e83c83c5147e5a341ef7 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Sun, 8 Sep 2024 23:18:08 +0800 Subject: [PATCH] tiny fix Former-commit-id: c9b3870adb60a2aca8cfd82c1a8b8044319bacbc --- scripts/cal_mfu.py | 35 +++++---- .../model/model_utils/checkpointing.py | 71 ++++++++++--------- 2 files changed, 59 insertions(+), 47 deletions(-) diff --git a/scripts/cal_mfu.py b/scripts/cal_mfu.py index f4d5376b..a04c388a 100644 --- a/scripts/cal_mfu.py +++ b/scripts/cal_mfu.py @@ -18,6 +18,7 @@ import os import fire import torch +import torch.distributed as dist from transformers import AutoConfig from llamafactory.train.tuner import run_exp @@ -28,7 +29,7 @@ BASE = 2 # gemm (add + mul) def compute_model_flops( model_name_or_path: str, - batch_size: int, + total_batch_size: int, seq_length: int, include_backward: bool = True, include_recompute: bool = False, @@ -48,7 +49,7 @@ def compute_model_flops( # mlp module mlp_flops_per_token = 3 * BASE * hidden_size * intermediate_size # up, gate, down - mlp_flops = batch_size * seq_length * num_hidden_layers * mlp_flops_per_token + mlp_flops = total_batch_size * seq_length * num_hidden_layers * mlp_flops_per_token # attn projector module q_flops_per_token = BASE * hidden_size * hidden_size @@ -56,15 +57,15 @@ def compute_model_flops( k_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads v_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads attn_proj_flops_per_token = q_flops_per_token + o_flops_per_token + k_flops_per_token + v_flops_per_token - attn_proj_flops = batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token + attn_proj_flops = total_batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token # attn sdpa module sdpa_flops_per_layer = 2 * BASE * hidden_size * seq_length * seq_length # (q * k^T) * v - sdpa_flops = batch_size * num_hidden_layers * sdpa_flops_per_layer + sdpa_flops = total_batch_size * num_hidden_layers * sdpa_flops_per_layer # embedding module embedding_flops_per_token = hidden_size * vocab_size - embedding_flops = batch_size * seq_length * embedding_flops_per_token + embedding_flops = total_batch_size * seq_length * embedding_flops_per_token if tie_word_embeddings is False: embedding_flops *= 2 @@ -85,17 +86,19 @@ def compute_model_flops( return total_flops -def compute_device_flops() -> float: +def compute_device_flops(world_size: int) -> float: + r""" + Calculates the FLOPs of the device capability per second. + """ device_name = torch.cuda.get_device_name() - device_count = torch.cuda.device_count() if "H100" in device_name or "H800" in device_name: - return 989 * 1e12 * device_count + return 989 * 1e12 * world_size elif "A100" in device_name or "A800" in device_name: - return 312 * 1e12 * device_count + return 312 * 1e12 * world_size elif "V100" in device_name: - return 125 * 1e12 * device_count + return 125 * 1e12 * world_size elif "4090" in device_name: - return 98 * 1e12 * device_count + return 98 * 1e12 * world_size else: raise NotImplementedError("Device not supported: {}.".format(device_name)) @@ -140,10 +143,16 @@ def calculate_mfu( with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f: result = json.load(f) + if dist.is_initialized(): + world_size = dist.get_world_size() + else: + world_size = 1 + + total_batch_size = batch_size * world_size mfu_value = ( result["train_steps_per_second"] - * compute_model_flops(model_name_or_path, batch_size, seq_length) - / compute_device_flops() + * compute_model_flops(model_name_or_path, total_batch_size, seq_length) + / compute_device_flops(world_size) ) print("MFU: {:.2f}%".format(mfu_value * 100)) diff --git a/src/llamafactory/model/model_utils/checkpointing.py b/src/llamafactory/model/model_utils/checkpointing.py index 0c5c98ec..412b4779 100644 --- a/src/llamafactory/model/model_utils/checkpointing.py +++ b/src/llamafactory/model/model_utils/checkpointing.py @@ -21,7 +21,7 @@ import inspect from functools import partial, wraps from types import MethodType -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union import torch @@ -38,48 +38,51 @@ if TYPE_CHECKING: logger = get_logger(__name__) -class UnslothGradientCheckpointing(torch.autograd.Function): - r""" - Saves VRAM by smartly offloading to RAM. - """ +def get_unsloth_gradient_checkpointing_func() -> Callable: + class UnslothGradientCheckpointing(torch.autograd.Function): + r""" + Saves VRAM by smartly offloading to RAM. + """ - @staticmethod - @torch.cuda.amp.custom_fwd - def forward( - ctx: "torch.autograd.Function", - forward_function: "torch.Module", - hidden_states: "torch.Tensor", - *args: Union["torch.Tensor", Any], - ) -> "torch.Tensor": - saved_hidden_states = hidden_states.to("cpu", non_blocking=True) - with torch.no_grad(): - output = forward_function(hidden_states, *args) + @staticmethod + @torch.cuda.amp.custom_fwd + def forward( + ctx: "torch.autograd.Function", + forward_function: "torch.Module", + hidden_states: "torch.Tensor", + *args: Union["torch.Tensor", Any], + ) -> "torch.Tensor": + saved_hidden_states = hidden_states.to("cpu", non_blocking=True) + with torch.no_grad(): + output = forward_function(hidden_states, *args) - ctx.save_for_backward(saved_hidden_states) - ctx.forward_function = forward_function - ctx.args = args - return output + ctx.save_for_backward(saved_hidden_states) + ctx.forward_function = forward_function + ctx.args = args + return output - @staticmethod - @torch.cuda.amp.custom_bwd - def backward(ctx: "torch.autograd.Function", grad_output: "torch.Tensor") -> "torch.Tensor": - (hidden_states,) = ctx.saved_tensors - hidden_states = hidden_states.to("cuda", non_blocking=True).detach() - hidden_states.requires_grad_(True) - with torch.enable_grad(): - (output,) = ctx.forward_function(hidden_states, *ctx.args) + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx: "torch.autograd.Function", grad_output: "torch.Tensor") -> "torch.Tensor": + (hidden_states,) = ctx.saved_tensors + hidden_states = hidden_states.to("cuda", non_blocking=True).detach() + hidden_states.requires_grad_(True) + with torch.enable_grad(): + (output,) = ctx.forward_function(hidden_states, *ctx.args) - torch.autograd.backward(output, grad_output) - return (None, hidden_states.grad) + (None,) * len(ctx.args) + torch.autograd.backward(output, grad_output) + return (None, hidden_states.grad) + (None,) * len(ctx.args) + + return UnslothGradientCheckpointing.apply -def get_custom_gradient_checkpointing_func(gradient_checkpointing_func): +def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable) -> Callable: r""" Only applies gradient checkpointing to trainable layers. """ @wraps(gradient_checkpointing_func) - def custom_gradient_checkpointing_func(func, *args: Union["torch.Tensor", Any], **kwargs): + def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs): module: "torch.nn.Module" = func.__self__ if any(param.requires_grad for param in module.parameters()): @@ -89,7 +92,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func): return gradient_checkpointing_func(func, *args, **kwargs) - if hasattr(gradient_checkpointing_func, "__self__"): # fix test case + if hasattr(gradient_checkpointing_func, "__self__"): # fix unsloth gc test case custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__ return custom_gradient_checkpointing_func @@ -114,7 +117,7 @@ def _gradient_checkpointing_enable( gradient_checkpointing_kwargs = {"use_reentrant": True} if use_unsloth_gc: - gradient_checkpointing_func = UnslothGradientCheckpointing.apply + gradient_checkpointing_func = get_unsloth_gradient_checkpointing_func() else: gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)