mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 03:40:34 +08:00
support HQQ/EETQ #4113
This commit is contained in:
@@ -22,7 +22,7 @@ from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
|
||||
from ..extras.misc import is_gpu_or_npu_available, torch_gc
|
||||
from ..extras.packages import is_gradio_available
|
||||
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir, load_config
|
||||
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
|
||||
from .locales import ALERTS, LOCALES
|
||||
from .utils import abort_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
|
||||
|
||||
@@ -104,6 +104,11 @@ class Runner:
|
||||
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.quantization_bit") in QUANTIZATION_BITS:
|
||||
quantization_bit = int(get("top.quantization_bit"))
|
||||
else:
|
||||
quantization_bit = None
|
||||
|
||||
args = dict(
|
||||
stage=TRAINING_STAGES[get("train.training_stage")],
|
||||
do_train=True,
|
||||
@@ -111,7 +116,8 @@ class Runner:
|
||||
cache_dir=user_config.get("cache_dir", None),
|
||||
preprocessing_num_workers=16,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
quantization_bit=quantization_bit,
|
||||
quantization_method=get("top.quantization_method"),
|
||||
template=get("top.template"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
@@ -234,13 +240,19 @@ class Runner:
|
||||
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
|
||||
user_config = load_config()
|
||||
|
||||
if get("top.quantization_bit") in QUANTIZATION_BITS:
|
||||
quantization_bit = int(get("top.quantization_bit"))
|
||||
else:
|
||||
quantization_bit = None
|
||||
|
||||
args = dict(
|
||||
stage="sft",
|
||||
model_name_or_path=get("top.model_path"),
|
||||
cache_dir=user_config.get("cache_dir", None),
|
||||
preprocessing_num_workers=16,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
quantization_bit=quantization_bit,
|
||||
quantization_method=get("top.quantization_method"),
|
||||
template=get("top.template"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
|
||||
|
||||
Reference in New Issue
Block a user