[v1] support reward training stage (#10431)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
浮梦
2026-05-20 20:46:52 +08:00
committed by GitHub
parent 40e786d016
commit 8b5ea65770
9 changed files with 217 additions and 17 deletions

View File

@@ -134,6 +134,9 @@ class BaseTrainer:
global_step=self.global_step,
epoch=self._resume_epoch,
)
# Keep callback state aligned with checkpoint-resumed trainer counters.
self.state.global_step = self.global_step
self.state.epoch = self._resume_epoch
if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1:
# qwen3.5 is not supported because of the different attention implementation, which will be supported in the future.
@@ -303,7 +306,7 @@ class BaseTrainer:
if self.global_step % self.args.logging_steps == 0:
logs = {
"epoch": epoch,
"step": self.global_step,
"step": self.state.global_step,
"loss": step_loss,
"grad_norm": grad_norm,
"learning_rate": current_lr,
@@ -335,7 +338,9 @@ class BaseTrainer:
)
else:
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
model_to_save.save_pretrained(
self.args.output_dir, state_dict=model_to_save.state_dict(), max_shard_size="4GB"
)
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
logger.info_rank0(f"Model saved to {self.args.output_dir}")

View File

@@ -143,6 +143,12 @@ class ModelEngine:
elif self.args.model_class == ModelClass.CLS:
from transformers import AutoModelForTokenClassification
self.model_config.num_labels = 1
self.model_config.classifier_dropout = 0.0
text_config = getattr(self.model_config, "text_config", None)
if text_config is not None:
text_config.num_labels = 1
text_config.classifier_dropout = 0.0
AutoClass = AutoModelForTokenClassification
else:
from transformers import AutoModel

View File

@@ -137,8 +137,8 @@ class BatchGenerator(Iterator):
else:
raise NotImplementedError("Iterable dataset is not supported yet.")
generato_seed = torch.Generator()
generato_seed.manual_seed(self.seed)
generator_seed = torch.Generator()
generator_seed.manual_seed(self.seed)
self._data_provider = StatefulDataLoader(
self.dataset,
@@ -149,7 +149,7 @@ class BatchGenerator(Iterator):
pin_memory=self.pin_memory,
pin_memory_device=DistributedInterface().current_device.type,
drop_last=self.drop_last,
generator=generato_seed,
generator=generator_seed,
)
if self.batching_strategy == BatchingStrategy.NORMAL:
self._length = len(self._data_provider)

View File

@@ -172,7 +172,7 @@ def _save_standard_training_states(
if rank == 0:
model_to_save = model.module if hasattr(model, "module") else model
model_dir = os.path.join(ckpt_dir, "model")
model_to_save.save_pretrained(model_dir, max_shard_size="4GB")
model_to_save.save_pretrained(model_dir, state_dict=model_to_save.state_dict(), max_shard_size="4GB")
processor.save_pretrained(model_dir)
os.makedirs(os.path.join(ckpt_dir, "optimizer"), exist_ok=True)
@@ -212,7 +212,11 @@ def _load_standard_training_states(
for f in sorted(glob.glob(os.path.join(model_dir, "*.bin"))):
state_dict.update(torch.load(f, map_location="cpu", weights_only=True))
if state_dict:
model_to_load.load_state_dict(state_dict)
incompatible_keys = model_to_load.load_state_dict(state_dict, strict=False)
if incompatible_keys.missing_keys:
raise RuntimeError(
f"Unexpected missing keys when loading checkpoint model weights: {incompatible_keys.missing_keys}."
)
else:
logger.warning_rank0(f"No model weights found in {model_dir}, skipping model state restore.")