2 Commits

Author SHA1 Message Date
Philip Ottesen
0779846513 [infer] support mixed multimodal payloads (#10225)
Signed-off-by: Philip Ottesen <phiott256@gmail.com>
2026-02-28 20:26:53 +08:00
jiaqiw09
45d335c709 [v1] add seed for training and fix gradient checkpointing (#10211) 2026-02-28 18:16:06 +08:00
9 changed files with 77 additions and 56 deletions

View File

@@ -14,16 +14,12 @@ dist_config:
name: fsdp2 name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
init_config:
name: init_on_meta
### data ### data
train_dataset: data/v1_sft_demo.yaml train_dataset: data/v1_sft_demo.yaml
### training ### training
output_dir: outputs/test_fsdp2 output_dir: outputs/test_fsdp2
micro_batch_size: 1 micro_batch_size: 1
global_batch_size: 1
cutoff_len: 2048 cutoff_len: 2048
learning_rate: 1.0e-4 learning_rate: 1.0e-4
bf16: false bf16: false

View File

@@ -154,25 +154,24 @@ def vllm_infer(
batch = train_dataset[i : min(i + batch_size, len(train_dataset))] batch = train_dataset[i : min(i + batch_size, len(train_dataset))]
for j in range(len(batch["input_ids"])): for j in range(len(batch["input_ids"])):
multi_modal_data = {}
video_metadata_kwargs = None
if batch["images"][j] is not None: if batch["images"][j] is not None:
image = batch["images"][j] image = batch["images"][j]
multi_modal_data = { multi_modal_data["image"] = template_obj.mm_plugin._regularize_images(
"image": template_obj.mm_plugin._regularize_images(
image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
)["images"] )["images"]
}
elif batch["videos"][j] is not None: if batch["videos"][j] is not None:
video_metadata, video_metadata_kwargs = None, None
video = batch["videos"][j] video = batch["videos"][j]
multi_modal_data = { multi_modal_data["video"] = template_obj.mm_plugin._regularize_videos(
"video": template_obj.mm_plugin._regularize_videos(
video, video,
image_max_pixels=image_max_pixels, image_max_pixels=image_max_pixels,
image_min_pixels=image_min_pixels, image_min_pixels=image_min_pixels,
video_fps=video_fps, video_fps=video_fps,
video_maxlen=video_maxlen, video_maxlen=video_maxlen,
)["videos"] )["videos"]
}
if need_video_kwargs: if need_video_kwargs:
container = av.open(video[0], "r") container = av.open(video[0], "r")
video_stream = next(stream for stream in container.streams if stream.type == "video") video_stream = next(stream for stream in container.streams if stream.type == "video")
@@ -192,18 +191,17 @@ def vllm_infer(
video_backend="opencv", video_backend="opencv",
) )
multi_modal_data["video"] = (multi_modal_data["video"], video_metadata) multi_modal_data["video"] = (multi_modal_data["video"], video_metadata)
elif batch["audios"][j] is not None:
if batch["audios"][j] is not None:
audio = batch["audios"][j] audio = batch["audios"][j]
audio_data = template_obj.mm_plugin._regularize_audios( audio_data = template_obj.mm_plugin._regularize_audios(
audio, audio,
sampling_rate=16000, sampling_rate=16000,
) )
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])} multi_modal_data["audio"] = zip(audio_data["audios"], audio_data["sampling_rates"])
else:
multi_modal_data = None
vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data} vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data or None}
if "video_metadata_kwargs" in locals() and video_metadata_kwargs is not None: if video_metadata_kwargs is not None:
vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs
vllm_inputs.append(vllm_input_data) vllm_inputs.append(vllm_input_data)

View File

