mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 22:02:51 +08:00
Update workflow.py
Former-commit-id: e16f128dc397d71e74398fa38818e57e80cc32f4
This commit is contained in:
parent
9c69cc1e16
commit
268c0efd67
@ -28,11 +28,10 @@ def run_sft(
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_modules = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_modules["tokenizer"]
|
||||
processor = tokenizer_modules["processor"]
|
||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft", processor=processor)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
model = load_model(tokenizer, model_args, finetuning_args, is_trainable=training_args.do_train)
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
tokenizer.padding_side = "left" # use left-padding in generation
|
||||
@ -49,8 +48,7 @@ def run_sft(
|
||||
# Override the decoding parameters of Seq2SeqTrainer
|
||||
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
|
||||
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
||||
if model_args.use_mllm:
|
||||
training_args.remove_unused_columns = False
|
||||
training_args.remove_unused_columns = False if model_args.visual_inputs else training_args.remove_unused_columns
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = CustomSeq2SeqTrainer(
|
||||
|
Loading…
x
Reference in New Issue
Block a user