[model] support yarn (#6693)

Former-commit-id: 8c412abc44a4c61b683465e36c6288580d980250
This commit is contained in:
hoshi-hiyouga
2025-01-18 13:56:09 +08:00
committed by GitHub
parent e4046bdd1f
commit 87d685b59f
11 changed files with 84 additions and 64 deletions

View File

@@ -24,7 +24,7 @@ from transformers.utils import is_torch_npu_available
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray
from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir, load_config
from .locales import ALERTS, LOCALES
from .utils import abort_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
@@ -120,7 +120,7 @@ class Runner:
preprocessing_num_workers=16,
finetuning_type=finetuning_type,
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") != "none" else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
enable_liger_kernel=(get("top.booster") == "liger_kernel"),
@@ -170,7 +170,7 @@ class Runner:
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
# quantization
if get("top.quantization_bit") in QUANTIZATION_BITS:
if get("top.quantization_bit") != "none":
args["quantization_bit"] = int(get("top.quantization_bit"))
args["quantization_method"] = get("top.quantization_method")
args["double_quantization"] = not is_torch_npu_available()
@@ -280,7 +280,7 @@ class Runner:
finetuning_type=finetuning_type,
quantization_method=get("top.quantization_method"),
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") != "none" else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
dataset_dir=get("eval.dataset_dir"),
@@ -311,9 +311,10 @@ class Runner:
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
# quantization
if get("top.quantization_bit") in QUANTIZATION_BITS:
if get("top.quantization_bit") != "none":
args["quantization_bit"] = int(get("top.quantization_bit"))
args["quantization_method"] = get("top.quantization_method")
args["double_quantization"] = not is_torch_npu_available()
return args