support HQQ/EETQ #4113

This commit is contained in:
hiyouga
2024-06-27 00:29:42 +08:00
parent addca926de
commit ad144c2265
16 changed files with 134 additions and 57 deletions

View File

@@ -22,7 +22,7 @@ from transformers.trainer import TRAINING_ARGS_NAME
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
from ..extras.misc import is_gpu_or_npu_available, torch_gc
from ..extras.packages import is_gradio_available
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, get_save_dir, load_config
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
from .locales import ALERTS, LOCALES
from .utils import abort_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
@@ -104,6 +104,11 @@ class Runner:
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config()
if get("top.quantization_bit") in QUANTIZATION_BITS:
quantization_bit = int(get("top.quantization_bit"))
else:
quantization_bit = None
args = dict(
stage=TRAINING_STAGES[get("train.training_stage")],
do_train=True,
@@ -111,7 +116,8 @@ class Runner:
cache_dir=user_config.get("cache_dir", None),
preprocessing_num_workers=16,
finetuning_type=finetuning_type,
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
quantization_bit=quantization_bit,
quantization_method=get("top.quantization_method"),
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
@@ -234,13 +240,19 @@ class Runner:
model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
user_config = load_config()
if get("top.quantization_bit") in QUANTIZATION_BITS:
quantization_bit = int(get("top.quantization_bit"))
else:
quantization_bit = None
args = dict(
stage="sft",
model_name_or_path=get("top.model_path"),
cache_dir=user_config.get("cache_dir", None),
preprocessing_num_workers=16,
finetuning_type=finetuning_type,
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
quantization_bit=quantization_bit,
quantization_method=get("top.quantization_method"),
template=get("top.template"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",