Merge pull request #5871 from hiyouga/dev

[loss&ui] fix incorrect loss of vlms, add extra args to ui

Former-commit-id: b2d4b9a7a870aba92ed6e74b7805d606ddc0edbc
This commit is contained in:
hoshi-hiyouga 2024-10-30 17:13:17 +08:00 committed by GitHub
commit efda735f32
19 changed files with 91 additions and 82 deletions

View File

@ -1,5 +1,5 @@
transformers>=4.41.2,<=4.46.0 transformers>=4.41.2,<=4.46.1
datasets>=2.16.0,<=2.21.0 datasets>=2.16.0,<=3.0.2
accelerate>=0.34.0,<=1.0.1 accelerate>=0.34.0,<=1.0.1
peft>=0.11.1,<=0.12.0 peft>=0.11.1,<=0.12.0
trl>=0.8.6,<=0.9.6 trl>=0.8.6,<=0.9.6

View File

@ -20,17 +20,17 @@ Level:
Dependency graph: Dependency graph:
main: main:
transformers>=4.41.2,<=4.46.0 transformers>=4.41.2,<=4.46.1
datasets>=2.16.0,<=2.21.0 datasets>=2.16.0,<=3.0.2
accelerate>=0.34.0,<=1.0.1 accelerate>=0.34.0,<=1.0.1
peft>=0.11.1,<=0.12.0 peft>=0.11.1,<=0.12.0
trl>=0.8.6,<=0.9.6 trl>=0.8.6,<=0.9.6
attention: attention:
transformers>=4.42.4 (gemma+fa2) transformers>=4.42.4 (gemma+fa2)
longlora: longlora:
transformers>=4.41.2,<=4.46.0 transformers>=4.41.2,<=4.46.1
packing: packing:
transformers>=4.41.2,<=4.46.0 transformers>=4.41.2,<=4.46.1
Disable version checking: DISABLE_VERSION_CHECK=1 Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1 Enable VRAM recording: RECORD_VRAM=1

View File

