support lora target auto find

This commit is contained in:
hiyouga
2023-09-09 15:38:37 +08:00
parent d8d82ca281
commit bca1a247bc
11 changed files with 117 additions and 72 deletions

View File

@@ -16,8 +16,8 @@ USER_CONFIG = "user.config"
DATA_CONFIG = "dataset_info.json"
def get_save_dir(model_name: str) -> str:
return os.path.join(DEFAULT_SAVE_DIR, os.path.split(model_name)[-1])
def get_save_dir(*args) -> os.PathLike:
return os.path.join(DEFAULT_SAVE_DIR, *args)
def get_config_path() -> os.PathLike:
@@ -29,7 +29,7 @@ def load_config() -> Dict[str, Any]:
with open(get_config_path(), "r", encoding="utf-8") as f:
return json.load(f)
except:
return {"lang": "", "last_model": "", "path_dict": {}}
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
def save_config(lang: str, model_name: str, model_path: str) -> None:
@@ -56,7 +56,7 @@ def get_template(model_name: str) -> str:
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
checkpoints = []
save_dir = os.path.join(get_save_dir(model_name), finetuning_type)
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
for checkpoint in os.listdir(save_dir):
if (