This commit is contained in:
hiyouga
2024-06-06 03:33:44 +08:00
parent f2580ad403
commit 7daf8366db
5 changed files with 25 additions and 27 deletions

View File

@@ -174,11 +174,24 @@ def save_args(config_path: str, config_dict: Dict[str, Any]) -> str:
return str(get_arg_save_path(config_path))
def list_output_dirs(model_name: str, finetuning_type: str, initial_dir: str) -> "gr.Dropdown":
def list_config_paths(current_time: str) -> "gr.Dropdown":
r"""
Lists all the saved configuration files.
"""
config_files = ["{}.yaml".format(current_time)]
if os.path.isdir(DEFAULT_CONFIG_DIR):
for file_name in os.listdir(DEFAULT_CONFIG_DIR):
if file_name.endswith(".yaml"):
config_files.append(file_name)
return gr.Dropdown(choices=config_files)
def list_output_dirs(model_name: str, finetuning_type: str, current_time: str) -> "gr.Dropdown":
r"""
Lists all the directories that can resume from.
"""
output_dirs = [initial_dir]
output_dirs = ["train_{}".format(current_time)]
if model_name:
save_dir = get_save_dir(model_name, finetuning_type)
if save_dir and os.path.isdir(save_dir):
@@ -190,18 +203,6 @@ def list_output_dirs(model_name: str, finetuning_type: str, initial_dir: str) ->
return gr.Dropdown(choices=output_dirs)
def list_config_paths() -> "gr.Dropdown":
"""
Lists all the saved configuration files that can be loaded.
"""
if os.path.exists(DEFAULT_CONFIG_DIR) and os.path.isdir(DEFAULT_CONFIG_DIR):
config_files = [file_name for file_name in os.listdir(DEFAULT_CONFIG_DIR) if file_name.endswith(".yaml")]
else:
config_files = []
return gr.Dropdown(choices=config_files)
def check_output_dir(lang: str, model_name: str, finetuning_type: str, output_dir: str) -> None:
r"""
Check if output dir exists.