mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 22:02:51 +08:00
parent
a225b5a70c
commit
ad0304e147
@ -12,6 +12,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -19,6 +21,7 @@ from transformers import PreTrainedModel
|
|||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
from ..extras.callbacks import LogCallback
|
from ..extras.callbacks import LogCallback
|
||||||
|
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..hparams import get_infer_args, get_train_args
|
from ..hparams import get_infer_args, get_train_args
|
||||||
from ..model import load_model, load_tokenizer
|
from ..model import load_model, load_tokenizer
|
||||||
@ -98,6 +101,25 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
|||||||
safe_serialization=(not model_args.export_legacy_format),
|
safe_serialization=(not model_args.export_legacy_format),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if finetuning_args.stage == "rm":
|
||||||
|
if model_args.adapter_name_or_path is not None:
|
||||||
|
vhead_path = model_args.adapter_name_or_path[-1]
|
||||||
|
else:
|
||||||
|
vhead_path = model_args.model_name_or_path
|
||||||
|
|
||||||
|
if os.path.exists(os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME)):
|
||||||
|
shutil.copy(
|
||||||
|
os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME),
|
||||||
|
os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME),
|
||||||
|
)
|
||||||
|
logger.info("Copied valuehead to {}.".format(model_args.export_dir))
|
||||||
|
elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)):
|
||||||
|
shutil.copy(
|
||||||
|
os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME),
|
||||||
|
os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME),
|
||||||
|
)
|
||||||
|
logger.info("Copied valuehead to {}.".format(model_args.export_dir))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tokenizer.padding_side = "left" # restore padding side
|
tokenizer.padding_side = "left" # restore padding side
|
||||||
tokenizer.init_kwargs["padding_side"] = "left"
|
tokenizer.init_kwargs["padding_side"] = "left"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user