fix ppo dataset bug #4012

Former-commit-id: 149610c636bbb974e546d13fa302884ea65a6d38
This commit is contained in:
hiyouga 2024-06-06 19:03:20 +08:00
parent e898d8bbc4
commit e0aadd4b34
4 changed files with 4 additions and 4 deletions

View File

@ -130,7 +130,7 @@ def get_dataset(
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None, processor: Optional["ProcessorMixin"] = None,
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:

View File

@ -23,7 +23,7 @@ if TYPE_CHECKING:
def get_preprocess_and_print_func( def get_preprocess_and_print_func(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
template: "Template", template: "Template",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],

View File

@ -18,7 +18,7 @@ def preprocess_pretrain_dataset(
if data_args.template == "gemma": if data_args.template == "gemma":
text_examples = [tokenizer.bos_token + example for example in text_examples] text_examples = [tokenizer.bos_token + example for example in text_examples]
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len) result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len, truncation=True)
else: else:
tokenized_examples = tokenizer(text_examples, add_special_tokens=False) tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}

View File

@ -29,7 +29,7 @@ def run_ppo(
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
dataset = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module) dataset = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training