mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 11:20:35 +08:00
support DoRA, AWQ, AQLM #2512
This commit is contained in:
@@ -24,7 +24,7 @@ if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from ..hparams import ModelArguments
|
||||
from ..hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -253,6 +253,7 @@ def patch_config(
|
||||
config: "PretrainedConfig",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
config_kwargs: Dict[str, Any],
|
||||
is_trainable: bool,
|
||||
) -> None:
|
||||
@@ -273,6 +274,9 @@ def patch_config(
|
||||
|
||||
_configure_quantization(config, tokenizer, model_args, config_kwargs)
|
||||
|
||||
if finetuning_args.use_dora:
|
||||
config_kwargs["device_map"] = {"": get_current_device()}
|
||||
|
||||
|
||||
def patch_model(
|
||||
model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", is_trainable: bool
|
||||
|
||||
Reference in New Issue
Block a user