[v1] add batch generator (#9744)

This commit is contained in:
Yaowei Zheng
2026-01-10 04:24:09 +08:00
committed by GitHub
parent d7d734d54c
commit b2effbd77c
26 changed files with 604 additions and 850 deletions

View File

@@ -13,7 +13,7 @@
# limitations under the License.
from .arg_parser import InputArgument, get_args
from .arg_utils import ModelClass, SampleBackend
from .arg_utils import BatchingStrategy, ModelClass, SampleBackend
from .data_args import DataArguments
from .model_args import ModelArguments
from .sample_args import SampleArguments
@@ -21,6 +21,7 @@ from .training_args import TrainingArguments
__all__ = [
"BatchingStrategy",
"DataArguments",
"InputArgument",
"ModelArguments",

View File

@@ -50,6 +50,14 @@ class SampleBackend(StrEnum):
VLLM = "vllm"
@unique
class BatchingStrategy(StrEnum):
NORMAL = "normal"
PADDING_FREE = "padding_free"
DYNAMIC_BATCHING = "dynamic_batching"
DYNAMIC_PADDING_FREE = "dynamic_padding_free"
def _convert_str_dict(data: dict) -> dict:
"""Parse string representation inside the dictionary.

View File

@@ -22,7 +22,3 @@ class DataArguments:
default=None,
metadata={"help": "Path to the dataset."},
)
cutoff_len: int = field(
default=2048,
metadata={"help": "Cutoff length for the dataset."},
)

View File

@@ -16,7 +16,7 @@ import os
from dataclasses import dataclass, field
from uuid import uuid4
from .arg_utils import PluginConfig, get_plugin_config
from .arg_utils import BatchingStrategy, PluginConfig, get_plugin_config
@dataclass
@@ -29,18 +29,30 @@ class TrainingArguments:
default=1,
metadata={"help": "Micro batch size for training."},
)
global_batch_size: int = field(
default=1,
metadata={"help": "Global batch size for training."},
global_batch_size: int | None = field(
default=None,
metadata={"help": "Global batch size for training, default to DP size * micro batch size."},
)
learning_rate: float = field(
default=1e-4,
metadata={"help": "Learning rate for training."},
)
cutoff_len: int = field(
default=2048,
metadata={"help": "Maximum sequence length for training."},
)
bf16: bool = field(
default=False,
metadata={"help": "Use bf16 for training."},
)
batching_strategy: BatchingStrategy = field(
default=BatchingStrategy.NORMAL,
metadata={"help": "Batching strategy for training."},
)
batching_workers: int = field(
default=16,
metadata={"help": "Number of workers for batching."},
)
dist_config: PluginConfig | None = field(
default=None,
metadata={"help": "Distribution configuration for training."},