mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-07 03:12:13 +08:00
114 lines
4.5 KiB
Python
114 lines
4.5 KiB
Python
# Copyright 2025 the KVCache.AI team, Approaching AI, and the LlamaFactory team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from typing import TYPE_CHECKING, Optional
|
|
|
|
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
|
|
from ...extras.constants import IGNORE_INDEX
|
|
from ...extras.logging import get_logger
|
|
from ...extras.misc import calculate_tps
|
|
from ...extras.ploting import plot_loss
|
|
from ...model import load_model, load_tokenizer
|
|
from ..trainer_utils import create_modelcard_and_push
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
|
|
|
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def run_sft(
|
|
model_args: "ModelArguments",
|
|
data_args: "DataArguments",
|
|
training_args: "Seq2SeqTrainingArguments",
|
|
finetuning_args: "FinetuningArguments",
|
|
generating_args: "GeneratingArguments",
|
|
callbacks: Optional[list["TrainerCallback"]] = None,
|
|
):
|
|
tokenizer_module = load_tokenizer(model_args)
|
|
tokenizer = tokenizer_module["tokenizer"]
|
|
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
|
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
|
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
|
|
|
from ktransformers.util.globals import GLOBAL_CONFIG
|
|
|
|
GLOBAL_CONFIG._config["mod"] = "sft"
|
|
|
|
if getattr(model, "is_quantized", False) and not training_args.do_train:
|
|
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
|
|
|
|
data_collator = SFTDataCollatorWith4DAttentionMask(
|
|
template=template,
|
|
model=model if not training_args.predict_with_generate else None,
|
|
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
|
|
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
|
block_diag_attn=model_args.block_diag_attn,
|
|
attn_implementation=getattr(model.config, "_attn_implementation", None),
|
|
compute_dtype=model_args.compute_dtype,
|
|
**tokenizer_module,
|
|
)
|
|
|
|
# Metric utils
|
|
metric_module = {}
|
|
if training_args.predict_with_generate:
|
|
raise NotImplementedError("`predict_with_generate` is not supported in KTransformers SFT yet.")
|
|
elif finetuning_args.compute_accuracy:
|
|
raise NotImplementedError("`compute_accuracy` is not supported in KTransformers SFT yet.")
|
|
|
|
# Initialize our Trainer
|
|
from ktransformers.sft.lora import KTrainer
|
|
|
|
trainer = KTrainer(
|
|
model=model,
|
|
args=training_args,
|
|
tokenizer=tokenizer_module,
|
|
data_collator=data_collator,
|
|
callbacks=callbacks,
|
|
**dataset_module,
|
|
**metric_module,
|
|
)
|
|
trainer.model_accepts_loss_kwargs = False
|
|
|
|
# Training
|
|
if training_args.do_train:
|
|
model.config.use_cache = False
|
|
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
|
trainer.save_model()
|
|
if finetuning_args.include_effective_tokens_per_second:
|
|
train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
|
|
dataset_module["train_dataset"], train_result.metrics, stage="sft"
|
|
)
|
|
|
|
trainer.log_metrics("train", train_result.metrics)
|
|
trainer.save_metrics("train", train_result.metrics)
|
|
trainer.save_state()
|
|
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
|
keys = ["loss"]
|
|
if isinstance(dataset_module.get("eval_dataset"), dict):
|
|
keys += sum(
|
|
[[f"eval_{key}_loss", f"eval_{key}_accuracy"] for key in dataset_module["eval_dataset"].keys()], []
|
|
)
|
|
else:
|
|
keys += ["eval_loss", "eval_accuracy"]
|
|
|
|
plot_loss(training_args.output_dir, keys=keys)
|
|
|
|
# Create model card
|
|
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|