diff --git a/src/llmtuner/train/sft/workflow.py b/src/llmtuner/train/sft/workflow.py index c5acb4bc..3ead9edf 100644 --- a/src/llmtuner/train/sft/workflow.py +++ b/src/llmtuner/train/sft/workflow.py @@ -28,11 +28,10 @@ def run_sft( generating_args: "GeneratingArguments", callbacks: Optional[List["TrainerCallback"]] = None, ): - tokenizer_modules = load_tokenizer(model_args) - tokenizer = tokenizer_modules["tokenizer"] - processor = tokenizer_modules["processor"] - dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft", processor=processor) - model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) + tokenizer_module = load_tokenizer(model_args) + dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) + tokenizer = tokenizer_module["tokenizer"] + model = load_model(tokenizer, model_args, finetuning_args, is_trainable=training_args.do_train) if training_args.predict_with_generate: tokenizer.padding_side = "left" # use left-padding in generation @@ -49,8 +48,7 @@ def run_sft( # Override the decoding parameters of Seq2SeqTrainer training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams - if model_args.use_mllm: - training_args.remove_unused_columns = False + training_args.remove_unused_columns = False if model_args.visual_inputs else training_args.remove_unused_columns # Initialize our Trainer trainer = CustomSeq2SeqTrainer(