Update workflow.py

Former-commit-id: e16f128dc397d71e74398fa38818e57e80cc32f4
This commit is contained in:
hoshi-hiyouga 2024-04-26 03:29:12 +08:00 committed by GitHub
parent 9c69cc1e16
commit 268c0efd67

View File

@ -28,11 +28,10 @@ def run_sft(
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None, callbacks: Optional[List["TrainerCallback"]] = None,
): ):
tokenizer_modules = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_modules["tokenizer"] dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
processor = tokenizer_modules["processor"] tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft", processor=processor) model = load_model(tokenizer, model_args, finetuning_args, is_trainable=training_args.do_train)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
if training_args.predict_with_generate: if training_args.predict_with_generate:
tokenizer.padding_side = "left" # use left-padding in generation tokenizer.padding_side = "left" # use left-padding in generation
@ -49,8 +48,7 @@ def run_sft(
# Override the decoding parameters of Seq2SeqTrainer # Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
if model_args.use_mllm: training_args.remove_unused_columns = False if model_args.visual_inputs else training_args.remove_unused_columns
training_args.remove_unused_columns = False
# Initialize our Trainer # Initialize our Trainer
trainer = CustomSeq2SeqTrainer( trainer = CustomSeq2SeqTrainer(