support unsloth 2024.4

Former-commit-id: 7dc72fb58cb988418323f63821a21a184ecf0718
This commit is contained in:
hiyouga 2024-04-16 00:25:03 +08:00
parent bd2b758b48
commit b40f266617
2 changed files with 8 additions and 2 deletions

View File

@ -145,18 +145,22 @@ def init_adapter(
"lora_alpha": finetuning_args.lora_alpha, "lora_alpha": finetuning_args.lora_alpha,
"lora_dropout": finetuning_args.lora_dropout, "lora_dropout": finetuning_args.lora_dropout,
"use_rslora": finetuning_args.use_rslora, "use_rslora": finetuning_args.use_rslora,
"modules_to_save": finetuning_args.additional_target,
} }
if model_args.use_unsloth: if model_args.use_unsloth:
from unsloth import FastLanguageModel # type: ignore from unsloth import FastLanguageModel # type: ignore
unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length} unsloth_peft_kwargs = {
"model": model,
"max_seq_length": model_args.model_max_length,
"use_gradient_checkpointing": "unsloth",
}
model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs) model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
else: else:
lora_config = LoraConfig( lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, task_type=TaskType.CAUSAL_LM,
inference_mode=False, inference_mode=False,
modules_to_save=finetuning_args.additional_target,
use_dora=finetuning_args.use_dora, use_dora=finetuning_args.use_dora,
**peft_kwargs, **peft_kwargs,
) )

View File

@ -82,6 +82,8 @@ def load_model(
"token": model_args.hf_hub_token, "token": model_args.hf_hub_token,
"device_map": {"": get_current_device()}, "device_map": {"": get_current_device()},
"rope_scaling": getattr(config, "rope_scaling", None), "rope_scaling": getattr(config, "rope_scaling", None),
"fix_tokenizer": False,
"trust_remote_code": True,
} }
try: try:
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)