modify some style

This commit is contained in:
BUAADreamer
2024-04-25 21:58:18 +08:00
parent 2cab2d42fb
commit 2d4ded535f
6 changed files with 26 additions and 158 deletions

View File

@@ -17,12 +17,7 @@ from .trainer import CustomSeq2SeqTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from ...hparams import (
DataArguments,
FinetuningArguments,
GeneratingArguments,
ModelArguments,
)
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
def run_sft(
@@ -36,14 +31,7 @@ def run_sft(
tokenizer_modules = load_tokenizer(model_args)
tokenizer = tokenizer_modules["tokenizer"]
processor = tokenizer_modules["processor"]
dataset = get_dataset(
tokenizer,
model_args,
data_args,
training_args,
stage="sft",
processor=processor,
)
dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft", processor=processor)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
if training_args.predict_with_generate:
@@ -54,7 +42,7 @@ def run_sft(
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
pad_to_multiple_of=(8 if tokenizer.padding_side == "right" else None), # for shift short attention
pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
label_pad_token_id=(IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id),
)
@@ -72,7 +60,7 @@ def run_sft(
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
compute_metrics=(ComputeMetrics(tokenizer) if training_args.predict_with_generate else None),
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**split_dataset(dataset, data_args, training_args),
)