fix chat engine, update webui

Former-commit-id: 5d956e2a5167201aecdfce2794c25d8a2d84e234
This commit is contained in:
hiyouga 2024-03-08 03:01:53 +08:00
parent 0a0959facf
commit 7443ac3116
9 changed files with 250 additions and 83 deletions

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
from ..hparams import get_infer_args from ..hparams import get_infer_args
@ -10,21 +11,24 @@ if TYPE_CHECKING:
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
asyncio.set_event_loop(loop)
loop.run_forever()
class ChatModel: class ChatModel:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, generating_args = get_infer_args(args) model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
if model_args.infer_backend == "hf": if model_args.infer_backend == "huggingface":
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args) self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == "vllm": elif model_args.infer_backend == "vllm":
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args) self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
else: else:
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend)) raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
def _get_event_loop(): self._loop = asyncio.new_event_loop()
try: self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
return asyncio.get_running_loop() self._thread.start()
except RuntimeError:
return asyncio.new_event_loop()
def chat( def chat(
self, self,
@ -33,8 +37,8 @@ class ChatModel:
tools: Optional[str] = None, tools: Optional[str] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
loop = self._get_event_loop() task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, **input_kwargs), self._loop)
return loop.run_until_complete(self.achat(messages, system, tools, **input_kwargs)) return task.result()
async def achat( async def achat(
self, self,
@ -52,11 +56,11 @@ class ChatModel:
tools: Optional[str] = None, tools: Optional[str] = None,
**input_kwargs, **input_kwargs,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
loop = self._get_event_loop()
generator = self.astream_chat(messages, system, tools, **input_kwargs) generator = self.astream_chat(messages, system, tools, **input_kwargs)
while True: while True:
try: try:
yield loop.run_until_complete(generator.__anext__()) task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
yield task.result()
except StopAsyncIteration: except StopAsyncIteration:
break break
@ -75,8 +79,8 @@ class ChatModel:
batch_input: List[str], batch_input: List[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ) -> List[float]:
loop = self._get_event_loop() task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
return loop.run_until_complete(self.aget_scores(batch_input, **input_kwargs)) return task.result()
async def aget_scores( async def aget_scores(
self, self,

View File

@ -147,7 +147,7 @@ class HuggingfaceEngine(BaseEngine):
) )
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer gen_kwargs["streamer"] = streamer
thread = Thread(target=model.generate, kwargs=gen_kwargs) thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
thread.start() thread.start()
def stream(): def stream():

View File

@ -155,20 +155,20 @@ class RLHFArguments:
@dataclass @dataclass
class GaloreArguments: class GaloreArguments:
r""" r"""
Arguments pertaining to the GaLore optimization. Arguments pertaining to the GaLore algorithm.
""" """
use_galore: bool = field( use_galore: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to use galore optimizer."}, metadata={"help": "Whether or not to use gradient low-Rank projection."},
) )
galore_target: str = field( galore_target: str = field(
default="mlp,attn", default="mlp,attn",
metadata={"help": "Name(s) of modules to apply GaLore."}, metadata={"help": "Name(s) of modules to apply GaLore. Use commas to separate multiple modules."},
) )
galore_rank: int = field( galore_rank: int = field(
default=16, default=16,
metadata={"help": "GaLore rank."}, metadata={"help": "The rank of GaLore gradients."},
) )
galore_update_interval: int = field( galore_update_interval: int = field(
default=200, default=200,
@ -176,7 +176,7 @@ class GaloreArguments:
) )
galore_scale: float = field( galore_scale: float = field(
default=0.25, default=0.25,
metadata={"help": "GaLore scale."}, metadata={"help": "GaLore scaling coefficient."},
) )
galore_proj_type: Literal["std", "reverse_std", "right", "left", "full"] = field( galore_proj_type: Literal["std", "reverse_std", "right", "left", "full"] = field(
default="std", default="std",

View File

@ -81,8 +81,8 @@ class ModelArguments:
default=False, default=False,
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."}, metadata={"help": "Whether or not to upcast the output of lm_head in fp32."},
) )
infer_backend: Literal["hf", "vllm"] = field( infer_backend: Literal["huggingface", "vllm"] = field(
default="hf", default="huggingface",
metadata={"help": "Backend engine used at inference."}, metadata={"help": "Backend engine used at inference."},
) )
vllm_maxlen: int = field( vllm_maxlen: int = field(

View File

@ -75,6 +75,7 @@ class WebChatModel(ChatModel):
flash_attn=(get("top.booster") == "flash_attn"), flash_attn=(get("top.booster") == "flash_attn"),
use_unsloth=(get("top.booster") == "unsloth"), use_unsloth=(get("top.booster") == "unsloth"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
infer_backend=get("infer.infer_backend"),
) )
super().__init__(args) super().__init__(args)

View File

@ -15,12 +15,15 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
input_elems = engine.manager.get_base_elems() input_elems = engine.manager.get_base_elems()
elem_dict = dict() elem_dict = dict()
infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface")
with gr.Row(): with gr.Row():
load_btn = gr.Button() load_btn = gr.Button()
unload_btn = gr.Button() unload_btn = gr.Button()
info_box = gr.Textbox(show_label=False, interactive=False) info_box = gr.Textbox(show_label=False, interactive=False)
elem_dict.update(dict(load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
input_elems.update({infer_backend})
elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False) chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
elem_dict.update(dict(chat_box=chat_box, **chat_elems)) elem_dict.update(dict(chat_box=chat_box, **chat_elems))

View File

@ -34,38 +34,38 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems)) elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
with gr.Row(): with gr.Row():
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=16384, step=1)
learning_rate = gr.Textbox(value="5e-5") learning_rate = gr.Textbox(value="5e-5")
num_train_epochs = gr.Textbox(value="3.0") num_train_epochs = gr.Textbox(value="3.0")
max_grad_norm = gr.Textbox(value="1.0")
max_samples = gr.Textbox(value="100000") max_samples = gr.Textbox(value="100000")
compute_type = gr.Radio(choices=["fp16", "bf16", "fp32"], value="fp16") compute_type = gr.Dropdown(choices=["fp16", "bf16", "fp32", "pure_bf16"], value="fp16")
input_elems.update({cutoff_len, learning_rate, num_train_epochs, max_samples, compute_type}) input_elems.update({learning_rate, num_train_epochs, max_grad_norm, max_samples, compute_type})
elem_dict.update( elem_dict.update(
dict( dict(
cutoff_len=cutoff_len,
learning_rate=learning_rate, learning_rate=learning_rate,
num_train_epochs=num_train_epochs, num_train_epochs=num_train_epochs,
max_grad_norm=max_grad_norm,
max_samples=max_samples, max_samples=max_samples,
compute_type=compute_type, compute_type=compute_type,
) )
) )
with gr.Row(): with gr.Row():
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=16384, step=1)
batch_size = gr.Slider(value=2, minimum=1, maximum=1024, step=1) batch_size = gr.Slider(value=2, minimum=1, maximum=1024, step=1)
gradient_accumulation_steps = gr.Slider(value=8, minimum=1, maximum=1024, step=1) gradient_accumulation_steps = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine")
max_grad_norm = gr.Textbox(value="1.0")
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001) val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine")
input_elems.update({batch_size, gradient_accumulation_steps, lr_scheduler_type, max_grad_norm, val_size}) input_elems.update({cutoff_len, batch_size, gradient_accumulation_steps, val_size, lr_scheduler_type})
elem_dict.update( elem_dict.update(
dict( dict(
cutoff_len=cutoff_len,
batch_size=batch_size, batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps, gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type,
max_grad_norm=max_grad_norm,
val_size=val_size, val_size=val_size,
lr_scheduler_type=lr_scheduler_type,
) )
) )
@ -75,12 +75,14 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10) save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1) warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
neftune_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1) neftune_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1)
optim = gr.Textbox(value="adamw_torch")
with gr.Row(): with gr.Row():
resize_vocab = gr.Checkbox() resize_vocab = gr.Checkbox()
sft_packing = gr.Checkbox() sft_packing = gr.Checkbox()
upcast_layernorm = gr.Checkbox() upcast_layernorm = gr.Checkbox()
use_llama_pro = gr.Checkbox() use_llama_pro = gr.Checkbox()
shift_attn = gr.Checkbox()
input_elems.update( input_elems.update(
{ {
@ -88,10 +90,12 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
save_steps, save_steps,
warmup_steps, warmup_steps,
neftune_alpha, neftune_alpha,
optim,
resize_vocab, resize_vocab,
sft_packing, sft_packing,
upcast_layernorm, upcast_layernorm,
use_llama_pro, use_llama_pro,
shift_attn,
} }
) )
elem_dict.update( elem_dict.update(
@ -101,10 +105,12 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
save_steps=save_steps, save_steps=save_steps,
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
neftune_alpha=neftune_alpha, neftune_alpha=neftune_alpha,
optim=optim,
resize_vocab=resize_vocab, resize_vocab=resize_vocab,
sft_packing=sft_packing, sft_packing=sft_packing,
upcast_layernorm=upcast_layernorm, upcast_layernorm=upcast_layernorm,
use_llama_pro=use_llama_pro, use_llama_pro=use_llama_pro,
shift_attn=shift_attn,
) )
) )
@ -169,6 +175,26 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn) dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn)
) )
with gr.Accordion(label="GaLore config", open=False) as galore_tab:
with gr.Row():
use_galore = gr.Checkbox(scale=1)
galore_rank = gr.Slider(value=16, minimum=1, maximum=1024, step=1, scale=2)
galore_update_interval = gr.Slider(value=200, minimum=1, maximum=1024, step=1, scale=2)
galore_scale = gr.Slider(value=0.25, minimum=0, maximum=1, step=0.01, scale=2)
galore_target = gr.Textbox(value="mlp,attn", scale=3)
input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
elem_dict.update(
dict(
galore_tab=galore_tab,
use_galore=use_galore,
galore_rank=galore_rank,
galore_update_interval=galore_update_interval,
galore_scale=galore_scale,
galore_target=galore_target,
)
)
with gr.Row(): with gr.Row():
cmd_preview_btn = gr.Button() cmd_preview_btn = gr.Button()
start_btn = gr.Button() start_btn = gr.Button()

