support BAdam in WebUI

This commit is contained in:
codingma
2024-04-28 11:31:34 +08:00
parent e898fabbe3
commit 26f7170393
3 changed files with 144 additions and 0 deletions

View File

@@ -151,6 +151,7 @@ class Runner:
fp16=(get("train.compute_type") == "fp16"),
bf16=(get("train.compute_type") == "bf16"),
pure_bf16=(get("train.compute_type") == "pure_bf16"),
use_badam=get("train.use_badam"),
)
args["disable_tqdm"] = True
@@ -198,6 +199,14 @@ class Runner:
args["galore_scale"] = get("train.galore_scale")
args["galore_target"] = get("train.galore_target")
if args["use_badam"]:
args["badam_mode"] = get("train.badam_mode")
args["badam_switch_block_every"] = get("train.badam_switch_block_every")
args["badam_switch_mode"] = get("train.badam_switch_mode")
args["badam_update_ratio"] = get("train.badam_update_ratio")
args["badam_mask_mode"] = get("train.badam_mask_mode")
args["badam_verbose"] = get("train.badam_verbose")
return args
def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: