mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
tiny fix
Former-commit-id: c9b3870adb60a2aca8cfd82c1a8b8044319bacbc
This commit is contained in:
parent
ec6b85d8f9
commit
0229263fbe
@ -18,6 +18,7 @@ import os
|
|||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
from llamafactory.train.tuner import run_exp
|
from llamafactory.train.tuner import run_exp
|
||||||
@ -28,7 +29,7 @@ BASE = 2 # gemm (add + mul)
|
|||||||
|
|
||||||
def compute_model_flops(
|
def compute_model_flops(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
batch_size: int,
|
total_batch_size: int,
|
||||||
seq_length: int,
|
seq_length: int,
|
||||||
include_backward: bool = True,
|
include_backward: bool = True,
|
||||||
include_recompute: bool = False,
|
include_recompute: bool = False,
|
||||||
@ -48,7 +49,7 @@ def compute_model_flops(
|
|||||||
|
|
||||||
# mlp module
|
# mlp module
|
||||||
mlp_flops_per_token = 3 * BASE * hidden_size * intermediate_size # up, gate, down
|
mlp_flops_per_token = 3 * BASE * hidden_size * intermediate_size # up, gate, down
|
||||||
mlp_flops = batch_size * seq_length * num_hidden_layers * mlp_flops_per_token
|
mlp_flops = total_batch_size * seq_length * num_hidden_layers * mlp_flops_per_token
|
||||||
|
|
||||||
# attn projector module
|
# attn projector module
|
||||||
q_flops_per_token = BASE * hidden_size * hidden_size
|
q_flops_per_token = BASE * hidden_size * hidden_size
|
||||||
@ -56,15 +57,15 @@ def compute_model_flops(
|
|||||||
k_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
|
k_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
|
||||||
v_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
|
v_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
|
||||||
attn_proj_flops_per_token = q_flops_per_token + o_flops_per_token + k_flops_per_token + v_flops_per_token
|
attn_proj_flops_per_token = q_flops_per_token + o_flops_per_token + k_flops_per_token + v_flops_per_token
|
||||||
attn_proj_flops = batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token
|
attn_proj_flops = total_batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token
|
||||||
|
|
||||||
# attn sdpa module
|
# attn sdpa module
|
||||||
sdpa_flops_per_layer = 2 * BASE * hidden_size * seq_length * seq_length # (q * k^T) * v
|
sdpa_flops_per_layer = 2 * BASE * hidden_size * seq_length * seq_length # (q * k^T) * v
|
||||||
sdpa_flops = batch_size * num_hidden_layers * sdpa_flops_per_layer
|
sdpa_flops = total_batch_size * num_hidden_layers * sdpa_flops_per_layer
|
||||||
|
|
||||||
# embedding module
|
# embedding module
|
||||||
embedding_flops_per_token = hidden_size * vocab_size
|
embedding_flops_per_token = hidden_size * vocab_size
|
||||||
embedding_flops = batch_size * seq_length * embedding_flops_per_token
|
embedding_flops = total_batch_size * seq_length * embedding_flops_per_token
|
||||||
if tie_word_embeddings is False:
|
if tie_word_embeddings is False:
|
||||||
embedding_flops *= 2
|
embedding_flops *= 2
|
||||||
|
|
||||||
@ -85,17 +86,19 @@ def compute_model_flops(
|
|||||||
return total_flops
|
return total_flops
|
||||||
|
|
||||||
|
|
||||||
def compute_device_flops() -> float:
|
def compute_device_flops(world_size: int) -> float:
|
||||||
|
r"""
|
||||||
|
Calculates the FLOPs of the device capability per second.
|
||||||
|
"""
|
||||||
device_name = torch.cuda.get_device_name()
|
device_name = torch.cuda.get_device_name()
|
||||||
device_count = torch.cuda.device_count()
|
|
||||||
if "H100" in device_name or "H800" in device_name:
|
if "H100" in device_name or "H800" in device_name:
|
||||||
return 989 * 1e12 * device_count
|
return 989 * 1e12 * world_size
|
||||||
elif "A100" in device_name or "A800" in device_name:
|
elif "A100" in device_name or "A800" in device_name:
|
||||||
return 312 * 1e12 * device_count
|
return 312 * 1e12 * world_size
|
||||||
elif "V100" in device_name:
|
elif "V100" in device_name:
|
||||||
return 125 * 1e12 * device_count
|
return 125 * 1e12 * world_size
|
||||||
elif "4090" in device_name:
|
elif "4090" in device_name:
|
||||||
return 98 * 1e12 * device_count
|
return 98 * 1e12 * world_size
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Device not supported: {}.".format(device_name))
|
raise NotImplementedError("Device not supported: {}.".format(device_name))
|
||||||
|
|
||||||
@ -140,10 +143,16 @@ def calculate_mfu(
|
|||||||
with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f:
|
with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f:
|
||||||
result = json.load(f)
|
result = json.load(f)
|
||||||
|
|
||||||
|
if dist.is_initialized():
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
else:
|
||||||
|
world_size = 1
|
||||||
|
|
||||||
|
total_batch_size = batch_size * world_size
|
||||||
mfu_value = (
|
mfu_value = (
|
||||||
result["train_steps_per_second"]
|
result["train_steps_per_second"]
|
||||||
* compute_model_flops(model_name_or_path, batch_size, seq_length)
|
* compute_model_flops(model_name_or_path, total_batch_size, seq_length)
|
||||||
/ compute_device_flops()
|
/ compute_device_flops(world_size)
|
||||||
)
|
)
|
||||||
print("MFU: {:.2f}%".format(mfu_value * 100))
|
print("MFU: {:.2f}%".format(mfu_value * 100))
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@
|
|||||||
import inspect
|
import inspect
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -38,48 +38,51 @@ if TYPE_CHECKING:
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UnslothGradientCheckpointing(torch.autograd.Function):
|
def get_unsloth_gradient_checkpointing_func() -> Callable:
|
||||||
r"""
|
class UnslothGradientCheckpointing(torch.autograd.Function):
|
||||||
Saves VRAM by smartly offloading to RAM.
|
r"""
|
||||||
"""
|
Saves VRAM by smartly offloading to RAM.
|
||||||
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch.cuda.amp.custom_fwd
|
@torch.cuda.amp.custom_fwd
|
||||||
def forward(
|
def forward(
|
||||||
ctx: "torch.autograd.Function",
|
ctx: "torch.autograd.Function",
|
||||||
forward_function: "torch.Module",
|
forward_function: "torch.Module",
|
||||||
hidden_states: "torch.Tensor",
|
hidden_states: "torch.Tensor",
|
||||||
*args: Union["torch.Tensor", Any],
|
*args: Union["torch.Tensor", Any],
|
||||||
) -> "torch.Tensor":
|
) -> "torch.Tensor":
|
||||||
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = forward_function(hidden_states, *args)
|
output = forward_function(hidden_states, *args)
|
||||||
|
|
||||||
ctx.save_for_backward(saved_hidden_states)
|
ctx.save_for_backward(saved_hidden_states)
|
||||||
ctx.forward_function = forward_function
|
ctx.forward_function = forward_function
|
||||||
ctx.args = args
|
ctx.args = args
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch.cuda.amp.custom_bwd
|
@torch.cuda.amp.custom_bwd
|
||||||
def backward(ctx: "torch.autograd.Function", grad_output: "torch.Tensor") -> "torch.Tensor":
|
def backward(ctx: "torch.autograd.Function", grad_output: "torch.Tensor") -> "torch.Tensor":
|
||||||
(hidden_states,) = ctx.saved_tensors
|
(hidden_states,) = ctx.saved_tensors
|
||||||
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
||||||
hidden_states.requires_grad_(True)
|
hidden_states.requires_grad_(True)
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
(output,) = ctx.forward_function(hidden_states, *ctx.args)
|
(output,) = ctx.forward_function(hidden_states, *ctx.args)
|
||||||
|
|
||||||
torch.autograd.backward(output, grad_output)
|
torch.autograd.backward(output, grad_output)
|
||||||
return (None, hidden_states.grad) + (None,) * len(ctx.args)
|
return (None, hidden_states.grad) + (None,) * len(ctx.args)
|
||||||
|
|
||||||
|
return UnslothGradientCheckpointing.apply
|
||||||
|
|
||||||
|
|
||||||
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func):
|
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable) -> Callable:
|
||||||
r"""
|
r"""
|
||||||
Only applies gradient checkpointing to trainable layers.
|
Only applies gradient checkpointing to trainable layers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@wraps(gradient_checkpointing_func)
|
@wraps(gradient_checkpointing_func)
|
||||||
def custom_gradient_checkpointing_func(func, *args: Union["torch.Tensor", Any], **kwargs):
|
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
|
||||||
module: "torch.nn.Module" = func.__self__
|
module: "torch.nn.Module" = func.__self__
|
||||||
|
|
||||||
if any(param.requires_grad for param in module.parameters()):
|
if any(param.requires_grad for param in module.parameters()):
|
||||||
@ -89,7 +92,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func):
|
|||||||
|
|
||||||
return gradient_checkpointing_func(func, *args, **kwargs)
|
return gradient_checkpointing_func(func, *args, **kwargs)
|
||||||
|
|
||||||
if hasattr(gradient_checkpointing_func, "__self__"): # fix test case
|
if hasattr(gradient_checkpointing_func, "__self__"): # fix unsloth gc test case
|
||||||
custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
|
custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
|
||||||
|
|
||||||
return custom_gradient_checkpointing_func
|
return custom_gradient_checkpointing_func
|
||||||
@ -114,7 +117,7 @@ def _gradient_checkpointing_enable(
|
|||||||
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
gradient_checkpointing_kwargs = {"use_reentrant": True}
|
||||||
|
|
||||||
if use_unsloth_gc:
|
if use_unsloth_gc:
|
||||||
gradient_checkpointing_func = UnslothGradientCheckpointing.apply
|
gradient_checkpointing_func = get_unsloth_gradient_checkpointing_func()
|
||||||
else:
|
else:
|
||||||
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
|
gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user