mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-04 02:35:59 +08:00
[v1] add batch generator (#9744)
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .arg_parser import InputArgument, get_args
|
||||
from .arg_utils import ModelClass, SampleBackend
|
||||
from .arg_utils import BatchingStrategy, ModelClass, SampleBackend
|
||||
from .data_args import DataArguments
|
||||
from .model_args import ModelArguments
|
||||
from .sample_args import SampleArguments
|
||||
@@ -21,6 +21,7 @@ from .training_args import TrainingArguments
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BatchingStrategy",
|
||||
"DataArguments",
|
||||
"InputArgument",
|
||||
"ModelArguments",
|
||||
|
||||
@@ -50,6 +50,14 @@ class SampleBackend(StrEnum):
|
||||
VLLM = "vllm"
|
||||
|
||||
|
||||
@unique
|
||||
class BatchingStrategy(StrEnum):
|
||||
NORMAL = "normal"
|
||||
PADDING_FREE = "padding_free"
|
||||
DYNAMIC_BATCHING = "dynamic_batching"
|
||||
DYNAMIC_PADDING_FREE = "dynamic_padding_free"
|
||||
|
||||
|
||||
def _convert_str_dict(data: dict) -> dict:
|
||||
"""Parse string representation inside the dictionary.
|
||||
|
||||
|
||||
@@ -22,7 +22,3 @@ class DataArguments:
|
||||
default=None,
|
||||
metadata={"help": "Path to the dataset."},
|
||||
)
|
||||
cutoff_len: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "Cutoff length for the dataset."},
|
||||
)
|
||||
|
||||
@@ -16,7 +16,7 @@ import os
|
||||
from dataclasses import dataclass, field
|
||||
from uuid import uuid4
|
||||
|
||||
from .arg_utils import PluginConfig, get_plugin_config
|
||||
from .arg_utils import BatchingStrategy, PluginConfig, get_plugin_config
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -29,18 +29,30 @@ class TrainingArguments:
|
||||
default=1,
|
||||
metadata={"help": "Micro batch size for training."},
|
||||
)
|
||||
global_batch_size: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Global batch size for training."},
|
||||
global_batch_size: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Global batch size for training, default to DP size * micro batch size."},
|
||||
)
|
||||
learning_rate: float = field(
|
||||
default=1e-4,
|
||||
metadata={"help": "Learning rate for training."},
|
||||
)
|
||||
cutoff_len: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "Maximum sequence length for training."},
|
||||
)
|
||||
bf16: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use bf16 for training."},
|
||||
)
|
||||
batching_strategy: BatchingStrategy = field(
|
||||
default=BatchingStrategy.NORMAL,
|
||||
metadata={"help": "Batching strategy for training."},
|
||||
)
|
||||
batching_workers: int = field(
|
||||
default=16,
|
||||
metadata={"help": "Number of workers for batching."},
|
||||
)
|
||||
dist_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Distribution configuration for training."},
|
||||
|
||||
Reference in New Issue
Block a user