mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 10:10:35 +08:00
[model] support yarn (#6693)
Former-commit-id: 8c412abc44a4c61b683465e36c6288580d980250
This commit is contained in:
@@ -24,7 +24,7 @@ from transformers.utils import is_torch_npu_available
|
||||
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
|
||||
from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray
|
||||
from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46
|
||||
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
|
||||
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir, load_config
|
||||
from .locales import ALERTS, LOCALES
|
||||
from .utils import abort_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
|
||||
|
||||
@@ -120,7 +120,7 @@ class Runner:
|
||||
preprocessing_num_workers=16,
|
||||
finetuning_type=finetuning_type,
|
||||
template=get("top.template"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") != "none" else None,
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
enable_liger_kernel=(get("top.booster") == "liger_kernel"),
|
||||
@@ -170,7 +170,7 @@ class Runner:
|
||||
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
|
||||
|
||||
# quantization
|
||||
if get("top.quantization_bit") in QUANTIZATION_BITS:
|
||||
if get("top.quantization_bit") != "none":
|
||||
args["quantization_bit"] = int(get("top.quantization_bit"))
|
||||
args["quantization_method"] = get("top.quantization_method")
|
||||
args["double_quantization"] = not is_torch_npu_available()
|
||||
@@ -280,7 +280,7 @@ class Runner:
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_method=get("top.quantization_method"),
|
||||
template=get("top.template"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") != "none" else None,
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
use_unsloth=(get("top.booster") == "unsloth"),
|
||||
dataset_dir=get("eval.dataset_dir"),
|
||||
@@ -311,9 +311,10 @@ class Runner:
|
||||
args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))
|
||||
|
||||
# quantization
|
||||
if get("top.quantization_bit") in QUANTIZATION_BITS:
|
||||
if get("top.quantization_bit") != "none":
|
||||
args["quantization_bit"] = int(get("top.quantization_bit"))
|
||||
args["quantization_method"] = get("top.quantization_method")
|
||||
args["double_quantization"] = not is_torch_npu_available()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
Reference in New Issue
Block a user