mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-06 05:32:50 +08:00
parent
ba6d290d0b
commit
3d291a82d3
@ -6,7 +6,9 @@ from tqdm import tqdm
|
|||||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
|
|
||||||
from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
|
from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||||
|
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||||
|
from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
||||||
|
|
||||||
from trl import PPOTrainer
|
from trl import PPOTrainer
|
||||||
from trl.core import PPODecorators, logprobs_from_logits
|
from trl.core import PPODecorators, logprobs_from_logits
|
||||||
@ -55,6 +57,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
|
|
||||||
self.state = TrainerState()
|
self.state = TrainerState()
|
||||||
self.control = TrainerControl()
|
self.control = TrainerControl()
|
||||||
|
self.is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
|
||||||
|
self.accelerator.state, "deepspeed_plugin"
|
||||||
|
)
|
||||||
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
||||||
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
|
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
|
||||||
|
|
||||||
@ -62,10 +67,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||||
|
|
||||||
if reward_model is not None:
|
if reward_model is not None:
|
||||||
is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
|
if self.is_deepspeed_enabled:
|
||||||
self.accelerator.state, "deepspeed_plugin"
|
|
||||||
)
|
|
||||||
if is_deepspeed_enabled:
|
|
||||||
if not (
|
if not (
|
||||||
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
|
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
|
||||||
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
|
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
|
||||||
@ -345,4 +347,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
Subclass and override to inject custom behavior.
|
Subclass and override to inject custom behavior.
|
||||||
"""
|
"""
|
||||||
if self.args.should_save:
|
if self.args.should_save:
|
||||||
self._save(output_dir)
|
try:
|
||||||
|
self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model))
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(
|
||||||
|
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
|
||||||
|
" zero_to_fp32.py to recover weights"
|
||||||
|
)
|
||||||
|
self._save(output_dir, state_dict={})
|
||||||
|
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
|
||||||
|
self.model.save_checkpoint(output_dir) # wrapped model
|
||||||
|
Loading…
x
Reference in New Issue
Block a user