mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 10:10:35 +08:00
[model] update kt code (#9406)
This commit is contained in:
@@ -47,6 +47,7 @@ def run_sft(
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
from ktransformers.util.globals import GLOBAL_CONFIG
|
||||
|
||||
GLOBAL_CONFIG._config["mod"] = "sft"
|
||||
|
||||
if getattr(model, "is_quantized", False) and not training_args.do_train:
|
||||
@@ -66,12 +67,13 @@ def run_sft(
|
||||
# Metric utils
|
||||
metric_module = {}
|
||||
if training_args.predict_with_generate:
|
||||
raise NotImplementedError("`predict_with_generate` is not supported in KTransformers SFT yet. if you do need it, please open an issue.")
|
||||
raise NotImplementedError("`predict_with_generate` is not supported in KTransformers SFT yet.")
|
||||
elif finetuning_args.compute_accuracy:
|
||||
raise NotImplementedError("`compute_accuracy` is not supported in KTransformers SFT yet. if you do need it, please open an issue.")
|
||||
raise NotImplementedError("`compute_accuracy` is not supported in KTransformers SFT yet.")
|
||||
|
||||
# Initialize our Trainer
|
||||
from ktransformers.sft.lora import KTrainer
|
||||
|
||||
trainer = KTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
|
||||
Reference in New Issue
Block a user