View File

@ -245,20 +245,6 @@ LOCALES = {
"label": "样例", "label": "样例",
}, },
}, },
"cutoff_len": {
"en": {
"label": "Cutoff length",
"info": "Max tokens in input sequence.",
},
"ru": {
"label": "Длина обрезки",
"info": "Максимальное количество токенов во входной последовательности.",
},
"zh": {
"label": "截断长度",
"info": "输入序列分词后的最大长度。",
},
},
"learning_rate": { "learning_rate": {
"en": { "en": {
"label": "Learning rate", "label": "Learning rate",
@ -287,6 +273,20 @@ LOCALES = {
"info": "需要执行的训练总轮数。", "info": "需要执行的训练总轮数。",
}, },
}, },
"max_grad_norm": {
"en": {
"label": "Maximum gradient norm",
"info": "Norm for gradient clipping.",
},
"ru": {
"label": "Максимальная норма градиента",
"info": "Норма для обрезки градиента.",
},
"zh": {
"label": "最大梯度范数",
"info": "用于梯度裁剪的范数。",
},
},
"max_samples": { "max_samples": {
"en": { "en": {
"label": "Max samples", "label": "Max samples",
@ -304,15 +304,29 @@ LOCALES = {
"compute_type": { "compute_type": {
"en": { "en": {
"label": "Compute type", "label": "Compute type",
"info": "Whether to use mixed precision training (fp16 or bf16).", "info": "Whether to use mixed precision training.",
}, },
"ru": { "ru": {
"label": "Тип вычислений", "label": "Тип вычислений",
"info": "Использовать ли обучение смешанной точности fp16 или bf16.", "info": "Использовать ли обучение смешанной точности.",
}, },
"zh": { "zh": {
"label": "计算类型", "label": "计算类型",
"info": "是否使用混合精度训练fp16 或 bf16", "info": "是否使用混合精度训练。",
},
},
"cutoff_len": {
"en": {
"label": "Cutoff length",
"info": "Max tokens in input sequence.",
},
"ru": {
"label": "Длина обрезки",
"info": "Максимальное количество токенов во входной последовательности.",
},
"zh": {
"label": "截断长度",
"info": "输入序列分词后的最大长度。",
}, },
}, },
"batch_size": { "batch_size": {
@ -343,34 +357,6 @@ LOCALES = {
"info": "梯度累积的步数。", "info": "梯度累积的步数。",
}, },
}, },
"lr_scheduler_type": {
"en": {
"label": "LR scheduler",
"info": "Name of the learning rate scheduler.",
},
"ru": {
"label": "Планировщик скорости обучения",
"info": "Название планировщика скорости обучения.",
},
"zh": {
"label": "学习率调节器",
"info": "学习率调度器的名称。",
},
},
"max_grad_norm": {
"en": {
"label": "Maximum gradient norm",
"info": "Norm for gradient clipping.",
},
"ru": {
"label": "Максимальная норма градиента",
"info": "Норма для обрезки градиента.",
},
"zh": {
"label": "最大梯度范数",
"info": "用于梯度裁剪的范数。",
},
},
"val_size": { "val_size": {
"en": { "en": {
"label": "Val size", "label": "Val size",
@ -385,6 +371,20 @@ LOCALES = {
"info": "验证集占全部样本的百分比。", "info": "验证集占全部样本的百分比。",
}, },
}, },
"lr_scheduler_type": {
"en": {
"label": "LR scheduler",
"info": "Name of the learning rate scheduler.",
},
"ru": {
"label": "Планировщик скорости обучения",
"info": "Название планировщика скорости обучения.",
},
"zh": {
"label": "学习率调节器",
"info": "学习率调度器的名称。",
},
},
"extra_tab": { "extra_tab": {
"en": { "en": {
"label": "Extra configurations", "label": "Extra configurations",
@ -452,6 +452,20 @@ LOCALES = {
"info": "嵌入向量所添加的噪声大小。", "info": "嵌入向量所添加的噪声大小。",
}, },
}, },
"optim": {
"en": {
"label": "Optimizer",
"info": "The optimizer to use: adamw_torch, adamw_8bit or adafactor.",
},
"ru": {
"label": "Оптимизатор",
"info": "Оптимизатор для использования: adamw_torch, adamw_8bit или adafactor.",
},
"zh": {
"label": "优化器",
"info": "使用的优化器adamw_torch、adamw_8bit 或 adafactor。",
},
},
"resize_vocab": { "resize_vocab": {
"en": { "en": {
"label": "Resize token embeddings", "label": "Resize token embeddings",
@ -508,6 +522,20 @@ LOCALES = {
"info": "仅训练块扩展后的参数。", "info": "仅训练块扩展后的参数。",
}, },
}, },
"shift_attn": {
"en": {
"label": "Enable S^2 Attention",
"info": "Use shift short attention proposed by LongLoRA.",
},
"ru": {
"label": "Включить S^2 внимание",
"info": "Использовать сдвиг внимания на короткие дистанции предложенный LongLoRA.",
},
"zh": {
"label": "使用 S^2 Attention",
"info": "使用 LongLoRA 提出的 shift short attention。",
},
},
"freeze_tab": { "freeze_tab": {
"en": { "en": {
"label": "Freeze tuning configurations", "label": "Freeze tuning configurations",
@ -569,16 +597,16 @@ LOCALES = {
}, },
"zh": { "zh": {
"label": "LoRA 秩", "label": "LoRA 秩",
"info": "LoRA 矩阵的秩", "info": "LoRA 矩阵的秩大小",
}, },
}, },
"lora_alpha": { "lora_alpha": {
"en": { "en": {
"label": "LoRA Alpha", "label": "LoRA alpha",
"info": "Lora scaling coefficient.", "info": "Lora scaling coefficient.",
}, },
"ru": { "ru": {
"label": "LoRA Alpha", "label": "LoRA alpha",
"info": "Коэффициент масштабирования LoRA.", "info": "Коэффициент масштабирования LoRA.",
}, },
"zh": { "zh": {
@ -588,7 +616,7 @@ LOCALES = {
}, },
"lora_dropout": { "lora_dropout": {
"en": { "en": {
"label": "LoRA Dropout", "label": "LoRA dropout",
"info": "Dropout ratio of LoRA weights.", "info": "Dropout ratio of LoRA weights.",
}, },
"ru": { "ru": {
@ -603,15 +631,15 @@ LOCALES = {
"lora_target": { "lora_target": {
"en": { "en": {
"label": "LoRA modules (optional)", "label": "LoRA modules (optional)",
"info": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules.", "info": "Name(s) of modules to apply LoRA. Use commas to separate multiple modules.",
}, },
"ru": { "ru": {
"label": "Модули LoRA (опционально)", "label": "Модули LoRA (опционально)",
"info": "Имена целевых модулей для применения LoRA. Используйте запятые для разделения нескольких модулей.", "info": "Имена модулей для применения LoRA. Используйте запятые для разделения нескольких модулей.",
}, },
"zh": { "zh": {
"label": "LoRA 作用模块(非必填)", "label": "LoRA 作用模块(非必填)",
"info": "应用 LoRA 的目标模块名称。使用英文逗号分隔多个名称。", "info": "应用 LoRA 的模块名称。使用英文逗号分隔多个名称。",
}, },
}, },
"use_rslora": { "use_rslora": {
@ -659,7 +687,10 @@ LOCALES = {
"additional_target": { "additional_target": {
"en": { "en": {
"label": "Additional modules (optional)", "label": "Additional modules (optional)",
"info": "Name(s) of modules apart from LoRA layers to be set as trainable. Use commas to separate multiple modules.", "info": (
"Name(s) of modules apart from LoRA layers to be set as trainable. "
"Use commas to separate multiple modules."
),
}, },
"ru": { "ru": {
"label": "Дополнительные модули (опционально)", "label": "Дополнительные модули (опционально)",
@ -726,6 +757,87 @@ LOCALES = {
"info": "PPO 训练中奖励模型的适配器路径。(需要刷新适配器)", "info": "PPO 训练中奖励模型的适配器路径。(需要刷新适配器)",
}, },
}, },
"galore_tab": {
"en": {
"label": "GaLore configurations",
},
"ru": {
"label": "Конфигурации GaLore",
},
"zh": {
"label": "GaLore 参数设置",
},
},
"use_galore": {
"en": {
"label": "Use GaLore",
"info": "Enable gradient low-Rank projection.",
},
"ru": {
"label": "Использовать GaLore",
"info": "Включить проекцию градиента на низкоранговое пространство.",
},
"zh": {
"label": "使用 GaLore",
"info": "使用梯度低秩投影。",
},
},
"galore_rank": {
"en": {
"label": "GaLore rank",
"info": "The rank of GaLore gradients.",
},
"ru": {
"label": "Ранг GaLore",
"info": "Ранг градиентов GaLore.",
},
"zh": {
"label": "GaLore 秩",
"info": "GaLore 梯度的秩大小。",
},
},
"galore_update_interval": {
"en": {
"label": "Update interval",
"info": "Number of steps to update the GaLore projection.",
},
"ru": {
"label": "Интервал обновления",
"info": "Количество шагов для обновления проекции GaLore.",
},
"zh": {
"label": "更新间隔",
"info": "相邻两次投影更新的步数。",
},
},
"galore_scale": {
"en": {
"label": "GaLore scale",
"info": "GaLore scaling coefficient.",
},
"ru": {
"label": "LoRA Alpha",
"info": "Коэффициент масштабирования GaLore.",
},
"zh": {
"label": "GaLore 缩放系数",
"info": "GaLore 缩放系数大小。",
},
},
"galore_target": {
"en": {
"label": "GaLore modules",
"info": "Name(s) of modules to apply GaLore. Use commas to separate multiple modules.",
},
"ru": {
"label": "Модули GaLore",
"info": "Имена модулей для применения GaLore. Используйте запятые для разделения нескольких модулей.",
},
"zh": {
"label": "GaLore 作用模块",
"info": "应用 GaLore 的模块名称。使用英文逗号分隔多个名称。",
},
},
"cmd_preview_btn": { "cmd_preview_btn": {
"en": { "en": {
"value": "Preview command", "value": "Preview command",
@ -806,6 +918,17 @@ LOCALES = {
"label": "保存预测结果", "label": "保存预测结果",
}, },
}, },
"infer_backend": {
"en": {
"label": "Inference engine",
},
"ru": {
"label": "Инференс движок",
},
"zh": {
"label": "推理引擎",
},
},
"load_btn": { "load_btn": {
"en": { "en": {
"value": "Load model", "value": "Load model",

View File

@ -129,13 +129,17 @@ class Runner:
save_steps=get("train.save_steps"), save_steps=get("train.save_steps"),
warmup_steps=get("train.warmup_steps"), warmup_steps=get("train.warmup_steps"),
neftune_noise_alpha=get("train.neftune_alpha") or None, neftune_noise_alpha=get("train.neftune_alpha") or None,
optim=get("train.optim"),
resize_vocab=get("train.resize_vocab"), resize_vocab=get("train.resize_vocab"),
sft_packing=get("train.sft_packing"), sft_packing=get("train.sft_packing"),
upcast_layernorm=get("train.upcast_layernorm"), upcast_layernorm=get("train.upcast_layernorm"),
use_llama_pro=get("train.use_llama_pro"), use_llama_pro=get("train.use_llama_pro"),
shift_attn=get("train.shift_attn"),
use_galore=get("train.use_galore"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")), output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
fp16=(get("train.compute_type") == "fp16"), fp16=(get("train.compute_type") == "fp16"),
bf16=(get("train.compute_type") == "bf16"), bf16=(get("train.compute_type") == "bf16"),
pure_bf16=(get("train.compute_type") == "pure_bf16"),
) )
args["disable_tqdm"] = True args["disable_tqdm"] = True
@ -175,6 +179,12 @@ class Runner:
args["per_device_eval_batch_size"] = args["per_device_train_batch_size"] args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]
args["load_best_model_at_end"] = args["stage"] not in ["rm", "ppo"] args["load_best_model_at_end"] = args["stage"] not in ["rm", "ppo"]
if args["use_galore"]:
args["galore_rank"] = get("train.galore_rank")
args["galore_update_interval"] = get("train.galore_update_interval")
args["galore_scale"] = get("train.galore_scale")
args["galore_target"] = get("train.galore_target")
return args return args
def _parse_eval_args(self, data: Dict[Component, Any]) -> Dict[str, Any]: def _parse_eval_args(self, data: Dict[Component, Any]) -> Dict[str, Any]: