[refactor] Add KTransformers AMX MoE SFT support via Accelerate (#10430)

Co-authored-by: mrhaoxx <mr.haoxx@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Peilin Li
2026-05-01 01:47:58 +08:00
committed by GitHub
parent 6b08b948c9
commit 887ee2b121
39 changed files with 287 additions and 1968 deletions

View File

@@ -103,37 +103,18 @@ def run_sft(
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
# Initialize our Trainer
if model_args.use_kt:
from ktransformers.sft.lora import KTrainer # type: ignore
from ktransformers.util.globals import GLOBAL_CONFIG # type: ignore
GLOBAL_CONFIG._config["mod"] = "sft"
trainer = KTrainer(
model=model,
args=training_args,
tokenizer=tokenizer_module,
data_collator=data_collator,
callbacks=callbacks,
**dataset_module,
**metric_module,
)
trainer.model_accepts_loss_kwargs = False
model.config.use_cache = False
else:
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
gen_kwargs=gen_kwargs,
ref_model=ref_model,
**dataset_module,
**tokenizer_module,
**metric_module,
)
trainer = CustomSeq2SeqTrainer(
model=model,
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
gen_kwargs=gen_kwargs,
ref_model=ref_model,
**dataset_module,
**tokenizer_module,
**metric_module,
)
# Training
if training_args.do_train: