mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-14 19:06:26 +08:00
[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -29,8 +29,8 @@ def calculate_flops(
|
||||
seq_length: int = 512,
|
||||
flash_attn: str = "auto",
|
||||
):
|
||||
r"""
|
||||
Calculates the flops of pre-trained models.
|
||||
r"""Calculate the flops of pre-trained models.
|
||||
|
||||
Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
|
||||
"""
|
||||
with get_accelerator().device(0):
|
||||
|
||||
@@ -45,8 +45,8 @@ def calculate_lr(
|
||||
is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate,
|
||||
packing: bool = False,
|
||||
):
|
||||
r"""
|
||||
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
|
||||
r"""Calculate the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
|
||||
|
||||
Usage:
|
||||
python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en_demo --cutoff_len 1024 --batch_size 16
|
||||
"""
|
||||
@@ -89,9 +89,8 @@ def calculate_lr(
|
||||
lr = BASE_LR * math.sqrt(token_batch_size / BASE_BS) # lr ~ sqrt(batch_size)
|
||||
lr = lr / 6.0 if is_mistral_or_gemma else lr
|
||||
print(
|
||||
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective token batch size {:.2f}".format(
|
||||
lr, valid_ratio * 100, token_batch_size
|
||||
)
|
||||
f"Optimal learning rate is {lr:.2e} for valid ratio% {valid_ratio * 100:.2f} "
|
||||
f"and effective token batch size {token_batch_size:.2f}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -34,9 +34,7 @@ def compute_model_flops(
|
||||
include_recompute: bool = False,
|
||||
include_flashattn: bool = False,
|
||||
) -> int:
|
||||
r"""
|
||||
Calculates the FLOPs of model per forward/backward pass.
|
||||
"""
|
||||
r"""Calculate the FLOPs of model per forward/backward pass."""
|
||||
config = AutoConfig.from_pretrained(model_name_or_path)
|
||||
hidden_size = getattr(config, "hidden_size", None)
|
||||
vocab_size = getattr(config, "vocab_size", None)
|
||||
@@ -86,9 +84,7 @@ def compute_model_flops(
|
||||
|
||||
|
||||
def compute_device_flops(world_size: int) -> float:
|
||||
r"""
|
||||
Calculates the FLOPs of the device capability per second.
|
||||
"""
|
||||
r"""Calculate the FLOPs of the device capability per second."""
|
||||
device_name = torch.cuda.get_device_name()
|
||||
if "H100" in device_name or "H800" in device_name:
|
||||
return 989 * 1e12 * world_size
|
||||
@@ -114,8 +110,8 @@ def calculate_mfu(
|
||||
liger_kernel: bool = False,
|
||||
unsloth_gc: bool = False,
|
||||
) -> float:
|
||||
r"""
|
||||
Calculates MFU for given model and hyper-params.
|
||||
r"""Calculate MFU for given model and hyper-params.
|
||||
|
||||
Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
|
||||
"""
|
||||
args = {
|
||||
|
||||
@@ -13,8 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Literal, Optional, Sequence
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import fire
|
||||
import torch
|
||||
@@ -30,16 +31,12 @@ from llamafactory.model import load_model, load_tokenizer
|
||||
|
||||
@dataclass
|
||||
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
r"""Data collator for pairwise data."""
|
||||
|
||||
train_on_prompt: bool = False
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Pads batched data to the longest sequence in the batch.
|
||||
"""
|
||||
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, torch.Tensor]:
|
||||
r"""Pad batched data to the longest sequence in the batch."""
|
||||
chosen_features = []
|
||||
for feature in features:
|
||||
chosen_features.append(
|
||||
@@ -68,8 +65,8 @@ def calculate_ppl(
|
||||
max_samples: Optional[int] = None,
|
||||
train_on_prompt: bool = False,
|
||||
):
|
||||
r"""
|
||||
Calculates the ppl on the dataset of the pre-trained models.
|
||||
r"""Calculate the ppl on the dataset of the pre-trained models.
|
||||
|
||||
Usage: export CUDA_VISIBLE_DEVICES=0
|
||||
python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json
|
||||
"""
|
||||
@@ -111,17 +108,17 @@ def calculate_ppl(
|
||||
criterion = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
total_ppl = 0
|
||||
perplexities = []
|
||||
batch: Dict[str, "torch.Tensor"]
|
||||
batch: dict[str, torch.Tensor]
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader, desc="Computing perplexities"):
|
||||
batch = batch.to(model.device)
|
||||
outputs = model(**batch)
|
||||
shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :]
|
||||
shift_labels: "torch.Tensor" = batch["labels"][..., 1:]
|
||||
shift_logits: torch.Tensor = outputs["logits"][..., :-1, :]
|
||||
shift_labels: torch.Tensor = batch["labels"][..., 1:]
|
||||
loss_mask = shift_labels != IGNORE_INDEX
|
||||
flatten_logits = shift_logits.contiguous().view(shift_labels.size(0) * shift_labels.size(1), -1)
|
||||
flatten_labels = shift_labels.contiguous().view(-1)
|
||||
token_logps: "torch.Tensor" = criterion(flatten_logits, flatten_labels)
|
||||
token_logps: torch.Tensor = criterion(flatten_logits, flatten_labels)
|
||||
token_logps = token_logps.contiguous().view(shift_logits.size(0), -1)
|
||||
sentence_logps = (token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
||||
total_ppl += sentence_logps.exp().sum().item()
|
||||
|
||||
@@ -29,8 +29,8 @@ def length_cdf(
|
||||
template: str = "default",
|
||||
interval: int = 1000,
|
||||
):
|
||||
r"""
|
||||
Calculates the distribution of the input lengths in the dataset.
|
||||
r"""Calculate the distribution of the input lengths in the dataset.
|
||||
|
||||
Usage: export CUDA_VISIBLE_DEVICES=0
|
||||
python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user