# Copyright 2024 the LlamaFactory team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import shutil from typing import TYPE_CHECKING, Any, Dict, List, Optional import torch from transformers import PreTrainedModel from ..data import get_template_and_fix_tokenizer from ..extras import logging from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..hparams import get_infer_args, get_train_args from ..model import load_model, load_tokenizer from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback from .dpo import run_dpo from .kto import run_kto from .ppo import run_ppo from .pt import run_pt from .rm import run_rm from .sft import run_sft from .trainer_utils import get_swanlab_callback if TYPE_CHECKING: from transformers import TrainerCallback logger = logging.get_logger(__name__) def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: callbacks.append(LogCallback()) model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) if finetuning_args.pissa_convert: callbacks.append(PissaConvertCallback()) if finetuning_args.use_swanlab: callbacks.append(get_swanlab_callback(finetuning_args)) callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last if finetuning_args.stage == "pt": run_pt(model_args, data_args, training_args, finetuning_args, callbacks) elif finetuning_args.stage == "sft": run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) elif finetuning_args.stage == "rm": run_rm(model_args, data_args, training_args, finetuning_args, callbacks) elif finetuning_args.stage == "ppo": run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) elif finetuning_args.stage == "dpo": run_dpo(model_args, data_args, training_args, finetuning_args, callbacks) elif finetuning_args.stage == "kto": run_kto(model_args, data_args, training_args, finetuning_args, callbacks) else: raise ValueError(f"Unknown task: {finetuning_args.stage}.") def export_model(args: Optional[Dict[str, Any]] = None) -> None: model_args, data_args, finetuning_args, _ = get_infer_args(args) if model_args.export_dir is None: raise ValueError("Please specify `export_dir` to save model.") if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None: raise ValueError("Please merge adapters before quantizing the model.") tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] processor = tokenizer_module["processor"] get_template_and_fix_tokenizer(tokenizer, data_args) model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab if getattr(model, "quantization_method", None) is not None and model_args.adapter_name_or_path is not None: raise ValueError("Cannot merge adapters to a quantized model.") if not isinstance(model, PreTrainedModel): raise ValueError("The model is not a `PreTrainedModel`, export aborted.") if getattr(model, "quantization_method", None) is not None: # quantized model adopts float16 type setattr(model.config, "torch_dtype", torch.float16) else: if model_args.infer_dtype == "auto": output_dtype = getattr(model.config, "torch_dtype", torch.float16) else: output_dtype = getattr(torch, model_args.infer_dtype) setattr(model.config, "torch_dtype", output_dtype) model = model.to(output_dtype) logger.info_rank0(f"Convert model dtype to: {output_dtype}.") model.save_pretrained( save_directory=model_args.export_dir, max_shard_size=f"{model_args.export_size}GB", safe_serialization=(not model_args.export_legacy_format), ) if model_args.export_hub_model_id is not None: model.push_to_hub( model_args.export_hub_model_id, token=model_args.hf_hub_token, max_shard_size=f"{model_args.export_size}GB", safe_serialization=(not model_args.export_legacy_format), ) if finetuning_args.stage == "rm": if model_args.adapter_name_or_path is not None: vhead_path = model_args.adapter_name_or_path[-1] else: vhead_path = model_args.model_name_or_path if os.path.exists(os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME)): shutil.copy( os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME), os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME), ) logger.info_rank0(f"Copied valuehead to {model_args.export_dir}.") elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)): shutil.copy( os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME), os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME), ) logger.info_rank0(f"Copied valuehead to {model_args.export_dir}.") try: tokenizer.padding_side = "left" # restore padding side tokenizer.init_kwargs["padding_side"] = "left" tokenizer.save_pretrained(model_args.export_dir) if model_args.export_hub_model_id is not None: tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token) if processor is not None: processor.save_pretrained(model_args.export_dir) if model_args.export_hub_model_id is not None: processor.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token) except Exception as e: logger.warning_rank0(f"Cannot save tokenizer, please copy the files manually: {e}.")