mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 20:00:36 +08:00
modify style
This commit is contained in:
@@ -3,9 +3,9 @@ from typing import TYPE_CHECKING, Any, Dict, Union
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
AutoProcessor,
|
||||
AutoModelForVision2Seq,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
)
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
@@ -17,6 +17,7 @@ from .utils.misc import load_valuehead_params, register_autoclass
|
||||
from .utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model
|
||||
from .utils.unsloth import load_unsloth_pretrained_model
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
@@ -42,7 +43,7 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
||||
|
||||
def load_tokenizer(
|
||||
model_args: "ModelArguments",
|
||||
) -> Dict[str, Union["PreTrainedTokenizer", "AutoProcesser"]]:
|
||||
) -> Dict[str, Union["PreTrainedTokenizer", "AutoProcessor"]]:
|
||||
r"""
|
||||
Loads pretrained tokenizer.
|
||||
|
||||
@@ -70,14 +71,10 @@ def load_tokenizer(
|
||||
dict(additional_special_tokens=model_args.new_special_tokens),
|
||||
replace_additional_special_tokens=False,
|
||||
)
|
||||
logger.info(
|
||||
"Add {} to special tokens.".format(",".join(model_args.new_special_tokens))
|
||||
)
|
||||
logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
|
||||
if num_added_tokens > 0 and not model_args.resize_vocab:
|
||||
model_args.resize_vocab = True
|
||||
logger.warning(
|
||||
"New tokens have been added, changed `resize_vocab` to True."
|
||||
)
|
||||
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
|
||||
|
||||
patch_tokenizer(tokenizer)
|
||||
tokenizer_modules = {"tokenizer": tokenizer, "processor": None}
|
||||
@@ -174,10 +171,8 @@ def load_model(
|
||||
|
||||
trainable_params, all_param = count_parameters(model)
|
||||
if is_trainable:
|
||||
param_stats = (
|
||||
"trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||
trainable_params, all_param, 100 * trainable_params / all_param
|
||||
)
|
||||
param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||
trainable_params, all_param, 100 * trainable_params / all_param
|
||||
)
|
||||
else:
|
||||
param_stats = "all params: {:d}".format(all_param)
|
||||
|
||||
Reference in New Issue
Block a user