mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
fix callback
Former-commit-id: 22d9a9c2af6674eb832ae4aee80d679f19b7006f
This commit is contained in:
parent
a696148d6b
commit
70b5232f9a
@ -46,7 +46,7 @@ class LogCallback(TrainerCallback):
|
|||||||
r"""
|
r"""
|
||||||
Event called after logging the last logs.
|
Event called after logging the last logs.
|
||||||
"""
|
"""
|
||||||
if "current_steps" not in state.log_history[-1]:
|
if "step" not in state.log_history[-1]:
|
||||||
return
|
return
|
||||||
cur_time = time.time()
|
cur_time = time.time()
|
||||||
cur_steps = state.log_history[-1].get("step")
|
cur_steps = state.log_history[-1].get("step")
|
||||||
|
@ -10,7 +10,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.utils import check_min_version
|
from transformers.utils import check_min_version
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
@ -113,11 +113,11 @@ def load_model_and_tokenizer(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Register auto class to save the custom code files.
|
# Register auto class to save the custom code files.
|
||||||
if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map:
|
if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map and isinstance(config, PretrainedConfig):
|
||||||
config.__class__.register_for_auto_class()
|
config.__class__.register_for_auto_class()
|
||||||
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
|
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map and isinstance(tokenizer, PreTrainedTokenizer):
|
||||||
tokenizer.__class__.register_for_auto_class()
|
tokenizer.__class__.register_for_auto_class()
|
||||||
if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map:
|
if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map and isinstance(model, PreTrainedModel):
|
||||||
model.__class__.register_for_auto_class()
|
model.__class__.register_for_auto_class()
|
||||||
|
|
||||||
# Initialize adapters
|
# Initialize adapters
|
||||||
|
@ -23,6 +23,9 @@ class PeftTrainer(Seq2SeqTrainer):
|
|||||||
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
|
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.finetuning_args = finetuning_args
|
self.finetuning_args = finetuning_args
|
||||||
|
self._remove_log()
|
||||||
|
|
||||||
|
def _remove_log(self):
|
||||||
if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
|
if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
|
||||||
logger.warning("Previous log file in this folder will be deleted.")
|
logger.warning("Previous log file in this folder will be deleted.")
|
||||||
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
|
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
|
||||||
|
@ -40,6 +40,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
self.state = TrainerState()
|
self.state = TrainerState()
|
||||||
self.control = TrainerControl()
|
self.control = TrainerControl()
|
||||||
self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer
|
self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer
|
||||||
|
self._remove_log()
|
||||||
|
|
||||||
def ppo_train(self, max_target_length: int) -> None:
|
def ppo_train(self, max_target_length: int) -> None:
|
||||||
r"""
|
r"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user