mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 19:30:36 +08:00
format style
This commit is contained in:
@@ -1,31 +1,33 @@
|
||||
import gc
|
||||
import os
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Dict, Tuple
|
||||
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel
|
||||
from transformers.utils import (
|
||||
WEIGHTS_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
is_torch_bf16_gpu_available,
|
||||
is_torch_cuda_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_xpu_available
|
||||
is_torch_xpu_available,
|
||||
)
|
||||
from peft import PeftModel
|
||||
|
||||
from .constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
|
||||
from .constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from .logging import get_logger
|
||||
|
||||
|
||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||
try:
|
||||
_is_bf16_available = is_torch_bf16_gpu_available()
|
||||
except:
|
||||
except Exception:
|
||||
_is_bf16_available = False
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from llmtuner.hparams import ModelArguments
|
||||
|
||||
|
||||
@@ -36,6 +38,7 @@ class AverageMeter:
|
||||
r"""
|
||||
Computes and stores the average and current value.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
@@ -75,9 +78,7 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
|
||||
|
||||
def fix_valuehead_checkpoint(
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
output_dir: str,
|
||||
safe_serialization: bool
|
||||
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool
|
||||
) -> None:
|
||||
r"""
|
||||
The model is already unwrapped.
|
||||
@@ -95,6 +96,7 @@ def fix_valuehead_checkpoint(
|
||||
if safe_serialization:
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
|
||||
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
||||
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
||||
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||
@@ -112,9 +114,7 @@ def fix_valuehead_checkpoint(
|
||||
|
||||
os.remove(path_to_checkpoint)
|
||||
model.pretrained_model.save_pretrained(
|
||||
output_dir,
|
||||
state_dict=decoder_state_dict or None,
|
||||
safe_serialization=safe_serialization
|
||||
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization
|
||||
)
|
||||
|
||||
if safe_serialization:
|
||||
@@ -182,11 +182,10 @@ def try_download_model_from_ms(model_args: "ModelArguments") -> None:
|
||||
|
||||
try:
|
||||
from modelscope import snapshot_download
|
||||
|
||||
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
||||
model_args.model_name_or_path = snapshot_download(
|
||||
model_args.model_name_or_path,
|
||||
revision=revision,
|
||||
cache_dir=model_args.cache_dir
|
||||
model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||
|
||||
Reference in New Issue
Block a user