@@ -180,35 +180,32 @@ class VllmEngine(BaseEngine):
else self.generating_args["skip_special_tokens"], else self.generating_args["skip_special_tokens"],
) )
multi_modal_data = {}
if images is not None: # add image features if images is not None: # add image features
multi_modal_data = { multi_modal_data["image"] = self.template.mm_plugin._regularize_images(
"image": self.template.mm_plugin._regularize_images(
images, images,
image_max_pixels=self.model_args.image_max_pixels, image_max_pixels=self.model_args.image_max_pixels,
image_min_pixels=self.model_args.image_min_pixels, image_min_pixels=self.model_args.image_min_pixels,
)["images"] )["images"]
}
elif videos is not None: if videos is not None:
multi_modal_data = { multi_modal_data["video"] = self.template.mm_plugin._regularize_videos(
"video": self.template.mm_plugin._regularize_videos(
videos, videos,
image_max_pixels=self.model_args.video_max_pixels, image_max_pixels=self.model_args.video_max_pixels,
image_min_pixels=self.model_args.video_min_pixels, image_min_pixels=self.model_args.video_min_pixels,
video_fps=self.model_args.video_fps, video_fps=self.model_args.video_fps,
video_maxlen=self.model_args.video_maxlen, video_maxlen=self.model_args.video_maxlen,
)["videos"] )["videos"]
}
elif audios is not None: if audios is not None:
audio_data = self.template.mm_plugin._regularize_audios( audio_data = self.template.mm_plugin._regularize_audios(
audios, audios,
sampling_rate=self.model_args.audio_sampling_rate, sampling_rate=self.model_args.audio_sampling_rate,
) )
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])} multi_modal_data["audio"] = zip(audio_data["audios"], audio_data["sampling_rates"])
else:
multi_modal_data = None
result_generator = self.model.generate( result_generator = self.model.generate(
{"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data}, {"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data or None},
sampling_params=sampling_params, sampling_params=sampling_params,
request_id=request_id, request_id=request_id,
lora_request=self.lora_request, lora_request=self.lora_request,

View File

@@ -21,6 +21,7 @@ from omegaconf import OmegaConf
from transformers import HfArgumentParser from transformers import HfArgumentParser
from ..utils.env import is_env_enabled from ..utils.env import is_env_enabled
from ..utils.helper import set_seed
from .data_args import DataArguments from .data_args import DataArguments
from .model_args import ModelArguments from .model_args import ModelArguments
from .sample_args import SampleArguments from .sample_args import SampleArguments
@@ -56,6 +57,14 @@ def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments,
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
# Seed as early as possible after argument parsing so all downstream
# components (dist init, dataloader, model init in run_* entrypoints) share the same RNG state.
for arg in parsed_args:
seed = getattr(arg, "seed", None)
if seed is not None:
set_seed(seed)
break
return tuple(parsed_args) return tuple(parsed_args)

View File

@@ -66,7 +66,7 @@ class TrainingArguments:
metadata={"help": "Number of workers for batching."}, metadata={"help": "Number of workers for batching."},
) )
enable_activation_checkpointing: bool = field( enable_activation_checkpointing: bool = field(
default=True, default=False,
metadata={"help": "Enable activation checkpointing for training."}, metadata={"help": "Enable activation checkpointing for training."},
) )
dist_config: PluginConfig | None = field( dist_config: PluginConfig | None = field(
@@ -81,6 +81,10 @@ class TrainingArguments:
default=None, default=None,
metadata={"help": "Learning rate scheduler configuration for training."}, metadata={"help": "Learning rate scheduler configuration for training."},
) )
seed: int = field(
default=42,
metadata={"help": "Random seed that will be set at the beginning of training."},
)
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.dist_config = get_plugin_config(self.dist_config) self.dist_config = get_plugin_config(self.dist_config)

View File

@@ -76,7 +76,7 @@ class BaseTrainer:
if self.args.enable_activation_checkpointing: if self.args.enable_activation_checkpointing:
self.model.gradient_checkpointing_enable({"use_reentrant": False}) self.model.gradient_checkpointing_enable({"use_reentrant": False})
self._accelerate_engine = None self._deepspeed_engine = None
dist_name = self.args.dist_config.name if self.args.dist_config is not None else None dist_name = self.args.dist_config.name if self.args.dist_config is not None else None
if dist_name == "deepspeed": if dist_name == "deepspeed":
@@ -108,6 +108,7 @@ class BaseTrainer:
cutoff_len=self.args.cutoff_len, cutoff_len=self.args.cutoff_len,
batching_workers=self.args.batching_workers, batching_workers=self.args.batching_workers,
batching_strategy=self.args.batching_strategy, batching_strategy=self.args.batching_strategy,
seed=self.args.seed,
) )
def _shard_model(self) -> None: def _shard_model(self) -> None:

