mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	tiny fix
Former-commit-id: 76177039c8f9ef5a63724a339dae6195d89fa215
This commit is contained in:
		
							parent
							
								
									3259397f89
								
							
						
					
					
						commit
						3cbc9109ea
					
				@ -18,6 +18,7 @@ import os
 | 
			
		||||
 | 
			
		||||
import fire
 | 
			
		||||
import torch
 | 
			
		||||
import torch.distributed as dist
 | 
			
		||||
from transformers import AutoConfig
 | 
			
		||||
 | 
			
		||||
from llamafactory.train.tuner import run_exp
 | 
			
		||||
@ -28,7 +29,7 @@ BASE = 2  # gemm (add + mul)
 | 
			
		||||
 | 
			
		||||
def compute_model_flops(
 | 
			
		||||
    model_name_or_path: str,
 | 
			
		||||
    batch_size: int,
 | 
			
		||||
    total_batch_size: int,
 | 
			
		||||
    seq_length: int,
 | 
			
		||||
    include_backward: bool = True,
 | 
			
		||||
    include_recompute: bool = False,
 | 
			
		||||
@ -48,7 +49,7 @@ def compute_model_flops(
 | 
			
		||||
 | 
			
		||||
    # mlp module
 | 
			
		||||
    mlp_flops_per_token = 3 * BASE * hidden_size * intermediate_size  # up, gate, down
 | 
			
		||||
    mlp_flops = batch_size * seq_length * num_hidden_layers * mlp_flops_per_token
 | 
			
		||||
    mlp_flops = total_batch_size * seq_length * num_hidden_layers * mlp_flops_per_token
 | 
			
		||||
 | 
			
		||||
    # attn projector module
 | 
			
		||||
    q_flops_per_token = BASE * hidden_size * hidden_size
 | 
			
		||||
@ -56,15 +57,15 @@ def compute_model_flops(
 | 
			
		||||
    k_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
 | 
			
		||||
    v_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
 | 
			
		||||
    attn_proj_flops_per_token = q_flops_per_token + o_flops_per_token + k_flops_per_token + v_flops_per_token
 | 
			
		||||
    attn_proj_flops = batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token
 | 
			
		||||
    attn_proj_flops = total_batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token
 | 
			
		||||
 | 
			
		||||
    # attn sdpa module
 | 
			
		||||
    sdpa_flops_per_layer = 2 * BASE * hidden_size * seq_length * seq_length  # (q * k^T) * v
 | 
			
		||||
    sdpa_flops = batch_size * num_hidden_layers * sdpa_flops_per_layer
 | 
			
		||||
    sdpa_flops = total_batch_size * num_hidden_layers * sdpa_flops_per_layer
 | 
			
		||||
 | 
			
		||||
    # embedding module
 | 
			
		||||
    embedding_flops_per_token = hidden_size * vocab_size
 | 
			
		||||
    embedding_flops = batch_size * seq_length * embedding_flops_per_token
 | 
			
		||||
    embedding_flops = total_batch_size * seq_length * embedding_flops_per_token
 | 
			
		||||
    if tie_word_embeddings is False:
 | 
			
		||||
        embedding_flops *= 2
 | 
			
		||||
 | 
			
		||||
@ -85,17 +86,19 @@ def compute_model_flops(
 | 
			
		||||
    return total_flops
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def compute_device_flops() -> float:
 | 
			
		||||
def compute_device_flops(world_size: int) -> float:
 | 
			
		||||
    r"""
 | 
			
		||||
    Calculates the FLOPs of the device capability per second.
 | 
			
		||||
    """
 | 
			
		||||
    device_name = torch.cuda.get_device_name()
 | 
			
		||||
    device_count = torch.cuda.device_count()
 | 
			
		||||
    if "H100" in device_name or "H800" in device_name:
 | 
			
		||||
        return 989 * 1e12 * device_count
 | 
			
		||||
        return 989 * 1e12 * world_size
 | 
			
		||||
    elif "A100" in device_name or "A800" in device_name:
 | 
			
		||||
        return 312 * 1e12 * device_count
 | 
			
		||||
        return 312 * 1e12 * world_size
 | 
			
		||||
    elif "V100" in device_name:
 | 
			
		||||
        return 125 * 1e12 * device_count
 | 
			
		||||
        return 125 * 1e12 * world_size
 | 
			
		||||
    elif "4090" in device_name:
 | 
			
		||||
        return 98 * 1e12 * device_count
 | 
			
		||||
        return 98 * 1e12 * world_size
 | 
			
		||||
    else:
 | 
			
		||||
        raise NotImplementedError("Device not supported: {}.".format(device_name))
 | 
			
		||||
 | 
			
		||||
@ -140,10 +143,16 @@ def calculate_mfu(
 | 
			
		||||
    with open(os.path.join("saves", "test_mfu", "all_results.json"), "r", encoding="utf-8") as f:
 | 
			
		||||
        result = json.load(f)
 | 
			
		||||
 | 
			
		||||
    if dist.is_initialized():
 | 
			
		||||
        world_size = dist.get_world_size()
 | 
			
		||||
    else:
 | 
			
		||||
        world_size = 1
 | 
			
		||||
 | 
			
		||||
    total_batch_size = batch_size * world_size
 | 
			
		||||
    mfu_value = (
 | 
			
		||||
        result["train_steps_per_second"]
 | 
			
		||||
        * compute_model_flops(model_name_or_path, batch_size, seq_length)
 | 
			
		||||
        / compute_device_flops()
 | 
			
		||||
        * compute_model_flops(model_name_or_path, total_batch_size, seq_length)
 | 
			
		||||
        / compute_device_flops(world_size)
 | 
			
		||||
    )
 | 
			
		||||
    print("MFU: {:.2f}%".format(mfu_value * 100))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -21,7 +21,7 @@
 | 
			
		||||
import inspect
 | 
			
		||||
from functools import partial, wraps
 | 
			
		||||
from types import MethodType
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
@ -38,48 +38,51 @@ if TYPE_CHECKING:
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UnslothGradientCheckpointing(torch.autograd.Function):
 | 
			
		||||
    r"""
 | 
			
		||||
    Saves VRAM by smartly offloading to RAM.
 | 
			
		||||
    """
 | 
			
		||||
def get_unsloth_gradient_checkpointing_func() -> Callable:
 | 
			
		||||
    class UnslothGradientCheckpointing(torch.autograd.Function):
 | 
			
		||||
        r"""
 | 
			
		||||
        Saves VRAM by smartly offloading to RAM.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @torch.cuda.amp.custom_fwd
 | 
			
		||||
    def forward(
 | 
			
		||||
        ctx: "torch.autograd.Function",
 | 
			
		||||
        forward_function: "torch.Module",
 | 
			
		||||
        hidden_states: "torch.Tensor",
 | 
			
		||||
        *args: Union["torch.Tensor", Any],
 | 
			
		||||
    ) -> "torch.Tensor":
 | 
			
		||||
        saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            output = forward_function(hidden_states, *args)
 | 
			
		||||
        @staticmethod
 | 
			
		||||
        @torch.cuda.amp.custom_fwd
 | 
			
		||||
        def forward(
 | 
			
		||||
            ctx: "torch.autograd.Function",
 | 
			
		||||
            forward_function: "torch.Module",
 | 
			
		||||
            hidden_states: "torch.Tensor",
 | 
			
		||||
            *args: Union["torch.Tensor", Any],
 | 
			
		||||
        ) -> "torch.Tensor":
 | 
			
		||||
            saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
 | 
			
		||||
            with torch.no_grad():
 | 
			
		||||
                output = forward_function(hidden_states, *args)
 | 
			
		||||
 | 
			
		||||
        ctx.save_for_backward(saved_hidden_states)
 | 
			
		||||
        ctx.forward_function = forward_function
 | 
			
		||||
        ctx.args = args
 | 
			
		||||
        return output
 | 
			
		||||
            ctx.save_for_backward(saved_hidden_states)
 | 
			
		||||
            ctx.forward_function = forward_function
 | 
			
		||||
            ctx.args = args
 | 
			
		||||
            return output
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @torch.cuda.amp.custom_bwd
 | 
			
		||||
    def backward(ctx: "torch.autograd.Function", grad_output: "torch.Tensor") -> "torch.Tensor":
 | 
			
		||||
        (hidden_states,) = ctx.saved_tensors
 | 
			
		||||
        hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
 | 
			
		||||
        hidden_states.requires_grad_(True)
 | 
			
		||||
        with torch.enable_grad():
 | 
			
		||||
            (output,) = ctx.forward_function(hidden_states, *ctx.args)
 | 
			
		||||
        @staticmethod
 | 
			
		||||
        @torch.cuda.amp.custom_bwd
 | 
			
		||||
        def backward(ctx: "torch.autograd.Function", grad_output: "torch.Tensor") -> "torch.Tensor":
 | 
			
		||||
            (hidden_states,) = ctx.saved_tensors
 | 
			
		||||
            hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
 | 
			
		||||
            hidden_states.requires_grad_(True)
 | 
			
		||||
            with torch.enable_grad():
 | 
			
		||||
                (output,) = ctx.forward_function(hidden_states, *ctx.args)
 | 
			
		||||
 | 
			
		||||
        torch.autograd.backward(output, grad_output)
 | 
			
		||||
        return (None, hidden_states.grad) + (None,) * len(ctx.args)
 | 
			
		||||
            torch.autograd.backward(output, grad_output)
 | 
			
		||||
            return (None, hidden_states.grad) + (None,) * len(ctx.args)
 | 
			
		||||
 | 
			
		||||
    return UnslothGradientCheckpointing.apply
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func):
 | 
			
		||||
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable) -> Callable:
 | 
			
		||||
    r"""
 | 
			
		||||
    Only applies gradient checkpointing to trainable layers.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @wraps(gradient_checkpointing_func)
 | 
			
		||||
    def custom_gradient_checkpointing_func(func, *args: Union["torch.Tensor", Any], **kwargs):
 | 
			
		||||
    def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
 | 
			
		||||
        module: "torch.nn.Module" = func.__self__
 | 
			
		||||
 | 
			
		||||
        if any(param.requires_grad for param in module.parameters()):
 | 
			
		||||
@ -89,7 +92,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func):
 | 
			
		||||
 | 
			
		||||
        return gradient_checkpointing_func(func, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    if hasattr(gradient_checkpointing_func, "__self__"):  # fix test case
 | 
			
		||||
    if hasattr(gradient_checkpointing_func, "__self__"):  # fix unsloth gc test case
 | 
			
		||||
        custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
 | 
			
		||||
 | 
			
		||||
    return custom_gradient_checkpointing_func
 | 
			
		||||
@ -114,7 +117,7 @@ def _gradient_checkpointing_enable(
 | 
			
		||||
        gradient_checkpointing_kwargs = {"use_reentrant": True}
 | 
			
		||||
 | 
			
		||||
    if use_unsloth_gc:
 | 
			
		||||
        gradient_checkpointing_func = UnslothGradientCheckpointing.apply
 | 
			
		||||
        gradient_checkpointing_func = get_unsloth_gradient_checkpointing_func()
 | 
			
		||||
    else:
 | 
			
		||||
        gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user