mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
delete file
Former-commit-id: 479d0af2dc4ab8282b9d55aba1b03ab3a54f400b
This commit is contained in:
parent
d843efc413
commit
1f2c56bff9
@ -1,21 +0,0 @@
|
|||||||
import os
|
|
||||||
import torch
|
|
||||||
from transformers.trainer import WEIGHTS_NAME
|
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
|
|
||||||
vhead_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
|
|
||||||
if not os.path.exists(vhead_file):
|
|
||||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
|
|
||||||
return False
|
|
||||||
vhead_params = torch.load(vhead_file, map_location="cpu")
|
|
||||||
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
|
|
||||||
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
|
|
||||||
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
|
|
||||||
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
|
|
||||||
return True
|
|
@ -1,13 +0,0 @@
|
|||||||
from typing import Literal, Optional
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GeneralArguments:
|
|
||||||
r"""
|
|
||||||
Arguments pertaining to which stage we are going to perform.
|
|
||||||
"""
|
|
||||||
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
|
|
||||||
default="sft",
|
|
||||||
metadata={"help": "Which stage will be performed in training."}
|
|
||||||
)
|
|
Loading…
x
Reference in New Issue
Block a user