support DoRA, AWQ, AQLM #2512

This commit is contained in:
hiyouga
2024-02-28 19:53:28 +08:00
parent 511b15b96a
commit cfefacaa37
9 changed files with 40 additions and 9 deletions

View File

@@ -24,7 +24,7 @@ if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import ModelArguments
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
@@ -253,6 +253,7 @@ def patch_config(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
config_kwargs: Dict[str, Any],
is_trainable: bool,
) -> None:
@@ -273,6 +274,9 @@ def patch_config(
_configure_quantization(config, tokenizer, model_args, config_kwargs)
if finetuning_args.use_dora:
config_kwargs["device_map"] = {"": get_current_device()}
def patch_model(
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool