mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-14 17:12:48 +08:00
fix chat engine, update webui
Former-commit-id: 5d956e2a5167201aecdfce2794c25d8a2d84e234
This commit is contained in:
parent
0a0959facf
commit
7443ac3116
@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
from threading import Thread
|
||||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
||||||
|
|
||||||
from ..hparams import get_infer_args
|
from ..hparams import get_infer_args
|
||||||
@ -10,21 +11,24 @@ if TYPE_CHECKING:
|
|||||||
from .base_engine import BaseEngine, Response
|
from .base_engine import BaseEngine, Response
|
||||||
|
|
||||||
|
|
||||||
|
def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
loop.run_forever()
|
||||||
|
|
||||||
|
|
||||||
class ChatModel:
|
class ChatModel:
|
||||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||||
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
||||||
if model_args.infer_backend == "hf":
|
if model_args.infer_backend == "huggingface":
|
||||||
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
||||||
elif model_args.infer_backend == "vllm":
|
elif model_args.infer_backend == "vllm":
|
||||||
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
|
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
|
||||||
|
|
||||||
def _get_event_loop():
|
self._loop = asyncio.new_event_loop()
|
||||||
try:
|
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
||||||
return asyncio.get_running_loop()
|
self._thread.start()
|
||||||
except RuntimeError:
|
|
||||||
return asyncio.new_event_loop()
|
|
||||||
|
|
||||||
def chat(
|
def chat(
|
||||||
self,
|
self,
|
||||||
@ -33,8 +37,8 @@ class ChatModel:
|
|||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
loop = self._get_event_loop()
|
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, **input_kwargs), self._loop)
|
||||||
return loop.run_until_complete(self.achat(messages, system, tools, **input_kwargs))
|
return task.result()
|
||||||
|
|
||||||
async def achat(
|
async def achat(
|
||||||
self,
|
self,
|
||||||
@ -52,11 +56,11 @@ class ChatModel:
|
|||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
loop = self._get_event_loop()
|
|
||||||
generator = self.astream_chat(messages, system, tools, **input_kwargs)
|
generator = self.astream_chat(messages, system, tools, **input_kwargs)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
yield loop.run_until_complete(generator.__anext__())
|
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
|
||||||
|
yield task.result()
|
||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -75,8 +79,8 @@ class ChatModel:
|
|||||||
batch_input: List[str],
|
batch_input: List[str],
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List[float]:
|
) -> List[float]:
|
||||||
loop = self._get_event_loop()
|
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
|
||||||
return loop.run_until_complete(self.aget_scores(batch_input, **input_kwargs))
|
return task.result()
|
||||||
|
|
||||||
async def aget_scores(
|
async def aget_scores(
|
||||||
self,
|
self,
|
||||||
|
@ -147,7 +147,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
)
|
)
|
||||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
gen_kwargs["streamer"] = streamer
|
gen_kwargs["streamer"] = streamer
|
||||||
thread = Thread(target=model.generate, kwargs=gen_kwargs)
|
thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
def stream():
|
def stream():
|
||||||
|
@ -155,20 +155,20 @@ class RLHFArguments:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class GaloreArguments:
|
class GaloreArguments:
|
||||||
r"""
|
r"""
|
||||||
Arguments pertaining to the GaLore optimization.
|
Arguments pertaining to the GaLore algorithm.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
use_galore: bool = field(
|
use_galore: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to use galore optimizer."},
|
metadata={"help": "Whether or not to use gradient low-Rank projection."},
|
||||||
)
|
)
|
||||||
galore_target: str = field(
|
galore_target: str = field(
|
||||||
default="mlp,attn",
|
default="mlp,attn",
|
||||||
metadata={"help": "Name(s) of modules to apply GaLore."},
|
metadata={"help": "Name(s) of modules to apply GaLore. Use commas to separate multiple modules."},
|
||||||
)
|
)
|
||||||
galore_rank: int = field(
|
galore_rank: int = field(
|
||||||
default=16,
|
default=16,
|
||||||
metadata={"help": "GaLore rank."},
|
metadata={"help": "The rank of GaLore gradients."},
|
||||||
)
|
)
|
||||||
galore_update_interval: int = field(
|
galore_update_interval: int = field(
|
||||||
default=200,
|
default=200,
|
||||||
@ -176,7 +176,7 @@ class GaloreArguments:
|
|||||||
)
|
)
|
||||||
galore_scale: float = field(
|
galore_scale: float = field(
|
||||||
default=0.25,
|
default=0.25,
|
||||||
metadata={"help": "GaLore scale."},
|
metadata={"help": "GaLore scaling coefficient."},
|
||||||
)
|
)
|
||||||
galore_proj_type: Literal["std", "reverse_std", "right", "left", "full"] = field(
|
galore_proj_type: Literal["std", "reverse_std", "right", "left", "full"] = field(
|
||||||
default="std",
|
default="std",
|
||||||
|
@ -81,8 +81,8 @@ class ModelArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."},
|
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."},
|
||||||
)
|
)
|
||||||
infer_backend: Literal["hf", "vllm"] = field(
|
infer_backend: Literal["huggingface", "vllm"] = field(
|
||||||
default="hf",
|
default="huggingface",
|
||||||
metadata={"help": "Backend engine used at inference."},
|
metadata={"help": "Backend engine used at inference."},
|
||||||
)
|
)
|
||||||
vllm_maxlen: int = field(
|
vllm_maxlen: int = field(
|
||||||
|
@ -75,6 +75,7 @@ class WebChatModel(ChatModel):
|
|||||||
flash_attn=(get("top.booster") == "flash_attn"),
|
flash_attn=(get("top.booster") == "flash_attn"),
|
||||||
use_unsloth=(get("top.booster") == "unsloth"),
|
use_unsloth=(get("top.booster") == "unsloth"),
|
||||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||||
|
infer_backend=get("infer.infer_backend"),
|
||||||
)
|
)
|
||||||
super().__init__(args)
|
super().__init__(args)
|
||||||
|
|
||||||
|
@ -15,12 +15,15 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
input_elems = engine.manager.get_base_elems()
|
input_elems = engine.manager.get_base_elems()
|
||||||
elem_dict = dict()
|
elem_dict = dict()
|
||||||
|
|
||||||
|
infer_backend = gr.Dropdown(choices=["huggingface", "vllm"], value="huggingface")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
load_btn = gr.Button()
|
load_btn = gr.Button()
|
||||||
unload_btn = gr.Button()
|
unload_btn = gr.Button()
|
||||||
|
|
||||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||||
elem_dict.update(dict(load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
|
|
||||||
|
input_elems.update({infer_backend})
|
||||||
|
elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
|
||||||
|
|
||||||
chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
|
chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
|
||||||
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
|
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
|
||||||
|
@ -34,38 +34,38 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
|
elem_dict.update(dict(training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=16384, step=1)
|
|
||||||
learning_rate = gr.Textbox(value="5e-5")
|
learning_rate = gr.Textbox(value="5e-5")
|
||||||
num_train_epochs = gr.Textbox(value="3.0")
|
num_train_epochs = gr.Textbox(value="3.0")
|
||||||
|
max_grad_norm = gr.Textbox(value="1.0")
|
||||||
max_samples = gr.Textbox(value="100000")
|
max_samples = gr.Textbox(value="100000")
|
||||||
compute_type = gr.Radio(choices=["fp16", "bf16", "fp32"], value="fp16")
|
compute_type = gr.Dropdown(choices=["fp16", "bf16", "fp32", "pure_bf16"], value="fp16")
|
||||||
|
|
||||||
input_elems.update({cutoff_len, learning_rate, num_train_epochs, max_samples, compute_type})
|
input_elems.update({learning_rate, num_train_epochs, max_grad_norm, max_samples, compute_type})
|
||||||
elem_dict.update(
|
elem_dict.update(
|
||||||
dict(
|
dict(
|
||||||
cutoff_len=cutoff_len,
|
|
||||||
learning_rate=learning_rate,
|
learning_rate=learning_rate,
|
||||||
num_train_epochs=num_train_epochs,
|
num_train_epochs=num_train_epochs,
|
||||||
|
max_grad_norm=max_grad_norm,
|
||||||
max_samples=max_samples,
|
max_samples=max_samples,
|
||||||
compute_type=compute_type,
|
compute_type=compute_type,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=16384, step=1)
|
||||||
batch_size = gr.Slider(value=2, minimum=1, maximum=1024, step=1)
|
batch_size = gr.Slider(value=2, minimum=1, maximum=1024, step=1)
|
||||||
gradient_accumulation_steps = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
|
gradient_accumulation_steps = gr.Slider(value=8, minimum=1, maximum=1024, step=1)
|
||||||
lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine")
|
|
||||||
max_grad_norm = gr.Textbox(value="1.0")
|
|
||||||
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
||||||
|
lr_scheduler_type = gr.Dropdown(choices=[scheduler.value for scheduler in SchedulerType], value="cosine")
|
||||||
|
|
||||||
input_elems.update({batch_size, gradient_accumulation_steps, lr_scheduler_type, max_grad_norm, val_size})
|
input_elems.update({cutoff_len, batch_size, gradient_accumulation_steps, val_size, lr_scheduler_type})
|
||||||
elem_dict.update(
|
elem_dict.update(
|
||||||
dict(
|
dict(
|
||||||
|
cutoff_len=cutoff_len,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||||
lr_scheduler_type=lr_scheduler_type,
|
|
||||||
max_grad_norm=max_grad_norm,
|
|
||||||
val_size=val_size,
|
val_size=val_size,
|
||||||
|
lr_scheduler_type=lr_scheduler_type,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -75,12 +75,14 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
||||||
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
|
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
|
||||||
neftune_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1)
|
neftune_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1)
|
||||||
|
optim = gr.Textbox(value="adamw_torch")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
resize_vocab = gr.Checkbox()
|
resize_vocab = gr.Checkbox()
|
||||||
sft_packing = gr.Checkbox()
|
sft_packing = gr.Checkbox()
|
||||||
upcast_layernorm = gr.Checkbox()
|
upcast_layernorm = gr.Checkbox()
|
||||||
use_llama_pro = gr.Checkbox()
|
use_llama_pro = gr.Checkbox()
|
||||||
|
shift_attn = gr.Checkbox()
|
||||||
|
|
||||||
input_elems.update(
|
input_elems.update(
|
||||||
{
|
{
|
||||||
@ -88,10 +90,12 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
save_steps,
|
save_steps,
|
||||||
warmup_steps,
|
warmup_steps,
|
||||||
neftune_alpha,
|
neftune_alpha,
|
||||||
|
optim,
|
||||||
resize_vocab,
|
resize_vocab,
|
||||||
sft_packing,
|
sft_packing,
|
||||||
upcast_layernorm,
|
upcast_layernorm,
|
||||||
use_llama_pro,
|
use_llama_pro,
|
||||||
|
shift_attn,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
elem_dict.update(
|
elem_dict.update(
|
||||||
@ -101,10 +105,12 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
save_steps=save_steps,
|
save_steps=save_steps,
|
||||||
warmup_steps=warmup_steps,
|
warmup_steps=warmup_steps,
|
||||||
neftune_alpha=neftune_alpha,
|
neftune_alpha=neftune_alpha,
|
||||||
|
optim=optim,
|
||||||
resize_vocab=resize_vocab,
|
resize_vocab=resize_vocab,
|
||||||
sft_packing=sft_packing,
|
sft_packing=sft_packing,
|
||||||
upcast_layernorm=upcast_layernorm,
|
upcast_layernorm=upcast_layernorm,
|
||||||
use_llama_pro=use_llama_pro,
|
use_llama_pro=use_llama_pro,
|
||||||
|
shift_attn=shift_attn,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -169,6 +175,26 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn)
|
dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model, refresh_btn=refresh_btn)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with gr.Accordion(label="GaLore config", open=False) as galore_tab:
|
||||||
|
with gr.Row():
|
||||||
|
use_galore = gr.Checkbox(scale=1)
|
||||||
|
galore_rank = gr.Slider(value=16, minimum=1, maximum=1024, step=1, scale=2)
|
||||||
|
galore_update_interval = gr.Slider(value=200, minimum=1, maximum=1024, step=1, scale=2)
|
||||||
|
galore_scale = gr.Slider(value=0.25, minimum=0, maximum=1, step=0.01, scale=2)
|
||||||
|
galore_target = gr.Textbox(value="mlp,attn", scale=3)
|
||||||
|
|
||||||
|
input_elems.update({use_galore, galore_rank, galore_update_interval, galore_scale, galore_target})
|
||||||
|
elem_dict.update(
|
||||||
|
dict(
|
||||||
|
galore_tab=galore_tab,
|
||||||
|
use_galore=use_galore,
|
||||||
|
galore_rank=galore_rank,
|
||||||
|
galore_update_interval=galore_update_interval,
|
||||||
|
galore_scale=galore_scale,
|
||||||
|
galore_target=galore_target,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
cmd_preview_btn = gr.Button()
|
cmd_preview_btn = gr.Button()
|
||||||
start_btn = gr.Button()
|
start_btn = gr.Button()
|
||||||
|
@ -245,20 +245,6 @@ LOCALES = {
|
|||||||
"label": "样例",
|
"label": "样例",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"cutoff_len": {
|
|
||||||
"en": {
|
|
||||||
"label": "Cutoff length",
|
|
||||||
"info": "Max tokens in input sequence.",
|
|
||||||
},
|
|
||||||
"ru": {
|
|
||||||
"label": "Длина обрезки",
|
|
||||||
"info": "Максимальное количество токенов во входной последовательности.",
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "截断长度",
|
|
||||||
"info": "输入序列分词后的最大长度。",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"learning_rate": {
|
"learning_rate": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Learning rate",
|
"label": "Learning rate",
|
||||||
@ -287,6 +273,20 @@ LOCALES = {
|
|||||||
"info": "需要执行的训练总轮数。",
|
"info": "需要执行的训练总轮数。",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"max_grad_norm": {
|
||||||
|
"en": {
|
||||||
|
"label": "Maximum gradient norm",
|
||||||
|
"info": "Norm for gradient clipping.",
|
||||||
|
},
|
||||||
|
"ru": {
|
||||||
|
"label": "Максимальная норма градиента",
|
||||||
|
"info": "Норма для обрезки градиента.",
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "最大梯度范数",
|
||||||
|
"info": "用于梯度裁剪的范数。",
|
||||||
|
},
|
||||||
|
},
|
||||||
"max_samples": {
|
"max_samples": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Max samples",
|
"label": "Max samples",
|
||||||
@ -304,15 +304,29 @@ LOCALES = {
|
|||||||
"compute_type": {
|
"compute_type": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Compute type",
|
"label": "Compute type",
|
||||||
"info": "Whether to use mixed precision training (fp16 or bf16).",
|
"info": "Whether to use mixed precision training.",
|
||||||
},
|
},
|
||||||
"ru": {
|
"ru": {
|
||||||
"label": "Тип вычислений",
|
"label": "Тип вычислений",
|
||||||
"info": "Использовать ли обучение смешанной точности fp16 или bf16.",
|
"info": "Использовать ли обучение смешанной точности.",
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {
|
||||||
"label": "计算类型",
|
"label": "计算类型",
|
||||||
"info": "是否使用混合精度训练(fp16 或 bf16)。",
|
"info": "是否使用混合精度训练。",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"cutoff_len": {
|
||||||
|
"en": {
|
||||||
|
"label": "Cutoff length",
|
||||||
|
"info": "Max tokens in input sequence.",
|
||||||
|
},
|
||||||
|
"ru": {
|
||||||
|
"label": "Длина обрезки",
|
||||||
|
"info": "Максимальное количество токенов во входной последовательности.",
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "截断长度",
|
||||||
|
"info": "输入序列分词后的最大长度。",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"batch_size": {
|
"batch_size": {
|
||||||
@ -343,34 +357,6 @@ LOCALES = {
|
|||||||
"info": "梯度累积的步数。",
|
"info": "梯度累积的步数。",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"lr_scheduler_type": {
|
|
||||||
"en": {
|
|
||||||
"label": "LR scheduler",
|
|
||||||
"info": "Name of the learning rate scheduler.",
|
|
||||||
},
|
|
||||||
"ru": {
|
|
||||||
"label": "Планировщик скорости обучения",
|
|
||||||
"info": "Название планировщика скорости обучения.",
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "学习率调节器",
|
|
||||||
"info": "学习率调度器的名称。",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"max_grad_norm": {
|
|
||||||
"en": {
|
|
||||||
"label": "Maximum gradient norm",
|
|
||||||
"info": "Norm for gradient clipping.",
|
|
||||||
},
|
|
||||||
"ru": {
|
|
||||||
"label": "Максимальная норма градиента",
|
|
||||||
"info": "Норма для обрезки градиента.",
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "最大梯度范数",
|
|
||||||
"info": "用于梯度裁剪的范数。",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"val_size": {
|
"val_size": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Val size",
|
"label": "Val size",
|
||||||
@ -385,6 +371,20 @@ LOCALES = {
|
|||||||
"info": "验证集占全部样本的百分比。",
|
"info": "验证集占全部样本的百分比。",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"lr_scheduler_type": {
|
||||||
|
"en": {
|
||||||
|
"label": "LR scheduler",
|
||||||
|
"info": "Name of the learning rate scheduler.",
|
||||||
|
},
|
||||||
|
"ru": {
|
||||||
|
"label": "Планировщик скорости обучения",
|
||||||
|
"info": "Название планировщика скорости обучения.",
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "学习率调节器",
|
||||||
|
"info": "学习率调度器的名称。",
|
||||||
|
},
|
||||||
|
},
|
||||||
"extra_tab": {
|
"extra_tab": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Extra configurations",
|
"label": "Extra configurations",
|
||||||
@ -452,6 +452,20 @@ LOCALES = {
|
|||||||
"info": "嵌入向量所添加的噪声大小。",
|
"info": "嵌入向量所添加的噪声大小。",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"optim": {
|
||||||
|
"en": {
|
||||||
|
"label": "Optimizer",
|
||||||
|
"info": "The optimizer to use: adamw_torch, adamw_8bit or adafactor.",
|
||||||
|
},
|
||||||
|
"ru": {
|
||||||
|
"label": "Оптимизатор",
|
||||||
|
"info": "Оптимизатор для использования: adamw_torch, adamw_8bit или adafactor.",
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "优化器",
|
||||||
|
"info": "使用的优化器:adamw_torch、adamw_8bit 或 adafactor。",
|
||||||
|
},
|
||||||
|
},
|
||||||
"resize_vocab": {
|
"resize_vocab": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Resize token embeddings",
|
"label": "Resize token embeddings",
|
||||||
@ -508,6 +522,20 @@ LOCALES = {
|
|||||||
"info": "仅训练块扩展后的参数。",
|
"info": "仅训练块扩展后的参数。",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"shift_attn": {
|
||||||
|
"en": {
|
||||||
|
"label": "Enable S^2 Attention",
|
||||||
|
"info": "Use shift short attention proposed by LongLoRA.",
|
||||||
|
},
|
||||||
|
"ru": {
|
||||||
|
"label": "Включить S^2 внимание",
|
||||||
|
"info": "Использовать сдвиг внимания на короткие дистанции предложенный LongLoRA.",
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "使用 S^2 Attention",
|
||||||
|
"info": "使用 LongLoRA 提出的 shift short attention。",
|
||||||
|
},
|
||||||
|
},
|
||||||
"freeze_tab": {
|
"freeze_tab": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Freeze tuning configurations",
|
"label": "Freeze tuning configurations",
|
||||||
@ -569,16 +597,16 @@ LOCALES = {
|
|||||||
},
|
},
|
||||||
"zh": {
|
"zh": {
|
||||||
"label": "LoRA 秩",
|
"label": "LoRA 秩",
|
||||||
"info": "LoRA 矩阵的秩。",
|
"info": "LoRA 矩阵的秩大小。",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"lora_alpha": {
|
"lora_alpha": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "LoRA Alpha",
|
"label": "LoRA alpha",
|
||||||
"info": "Lora scaling coefficient.",
|
"info": "Lora scaling coefficient.",
|
||||||
},
|
},
|
||||||
"ru": {
|
"ru": {
|
||||||
"label": "LoRA Alpha",
|
"label": "LoRA alpha",
|
||||||
"info": "Коэффициент масштабирования LoRA.",
|
"info": "Коэффициент масштабирования LoRA.",
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {
|
||||||
@ -588,7 +616,7 @@ LOCALES = {
|
|||||||
},
|
},
|
||||||
"lora_dropout": {
|
"lora_dropout": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "LoRA Dropout",
|
"label": "LoRA dropout",
|
||||||
"info": "Dropout ratio of LoRA weights.",
|
"info": "Dropout ratio of LoRA weights.",
|
||||||
},
|
},
|
||||||
"ru": {
|
"ru": {
|
||||||
@ -603,15 +631,15 @@ LOCALES = {
|
|||||||
"lora_target": {
|
"lora_target": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "LoRA modules (optional)",
|
"label": "LoRA modules (optional)",
|
||||||
"info": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules.",
|
"info": "Name(s) of modules to apply LoRA. Use commas to separate multiple modules.",
|
||||||
},
|
},
|
||||||
"ru": {
|
"ru": {
|
||||||
"label": "Модули LoRA (опционально)",
|
"label": "Модули LoRA (опционально)",
|
||||||
"info": "Имена целевых модулей для применения LoRA. Используйте запятые для разделения нескольких модулей.",
|
"info": "Имена модулей для применения LoRA. Используйте запятые для разделения нескольких модулей.",
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {
|
||||||
"label": "LoRA 作用模块(非必填)",
|
"label": "LoRA 作用模块(非必填)",
|
||||||
"info": "应用 LoRA 的目标模块名称。使用英文逗号分隔多个名称。",
|
"info": "应用 LoRA 的模块名称。使用英文逗号分隔多个名称。",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"use_rslora": {
|
"use_rslora": {
|
||||||
@ -659,7 +687,10 @@ LOCALES = {
|
|||||||
"additional_target": {
|
"additional_target": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Additional modules (optional)",
|
"label": "Additional modules (optional)",
|
||||||
"info": "Name(s) of modules apart from LoRA layers to be set as trainable. Use commas to separate multiple modules.",
|
"info": (
|
||||||
|
"Name(s) of modules apart from LoRA layers to be set as trainable. "
|
||||||
|
"Use commas to separate multiple modules."
|
||||||
|
),
|
||||||
},
|
},
|
||||||
"ru": {
|
"ru": {
|
||||||
"label": "Дополнительные модули (опционально)",
|
"label": "Дополнительные модули (опционально)",
|
||||||
@ -726,6 +757,87 @@ LOCALES = {
|
|||||||
"info": "PPO 训练中奖励模型的适配器路径。(需要刷新适配器)",
|
"info": "PPO 训练中奖励模型的适配器路径。(需要刷新适配器)",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"galore_tab": {
|
||||||
|
"en": {
|
||||||
|
"label": "GaLore configurations",
|
||||||
|
},
|
||||||
|
"ru": {
|
||||||
|
"label": "Конфигурации GaLore",
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "GaLore 参数设置",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"use_galore": {
|
||||||
|
"en": {
|
||||||
|
"label": "Use GaLore",
|
||||||
|
"info": "Enable gradient low-Rank projection.",
|
||||||
|
},
|
||||||
|
"ru": {
|
||||||
|
"label": "Использовать GaLore",
|
||||||
|
"info": "Включить проекцию градиента на низкоранговое пространство.",
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "使用 GaLore",
|
||||||
|
"info": "使用梯度低秩投影。",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"galore_rank": {
|
||||||
|
"en": {
|
||||||
|
"label": "GaLore rank",
|
||||||
|
"info": "The rank of GaLore gradients.",
|
||||||
|
},
|
||||||
|
"ru": {
|
||||||
|
"label": "Ранг GaLore",
|
||||||
|
"info": "Ранг градиентов GaLore.",
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "GaLore 秩",
|
||||||
|
"info": "GaLore 梯度的秩大小。",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"galore_update_interval": {
|
||||||
|
"en": {
|
||||||
|
"label": "Update interval",
|
||||||
|
"info": "Number of steps to update the GaLore projection.",
|
||||||
|
},
|
||||||
|
"ru": {
|
||||||
|
"label": "Интервал обновления",
|
||||||
|
"info": "Количество шагов для обновления проекции GaLore.",
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "更新间隔",
|
||||||
|
"info": "相邻两次投影更新的步数。",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"galore_scale": {
|
||||||
|
"en": {
|
||||||
|
"label": "GaLore scale",
|
||||||
|
"info": "GaLore scaling coefficient.",
|
||||||
|
},
|
||||||
|
"ru": {
|
||||||
|
"label": "LoRA Alpha",
|
||||||
|
"info": "Коэффициент масштабирования GaLore.",
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "GaLore 缩放系数",
|
||||||
|
"info": "GaLore 缩放系数大小。",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"galore_target": {
|
||||||
|
"en": {
|
||||||
|
"label": "GaLore modules",
|
||||||
|
"info": "Name(s) of modules to apply GaLore. Use commas to separate multiple modules.",
|
||||||
|
},
|
||||||
|
"ru": {
|
||||||
|
"label": "Модули GaLore",
|
||||||
|
"info": "Имена модулей для применения GaLore. Используйте запятые для разделения нескольких модулей.",
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "GaLore 作用模块",
|
||||||
|
"info": "应用 GaLore 的模块名称。使用英文逗号分隔多个名称。",
|
||||||
|
},
|
||||||
|
},
|
||||||
"cmd_preview_btn": {
|
"cmd_preview_btn": {
|
||||||
"en": {
|
"en": {
|
||||||
"value": "Preview command",
|
"value": "Preview command",
|
||||||
@ -806,6 +918,17 @@ LOCALES = {
|
|||||||
"label": "保存预测结果",
|
"label": "保存预测结果",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"infer_backend": {
|
||||||
|
"en": {
|
||||||
|
"label": "Inference engine",
|
||||||
|
},
|
||||||
|
"ru": {
|
||||||
|
"label": "Инференс движок",
|
||||||
|
},
|
||||||
|
"zh": {
|
||||||
|
"label": "推理引擎",
|
||||||
|
},
|
||||||
|
},
|
||||||
"load_btn": {
|
"load_btn": {
|
||||||
"en": {
|
"en": {
|
||||||
"value": "Load model",
|
"value": "Load model",
|
||||||
|
@ -129,13 +129,17 @@ class Runner:
|
|||||||
save_steps=get("train.save_steps"),
|
save_steps=get("train.save_steps"),
|
||||||
warmup_steps=get("train.warmup_steps"),
|
warmup_steps=get("train.warmup_steps"),
|
||||||
neftune_noise_alpha=get("train.neftune_alpha") or None,
|
neftune_noise_alpha=get("train.neftune_alpha") or None,
|
||||||
|
optim=get("train.optim"),
|
||||||
resize_vocab=get("train.resize_vocab"),
|
resize_vocab=get("train.resize_vocab"),
|
||||||
sft_packing=get("train.sft_packing"),
|
sft_packing=get("train.sft_packing"),
|
||||||
upcast_layernorm=get("train.upcast_layernorm"),
|
upcast_layernorm=get("train.upcast_layernorm"),
|
||||||
use_llama_pro=get("train.use_llama_pro"),
|
use_llama_pro=get("train.use_llama_pro"),
|
||||||
|
shift_attn=get("train.shift_attn"),
|
||||||
|
use_galore=get("train.use_galore"),
|
||||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
|
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")),
|
||||||
fp16=(get("train.compute_type") == "fp16"),
|
fp16=(get("train.compute_type") == "fp16"),
|
||||||
bf16=(get("train.compute_type") == "bf16"),
|
bf16=(get("train.compute_type") == "bf16"),
|
||||||
|
pure_bf16=(get("train.compute_type") == "pure_bf16"),
|
||||||
)
|
)
|
||||||
args["disable_tqdm"] = True
|
args["disable_tqdm"] = True
|
||||||
|
|
||||||
@ -175,6 +179,12 @@ class Runner:
|
|||||||
args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]
|
args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]
|
||||||
args["load_best_model_at_end"] = args["stage"] not in ["rm", "ppo"]
|
args["load_best_model_at_end"] = args["stage"] not in ["rm", "ppo"]
|
||||||
|
|
||||||
|
if args["use_galore"]:
|
||||||
|
args["galore_rank"] = get("train.galore_rank")
|
||||||
|
args["galore_update_interval"] = get("train.galore_update_interval")
|
||||||
|
args["galore_scale"] = get("train.galore_scale")
|
||||||
|
args["galore_target"] = get("train.galore_target")
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def _parse_eval_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
|
def _parse_eval_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user