mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 04:32:50 +08:00
support unsloth 2024.4
Former-commit-id: 7dc72fb58cb988418323f63821a21a184ecf0718
This commit is contained in:
parent
bd2b758b48
commit
b40f266617
@ -145,18 +145,22 @@ def init_adapter(
|
|||||||
"lora_alpha": finetuning_args.lora_alpha,
|
"lora_alpha": finetuning_args.lora_alpha,
|
||||||
"lora_dropout": finetuning_args.lora_dropout,
|
"lora_dropout": finetuning_args.lora_dropout,
|
||||||
"use_rslora": finetuning_args.use_rslora,
|
"use_rslora": finetuning_args.use_rslora,
|
||||||
|
"modules_to_save": finetuning_args.additional_target,
|
||||||
}
|
}
|
||||||
|
|
||||||
if model_args.use_unsloth:
|
if model_args.use_unsloth:
|
||||||
from unsloth import FastLanguageModel # type: ignore
|
from unsloth import FastLanguageModel # type: ignore
|
||||||
|
|
||||||
unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length}
|
unsloth_peft_kwargs = {
|
||||||
|
"model": model,
|
||||||
|
"max_seq_length": model_args.model_max_length,
|
||||||
|
"use_gradient_checkpointing": "unsloth",
|
||||||
|
}
|
||||||
model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
|
model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
|
||||||
else:
|
else:
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
task_type=TaskType.CAUSAL_LM,
|
task_type=TaskType.CAUSAL_LM,
|
||||||
inference_mode=False,
|
inference_mode=False,
|
||||||
modules_to_save=finetuning_args.additional_target,
|
|
||||||
use_dora=finetuning_args.use_dora,
|
use_dora=finetuning_args.use_dora,
|
||||||
**peft_kwargs,
|
**peft_kwargs,
|
||||||
)
|
)
|
||||||
|
@ -82,6 +82,8 @@ def load_model(
|
|||||||
"token": model_args.hf_hub_token,
|
"token": model_args.hf_hub_token,
|
||||||
"device_map": {"": get_current_device()},
|
"device_map": {"": get_current_device()},
|
||||||
"rope_scaling": getattr(config, "rope_scaling", None),
|
"rope_scaling": getattr(config, "rope_scaling", None),
|
||||||
|
"fix_tokenizer": False,
|
||||||
|
"trust_remote_code": True,
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user