delete file

Former-commit-id: 479d0af2dc4ab8282b9d55aba1b03ab3a54f400b
This commit is contained in:
hiyouga 2023-11-07 16:20:12 +08:00
parent d843efc413
commit 1f2c56bff9
2 changed files with 0 additions and 34 deletions

View File

@ -1,21 +0,0 @@
import os
import torch
from transformers.trainer import WEIGHTS_NAME
from llmtuner.extras.logging import get_logger
logger = get_logger(__name__)
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
vhead_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
if not os.path.exists(vhead_file):
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
return False
vhead_params = torch.load(vhead_file, map_location="cpu")
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
return True

View File

@ -1,13 +0,0 @@
from typing import Literal, Optional
from dataclasses import dataclass, field
@dataclass
class GeneralArguments:
r"""
Arguments pertaining to which stage we are going to perform.
"""
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
default="sft",
metadata={"help": "Which stage will be performed in training."}
)