Former-commit-id: c9b3870adb60a2aca8cfd82c1a8b8044319bacbc
This commit is contained in:
hiyouga 2024-09-08 23:18:08 +08:00
parent ec6b85d8f9
commit 0229263fbe
2 changed files with 59 additions and 47 deletions

View File

@ -18,6 +18,7 @@ import os
import fire import fire
import torch import torch
import torch.distributed as dist
from transformers import AutoConfig from transformers import AutoConfig
from llamafactory.train.tuner import run_exp from llamafactory.train.tuner import run_exp
@ -28,7 +29,7 @@ BASE = 2 # gemm (add + mul)
def compute_model_flops( def compute_model_flops(
model_name_or_path: str, model_name_or_path: str,
batch_size: int, total_batch_size: int,
seq_length: int, seq_length: int,
include_backward: bool = True, include_backward: bool = True,
include_recompute: bool = False, include_recompute: bool = False,
@ -48,7 +49,7 @@ def compute_model_flops(
# mlp module # mlp module
mlp_flops_per_token = 3 * BASE * hidden_size * intermediate_size # up, gate, down mlp_flops_per_token = 3 * BASE * hidden_size * intermediate_size # up, gate, down
mlp_flops = batch_size * seq_length * num_hidden_layers * mlp_flops_per_token mlp_flops = total_batch_size * seq_length * num_hidden_layers * mlp_flops_per_token
# attn projector module # attn projector module
q_flops_per_token = BASE * hidden_size * hidden_size q_flops_per_token = BASE * hidden_size * hidden_size
@ -56,15 +57,15 @@ def compute_model_flops(
k_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads k_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
v_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads v_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
attn_proj_flops_per_token = q_flops_per_token + o_flops_per_token + k_flops_per_token + v_flops_per_token attn_proj_flops_per_token = q_flops_per_token + o_flops_per_token + k_flops_per_token + v_flops_per_token
attn_proj_flops = batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token attn_proj_flops = total_batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token
# attn sdpa module # attn sdpa module
sdpa_flops_per_layer = 2 * BASE * hidden_size * seq_length * seq_length # (q * k^T) * v sdpa_flops_per_layer = 2 * BASE * hidden_size * seq_length * seq_length # (q * k^T) * v
sdpa_flops = batch_size * num_hidden_layers * sdpa_flops_per_layer sdpa_flops = total_batch_size * num_hidden_layers * sdpa_flops_per_layer
# embedding module # embedding module
embedding_flops_per_token = hidden_size * vocab_size embedding_flops_per_token = hidden_size * vocab_size
embedding_flops = batch_size * seq_length * embedding_flops_per_token embedding_flops = total_batch_size * seq_length * embedding_flops_per_token
if tie_word_embeddings is False: if tie_word_embeddings is False:
embedding_flops *= 2 embedding_flops *= 2
@ -85,17 +86,19 @@ def compute_model_flops(
return total_flops return total_flops
def compute_device_flops() -> float: def compute_device_flops(world_size: int) -> float:
r"""
Calculates the FLOPs of the device capability per second.
"""
device_name = torch.cuda.get_device_name() device_name = torch.cuda.get_device_name()
device_count = torch.cuda.device_count()
if "H100" in device_name or "H800" in device_name: if "H100" in device_name or "H800" in device_name:
return 989 * 1e12 * device_count return 989 * 1e12 * world_size
elif "A100" in device_name or "A800" in device_name: elif "A100" in device_name or "A800" in device_name:
return 312 * 1e12 * device_count return 312 * 1e12 * world_size
elif "V100" in device_name: elif "V100" in device_name:
return 125 * 1e12 * device_count return 125 * 1e12 * world_size
elif "4090" in device_name: elif "4090" in device_name:
return 98 * 1e12 * device_count return 98 * 1e12 * world_size
else: else:
raise NotImplementedError("Device not supported: {}.".format(device_name)) raise NotImplementedError("Device not supported: {}.".format(device_name))
@ -140,10 +143,16 @@ def calculate_mfu(
with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f: with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f:
result = json.load(f) result = json.load(f)
if dist.is_initialized():
world_size = dist.get_world_size()
else:
world_size = 1
total_batch_size = batch_size * world_size
mfu_value = ( mfu_value = (
result["train_steps_per_second"] result["train_steps_per_second"]
* compute_model_flops(model_name_or_path, batch_size, seq_length) * compute_model_flops(model_name_or_path, total_batch_size, seq_length)
/ compute_device_flops() / compute_device_flops(world_size)
) )
print("MFU: {:.2f}%".format(mfu_value * 100)) print("MFU: {:.2f}%".format(mfu_value * 100))

View File

@ -21,7 +21,7 @@
import inspect import inspect
from functools import partial, wraps from functools import partial, wraps
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
import torch import torch
@ -38,6 +38,7 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
def get_unsloth_gradient_checkpointing_func() -> Callable:
class UnslothGradientCheckpointing(torch.autograd.Function): class UnslothGradientCheckpointing(torch.autograd.Function):
r""" r"""
Saves VRAM by smartly offloading to RAM. Saves VRAM by smartly offloading to RAM.
@ -72,14 +73,16 @@ class UnslothGradientCheckpointing(torch.autograd.Function):
torch.autograd.backward(output, grad_output) torch.autograd.backward(output, grad_output)
return (None, hidden_states.grad) + (None,) * len(ctx.args) return (None, hidden_states.grad) + (None,) * len(ctx.args)
return UnslothGradientCheckpointing.apply
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func):
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable) -> Callable:
r""" r"""
Only applies gradient checkpointing to trainable layers. Only applies gradient checkpointing to trainable layers.
""" """
@wraps(gradient_checkpointing_func) @wraps(gradient_checkpointing_func)
def custom_gradient_checkpointing_func(func, *args: Union["torch.Tensor", Any], **kwargs): def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
module: "torch.nn.Module" = func.__self__ module: "torch.nn.Module" = func.__self__
if any(param.requires_grad for param in module.parameters()): if any(param.requires_grad for param in module.parameters()):
@ -89,7 +92,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func):
return gradient_checkpointing_func(func, *args, **kwargs) return gradient_checkpointing_func(func, *args, **kwargs)
if hasattr(gradient_checkpointing_func, "__self__"): # fix test case if hasattr(gradient_checkpointing_func, "__self__"): # fix unsloth gc test case
custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__ custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
return custom_gradient_checkpointing_func return custom_gradient_checkpointing_func
@ -114,7 +117,7 @@ def _gradient_checkpointing_enable(
gradient_checkpointing_kwargs = {"use_reentrant": True} gradient_checkpointing_kwargs = {"use_reentrant": True}
if use_unsloth_gc: if use_unsloth_gc:
gradient_checkpointing_func = UnslothGradientCheckpointing.apply gradient_checkpointing_func = get_unsloth_gradient_checkpointing_func()
else: else:
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs) gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)