mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 13:42:51 +08:00
Update workflow.py
Former-commit-id: e16f128dc397d71e74398fa38818e57e80cc32f4
This commit is contained in:
parent
9c69cc1e16
commit
268c0efd67
@ -28,11 +28,10 @@ def run_sft(
|
|||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||||
):
|
):
|
||||||
tokenizer_modules = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
tokenizer = tokenizer_modules["tokenizer"]
|
dataset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||||
processor = tokenizer_modules["processor"]
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft", processor=processor)
|
model = load_model(tokenizer, model_args, finetuning_args, is_trainable=training_args.do_train)
|
||||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
|
||||||
|
|
||||||
if training_args.predict_with_generate:
|
if training_args.predict_with_generate:
|
||||||
tokenizer.padding_side = "left" # use left-padding in generation
|
tokenizer.padding_side = "left" # use left-padding in generation
|
||||||
@ -49,8 +48,7 @@ def run_sft(
|
|||||||
# Override the decoding parameters of Seq2SeqTrainer
|
# Override the decoding parameters of Seq2SeqTrainer
|
||||||
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
|
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
|
||||||
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
||||||
if model_args.use_mllm:
|
training_args.remove_unused_columns = False if model_args.visual_inputs else training_args.remove_unused_columns
|
||||||
training_args.remove_unused_columns = False
|
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = CustomSeq2SeqTrainer(
|
trainer = CustomSeq2SeqTrainer(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user