mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
parent
dabd40750c
commit
dc8714a003
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
|
from peft import PeftModel, TaskType, LoraConfig, get_peft_model
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
@ -71,6 +72,10 @@ def init_adapter(
|
|||||||
assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter."
|
assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter."
|
||||||
is_mergeable = False
|
is_mergeable = False
|
||||||
|
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
|
||||||
|
is_mergeable = False
|
||||||
|
|
||||||
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
|
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
|
||||||
adapter_to_merge = model_args.adapter_name_or_path[:-1]
|
adapter_to_merge = model_args.adapter_name_or_path[:-1]
|
||||||
adapter_to_resume = model_args.adapter_name_or_path[-1]
|
adapter_to_resume = model_args.adapter_name_or_path[-1]
|
||||||
|
@ -3,7 +3,7 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
import random
|
import random
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
|
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
|
||||||
|
Loading…
x
Reference in New Issue
Block a user