mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
parent
bd6213331f
commit
6378864390
@ -1,78 +1,23 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import torch
|
from typing import TYPE_CHECKING
|
||||||
from typing import TYPE_CHECKING, Dict
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from transformers import TrainerCallback
|
||||||
from transformers import PreTrainedModel, TrainerCallback
|
|
||||||
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
|
||||||
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
||||||
from peft import PeftModel
|
|
||||||
|
|
||||||
from llmtuner.extras.constants import LOG_FILE_NAME, V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
|
from llmtuner.extras.constants import LOG_FILE_NAME
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
|
from llmtuner.extras.misc import fix_valuehead_checkpoint
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import TrainingArguments, TrainerState, TrainerControl
|
from transformers import TrainingArguments, TrainerState, TrainerControl
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _fix_valuehead_checkpoint(
|
|
||||||
model: "AutoModelForCausalLMWithValueHead",
|
|
||||||
output_dir: str,
|
|
||||||
safe_serialization: bool
|
|
||||||
) -> None:
|
|
||||||
r"""
|
|
||||||
The model is already unwrapped.
|
|
||||||
|
|
||||||
There are three cases:
|
|
||||||
1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
|
|
||||||
2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
|
|
||||||
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
|
|
||||||
|
|
||||||
We assume `stage3_gather_16bit_weights_on_model_save=true`.
|
|
||||||
"""
|
|
||||||
if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
|
|
||||||
return
|
|
||||||
|
|
||||||
if safe_serialization:
|
|
||||||
from safetensors import safe_open
|
|
||||||
from safetensors.torch import save_file
|
|
||||||
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
|
||||||
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
|
||||||
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
|
||||||
else:
|
|
||||||
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
|
||||||
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
|
||||||
|
|
||||||
decoder_state_dict = {}
|
|
||||||
v_head_state_dict = {}
|
|
||||||
for name, param in state_dict.items():
|
|
||||||
if name.startswith("v_head."):
|
|
||||||
v_head_state_dict[name] = param
|
|
||||||
else:
|
|
||||||
decoder_state_dict[name.replace("pretrained_model.", "")] = param
|
|
||||||
|
|
||||||
os.remove(path_to_checkpoint)
|
|
||||||
model.pretrained_model.save_pretrained(
|
|
||||||
output_dir,
|
|
||||||
state_dict=decoder_state_dict or None,
|
|
||||||
safe_serialization=safe_serialization
|
|
||||||
)
|
|
||||||
|
|
||||||
if safe_serialization:
|
|
||||||
save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
|
|
||||||
else:
|
|
||||||
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
|
||||||
|
|
||||||
logger.info("Value head model saved at: {}".format(output_dir))
|
|
||||||
|
|
||||||
|
|
||||||
class FixValueHeadModelCallback(TrainerCallback):
|
class FixValueHeadModelCallback(TrainerCallback):
|
||||||
|
|
||||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
@ -80,21 +25,12 @@ class FixValueHeadModelCallback(TrainerCallback):
|
|||||||
Event called after a checkpoint save.
|
Event called after a checkpoint save.
|
||||||
"""
|
"""
|
||||||
if args.should_save:
|
if args.should_save:
|
||||||
_fix_valuehead_checkpoint(
|
fix_valuehead_checkpoint(
|
||||||
model=kwargs.pop("model"),
|
model=kwargs.pop("model"),
|
||||||
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
|
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
|
||||||
safe_serialization=args.save_safetensors
|
safe_serialization=args.save_safetensors
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
|
||||||
r"""
|
|
||||||
Event called at the end of training.
|
|
||||||
"""
|
|
||||||
if args.should_save:
|
|
||||||
_fix_valuehead_checkpoint(
|
|
||||||
model=kwargs.pop("model"), output_dir=args.output_dir, safe_serialization=args.save_safetensors
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LogCallback(TrainerCallback):
|
class LogCallback(TrainerCallback):
|
||||||
|
|
||||||
|
@ -1,14 +1,21 @@
|
|||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from typing import TYPE_CHECKING, Tuple
|
from typing import TYPE_CHECKING, Dict, Tuple
|
||||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
SAFE_WEIGHTS_NAME,
|
||||||
is_torch_bf16_gpu_available,
|
is_torch_bf16_gpu_available,
|
||||||
is_torch_cuda_available,
|
is_torch_cuda_available,
|
||||||
is_torch_npu_available,
|
is_torch_npu_available,
|
||||||
is_torch_xpu_available
|
is_torch_xpu_available
|
||||||
)
|
)
|
||||||
|
from peft import PeftModel
|
||||||
|
|
||||||
|
from llmtuner.extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
|
||||||
|
from llmtuner.extras.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||||
try:
|
try:
|
||||||
@ -18,9 +25,13 @@ except:
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
from llmtuner.hparams import ModelArguments
|
from llmtuner.hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AverageMeter:
|
class AverageMeter:
|
||||||
r"""
|
r"""
|
||||||
Computes and stores the average and current value.
|
Computes and stores the average and current value.
|
||||||
@ -63,6 +74,57 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
|||||||
return trainable_params, all_param
|
return trainable_params, all_param
|
||||||
|
|
||||||
|
|
||||||
|
def fix_valuehead_checkpoint(
|
||||||
|
model: "AutoModelForCausalLMWithValueHead",
|
||||||
|
output_dir: str,
|
||||||
|
safe_serialization: bool
|
||||||
|
) -> None:
|
||||||
|
r"""
|
||||||
|
The model is already unwrapped.
|
||||||
|
|
||||||
|
There are three cases:
|
||||||
|
1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
|
||||||
|
2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
|
||||||
|
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
|
||||||
|
|
||||||
|
We assume `stage3_gather_16bit_weights_on_model_save=true`.
|
||||||
|
"""
|
||||||
|
if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
|
||||||
|
return
|
||||||
|
|
||||||
|
if safe_serialization:
|
||||||
|
from safetensors import safe_open
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
|
||||||
|
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
|
||||||
|
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
|
||||||
|
else:
|
||||||
|
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
|
||||||
|
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
|
||||||
|
|
||||||
|
decoder_state_dict = {}
|
||||||
|
v_head_state_dict = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name.startswith("v_head."):
|
||||||
|
v_head_state_dict[name] = param
|
||||||
|
else:
|
||||||
|
decoder_state_dict[name.replace("pretrained_model.", "")] = param
|
||||||
|
|
||||||
|
os.remove(path_to_checkpoint)
|
||||||
|
model.pretrained_model.save_pretrained(
|
||||||
|
output_dir,
|
||||||
|
state_dict=decoder_state_dict or None,
|
||||||
|
safe_serialization=safe_serialization
|
||||||
|
)
|
||||||
|
|
||||||
|
if safe_serialization:
|
||||||
|
save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
|
||||||
|
else:
|
||||||
|
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
||||||
|
|
||||||
|
logger.info("Value head model saved at: {}".format(output_dir))
|
||||||
|
|
||||||
|
|
||||||
def get_current_device() -> torch.device:
|
def get_current_device() -> torch.device:
|
||||||
r"""
|
r"""
|
||||||
Gets the current available device.
|
Gets the current available device.
|
||||||
|
@ -9,6 +9,7 @@ from transformers.optimization import get_scheduler
|
|||||||
|
|
||||||
from llmtuner.data import get_dataset, preprocess_dataset
|
from llmtuner.data import get_dataset, preprocess_dataset
|
||||||
from llmtuner.extras.callbacks import FixValueHeadModelCallback
|
from llmtuner.extras.callbacks import FixValueHeadModelCallback
|
||||||
|
from llmtuner.extras.misc import fix_valuehead_checkpoint
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.model import load_model_and_tokenizer
|
from llmtuner.model import load_model_and_tokenizer
|
||||||
from llmtuner.train.utils import create_ref_model, create_reward_model
|
from llmtuner.train.utils import create_ref_model, create_reward_model
|
||||||
@ -95,6 +96,8 @@ def run_ppo(
|
|||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
ppo_trainer.ppo_train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
ppo_trainer.save_model()
|
ppo_trainer.save_model()
|
||||||
|
if training_args.should_save:
|
||||||
|
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
|
||||||
ppo_trainer.save_state() # must be called after save_model to have a folder
|
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||||
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
||||||
|
@ -5,6 +5,7 @@ from transformers import Seq2SeqTrainingArguments
|
|||||||
|
|
||||||
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
|
||||||
from llmtuner.extras.callbacks import FixValueHeadModelCallback
|
from llmtuner.extras.callbacks import FixValueHeadModelCallback
|
||||||
|
from llmtuner.extras.misc import fix_valuehead_checkpoint
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
from llmtuner.model import load_model_and_tokenizer
|
from llmtuner.model import load_model_and_tokenizer
|
||||||
from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding
|
from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding
|
||||||
@ -49,6 +50,8 @@ def run_rm(
|
|||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
trainer.save_model()
|
trainer.save_model()
|
||||||
|
if training_args.should_save:
|
||||||
|
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
|
||||||
trainer.log_metrics("train", train_result.metrics)
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
trainer.save_state()
|
trainer.save_state()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user