Former-commit-id: c9b3870adb60a2aca8cfd82c1a8b8044319bacbc
This commit is contained in:
hiyouga 2024-09-08 23:18:08 +08:00
parent ec6b85d8f9
commit 0229263fbe
2 changed files with 59 additions and 47 deletions

View File

@ -18,6 +18,7 @@ import os
import fire import fire
import torch import torch
import torch.distributed as dist
from transformers import AutoConfig from transformers import AutoConfig
from llamafactory.train.tuner import run_exp from llamafactory.train.tuner import run_exp
@ -28,7 +29,7 @@ BASE = 2 # gemm (add + mul)
def compute_model_flops( def compute_model_flops(
model_name_or_path: str, model_name_or_path: str,
batch_size: int, total_batch_size: int,
seq_length: int, seq_length: int,
include_backward: bool = True, include_backward: bool = True,
include_recompute: bool = False, include_recompute: bool = False,
@ -48,7 +49,7 @@ def compute_model_flops(
# mlp module # mlp module
mlp_flops_per_token = 3 * BASE * hidden_size * intermediate_size # up, gate, down mlp_flops_per_token = 3 * BASE * hidden_size * intermediate_size # up, gate, down
mlp_flops = batch_size * seq_length * num_hidden_layers * mlp_flops_per_token mlp_flops = total_batch_size * seq_length * num_hidden_layers * mlp_flops_per_token
# attn projector module # attn projector module
q_flops_per_token = BASE * hidden_size * hidden_size q_flops_per_token = BASE * hidden_size * hidden_size
@ -56,15 +57,15 @@ def compute_model_flops(
k_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads k_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
v_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads v_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
attn_proj_flops_per_token = q_flops_per_token + o_flops_per_token + k_flops_per_token + v_flops_per_token attn_proj_flops_per_token = q_flops_per_token + o_flops_per_token + k_flops_per_token + v_flops_per_token
attn_proj_flops = batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token attn_proj_flops = total_batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token
# attn sdpa module # attn sdpa module
sdpa_flops_per_layer = 2 * BASE * hidden_size * seq_length * seq_length # (q * k^T) * v sdpa_flops_per_layer = 2 * BASE * hidden_size * seq_length * seq_length # (q * k^T) * v
sdpa_flops = batch_size * num_hidden_layers * sdpa_flops_per_layer sdpa_flops = total_batch_size * num_hidden_layers * sdpa_flops_per_layer
# embedding module # embedding module
embedding_flops_per_token = hidden_size * vocab_size embedding_flops_per_token = hidden_size * vocab_size
embedding_flops = batch_size * seq_length * embedding_flops_per_token embedding_flops = total_batch_size * seq_length * embedding_flops_per_token
if tie_word_embeddings is False: if tie_word_embeddings is False:
embedding_flops *= 2 embedding_flops *= 2
@ -85,17 +86,19 @@ def compute_model_flops(
return total_flops return total_flops
def compute_device_flops() -> float: def compute_device_flops(world_size: int) -> float:
r"""
Calculates the FLOPs of the device capability per second.
"""
device_name = torch.cuda.get_device_name() device_name = torch.cuda.get_device_name()
device_count = torch.cuda.device_count()
if "H100" in device_name or "H800" in device_name: if "H100" in device_name or "H800" in device_name:
return 989 * 1e12 * device_count return 989 * 1e12 * world_size
elif "A100" in device_name or "A800" in device_name: elif "A100" in device_name or "A800" in device_name:
return 312 * 1e12 * device_count return 312 * 1e12 * world_size
elif "V100" in device_name: elif "V100" in device_name:
return 125 * 1e12 * device_count return 125 * 1e12 * world_size
elif "4090" in device_name: elif "4090" in device_name:
return 98 * 1e12 * device_count return 98 * 1e12 * world_size
else: else:
raise NotImplementedError("Device not supported: {}.".format(device_name)) raise NotImplementedError("Device not supported: {}.".format(device_name))
@ -140,10 +143,16 @@ def calculate_mfu(
with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f: with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f:
result = json.load(f) result = json.load(f)
if dist.is_initialized():
world_size = dist.get_world_size()
else:
world_size = 1
total_batch_size = batch_size * world_size
mfu_value = ( mfu_value = (
result["train_steps_per_second"] result["train_steps_per_second"]
* compute_model_flops(model_name_or_path, batch_size, seq_length) * compute_model_flops(model_name_or_path, total_batch_size, seq_length)
/ compute_device_flops() / compute_device_flops(world_size)
) )
print("MFU: {:.2f}%".format(mfu_value * 100)) print("MFU: {:.2f}%".format(mfu_value * 100))

View File

@ -21,7 +21,7 @@
import inspect import inspect
from functools import partial, wraps from functools import partial, wraps
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
import torch import torch
@ -38,48 +38,51 @@ if TYPE_CHECKING:
logger = get_logger(__name__) logger = get_logger(__name__)
class UnslothGradientCheckpointing(torch.autograd.Function): def get_unsloth_gradient_checkpointing_func() -> Callable:
r""" class UnslothGradientCheckpointing(torch.autograd.Function):
Saves VRAM by smartly offloading to RAM. r"""
""" Saves VRAM by smartly offloading to RAM.
"""
@staticmethod @staticmethod
@torch.cuda.amp.custom_fwd @torch.cuda.amp.custom_fwd
def forward( def forward(
ctx: "torch.autograd.Function", ctx: "torch.autograd.Function",
forward_function: "torch.Module", forward_function: "torch.Module",
hidden_states: "torch.Tensor", hidden_states: "torch.Tensor",
*args: Union["torch.Tensor", Any], *args: Union["torch.Tensor", Any],
) -> "torch.Tensor": ) -> "torch.Tensor":
saved_hidden_states = hidden_states.to("cpu", non_blocking=True) saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad(): with torch.no_grad():
output = forward_function(hidden_states, *args) output = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states) ctx.save_for_backward(saved_hidden_states)
ctx.forward_function = forward_function ctx.forward_function = forward_function
ctx.args = args ctx.args = args
return output return output
@staticmethod @staticmethod
@torch.cuda.amp.custom_bwd @torch.cuda.amp.custom_bwd
def backward(ctx: "torch.autograd.Function", grad_output: "torch.Tensor") -> "torch.Tensor": def backward(ctx: "torch.autograd.Function", grad_output: "torch.Tensor") -> "torch.Tensor":
(hidden_states,) = ctx.saved_tensors (hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to("cuda", non_blocking=True).detach() hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad_(True) hidden_states.requires_grad_(True)
with torch.enable_grad(): with torch.enable_grad():
(output,) = ctx.forward_function(hidden_states, *ctx.args) (output,) = ctx.forward_function(hidden_states, *ctx.args)
torch.autograd.backward(output, grad_output) torch.autograd.backward(output, grad_output)
return (None, hidden_states.grad) + (None,) * len(ctx.args) return (None, hidden_states.grad) + (None,) * len(ctx.args)
return UnslothGradientCheckpointing.apply
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func): def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable) -> Callable:
r""" r"""
Only applies gradient checkpointing to trainable layers. Only applies gradient checkpointing to trainable layers.
""" """
@wraps(gradient_checkpointing_func) @wraps(gradient_checkpointing_func)
def custom_gradient_checkpointing_func(func, *args: Union["torch.Tensor", Any], **kwargs): def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
module: "torch.nn.Module" = func.__self__ module: "torch.nn.Module" = func.__self__
if any(param.requires_grad for param in module.parameters()): if any(param.requires_grad for param in module.parameters()):
@ -89,7 +92,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func):
return gradient_checkpointing_func(func, *args, **kwargs) return gradient_checkpointing_func(func, *args, **kwargs)
if hasattr(gradient_checkpointing_func, "__self__"): # fix test case if hasattr(gradient_checkpointing_func, "__self__"): # fix unsloth gc test case
custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__ custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
return custom_gradient_checkpointing_func return custom_gradient_checkpointing_func
@ -114,7 +117,7 @@ def _gradient_checkpointing_enable(
gradient_checkpointing_kwargs = {"use_reentrant": True} gradient_checkpointing_kwargs = {"use_reentrant": True}
if use_unsloth_gc: if use_unsloth_gc:
gradient_checkpointing_func = UnslothGradientCheckpointing.apply gradient_checkpointing_func = get_unsloth_gradient_checkpointing_func()
else: else:
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs) gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)