Update workflow.py

Former-commit-id: e16f128dc397d71e74398fa38818e57e80cc32f4
This commit is contained in:
hoshi-hiyouga 2024-04-26 03:29:12 +08:00 committed by GitHub
parent 9c69cc1e16
commit 268c0efd67

View File

@ -28,11 +28,10 @@ def run_sft(
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None,
):
tokenizer_modules = load_tokenizer(model_args)
tokenizer = tokenizer_modules["tokenizer"]
processor = tokenizer_modules["processor"]
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft", processor=processor)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
tokenizer_module = load_tokenizer(model_args)
dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
tokenizer = tokenizer_module["tokenizer"]
model = load_model(tokenizer, model_args, finetuning_args, is_trainable=training_args.do_train)
if training_args.predict_with_generate:
tokenizer.padding_side = "left" # use left-padding in generation
@ -49,8 +48,7 @@ def run_sft(
# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
if model_args.use_mllm:
training_args.remove_unused_columns = False
training_args.remove_unused_columns = False if model_args.visual_inputs else training_args.remove_unused_columns
# Initialize our Trainer
trainer = CustomSeq2SeqTrainer(