mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
refactor adapter hparam
This commit is contained in:
@@ -2,14 +2,7 @@ import os
|
||||
import json
|
||||
import gradio as gr
|
||||
from typing import Any, Dict, Optional
|
||||
from transformers.utils import (
|
||||
WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
ADAPTER_WEIGHTS_NAME,
|
||||
ADAPTER_SAFE_WEIGHTS_NAME
|
||||
)
|
||||
from peft.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
|
||||
|
||||
from llmtuner.extras.constants import (
|
||||
DEFAULT_MODULE,
|
||||
@@ -22,18 +15,11 @@ from llmtuner.extras.misc import use_modelscope
|
||||
from llmtuner.hparams.data_args import DATA_CONFIG
|
||||
|
||||
|
||||
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}
|
||||
DEFAULT_CACHE_DIR = "cache"
|
||||
DEFAULT_DATA_DIR = "data"
|
||||
DEFAULT_SAVE_DIR = "saves"
|
||||
USER_CONFIG = "user.config"
|
||||
CKPT_NAMES = [
|
||||
WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
ADAPTER_WEIGHTS_NAME,
|
||||
ADAPTER_SAFE_WEIGHTS_NAME
|
||||
]
|
||||
|
||||
|
||||
def get_save_dir(*args) -> os.PathLike:
|
||||
@@ -90,18 +76,18 @@ def get_template(model_name: str) -> str:
|
||||
return "default"
|
||||
|
||||
|
||||
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
|
||||
checkpoints = []
|
||||
if model_name:
|
||||
def list_adapters(model_name: str, finetuning_type: str) -> Dict[str, Any]:
|
||||
adapters = []
|
||||
if model_name and finetuning_type == "lora": # full and freeze have no adapter
|
||||
save_dir = get_save_dir(model_name, finetuning_type)
|
||||
if save_dir and os.path.isdir(save_dir):
|
||||
for checkpoint in os.listdir(save_dir):
|
||||
for adapter in os.listdir(save_dir):
|
||||
if (
|
||||
os.path.isdir(os.path.join(save_dir, checkpoint))
|
||||
and any([os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CKPT_NAMES])
|
||||
os.path.isdir(os.path.join(save_dir, adapter))
|
||||
and any([os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES])
|
||||
):
|
||||
checkpoints.append(checkpoint)
|
||||
return gr.update(value=[], choices=checkpoints)
|
||||
adapters.append(adapter)
|
||||
return gr.update(value=[], choices=adapters)
|
||||
|
||||
|
||||
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
||||
|
||||
Reference in New Issue
Block a user