allow non-packing pretraining

Former-commit-id: bdb496644ce2c18806fc4fdae1fedcb3e5b5f808
This commit is contained in:
hiyouga 2024-03-09 22:21:46 +08:00
parent 1173441661
commit 868444e124
22 changed files with 64 additions and 67 deletions

View File

@ -59,7 +59,7 @@ class ChatCompletionMessage(BaseModel):
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: List[ChatMessage] messages: List[ChatMessage]
tools: Optional[list] = [] tools: list = []
do_sample: bool = True do_sample: bool = True
temperature: Optional[float] = None temperature: Optional[float] = None
top_p: Optional[float] = None top_p: Optional[float] = None

View File

@ -21,8 +21,11 @@ logger = get_logger(__name__)
def preprocess_pretrain_dataset( def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...` # build grouped texts with format `X1 X2 X3 ...` if packing is enabled
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]] text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
if not data_args.packing:
return tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
tokenized_examples = tokenizer(text_examples, add_special_tokens=False) tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
@ -245,7 +248,7 @@ def get_preprocess_and_print_func(
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args) preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
elif stage == "sft" and not training_args.predict_with_generate: elif stage == "sft" and not training_args.predict_with_generate:
if data_args.sft_packing: if data_args.packing:
preprocess_func = partial( preprocess_func = partial(
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
) )

View File

@ -36,8 +36,8 @@ class Template:
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000, cutoff_len: int = 1_000_000,
reserved_label_len: Optional[int] = 1, reserved_label_len: int = 1,
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
r""" r"""
Returns a single pair of token ids representing prompt and response respectively. Returns a single pair of token ids representing prompt and response respectively.
@ -56,8 +56,8 @@ class Template:
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000, cutoff_len: int = 1_000_000,
reserved_label_len: Optional[int] = 1, reserved_label_len: int = 1,
) -> Sequence[Tuple[List[int], List[int]]]: ) -> Sequence[Tuple[List[int], List[int]]]:
r""" r"""
Returns multiple pairs of token ids representing prompts and responses respectively. Returns multiple pairs of token ids representing prompts and responses respectively.
@ -207,11 +207,11 @@ def _register_template(
format_observation: Optional["Formatter"] = None, format_observation: Optional["Formatter"] = None,
format_tools: Optional["Formatter"] = None, format_tools: Optional["Formatter"] = None,
format_separator: Optional["Formatter"] = None, format_separator: Optional["Formatter"] = None,
default_system: Optional[str] = "", default_system: str = "",
stop_words: Optional[List[str]] = [], stop_words: List[str] = [],
efficient_eos: Optional[bool] = False, efficient_eos: bool = False,
replace_eos: Optional[bool] = False, replace_eos: bool = False,
force_system: Optional[bool] = False, force_system: bool = False,
) -> None: ) -> None:
r""" r"""
Registers a chat template. Registers a chat template.
@ -279,9 +279,7 @@ def _jinja_escape(content: str) -> str:
return content.replace("\n", r"\n").replace("'", r"\'") return content.replace("\n", r"\n").replace("'", r"\'")
def _convert_slots_to_jinja( def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: Optional[str] = "content"
) -> str:
slot_items = [] slot_items = []
for slot in slots: for slot in slots:
if isinstance(slot, str): if isinstance(slot, str):

View File

@ -1,7 +1,7 @@
import json import json
import math import math
import os import os
from typing import List, Optional from typing import List
from transformers.trainer import TRAINER_STATE_NAME from transformers.trainer import TRAINER_STATE_NAME
@ -30,7 +30,7 @@ def smooth(scalars: List[float]) -> List[float]:
return smoothed return smoothed
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None: def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)

View File

@ -78,9 +78,11 @@ class DataArguments:
default=0.0, default=0.0,
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}, metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."},
) )
sft_packing: bool = field( packing: Optional[bool] = field(
default=False, default=None,
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}, metadata={
"help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training."
},
) )
cache_path: Optional[str] = field( cache_path: Optional[str] = field(
default=None, default=None,

View File

@ -135,7 +135,6 @@ class ModelArguments:
) )
def __post_init__(self): def __post_init__(self):
self.aqlm_optimization = None
self.compute_dtype = None self.compute_dtype = None
self.device_map = None self.device_map = None
self.model_max_length = None self.model_max_length = None

View File

@ -230,7 +230,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
model_args.compute_dtype = torch.float16 model_args.compute_dtype = torch.float16
model_args.model_max_length = data_args.cutoff_len model_args.model_max_length = data_args.cutoff_len
model_args.aqlm_optimization = not training_args.predict_with_generate data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt"
# Log on each process the small summary: # Log on each process the small summary:
logger.info( logger.info(
@ -253,7 +253,6 @@ def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
_set_transformers_logging() _set_transformers_logging()
_verify_model_args(model_args, finetuning_args) _verify_model_args(model_args, finetuning_args)
model_args.aqlm_optimization = False
model_args.device_map = "auto" model_args.device_map = "auto"
if data_args.template is None: if data_args.template is None:
@ -267,7 +266,6 @@ def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
_set_transformers_logging() _set_transformers_logging()
_verify_model_args(model_args, finetuning_args) _verify_model_args(model_args, finetuning_args)
model_args.aqlm_optimization = True
model_args.device_map = "auto" model_args.device_map = "auto"
if data_args.template is None: if data_args.template is None:

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, Tuple
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
@ -52,8 +52,8 @@ def load_model(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments", model_args: "ModelArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False, is_trainable: bool = False,
add_valuehead: Optional[bool] = False, add_valuehead: bool = False,
) -> "PreTrainedModel": ) -> "PreTrainedModel":
r""" r"""
Loads pretrained model. Must after load_tokenizer. Loads pretrained model. Must after load_tokenizer.
@ -137,8 +137,8 @@ def load_model(
def load_model_and_tokenizer( def load_model_and_tokenizer(
model_args: "ModelArguments", model_args: "ModelArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False, is_trainable: bool = False,
add_valuehead: Optional[bool] = False, add_valuehead: bool = False,
) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]: ) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
r""" r"""
Loads pretrained model and tokenizer. Loads pretrained model and tokenizer.

View File

@ -3,7 +3,7 @@ import os
import random import random
from contextlib import nullcontext from contextlib import nullcontext
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Tuple
import torch import torch
from datasets import load_dataset from datasets import load_dataset
@ -219,7 +219,7 @@ def _configure_quantization(
def _prepare_model_for_training( def _prepare_model_for_training(
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: Optional[str] = "lm_head" model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
) -> None: ) -> None:
r""" r"""
Includes: Includes:

View File

@ -22,7 +22,7 @@ class CustomDPOTrainer(DPOTrainer):
ftx_gamma: float, ftx_gamma: float,
model: Union["PreTrainedModel", torch.nn.Module], model: Union["PreTrainedModel", torch.nn.Module],
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
disable_dropout: Optional[bool] = True, disable_dropout: bool = True,
**kwargs, **kwargs,
): ):
if disable_dropout: if disable_dropout:
@ -95,7 +95,7 @@ class CustomDPOTrainer(DPOTrainer):
self, self,
model: "PreTrainedModel", model: "PreTrainedModel",
batch: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor],
train_eval: Optional[Literal["train", "eval"]] = "train", train_eval: Literal["train", "eval"] = "train",
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
r""" r"""
Computes the DPO loss and other metrics for the given batch of inputs for train or test. Computes the DPO loss and other metrics for the given batch of inputs for train or test.

View File

@ -292,7 +292,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
queries: torch.Tensor, queries: torch.Tensor,
responses: torch.Tensor, responses: torch.Tensor,
model_inputs: dict, model_inputs: dict,
return_logits: Optional[bool] = False, return_logits: bool = False,
response_masks: Optional[torch.Tensor] = None, response_masks: Optional[torch.Tensor] = None,
): ):
r""" r"""

View File

@ -1,6 +1,6 @@
import json import json
import os import os
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Tuple, Union
import torch import torch
from transformers import Trainer from transformers import Trainer
@ -26,7 +26,7 @@ class PairwiseTrainer(Trainer):
self.can_return_loss = True # override property to return eval_loss self.can_return_loss = True # override property to return eval_loss
def compute_loss( def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: Optional[bool] = False self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
r""" r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected. Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.

View File

@ -46,7 +46,7 @@ def create_modelcard_and_push(
def create_ref_model( def create_ref_model(
model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: Optional[bool] = False model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: bool = False
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]: ) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
r""" r"""
Creates reference model for PPO/DPO training. Evaluation mode is not supported. Creates reference model for PPO/DPO training. Evaluation mode is not supported.

View File

@ -18,9 +18,7 @@ if TYPE_CHECKING:
class WebChatModel(ChatModel): class WebChatModel(ChatModel):
def __init__( def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
self, manager: "Manager", demo_mode: Optional[bool] = False, lazy_init: Optional[bool] = True
) -> None:
self.manager = manager self.manager = manager
self.demo_mode = demo_mode self.demo_mode = demo_mode
self.engine: Optional["BaseEngine"] = None self.engine: Optional["BaseEngine"] = None

View File

@ -104,10 +104,12 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
return {} return {}
def list_dataset( def list_dataset(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Dict[str, Any]:
dataset_dir: Optional[str] = None, training_stage: Optional[str] = list(TRAINING_STAGES.keys())[0]
) -> Dict[str, Any]:
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
ranking = TRAINING_STAGES[training_stage] in ["rm", "dpo"] ranking = TRAINING_STAGES[training_stage] in ["rm", "dpo"]
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking] datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
return gr.update(value=[], choices=datasets) return gr.update(value=[], choices=datasets)
def autoset_packing(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> Dict[str, Any]:
return gr.update(value=(TRAINING_STAGES[training_stage] == "pt"))

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Dict, Optional, Tuple from typing import TYPE_CHECKING, Dict, Tuple
import gradio as gr import gradio as gr
@ -14,7 +14,7 @@ if TYPE_CHECKING:
def create_chat_box( def create_chat_box(
engine: "Engine", visible: Optional[bool] = False engine: "Engine", visible: bool = False
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]: ) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
with gr.Box(visible=visible) as chat_box: with gr.Box(visible=visible) as chat_box:
chatbot = gr.Chatbot() chatbot = gr.Chatbot()

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict, Tuple
import gradio as gr import gradio as gr
@ -12,7 +12,7 @@ if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
def create_top() -> Dict[str, "Component"]: def create_top() -> Tuple["gr.Dropdown", Dict[str, "Component"]]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
with gr.Row(): with gr.Row():
@ -44,7 +44,7 @@ def create_top() -> Dict[str, "Component"]:
refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False) refresh_btn.click(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False)
return dict( return lang, dict(
lang=lang, lang=lang,
model_name=model_name, model_name=model_name,
model_path=model_path, model_path=model_path,

View File

@ -4,7 +4,7 @@ import gradio as gr
from transformers.trainer_utils import SchedulerType from transformers.trainer_utils import SchedulerType
from ...extras.constants import TRAINING_STAGES from ...extras.constants import TRAINING_STAGES
from ..common import DEFAULT_DATA_DIR, list_adapters, list_dataset from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
from ..components.data import create_preview_box from ..components.data import create_preview_box
from ..utils import gen_plot from ..utils import gen_plot
@ -78,7 +78,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
resize_vocab = gr.Checkbox() resize_vocab = gr.Checkbox()
sft_packing = gr.Checkbox() packing = gr.Checkbox()
upcast_layernorm = gr.Checkbox() upcast_layernorm = gr.Checkbox()
use_llama_pro = gr.Checkbox() use_llama_pro = gr.Checkbox()
shift_attn = gr.Checkbox() shift_attn = gr.Checkbox()
@ -91,7 +91,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
neftune_alpha, neftune_alpha,
optim, optim,
resize_vocab, resize_vocab,
sft_packing, packing,
upcast_layernorm, upcast_layernorm,
use_llama_pro, use_llama_pro,
shift_attn, shift_attn,
@ -106,7 +106,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
neftune_alpha=neftune_alpha, neftune_alpha=neftune_alpha,
optim=optim, optim=optim,
resize_vocab=resize_vocab, resize_vocab=resize_vocab,
sft_packing=sft_packing, packing=packing,
upcast_layernorm=upcast_layernorm, upcast_layernorm=upcast_layernorm,
use_llama_pro=use_llama_pro, use_llama_pro=use_llama_pro,
shift_attn=shift_attn, shift_attn=shift_attn,
@ -166,7 +166,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")], [engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
[reward_model], [reward_model],
queue=False, queue=False,
) ).then(autoset_packing, [training_stage], [packing], queue=False)
input_elems.update({dpo_beta, dpo_ftx, reward_model}) input_elems.update({dpo_beta, dpo_ftx, reward_model})
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model)) elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, dpo_ftx=dpo_ftx, reward_model=reward_model))

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, Generator, Optional from typing import Any, Dict, Generator
import gradio as gr import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here from gradio.components import Component # cannot use TYPE_CHECKING here
@ -12,7 +12,7 @@ from .utils import get_time
class Engine: class Engine:
def __init__(self, demo_mode: Optional[bool] = False, pure_chat: Optional[bool] = False) -> None: def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None:
self.demo_mode = demo_mode self.demo_mode = demo_mode
self.pure_chat = pure_chat self.pure_chat = pure_chat
self.manager = Manager() self.manager = Manager()

View File

@ -1,5 +1,3 @@
from typing import Optional
import gradio as gr import gradio as gr
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
@ -19,7 +17,7 @@ from .engine import Engine
require_version("gradio>=3.38.0,<4.0.0", 'To fix: pip install "gradio>=3.38.0,<4.0.0"') require_version("gradio>=3.38.0,<4.0.0", 'To fix: pip install "gradio>=3.38.0,<4.0.0"')
def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks: def create_ui(demo_mode: bool = False) -> gr.Blocks:
engine = Engine(demo_mode=demo_mode, pure_chat=False) engine = Engine(demo_mode=demo_mode, pure_chat=False)
with gr.Blocks(title="LLaMA Board", css=CSS) as demo: with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
@ -31,8 +29,7 @@ def create_ui(demo_mode: Optional[bool] = False) -> gr.Blocks:
) )
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button") gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
engine.manager.all_elems["top"] = create_top() lang, engine.manager.all_elems["top"] = create_top()
lang: "gr.Dropdown" = engine.manager.get_elem_by_name("top.lang")
with gr.Tab("Train"): with gr.Tab("Train"):
engine.manager.all_elems["train"] = create_train_tab(engine) engine.manager.all_elems["train"] = create_train_tab(engine)

View File

@ -480,18 +480,18 @@ LOCALES = {
"info": "更改分词器词表和嵌入层的大小。", "info": "更改分词器词表和嵌入层的大小。",
}, },
}, },
"sft_packing": { "packing": {
"en": { "en": {
"label": "Pack sequences", "label": "Pack sequences",
"info": "Pack sequences into samples of fixed length in supervised fine-tuning.", "info": "Pack sequences into samples of fixed length.",
}, },
"ru": { "ru": {
"label": "Упаковка последовательностей", "label": "Упаковка последовательностей",
"info": "Упаковка последовательностей в образцы фиксированной длины при контролируемой тонкой настройке.", "info": "Упаковка последовательностей в образцы фиксированной длины.",
}, },
"zh": { "zh": {
"label": "序列打包", "label": "序列打包",
"info": "在指令监督微调时将序列打包为等长样本。", "info": "将序列打包为等长样本。",
}, },
}, },
"upcast_layernorm": { "upcast_layernorm": {

View File

@ -2,7 +2,7 @@ import logging
import os import os
import time import time
from threading import Thread from threading import Thread
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, Generator, Tuple
import gradio as gr import gradio as gr
import transformers import transformers
@ -25,7 +25,7 @@ if TYPE_CHECKING:
class Runner: class Runner:
def __init__(self, manager: "Manager", demo_mode: Optional[bool] = False) -> None: def __init__(self, manager: "Manager", demo_mode: bool = False) -> None:
self.manager = manager self.manager = manager
self.demo_mode = demo_mode self.demo_mode = demo_mode
""" Resume """ """ Resume """
@ -136,7 +136,7 @@ class Runner:
neftune_noise_alpha=get("train.neftune_alpha") or None, neftune_noise_alpha=get("train.neftune_alpha") or None,
optim=get("train.optim"), optim=get("train.optim"),
resize_vocab=get("train.resize_vocab"), resize_vocab=get("train.resize_vocab"),
sft_packing=get("train.sft_packing"), packing=get("train.packing"),
upcast_layernorm=get("train.upcast_layernorm"), upcast_layernorm=get("train.upcast_layernorm"),
use_llama_pro=get("train.use_llama_pro"), use_llama_pro=get("train.use_llama_pro"),
shift_attn=get("train.shift_attn"), shift_attn=get("train.shift_attn"),