support llama pro #2338 , add rslora

This commit is contained in:
hiyouga
2024-02-15 02:27:36 +08:00
parent 8a1b389086
commit 7924ffc55d
24 changed files with 438 additions and 203 deletions

View File

@@ -1,5 +1,5 @@
import inspect
from typing import TYPE_CHECKING, Any, Dict, List
from typing import TYPE_CHECKING, Dict, List
import torch
from transformers import PreTrainedModel
@@ -13,7 +13,7 @@ from ..extras.misc import get_current_device
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedTokenizer
from ..hparams import DataArguments, FinetuningArguments, ModelArguments
from ..hparams import ModelArguments
logger = get_logger(__name__)
@@ -76,18 +76,6 @@ def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
return list(module_names)
def get_modelcard_args(
model_args: "ModelArguments", data_args: "DataArguments", finetuning_args: "FinetuningArguments"
) -> Dict[str, Any]:
return {
"tasks": "text-generation",
"license": "other",
"finetuned_from": model_args.model_name_or_path,
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
"tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else []),
}
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.