fix badam configs

Former-commit-id: 9433c8c215881692f318b89df03af97b4eda4dd5
This commit is contained in:
hiyouga 2024-05-02 02:47:04 +08:00
parent 931a30c7b8
commit ed8d9e0881
5 changed files with 44 additions and 69 deletions

View File

@ -221,16 +221,18 @@ class BAdamArgument:
default=None, default=None,
metadata={"help": "The starting block index for layer-wise BAdam."}, metadata={"help": "The starting block index for layer-wise BAdam."},
) )
badam_switch_block_every: Optional[int] = field(
default=50,
metadata={"help": "How often to switch model's block update. Set to -1 to disable the block update."},
)
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field( badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
default="ascending", default="ascending",
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."}, metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
) )
badam_switch_interval: Optional[int] = field(
default=50,
metadata={
"help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
},
)
badam_update_ratio: float = field( badam_update_ratio: float = field(
default=0.0, default=0.05,
metadata={"help": "The ratio of the update for ratio-wise BAdam."}, metadata={"help": "The ratio of the update for ratio-wise BAdam."},
) )
badam_mask_mode: Literal["adjacent", "scatter"] = field( badam_mask_mode: Literal["adjacent", "scatter"] = field(
@ -308,6 +310,9 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
if self.use_galore and self.finetuning_type == "lora": if self.use_galore and self.finetuning_type == "lora":
raise ValueError("Cannot use LoRA with GaLore together.") raise ValueError("Cannot use LoRA with GaLore together.")
if self.use_galore and self.use_badam:
raise ValueError("Cannot use GaLore with BAdam together.")
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora": if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
raise ValueError("`loraplus_lr_ratio` is only valid for the LoRA training.") raise ValueError("`loraplus_lr_ratio` is only valid for the LoRA training.")

View File

@ -317,14 +317,14 @@ def _create_badam_optimizer(
base_optimizer=base_optimizer, base_optimizer=base_optimizer,
named_parameters_list=list(model.named_parameters()), named_parameters_list=list(model.named_parameters()),
block_prefix_list=None, block_prefix_list=None,
switch_block_every=finetuning_args.badam_switch_block_every, switch_block_every=finetuning_args.badam_switch_interval,
start_block=finetuning_args.badam_start_block, start_block=finetuning_args.badam_start_block,
switch_mode=finetuning_args.badam_switch_mode, switch_mode=finetuning_args.badam_switch_mode,
verbose=finetuning_args.badam_verbose, verbose=finetuning_args.badam_verbose,
) )
logger.info( logger.info(
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, " f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
f"switch block every {finetuning_args.badam_switch_block_every} steps, " f"switch block every {finetuning_args.badam_switch_interval} steps, "
f"default start block is {finetuning_args.badam_start_block}" f"default start block is {finetuning_args.badam_start_block}"
) )

View File

@ -215,17 +215,17 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
use_badam = gr.Checkbox() use_badam = gr.Checkbox()
badam_mode = gr.Dropdown(choices=["layer", "ratio"], value="layer") badam_mode = gr.Dropdown(choices=["layer", "ratio"], value="layer")
badam_switch_mode = gr.Dropdown(choices=["ascending", "descending", "random", "fixed"], value="ascending") badam_switch_mode = gr.Dropdown(choices=["ascending", "descending", "random", "fixed"], value="ascending")
badam_switch_block_every = gr.Slider(value=50, minimum=-1, maximum=200, step=1) badam_switch_interval = gr.Slider(value=50, minimum=1, maximum=1024, step=1)
badam_update_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.01) badam_update_ratio = gr.Slider(value=0.05, minimum=0, maximum=1, step=0.01)
input_elems.update({use_badam, badam_mode, badam_switch_mode, badam_switch_block_every, badam_update_ratio}) input_elems.update({use_badam, badam_mode, badam_switch_mode, badam_switch_interval, badam_update_ratio})
elem_dict.update( elem_dict.update(
dict( dict(
badam_tab=badam_tab, badam_tab=badam_tab,
use_badam=use_badam, use_badam=use_badam,
badam_mode=badam_mode, badam_mode=badam_mode,
badam_switch_mode=badam_switch_mode, badam_switch_mode=badam_switch_mode,
badam_switch_block_every=badam_switch_block_every, badam_switch_interval=badam_switch_interval,
badam_update_ratio=badam_update_ratio, badam_update_ratio=badam_update_ratio,
) )
) )

View File