@ -69,25 +69,24 @@ def _load_single_dataset(
if os.path.isdir(local_path): # is directory if os.path.isdir(local_path): # is directory
for file_name in os.listdir(local_path): for file_name in os.listdir(local_path):
data_files.append(os.path.join(local_path, file_name)) data_files.append(os.path.join(local_path, file_name))
if data_path is None:
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
raise ValueError("File types should be identical.")
elif os.path.isfile(local_path): # is file elif os.path.isfile(local_path): # is file
data_files.append(local_path) data_files.append(local_path)
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
else: else:
raise ValueError(f"File {local_path} not found.") raise ValueError(f"File {local_path} not found.")
data_path = FILEEXT2TYPE.get(os.path.splitext(data_files[0])[-1][1:], None)
if data_path is None: if data_path is None:
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys()))) raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
if any(data_path != FILEEXT2TYPE.get(os.path.splitext(data_file)[-1][1:], None) for data_file in data_files):
raise ValueError("File types should be identical.")
else: else:
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.") raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
if dataset_attr.load_from == "ms_hub": if dataset_attr.load_from == "ms_hub":
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
from modelscope import MsDataset from modelscope import MsDataset # type: ignore
from modelscope.utils.config_ds import MS_DATASETS_CACHE from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
dataset = MsDataset.load( dataset = MsDataset.load(
@ -98,15 +97,15 @@ def _load_single_dataset(
split=dataset_attr.split, split=dataset_attr.split,
cache_dir=cache_dir, cache_dir=cache_dir,
token=model_args.ms_hub_token, token=model_args.ms_hub_token,
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")), use_streaming=data_args.streaming,
) )
if isinstance(dataset, MsDataset): if isinstance(dataset, MsDataset):
dataset = dataset.to_hf_dataset() dataset = dataset.to_hf_dataset()
elif dataset_attr.load_from == "om_hub": elif dataset_attr.load_from == "om_hub":
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0") require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
from openmind import OmDataset from openmind import OmDataset # type: ignore
from openmind.utils.hub import OM_DATASETS_CACHE from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore
cache_dir = model_args.cache_dir or OM_DATASETS_CACHE cache_dir = model_args.cache_dir or OM_DATASETS_CACHE
dataset = OmDataset.load_dataset( dataset = OmDataset.load_dataset(
@ -117,7 +116,7 @@ def _load_single_dataset(
split=dataset_attr.split, split=dataset_attr.split,
cache_dir=cache_dir, cache_dir=cache_dir,
token=model_args.om_hub_token, token=model_args.om_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file")), streaming=data_args.streaming,
) )
else: else:
dataset = load_dataset( dataset = load_dataset(
@ -128,13 +127,10 @@ def _load_single_dataset(
split=dataset_attr.split, split=dataset_attr.split,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token, token=model_args.hf_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file")), streaming=data_args.streaming,
trust_remote_code=True, trust_remote_code=True,
) )
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
if dataset_attr.num_samples is not None and not data_args.streaming: if dataset_attr.num_samples is not None and not data_args.streaming:
target_num = dataset_attr.num_samples target_num = dataset_attr.num_samples
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included

View File

@ -471,9 +471,7 @@ class PixtralPlugin(BasePlugin):
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if image_input_sizes is None: if image_input_sizes is None:
raise ValueError( raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
"The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)
)
image_size = image_input_sizes[0][num_image_tokens] image_size = image_input_sizes[0][num_image_tokens]
height, width = image_size height, width = image_size
@ -489,7 +487,7 @@ class PixtralPlugin(BasePlugin):
message["content"] = content message["content"] = content
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
return messages return messages

View File

@ -356,10 +356,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
r""" r"""
Gets chat template and fixes the tokenizer. Gets chat template and fixes the tokenizer.
""" """
if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
require_version("accelerate>=0.34.0", "To fix: pip install accelerate>=0.34.0")
if data_args.template is None: if data_args.template is None:
template = TEMPLATES["empty"] # placeholder template = TEMPLATES["empty"] # placeholder
else: else:
@ -367,6 +363,9 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
if template is None: if template is None:
raise ValueError(f"Template {data_args.template} does not exist.") raise ValueError(f"Template {data_args.template} does not exist.")
if template.mm_plugin.__class__.__name__ != "BasePlugin":
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
if data_args.train_on_prompt and template.efficient_eos: if data_args.train_on_prompt and template.efficient_eos:
raise ValueError("Current template does not support `train_on_prompt`.") raise ValueError("Current template does not support `train_on_prompt`.")

View File

@ -79,8 +79,8 @@ def check_dependencies() -> None:
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.") logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
else: else:
require_version("transformers>=4.41.2,<=4.46.0", "To fix: pip install transformers>=4.41.2,<=4.46.0") require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0") require_version("datasets>=2.16.0,<=3.0.2", "To fix: pip install datasets>=2.16.0,<=3.0.2")
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1") require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0") require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6") require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
@ -237,7 +237,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
if use_modelscope(): if use_modelscope():
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
from modelscope import snapshot_download from modelscope import snapshot_download # type: ignore
revision = "master" if model_args.model_revision == "main" else model_args.model_revision revision = "master" if model_args.model_revision == "main" else model_args.model_revision
return snapshot_download( return snapshot_download(
@ -248,7 +248,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
if use_openmind(): if use_openmind():
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0") require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
from openmind.utils.hub import snapshot_download from openmind.utils.hub import snapshot_download # type: ignore
return snapshot_download( return snapshot_download(
model_args.model_name_or_path, model_args.model_name_or_path,

View File

@ -81,7 +81,7 @@ def is_transformers_version_greater_than_4_43():
@lru_cache @lru_cache
def is_transformers_version_equal_to_4_46(): def is_transformers_version_equal_to_4_46():
return _get_package_version("transformers") == version.parse("4.46.0") return version.parse("4.46.0") <= _get_package_version("transformers") <= version.parse("4.46.1")
def is_uvicorn_available(): def is_uvicorn_available():

View File

@ -353,7 +353,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None: def _apply_llama_patch() -> None:
require_version("transformers>=4.41.2,<=4.46.0", "To fix: pip install transformers>=4.41.2,<=4.46.0") require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
LlamaAttention.forward = llama_attention_forward LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward

View File

@ -114,7 +114,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
def _patch_for_block_diag_attn(model_type: str) -> None: def _patch_for_block_diag_attn(model_type: str) -> None:
require_version("transformers>=4.41.2,<=4.46.0", "To fix: pip install transformers>=4.41.2,<=4.46.0") require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
if is_transformers_version_greater_than_4_43(): if is_transformers_version_greater_than_4_43():
import transformers.modeling_flash_attention_utils import transformers.modeling_flash_attention_utils

View File

@ -101,7 +101,7 @@ class CustomDPOTrainer(DPOTrainer):
self.callback_handler.add_callback(PissaConvertCallback) self.callback_handler.add_callback(PissaConvertCallback)
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@ -274,7 +274,7 @@ class CustomDPOTrainer(DPOTrainer):
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605 https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
""" """
loss = super().compute_loss(model, inputs, return_outputs) loss = super().compute_loss(model, inputs, return_outputs)
if kwargs.pop("num_items_in_batch", False) and is_transformers_version_equal_to_4_46(): if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
loss /= self.args.gradient_accumulation_steps loss /= self.args.gradient_accumulation_steps
return loss return loss

View File

@ -96,7 +96,7 @@ class CustomKTOTrainer(KTOTrainer):
self.add_callback(SaveProcessorCallback(processor)) self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@ -247,7 +247,7 @@ class CustomKTOTrainer(KTOTrainer):
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605 https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
""" """
loss = super().compute_loss(model, inputs, return_outputs) loss = super().compute_loss(model, inputs, return_outputs)
if kwargs.pop("num_items_in_batch", False) and is_transformers_version_equal_to_4_46(): if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
loss /= self.args.gradient_accumulation_steps loss /= self.args.gradient_accumulation_steps
return loss return loss

View File

@ -181,7 +181,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.add_callback(SaveProcessorCallback(processor)) self.add_callback(SaveProcessorCallback(processor))
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)

View File

@ -19,6 +19,7 @@ from transformers import Trainer
from typing_extensions import override from typing_extensions import override
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
@ -51,7 +52,7 @@ class CustomTrainer(Trainer):
self.add_callback(PissaConvertCallback) self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@ -68,3 +69,15 @@ class CustomTrainer(Trainer):
) -> "torch.optim.lr_scheduler.LRScheduler": ) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer) create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
loss /= self.args.gradient_accumulation_steps # other model should not scale the loss
return loss

View File

@ -60,7 +60,7 @@ class PairwiseTrainer(Trainer):
self.add_callback(PissaConvertCallback) self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@ -100,7 +100,7 @@ class PairwiseTrainer(Trainer):
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean() loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
if kwargs.pop("num_items_in_batch", False) and is_transformers_version_equal_to_4_46(): if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0 loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0
if return_outputs: if return_outputs:

View File

@ -27,6 +27,7 @@ from typing_extensions import override
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ...extras.packages import is_transformers_version_equal_to_4_46
from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
@ -60,7 +61,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.add_callback(PissaConvertCallback) self.add_callback(PissaConvertCallback)
if finetuning_args.use_badam: if finetuning_args.use_badam:
from badam import BAdamCallback, clip_grad_norm_old_version from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@ -78,6 +79,18 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
create_custom_scheduler(self.args, num_training_steps, optimizer) create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
r"""
Fixes the loss value for transformers 4.46.0.
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
"""
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
loss /= self.args.gradient_accumulation_steps # other model should not scale the loss
return loss
@override @override
def prediction_step( def prediction_step(
self, self,

View File

@ -41,13 +41,12 @@ def create_top() -> Dict[str, "Component"]:
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1) finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
checkpoint_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=6) checkpoint_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=6)
with gr.Accordion(open=False) as advanced_tab: with gr.Row():
with gr.Row(): quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True, scale=2)
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True, scale=2) quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=2)
quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=2) template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=2)
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=2) rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3)
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3) booster = gr.Radio(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto", scale=5)
booster = gr.Radio(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto", scale=5)
model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then( model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then(
list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
@ -66,7 +65,6 @@ def create_top() -> Dict[str, "Component"]:
model_path=model_path, model_path=model_path,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
checkpoint_path=checkpoint_path, checkpoint_path=checkpoint_path,
advanced_tab=advanced_tab,
quantization_bit=quantization_bit, quantization_bit=quantization_bit,
quantization_method=quantization_method, quantization_method=quantization_method,
template=template, template=template,

View File

@ -91,7 +91,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
save_steps = gr.Slider(minimum=10, maximum=5000, value=100, step=10) save_steps = gr.Slider(minimum=10, maximum=5000, value=100, step=10)
warmup_steps = gr.Slider(minimum=0, maximum=5000, value=0, step=1) warmup_steps = gr.Slider(minimum=0, maximum=5000, value=0, step=1)
neftune_alpha = gr.Slider(minimum=0, maximum=10, value=0, step=0.1) neftune_alpha = gr.Slider(minimum=0, maximum=10, value=0, step=0.1)
optim = gr.Textbox(value="adamw_torch") extra_args = gr.Textbox(value='{"optim": "adamw_torch"}')
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@ -116,7 +116,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
save_steps, save_steps,
warmup_steps, warmup_steps,
neftune_alpha, neftune_alpha,
optim, extra_args,
packing, packing,
neat_packing, neat_packing,
train_on_prompt, train_on_prompt,
@ -134,7 +134,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
save_steps=save_steps, save_steps=save_steps,
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
neftune_alpha=neftune_alpha, neftune_alpha=neftune_alpha,
optim=optim, extra_args=extra_args,
packing=packing, packing=packing,
neat_packing=neat_packing, neat_packing=neat_packing,
train_on_prompt=train_on_prompt, train_on_prompt=train_on_prompt,

View File

@ -87,20 +87,6 @@ LOCALES = {
"label": "체크포인트 경로", "label": "체크포인트 경로",
}, },
}, },
"advanced_tab": {
"en": {
"label": "Advanced configurations",
},
"ru": {
"label": "Расширенные конфигурации",
},
"zh": {
"label": "高级设置",
},
"ko": {
"label": "고급 설정",
},
},
"quantization_bit": { "quantization_bit": {
"en": { "en": {
"label": "Quantization bit", "label": "Quantization bit",
@ -581,11 +567,11 @@ LOCALES = {
}, },
"neftune_alpha": { "neftune_alpha": {
"en": { "en": {
"label": "NEFTune Alpha", "label": "NEFTune alpha",
"info": "Magnitude of noise adding to embedding vectors.", "info": "Magnitude of noise adding to embedding vectors.",
}, },
"ru": { "ru": {
"label": "NEFTune Alpha", "label": "NEFTune alpha",
"info": "Величина шума, добавляемого к векторам вложений.", "info": "Величина шума, добавляемого к векторам вложений.",
}, },
"zh": { "zh": {
@ -597,22 +583,22 @@ LOCALES = {
"info": "임베딩 벡터에 추가되는 노이즈의 크기.", "info": "임베딩 벡터에 추가되는 노이즈의 크기.",
}, },
}, },
"optim": { "extra_args": {
"en": { "en": {
"label": "Optimizer", "label": "Extra arguments",
"info": "The optimizer to use: adamw_torch, adamw_8bit or adafactor.", "info": "Extra arguments passed to the trainer in JSON format.",
}, },
"ru": { "ru": {
"label": "Оптимизатор", "label": "Дополнительные аргументы",
"info": "Оптимизатор для использования: adamw_torch, adamw_8bit или adafactor.", "info": "Дополнительные аргументы, которые передаются тренеру в формате JSON.",
}, },
"zh": { "zh": {
"label": "优化器", "label": "额外参数",
"info": "使用的优化器adamw_torch、adamw_8bit 或 adafactor", "info": "以 JSON 格式传递给训练器的额外参数",
}, },
"ko": { "ko": {
"label": "옵티마이저", "label": "추가 인수",
"info": "사용할 옵티마이저: adamw_torch, adamw_8bit 또는 adafactor 등.", "info": "JSON 형식으로 트레이너에게 전달할 추가 인수입니다.",
}, },
}, },
"packing": { "packing": {

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import os import os
from copy import deepcopy from copy import deepcopy
from subprocess import Popen, TimeoutExpired from subprocess import Popen, TimeoutExpired
@ -78,6 +79,11 @@ class Runner:
if not get("train.output_dir"): if not get("train.output_dir"):
return ALERTS["err_no_output_dir"][lang] return ALERTS["err_no_output_dir"][lang]
try:
json.loads(get("train.extra_args"))
except json.JSONDecodeError:
return ALERTS["err_json_schema"][lang]
stage = TRAINING_STAGES[get("train.training_stage")] stage = TRAINING_STAGES[get("train.training_stage")]
if stage == "ppo" and not get("train.reward_model"): if stage == "ppo" and not get("train.reward_model"):
return ALERTS["err_no_reward_model"][lang] return ALERTS["err_no_reward_model"][lang]
@ -130,7 +136,6 @@ class Runner:
save_steps=get("train.save_steps"), save_steps=get("train.save_steps"),
warmup_steps=get("train.warmup_steps"), warmup_steps=get("train.warmup_steps"),
neftune_noise_alpha=get("train.neftune_alpha") or None, neftune_noise_alpha=get("train.neftune_alpha") or None,
optim=get("train.optim"),
packing=get("train.packing") or get("train.neat_packing"), packing=get("train.packing") or get("train.neat_packing"),
neat_packing=get("train.neat_packing"), neat_packing=get("train.neat_packing"),
train_on_prompt=get("train.train_on_prompt"), train_on_prompt=get("train.train_on_prompt"),
@ -148,6 +153,7 @@ class Runner:
plot_loss=True, plot_loss=True,
ddp_timeout=180000000, ddp_timeout=180000000,
include_num_input_tokens_seen=True, include_num_input_tokens_seen=True,
**json.loads(get("train.extra_args")),
) )
# checkpoints # checkpoints