diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index f3db4d1e..d43e00f0 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING import torch from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model @@ -21,11 +21,11 @@ logger = get_logger(__name__) def init_adapter( config: "PretrainedConfig", - model: Union["PreTrainedModel"], + model: "PreTrainedModel", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", is_trainable: bool, -) -> Union["PreTrainedModel"]: +) -> "PreTrainedModel": r""" Initializes the adapters. diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 47298673..dd7eb44c 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -112,7 +112,7 @@ def load_model( finetuning_args: "FinetuningArguments", is_trainable: bool = False, add_valuehead: bool = False, -) -> Union["PreTrainedModel"]: +) -> "PreTrainedModel": r""" Loads pretrained model. """