mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-29 03:18:56 +08:00
[v1] support reward training stage (#10431)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -134,6 +134,9 @@ class BaseTrainer:
|
||||
global_step=self.global_step,
|
||||
epoch=self._resume_epoch,
|
||||
)
|
||||
# Keep callback state aligned with checkpoint-resumed trainer counters.
|
||||
self.state.global_step = self.global_step
|
||||
self.state.epoch = self._resume_epoch
|
||||
|
||||
if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1:
|
||||
# qwen3.5 is not supported because of the different attention implementation, which will be supported in the future.
|
||||
@@ -303,7 +306,7 @@ class BaseTrainer:
|
||||
if self.global_step % self.args.logging_steps == 0:
|
||||
logs = {
|
||||
"epoch": epoch,
|
||||
"step": self.global_step,
|
||||
"step": self.state.global_step,
|
||||
"loss": step_loss,
|
||||
"grad_norm": grad_norm,
|
||||
"learning_rate": current_lr,
|
||||
@@ -335,7 +338,9 @@ class BaseTrainer:
|
||||
)
|
||||
else:
|
||||
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
|
||||
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
||||
model_to_save.save_pretrained(
|
||||
self.args.output_dir, state_dict=model_to_save.state_dict(), max_shard_size="4GB"
|
||||
)
|
||||
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
||||
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
||||
|
||||
|
||||
@@ -143,6 +143,12 @@ class ModelEngine:
|
||||
elif self.args.model_class == ModelClass.CLS:
|
||||
from transformers import AutoModelForTokenClassification
|
||||
|
||||
self.model_config.num_labels = 1
|
||||
self.model_config.classifier_dropout = 0.0
|
||||
text_config = getattr(self.model_config, "text_config", None)
|
||||
if text_config is not None:
|
||||
text_config.num_labels = 1
|
||||
text_config.classifier_dropout = 0.0
|
||||
AutoClass = AutoModelForTokenClassification
|
||||
else:
|
||||
from transformers import AutoModel
|
||||
|
||||
@@ -137,8 +137,8 @@ class BatchGenerator(Iterator):
|
||||
else:
|
||||
raise NotImplementedError("Iterable dataset is not supported yet.")
|
||||
|
||||
generato_seed = torch.Generator()
|
||||
generato_seed.manual_seed(self.seed)
|
||||
generator_seed = torch.Generator()
|
||||
generator_seed.manual_seed(self.seed)
|
||||
|
||||
self._data_provider = StatefulDataLoader(
|
||||
self.dataset,
|
||||
@@ -149,7 +149,7 @@ class BatchGenerator(Iterator):
|
||||
pin_memory=self.pin_memory,
|
||||
pin_memory_device=DistributedInterface().current_device.type,
|
||||
drop_last=self.drop_last,
|
||||
generator=generato_seed,
|
||||
generator=generator_seed,
|
||||
)
|
||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||
self._length = len(self._data_provider)
|
||||
|
||||
@@ -172,7 +172,7 @@ def _save_standard_training_states(
|
||||
if rank == 0:
|
||||
model_to_save = model.module if hasattr(model, "module") else model
|
||||
model_dir = os.path.join(ckpt_dir, "model")
|
||||
model_to_save.save_pretrained(model_dir, max_shard_size="4GB")
|
||||
model_to_save.save_pretrained(model_dir, state_dict=model_to_save.state_dict(), max_shard_size="4GB")
|
||||
processor.save_pretrained(model_dir)
|
||||
|
||||
os.makedirs(os.path.join(ckpt_dir, "optimizer"), exist_ok=True)
|
||||
@@ -212,7 +212,11 @@ def _load_standard_training_states(
|
||||
for f in sorted(glob.glob(os.path.join(model_dir, "*.bin"))):
|
||||
state_dict.update(torch.load(f, map_location="cpu", weights_only=True))
|
||||
if state_dict:
|
||||
model_to_load.load_state_dict(state_dict)
|
||||
incompatible_keys = model_to_load.load_state_dict(state_dict, strict=False)
|
||||
if incompatible_keys.missing_keys:
|
||||
raise RuntimeError(
|
||||
f"Unexpected missing keys when loading checkpoint model weights: {incompatible_keys.missing_keys}."
|
||||
)
|
||||
else:
|
||||
logger.warning_rank0(f"No model weights found in {model_dir}, skipping model state restore.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user