[kt] refactor ktransformers integration (#9632)

This commit is contained in:
mrhaoxx
2025-12-18 21:26:04 +08:00
committed by GitHub
parent 9fd4b094d4
commit 964569751f
4 changed files with 37 additions and 149 deletions

View File

@@ -68,6 +68,12 @@ def run_sft(
# Metric utils
metric_module = {}
if model_args.use_kt:
if training_args.predict_with_generate:
raise NotImplementedError("`predict_with_generate` is not supported in KTransformers SFT yet.")
elif finetuning_args.compute_accuracy:
raise NotImplementedError("`compute_accuracy` is not supported in KTransformers SFT yet.")
if training_args.predict_with_generate:
metric_module["compute_metrics"] = ComputeSimilarity(tokenizer=tokenizer)
elif finetuning_args.compute_accuracy:
@@ -92,17 +98,36 @@ def run_sft(
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
# Initialize our Trainer
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
gen_kwargs=gen_kwargs,
**dataset_module,
**tokenizer_module,
**metric_module,
)
if model_args.use_kt:
from ktransformers.util.globals import GLOBAL_CONFIG
from ktransformers.sft.lora import KTrainer
GLOBAL_CONFIG._config["mod"] = "sft"
trainer = KTrainer(
model=model,
args=training_args,
tokenizer=tokenizer_module,
data_collator=data_collator,
callbacks=callbacks,
**dataset_module,
**metric_module,
)
trainer.model_accepts_loss_kwargs = False
model.config.use_cache = False
else:
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
gen_kwargs=gen_kwargs,
**dataset_module,
**tokenizer_module,
**metric_module,
)
# Training
if training_args.do_train: