format style

This commit is contained in:
hiyouga
2024-01-20 20:15:56 +08:00
parent f6d6e00337
commit 638234ceee
73 changed files with 1492 additions and 2325 deletions

View File

@@ -1,31 +1,33 @@
import gc
import os
import torch
from typing import TYPE_CHECKING, Dict, Tuple
import torch
from peft import PeftModel
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel
from transformers.utils import (
WEIGHTS_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_NAME,
is_torch_bf16_gpu_available,
is_torch_cuda_available,
is_torch_npu_available,
is_torch_xpu_available
is_torch_xpu_available,
)
from peft import PeftModel
from .constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from .logging import get_logger
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
try:
_is_bf16_available = is_torch_bf16_gpu_available()
except:
except Exception:
_is_bf16_available = False
if TYPE_CHECKING:
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.hparams import ModelArguments
@@ -36,6 +38,7 @@ class AverageMeter:
r"""
Computes and stores the average and current value.
"""
def __init__(self):
self.reset()
@@ -75,9 +78,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
def fix_valuehead_checkpoint(
model: "AutoModelForCausalLMWithValueHead",
output_dir: str,
safe_serialization: bool
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
) -> None:
r"""
The model is already unwrapped.
@@ -95,6 +96,7 @@ def fix_valuehead_checkpoint(
if safe_serialization:
from safetensors import safe_open
from safetensors.torch import save_file
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
@@ -112,9 +114,7 @@ def fix_valuehead_checkpoint(
os.remove(path_to_checkpoint)
model.pretrained_model.save_pretrained(
output_dir,
state_dict=decoder_state_dict or None,
safe_serialization=safe_serialization
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
)
if safe_serialization:
@@ -182,11 +182,10 @@ def try_download_model_from_ms(model_args: "ModelArguments") -> None:
try:
from modelscope import snapshot_download
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
model_args.model_name_or_path = snapshot_download(
model_args.model_name_or_path,
revision=revision,
cache_dir=model_args.cache_dir
model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir
)
except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`")