mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-07 04:05:58 +08:00
[webui] support other hub (#8567)
This commit is contained in:
@@ -77,14 +77,19 @@ def load_config() -> dict[str, Union[str, dict[str, Any]]]:
|
||||
with open(_get_config_path(), encoding="utf-8") as f:
|
||||
return safe_load(f)
|
||||
except Exception:
|
||||
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
|
||||
return {"lang": None, "hub_name": None, "last_model": None, "path_dict": {}, "cache_dir": None}
|
||||
|
||||
|
||||
def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
|
||||
def save_config(
|
||||
lang: str, hub_name: Optional[str] = None, model_name: Optional[str] = None, model_path: Optional[str] = None
|
||||
) -> None:
|
||||
r"""Save user config."""
|
||||
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
||||
user_config = load_config()
|
||||
user_config["lang"] = lang or user_config["lang"]
|
||||
if hub_name:
|
||||
user_config["hub_name"] = hub_name
|
||||
|
||||
if model_name:
|
||||
user_config["last_model"] = model_name
|
||||
|
||||
@@ -247,7 +252,7 @@ def create_ds_config() -> None:
|
||||
"stage": 2,
|
||||
"allgather_partitions": True,
|
||||
"allgather_bucket_size": 5e8,
|
||||
"overlap_comm": True,
|
||||
"overlap_comm": False,
|
||||
"reduce_scatter": True,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"contiguous_gradients": True,
|
||||
@@ -262,7 +267,7 @@ def create_ds_config() -> None:
|
||||
|
||||
ds_config["zero_optimization"] = {
|
||||
"stage": 3,
|
||||
"overlap_comm": True,
|
||||
"overlap_comm": False,
|
||||
"contiguous_gradients": True,
|
||||
"sub_group_size": 1e9,
|
||||
"reduce_bucket_size": "auto",
|
||||
|
||||
Reference in New Issue
Block a user