refactor adapter hparam

This commit is contained in:
hiyouga
2023-12-15 20:53:11 +08:00
parent d4c351f1ec
commit 0716f5e470
21 changed files with 302 additions and 311 deletions

View File

@@ -2,14 +2,7 @@ import os
import json
import gradio as gr
from typing import Any, Dict, Optional
from transformers.utils import (
WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
ADAPTER_WEIGHTS_NAME,
ADAPTER_SAFE_WEIGHTS_NAME
)
from peft.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
from llmtuner.extras.constants import (
DEFAULT_MODULE,
@@ -22,18 +15,11 @@ from llmtuner.extras.misc import use_modelscope
from llmtuner.hparams.data_args import DATA_CONFIG
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}
DEFAULT_CACHE_DIR = "cache"
DEFAULT_DATA_DIR = "data"
DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user.config"
CKPT_NAMES = [
WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
ADAPTER_WEIGHTS_NAME,
ADAPTER_SAFE_WEIGHTS_NAME
]
def get_save_dir(*args) -> os.PathLike:
@@ -90,18 +76,18 @@ def get_template(model_name: str) -> str:
return "default"
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
checkpoints = []
if model_name:
def list_adapters(model_name: str, finetuning_type: str) -> Dict[str, Any]:
adapters = []
if model_name and finetuning_type == "lora": # full and freeze have no adapter
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for checkpoint in os.listdir(save_dir):
for adapter in os.listdir(save_dir):
if (
os.path.isdir(os.path.join(save_dir, checkpoint))
and any([os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CKPT_NAMES])
os.path.isdir(os.path.join(save_dir, adapter))
and any([os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES])
):
checkpoints.append(checkpoint)
return gr.update(value=[], choices=checkpoints)
adapters.append(adapter)
return gr.update(value=[], choices=adapters)
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]: