mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 20:52:59 +08:00
improve aqlm optim
Former-commit-id: 259af60d28985b919911587716c24a3ac7f7de64
This commit is contained in:
parent
c776cdfc3e
commit
9561809ce9
@ -14,7 +14,7 @@ from transformers.utils import cached_file
|
|||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
from ..extras.constants import CHOICES, SUBJECTS
|
from ..extras.constants import CHOICES, SUBJECTS
|
||||||
from ..hparams import get_eval_args
|
from ..hparams import get_eval_args
|
||||||
from ..model import dispatch_model, load_model_and_tokenizer
|
from ..model import load_model_and_tokenizer
|
||||||
from .template import get_eval_template
|
from .template import get_eval_template
|
||||||
|
|
||||||
|
|
||||||
@ -23,7 +23,6 @@ class Evaluator:
|
|||||||
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
||||||
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
|
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
|
||||||
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
||||||
self.model = dispatch_model(self.model)
|
|
||||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
|
self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template)
|
||||||
self.eval_template = get_eval_template(self.eval_args.lang)
|
self.eval_template = get_eval_template(self.eval_args.lang)
|
||||||
self.choice_inputs = [
|
self.choice_inputs = [
|
||||||
|
@ -121,6 +121,9 @@ class ModelArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
|
metadata={"help": "For debugging purposes, print the status of the parameters in the model."},
|
||||||
)
|
)
|
||||||
|
aqlm_optimization: Optional[bool] = field(
|
||||||
|
default=False, metadata={"help": "Whether or not to optimize the training performance of AQLM models."}
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.compute_dtype = None
|
self.compute_dtype = None
|
||||||
|
@ -226,6 +226,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
|
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
|
||||||
)
|
)
|
||||||
model_args.model_max_length = data_args.cutoff_len
|
model_args.model_max_length = data_args.cutoff_len
|
||||||
|
model_args.aqlm_optimization = not training_args.predict_with_generate
|
||||||
|
|
||||||
# Log on each process the small summary:
|
# Log on each process the small summary:
|
||||||
logger.info(
|
logger.info(
|
||||||
@ -262,6 +263,7 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
|
|||||||
_set_transformers_logging()
|
_set_transformers_logging()
|
||||||
_verify_model_args(model_args, finetuning_args)
|
_verify_model_args(model_args, finetuning_args)
|
||||||
_check_dependencies(disabled=finetuning_args.disable_version_checking)
|
_check_dependencies(disabled=finetuning_args.disable_version_checking)
|
||||||
|
model_args.aqlm_optimization = True
|
||||||
|
|
||||||
if data_args.template is None:
|
if data_args.template is None:
|
||||||
raise ValueError("Please specify which `template` to use.")
|
raise ValueError("Please specify which `template` to use.")
|
||||||
|
@ -88,7 +88,7 @@ def load_model(
|
|||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
model_init_context = nullcontext()
|
model_init_context = nullcontext()
|
||||||
if is_trainable and getattr(config, "quantization_config", None):
|
if model_args.aqlm_optimization and getattr(config, "quantization_config", None):
|
||||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||||
if quantization_config.get("quant_method", None) == "aqlm":
|
if quantization_config.get("quant_method", None) == "aqlm":
|
||||||
import aqlm # type: ignore
|
import aqlm # type: ignore
|
||||||
|
Loading…
x
Reference in New Issue
Block a user