mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-18 04:40:35 +08:00
[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -34,9 +34,7 @@ def compute_model_flops(
|
||||
include_recompute: bool = False,
|
||||
include_flashattn: bool = False,
|
||||
) -> int:
|
||||
r"""
|
||||
Calculates the FLOPs of model per forward/backward pass.
|
||||
"""
|
||||
r"""Calculate the FLOPs of model per forward/backward pass."""
|
||||
config = AutoConfig.from_pretrained(model_name_or_path)
|
||||
hidden_size = getattr(config, "hidden_size", None)
|
||||
vocab_size = getattr(config, "vocab_size", None)
|
||||
@@ -86,9 +84,7 @@ def compute_model_flops(
|
||||
|
||||
|
||||
def compute_device_flops(world_size: int) -> float:
|
||||
r"""
|
||||
Calculates the FLOPs of the device capability per second.
|
||||
"""
|
||||
r"""Calculate the FLOPs of the device capability per second."""
|
||||
device_name = torch.cuda.get_device_name()
|
||||
if "H100" in device_name or "H800" in device_name:
|
||||
return 989 * 1e12 * world_size
|
||||
@@ -114,8 +110,8 @@ def calculate_mfu(
|
||||
liger_kernel: bool = False,
|
||||
unsloth_gc: bool = False,
|
||||
) -> float:
|
||||
r"""
|
||||
Calculates MFU for given model and hyper-params.
|
||||
r"""Calculate MFU for given model and hyper-params.
|
||||
|
||||
Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
|
||||
"""
|
||||
args = {
|
||||
|
||||
Reference in New Issue
Block a user