mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[model] support yarn (#6693)
Former-commit-id: 1f47b6186c267de86cbdbd47ba2adbf1f9db7f39
This commit is contained in:
parent
ee0b3b1e1a
commit
1efe525df7
@ -33,7 +33,7 @@ def preprocess_pretrain_dataset(
|
||||
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]]
|
||||
|
||||
if not data_args.packing:
|
||||
if data_args.template == "gemma":
|
||||
if getattr(tokenizer, "add_bos_token", False):
|
||||
text_examples = [tokenizer.bos_token + example for example in text_examples]
|
||||
|
||||
result = tokenizer(text_examples, add_special_tokens=False, truncation=True, max_length=data_args.cutoff_len)
|
||||
@ -47,7 +47,7 @@ def preprocess_pretrain_dataset(
|
||||
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
if data_args.template == "gemma":
|
||||
if getattr(tokenizer, "add_bos_token", False):
|
||||
for i in range(len(result["input_ids"])):
|
||||
result["input_ids"][i][0] = tokenizer.bos_token_id
|
||||
|
||||
|
@ -201,7 +201,7 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments,
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use memory-efficient model loading."},
|
||||
)
|
||||
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
|
||||
rope_scaling: Optional[Literal["linear", "dynamic", "yarn", "llama3"]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
||||
)
|
||||
|
@ -86,20 +86,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
except Exception as e:
|
||||
raise OSError("Failed to load tokenizer.") from e
|
||||
|
||||
if model_args.model_max_length is not None and tokenizer.model_max_length != model_args.model_max_length:
|
||||
tokenizer.model_max_length = model_args.model_max_length
|
||||
|
||||
if model_args.new_special_tokens is not None:
|
||||
num_added_tokens = tokenizer.add_special_tokens(
|
||||
dict(additional_special_tokens=model_args.new_special_tokens),
|
||||
replace_additional_special_tokens=False,
|
||||
)
|
||||
logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
|
||||
if num_added_tokens > 0 and not model_args.resize_vocab:
|
||||
model_args.resize_vocab = True
|
||||
logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.")
|
||||
|
||||
patch_tokenizer(tokenizer)
|
||||
patch_tokenizer(tokenizer, model_args)
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
patch_processor(processor, config, tokenizer, model_args)
|
||||
|
@ -39,6 +39,7 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
|
||||
logger.warning_rank0("Current model does not support RoPE scaling.")
|
||||
return
|
||||
|
||||
rope_kwargs = {}
|
||||
if model_args.model_max_length is not None:
|
||||
if is_trainable and model_args.rope_scaling == "dynamic":
|
||||
logger.warning_rank0(
|
||||
@ -50,14 +51,21 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
|
||||
if current_max_length and model_args.model_max_length > current_max_length:
|
||||
logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
|
||||
setattr(config, "max_position_embeddings", model_args.model_max_length)
|
||||
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||
rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||
else:
|
||||
logger.warning_rank0("Input length is smaller than max length. Consider increase input length.")
|
||||
scaling_factor = 1.0
|
||||
else:
|
||||
scaling_factor = 2.0
|
||||
rope_kwargs["factor"] = 1.0
|
||||
|
||||
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
||||
if model_args.rope_scaling == "dynamic":
|
||||
rope_kwargs["original_max_position_embeddings"] = current_max_length
|
||||
elif model_args.rope_scaling == "llama3":
|
||||
rope_kwargs["original_max_position_embeddings"] = current_max_length
|
||||
rope_kwargs["low_freq_factor"] = 1.0
|
||||
rope_kwargs["high_freq_factor"] = 4.0
|
||||
else:
|
||||
rope_kwargs["factor"] = 2.0
|
||||
|
||||
setattr(config, "rope_scaling", {"rope_type": model_args.rope_scaling, **rope_kwargs})
|
||||
logger.info_rank0(
|
||||
f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {scaling_factor}"
|
||||
f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {rope_kwargs['factor']}."
|
||||
)
|
||||
|
@ -53,10 +53,23 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
|
||||
def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> None:
|
||||
if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
|
||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||
|
||||
if model_args.model_max_length is not None and tokenizer.model_max_length != model_args.model_max_length:
|
||||
tokenizer.model_max_length = model_args.model_max_length
|
||||
|
||||
if model_args.new_special_tokens is not None:
|
||||
num_added_tokens = tokenizer.add_special_tokens(
|
||||
dict(additional_special_tokens=model_args.new_special_tokens),
|
||||
replace_additional_special_tokens=False,
|
||||
)
|
||||
logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
|
||||
if num_added_tokens > 0 and not model_args.resize_vocab:
|
||||
model_args.resize_vocab = True
|
||||
logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.")
|
||||
|
||||
|
||||
def patch_processor(
|
||||
processor: "ProcessorMixin",
|
||||
|
@ -16,12 +16,14 @@ import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||
|
||||
from transformers.utils import is_torch_npu_available
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..data import Role
|
||||
from ..extras.constants import PEFT_METHODS
|
||||
from ..extras.misc import torch_gc
|
||||
from ..extras.packages import is_gradio_available
|
||||
from .common import QUANTIZATION_BITS, get_save_dir
|
||||
from .common import get_save_dir, load_config
|
||||
from .locales import ALERTS
|
||||
|
||||
|
||||
@ -59,6 +61,8 @@ class WebChatModel(ChatModel):
|
||||
get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
|
||||
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
|
||||
finetuning_type, checkpoint_path = get("top.finetuning_type"), get("top.checkpoint_path")
|
||||
user_config = load_config()
|
||||
|
||||
error = ""
|
||||
if self.loaded:
|
||||
error = ALERTS["err_exists"][lang]
|
||||
@ -74,26 +78,22 @@ class WebChatModel(ChatModel):
|
||||
yield error
|
||||
return
|
||||
|
||||
if get("top.quantization_bit") in QUANTIZATION_BITS:
|
||||
quantization_bit = int(get("top.quantization_bit"))
|
||||
else:
|
||||
quantization_bit = None
|
||||
|
||||
yield ALERTS["info_loading"][lang]
|
||||
args = dict(
|
||||
model_name_or_path=model_path,
|
||||
cache_dir=user_config.get("cache_dir", None),
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=quantization_bit,
|
||||
quantization_method=get("top.quantization_method"),
|
||||
template=get("top.template"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") != "none" else None,
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
enable_liger_kernel=(get("top.booster") == "liger_kernel"),
|
||||
infer_backend=get("infer.infer_backend"),
|
||||
infer_dtype=get("infer.infer_dtype"),
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
# checkpoints
|
||||
if checkpoint_path:
|
||||
if finetuning_type in PEFT_METHODS: # list
|
||||
args["adapter_name_or_path"] = ",".join(
|
||||
@ -102,6 +102,12 @@ class WebChatModel(ChatModel):
|
||||
else: # str
|
||||
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path)
|
||||
|
||||
# quantization
|
||||
if get("top.quantization_bit") != "none":
|
||||
args["quantization_bit"] = int(get("top.quantization_bit"))
|
||||
args["quantization_method"] = get("top.quantization_method")
|
||||
args["double_quantization"] = not is_torch_npu_available()
|
||||
|
||||
super().__init__(args)
|
||||
yield ALERTS["info_loaded"][lang]
|
||||
|
||||
|
@ -47,8 +47,6 @@ DEFAULT_CONFIG_DIR = "config"
|
||||
DEFAULT_DATA_DIR = "data"
|
||||
DEFAULT_SAVE_DIR = "saves"
|
||||
USER_CONFIG = "user_config.yaml"
|
||||
QUANTIZATION_BITS = ["8", "6", "5", "4", "3", "2", "1"]
|
||||
GPTQ_BITS = ["8", "4", "3", "2"]
|
||||
|
||||
|
||||
def get_save_dir(*paths: str) -> os.PathLike:
|
||||
|
@ -18,7 +18,7 @@ from ...extras.constants import PEFT_METHODS
|
||||
from ...extras.misc import torch_gc
|
||||
from ...extras.packages import is_gradio_available
|
||||
from ...train.tuner import export_model
|
||||
from ..common import GPTQ_BITS, get_save_dir
|
||||
from ..common import get_save_dir, load_config
|
||||
from ..locales import ALERTS
|
||||
|
||||
|
||||
@ -32,6 +32,9 @@ if TYPE_CHECKING:
|
||||
from ..engine import Engine
|
||||
|
||||
|
||||
GPTQ_BITS = ["8", "4", "3", "2"]
|
||||
|
||||
|
||||
def can_quantize(checkpoint_path: Union[str, List[str]]) -> "gr.Dropdown":
|
||||
if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0:
|
||||
return gr.Dropdown(value="none", interactive=False)
|
||||
@ -54,6 +57,7 @@ def save_model(
|
||||
export_dir: str,
|
||||
export_hub_model_id: str,
|
||||
) -> Generator[str, None, None]:
|
||||
user_config = load_config()
|
||||
error = ""
|
||||
if not model_name:
|
||||
error = ALERTS["err_no_model"][lang]
|
||||
@ -75,6 +79,7 @@ def save_model(
|
||||
|
||||
args = dict(
|
||||
model_name_or_path=model_path,
|
||||
cache_dir=user_config.get("cache_dir", None),
|
||||
finetuning_type=finetuning_type,
|
||||
template=template,
|
||||
export_dir=export_dir,
|
||||
|
@ -41,11 +41,11 @@ def create_top() -> Dict[str, "Component"]:
|
||||
checkpoint_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=6)
|
||||
|
||||
with gr.Row():
|
||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True, scale=2)
|
||||
quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=2)
|
||||
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=2)
|
||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3)
|
||||
booster = gr.Radio(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto", scale=5)
|
||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True)
|
||||
quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes")
|
||||
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default")
|
||||
rope_scaling = gr.Dropdown(choices=["none", "linear", "dynamic", "yarn", "llama3"], value="none")
|
||||
booster = gr.Dropdown(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto")
|
||||
|
||||
model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then(
|
||||
list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
|
||||
|
@ -15,10 +15,10 @@
|
||||
LOCALES = {
|
||||
"lang": {
|
||||
"en": {
|
||||
"label": "Lang",
|
||||
"label": "Language",
|
||||
},
|
||||
"ru": {
|
||||
"label": "язык",
|
||||
"label": "Язык",
|
||||
},
|
||||
"zh": {
|
||||
"label": "语言",
|
||||
@ -30,11 +30,11 @@ LOCALES = {
|
||||
"model_name": {
|
||||
"en": {
|
||||
"label": "Model name",
|
||||
"info": "Input the name prefix to search for the model.",
|
||||
"info": "Input the initial name to search for the model.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Название модели",
|
||||
"info": "Введите префикс имени для поиска модели.",
|
||||
"info": "Введите начальное имя для поиска модели.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "模型名称",
|
||||
@ -42,7 +42,7 @@ LOCALES = {
|
||||
},
|
||||
"ko": {
|
||||
"label": "모델 이름",
|
||||
"info": "모델을 검색하기 위해 이름 접두어를 입력하세요.",
|
||||
"info": "모델을 검색할 초기 이름을 입력하세요.",
|
||||
},
|
||||
},
|
||||
"model_path": {
|
||||
@ -129,48 +129,50 @@ LOCALES = {
|
||||
},
|
||||
"template": {
|
||||
"en": {
|
||||
"label": "Prompt template",
|
||||
"info": "The template used in constructing prompts.",
|
||||
"label": "Chat template",
|
||||
"info": "The chat template used in constructing prompts.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Шаблон запроса",
|
||||
"info": "Шаблон, используемый при формировании запросов.",
|
||||
"label": "Шаблон чата",
|
||||
"info": "Шаблон чата используемый для составления подсказок.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "提示模板",
|
||||
"label": "对话模板",
|
||||
"info": "构建提示词时使用的模板。",
|
||||
},
|
||||
"ko": {
|
||||
"label": "프롬프트 템플릿",
|
||||
"info": "프롬프트 구성에 사용될 템플릿.",
|
||||
"label": "채팅 템플릿",
|
||||
"info": "프롬프트 작성에 사용되는 채팅 템플릿.",
|
||||
},
|
||||
},
|
||||
"rope_scaling": {
|
||||
"en": {
|
||||
"label": "RoPE scaling",
|
||||
"info": "RoPE scaling method to use.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Масштабирование RoPE",
|
||||
"info": "Метод масштабирования RoPE для использования.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "RoPE 插值方法",
|
||||
},
|
||||
"zh": {"label": "RoPE 插值方法", "info": "RoPE 插值时使用的方法。"},
|
||||
"ko": {
|
||||
"label": "RoPE 스케일링",
|
||||
"info": "사용할 RoPE 스케일링 방법.",
|
||||
},
|
||||
},
|
||||
"booster": {
|
||||
"en": {
|
||||
"label": "Booster",
|
||||
"info": "Approach used to boost training speed.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Ускоритель",
|
||||
"info": "Подход, используемый для ускорения обучения.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "加速方式",
|
||||
},
|
||||
"zh": {"label": "加速方式", "info": "使用的加速方法。"},
|
||||
"ko": {
|
||||
"label": "부스터",
|
||||
"info": "훈련 속도를 향상시키기 위해 사용된 접근 방식.",
|
||||
},
|
||||
},
|
||||
"training_stage": {
|
||||
|
@ -24,7 +24,7 @@ from transformers.utils import is_torch_npu_available
|
||||
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
|
||||
from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray
|
||||
from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46
|
||||
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
|
||||
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir, load_config
|
||||
from .locales import ALERTS, LOCALES
|
||||
from .utils import abort_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
|
||||
|
||||
@ -120,7 +120,7 @@ class Runner:
|
||||
preprocessing_num_workers=16,
|
||||
finetuning_type=finetuning_type,
|
||||
template=get("top.template"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") != "none" else None,
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
enable_liger_kernel=(get("top.booster") == "liger_kernel"),
|
||||
@ -170,7 +170,7 @@ class Runner:
|
||||
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
|
||||
|
||||
# quantization
|
||||
if get("top.quantization_bit") in QUANTIZATION_BITS:
|
||||
if get("top.quantization_bit") != "none":
|
||||
args["quantization_bit"] = int(get("top.quantization_bit"))
|
||||
args["quantization_method"] = get("top.quantization_method")
|
||||
args["double_quantization"] = not is_torch_npu_available()
|
||||
@ -280,7 +280,7 @@ class Runner:
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_method=get("top.quantization_method"),
|
||||
template=get("top.template"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") != "none" else None,
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
dataset_dir=get("eval.dataset_dir"),
|
||||
@ -311,9 +311,10 @@ class Runner:
|
||||
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
|
||||
|
||||
# quantization
|
||||
if get("top.quantization_bit") in QUANTIZATION_BITS:
|
||||
if get("top.quantization_bit") != "none":
|
||||
args["quantization_bit"] = int(get("top.quantization_bit"))
|
||||
args["quantization_method"] = get("top.quantization_method")
|
||||
args["double_quantization"] = not is_torch_npu_available()
|
||||
|
||||
return args
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user