View File

@@ -26,6 +26,7 @@
from collections.abc import Iterator from collections.abc import Iterator
from typing import Any from typing import Any
import torch
from torch.utils.data import default_collate from torch.utils.data import default_collate
from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
@@ -71,6 +72,7 @@ class BatchGenerator(Iterator):
batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL, batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL,
pin_memory: bool = True, pin_memory: bool = True,
drop_last: bool = True, drop_last: bool = True,
seed: int = 42,
) -> None: ) -> None:
self.dataset = dataset self.dataset = dataset
self.renderer = renderer self.renderer = renderer
@@ -82,6 +84,7 @@ class BatchGenerator(Iterator):
self.batching_strategy = batching_strategy self.batching_strategy = batching_strategy
self.pin_memory = pin_memory self.pin_memory = pin_memory
self.drop_last = drop_last self.drop_last = drop_last
self.seed = seed
# TODO: support length and infinity # TODO: support length and infinity
dp_size = DistributedInterface().get_world_size(Dim.DP) dp_size = DistributedInterface().get_world_size(Dim.DP)
@@ -128,12 +131,15 @@ class BatchGenerator(Iterator):
num_replicas=DistributedInterface().get_world_size(Dim.DP), num_replicas=DistributedInterface().get_world_size(Dim.DP),
rank=DistributedInterface().get_rank(Dim.DP), rank=DistributedInterface().get_rank(Dim.DP),
shuffle=True, shuffle=True,
seed=0, seed=self.seed,
drop_last=self.drop_last, drop_last=self.drop_last,
) )
else: else:
raise NotImplementedError("Iterable dataset is not supported yet.") raise NotImplementedError("Iterable dataset is not supported yet.")
generato_seed = torch.Generator()
generato_seed.manual_seed(self.seed)
self._data_provider = StatefulDataLoader( self._data_provider = StatefulDataLoader(
self.dataset, self.dataset,
batch_size=self.micro_batch_size * self.num_micro_batch, batch_size=self.micro_batch_size * self.num_micro_batch,
@@ -143,6 +149,7 @@ class BatchGenerator(Iterator):
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
pin_memory_device=DistributedInterface().current_device.type, pin_memory_device=DistributedInterface().current_device.type,
drop_last=self.drop_last, drop_last=self.drop_last,
generator=generato_seed,
) )
if self.batching_strategy == BatchingStrategy.NORMAL: if self.batching_strategy == BatchingStrategy.NORMAL:
self._length = len(self._data_provider) self._length = len(self._data_provider)

View File

@@ -166,12 +166,11 @@ class FSDP2Engine:
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None, offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
) )
use_gradient_checkpointing = True # Could be configurable # BaseTrainer is the single source of truth for gradient checkpointing.
if use_gradient_checkpointing: # FSDP2 only applies the input-grad compatibility hook when checkpointing is already enabled.
if getattr(model, "is_gradient_checkpointing", False):
if self.rank == 0: if self.rank == 0:
logger.info("Enabling gradient checkpointing (transformers native)...") logger.info("Gradient checkpointing is enabled. Applying FSDP2 input grad preparation.")
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
if hasattr(model, "enable_input_require_grads"): if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads() model.enable_input_require_grads()

View File

@@ -15,12 +15,22 @@
import torch import torch
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from transformers import set_seed as hf_set_seed
from ..accelerator.interface import DistributedInterface from ..accelerator.interface import DistributedInterface
from .constants import IGNORE_INDEX from .constants import IGNORE_INDEX
from .types import BatchInput, ModelInput, Processor, Tensor from .types import BatchInput, ModelInput, Processor, Tensor
def set_seed(seed: int) -> None:
"""Set seed for reproducibility.
Args:
seed: Random seed.
"""
hf_set_seed(seed)
def is_tokenizer(processor: Processor) -> bool: def is_tokenizer(processor: Processor) -> bool:
"""Check if processor is tokenizer. """Check if processor is tokenizer.