mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 19:52:50 +08:00
fix ppo dataset bug #4012
Former-commit-id: 149610c636bbb974e546d13fa302884ea65a6d38
This commit is contained in:
parent
e898d8bbc4
commit
e0aadd4b34
@ -130,7 +130,7 @@ def get_dataset(
|
|||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
stage: Literal["pt", "sft", "rm", "kto"],
|
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: Optional["ProcessorMixin"] = None,
|
processor: Optional["ProcessorMixin"] = None,
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
|
@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
|||||||
def get_preprocess_and_print_func(
|
def get_preprocess_and_print_func(
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
stage: Literal["pt", "sft", "rm", "kto"],
|
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||||
template: "Template",
|
template: "Template",
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
|
@ -18,7 +18,7 @@ def preprocess_pretrain_dataset(
|
|||||||
if data_args.template == "gemma":
|
if data_args.template == "gemma":
|
||||||
text_examples = [tokenizer.bos_token + example for example in text_examples]
|
text_examples = [tokenizer.bos_token + example for example in text_examples]
|
||||||
|
|
||||||
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len)
|
result = tokenizer(text_examples, add_special_tokens=False, max_length=data_args.cutoff_len, truncation=True)
|
||||||
else:
|
else:
|
||||||
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
||||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||||
|
@ -29,7 +29,7 @@ def run_ppo(
|
|||||||
):
|
):
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
tokenizer = tokenizer_module["tokenizer"]
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
dataset = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module)
|
dataset = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module)
|
||||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||||
|
|
||||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||||
|
Loading…
x
Reference in New Issue
Block a user