@ -905,15 +905,15 @@ LOCALES = {
"use_badam": { "use_badam": {
"en": { "en": {
"label": "Use BAdam", "label": "Use BAdam",
"info": "Enable the block coordinate optimization with Adam.", "info": "Enable the BAdam optimizer.",
}, },
"ru": { "ru": {
"label": "Использовать BAdam", "label": "Использовать BAdam",
"info": "Включите блочную оптимизацию координат с Adam.", "info": "Включите оптимизатор BAdam.",
}, },
"zh": { "zh": {
"label": "使用 BAdam", "label": "使用 BAdam",
"info": "使用多Block协同的Adam优化器。", "info": "使用 BAdam 优化器。",
}, },
}, },
"badam_mode": { "badam_mode": {
@ -923,25 +923,11 @@ LOCALES = {
}, },
"ru": { "ru": {
"label": "Режим BAdam", "label": "Режим BAdam",
"info": "Использовать оптимизатор BAdam с обработкой слоев или с обработкой коэффициентов.", "info": "Использовать ли оптимизатор BAdam с послоевой или пропорциональной настройкой.",
}, },
"zh": { "zh": {
"label": "BAdam 模式", "label": "BAdam 模式",
"info": "使用layer或者ratio比例模式。", "info": "使用 layer-wise 或 ratio-wise BAdam 优化器。",
},
},
"badam_switch_block_every": {
"en": {
"label": "Switch block frequency",
"info": "How often to switch model's block update. Set to -1 to disable the block update.",
},
"ru": {
"label": "Частота переключения",
"info": "Как часто переключать обновление блока модели. Установите -1, чтобы отключить обновление блока.",
},
"zh": {
"label": "切换block的频率",
"info": "控制切换block切换的频率如果是-1,则不切换。",
}, },
}, },
"badam_switch_mode": { "badam_switch_mode": {
@ -950,12 +936,26 @@ LOCALES = {
"info": "The strategy of picking block to update for layer-wise BAdam.", "info": "The strategy of picking block to update for layer-wise BAdam.",
}, },
"ru": { "ru": {
"label": "Переключить режим", "label": "Режим переключения",
"info": "Стратегия выбора блока для обновления в методе BAdam по слоям.", "info": "Стратегия выбора блока для обновления для послойного BAdam.",
}, },
"zh": { "zh": {
"label": "Block切换策略", "label": "切换策略",
"info": "如果是layer类型的训练模式如何切换block。", "info": "Layer-wise BAdam 优化器的块切换策略。",
},
},
"badam_switch_interval": {
"en": {
"label": "Switch interval",
"info": "Number of steps to update the block for layer-wise BAdam.",
},
"ru": {
"label": "Интервал переключения",
"info": "количество шагов для обновления блока для пошагового BAdam.",
},
"zh": {
"label": "切换频率",
"info": "Layer-wise BAdam 优化器的块切换频率。",
}, },
}, },
"badam_update_ratio": { "badam_update_ratio": {
@ -965,39 +965,11 @@ LOCALES = {
}, },
"ru": { "ru": {
"label": "Коэффициент обновления", "label": "Коэффициент обновления",
"info": "Коэффициент обновления для метода BAdam, основанного на коэффициентах.", "info": "Коэффициент обновления для BAdam с учётом соотношений.",
}, },
"zh": { "zh": {
"label": "Block更新比例", "label": "Block 更新比例",
"info": "如果是比例类型的训练模式block每次更新的范围比例。", "info": "Ratio-wise BAdam 优化器的更新比例。",
},
},
"badam_mask_mode": {
"en": {
"label": "Mask mode",
"info": "The mode of the mask for BAdam optimizer.",
},
"ru": {
"label": "Режим маски",
"info": "Режим маски для оптимизатора BAdam.",
},
"zh": {
"label": "Mask模式",
"info": "BAdam优化器内训练参数的mask关系。",
},
},
"badam_verbose": {
"en": {
"label": "Verbosity level",
"info": "0 for no print, 1 for print the block prefix, 2 for print trainable parameters.",
},
"ru": {
"label": "Уровень многословности",
"info": "0 для отсутствия печати, 1 для печати префикса блока, 2 для печати обучаемых параметров.",
},
"zh": {
"label": "输出日志级别",
"info": "0不输出1输出block前缀 1输出可训练的参数。",
}, },
}, },
"cmd_preview_btn": { "cmd_preview_btn": {

View File

@ -147,11 +147,11 @@ class Runner:
shift_attn=get("train.shift_attn"), shift_attn=get("train.shift_attn"),
report_to="all" if get("train.report_to") else "none", report_to="all" if get("train.report_to") else "none",
use_galore=get("train.use_galore"), use_galore=get("train.use_galore"),
use_badam=get("train.use_badam"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")), output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
fp16=(get("train.compute_type") == "fp16"), fp16=(get("train.compute_type") == "fp16"),
bf16=(get("train.compute_type") == "bf16"), bf16=(get("train.compute_type") == "bf16"),
pure_bf16=(get("train.compute_type") == "pure_bf16"), pure_bf16=(get("train.compute_type") == "pure_bf16"),
use_badam=get("train.use_badam"),
) )
args["disable_tqdm"] = True args["disable_tqdm"] = True
@ -201,11 +201,9 @@ class Runner:
if args["use_badam"]: if args["use_badam"]:
args["badam_mode"] = get("train.badam_mode") args["badam_mode"] = get("train.badam_mode")
args["badam_switch_block_every"] = get("train.badam_switch_block_every")
args["badam_switch_mode"] = get("train.badam_switch_mode") args["badam_switch_mode"] = get("train.badam_switch_mode")
args["badam_switch_interval"] = get("train.badam_switch_interval")
args["badam_update_ratio"] = get("train.badam_update_ratio") args["badam_update_ratio"] = get("train.badam_update_ratio")
args["badam_mask_mode"] = get("train.badam_mask_mode")
args["badam_verbose"] = get("train.badam_verbose")
return args return args