mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-23 23:30:36 +08:00
[kt] refactor ktransformers integration (#9632)
This commit is contained in:
@@ -68,6 +68,12 @@ def run_sft(
|
||||
|
||||
# Metric utils
|
||||
metric_module = {}
|
||||
if model_args.use_kt:
|
||||
if training_args.predict_with_generate:
|
||||
raise NotImplementedError("`predict_with_generate` is not supported in KTransformers SFT yet.")
|
||||
elif finetuning_args.compute_accuracy:
|
||||
raise NotImplementedError("`compute_accuracy` is not supported in KTransformers SFT yet.")
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
metric_module["compute_metrics"] = ComputeSimilarity(tokenizer=tokenizer)
|
||||
elif finetuning_args.compute_accuracy:
|
||||
@@ -92,17 +98,36 @@ def run_sft(
|
||||
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = CustomSeq2SeqTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
gen_kwargs=gen_kwargs,
|
||||
**dataset_module,
|
||||
**tokenizer_module,
|
||||
**metric_module,
|
||||
)
|
||||
if model_args.use_kt:
|
||||
from ktransformers.util.globals import GLOBAL_CONFIG
|
||||
from ktransformers.sft.lora import KTrainer
|
||||
|
||||
GLOBAL_CONFIG._config["mod"] = "sft"
|
||||
|
||||
trainer = KTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer_module,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**dataset_module,
|
||||
**metric_module,
|
||||
)
|
||||
trainer.model_accepts_loss_kwargs = False
|
||||
model.config.use_cache = False
|
||||
|
||||
else:
|
||||
trainer = CustomSeq2SeqTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
gen_kwargs=gen_kwargs,
|
||||
**dataset_module,
|
||||
**tokenizer_module,
|
||||
**metric_module,
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
|
||||
Reference in New Issue
Block a user