support DoRA, AWQ, AQLM #2512

Former-commit-id: 6614cc1f08aa944db083e27e451bbdd733f7dd97
This commit is contained in:
hiyouga
2024-02-28 19:53:28 +08:00
parent 1e7962dfc4
commit b392e6cfb9
9 changed files with 40 additions and 9 deletions

View File

@@ -24,7 +24,7 @@ if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from trl import AutoModelForCausalLMWithValueHead
from ..hparams import ModelArguments
from ..hparams import FinetuningArguments, ModelArguments
logger = get_logger(__name__)
@@ -253,6 +253,7 @@ def patch_config(
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
config_kwargs: Dict[str, Any],
is_trainable: bool,
) -> None:
@@ -273,6 +274,9 @@ def patch_config(
_configure_quantization(config, tokenizer, model_args, config_kwargs)
if finetuning_args.use_dora:
config_kwargs["device_map"] = {"": get_current_device()}
def patch_model(
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool