mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 13:42:51 +08:00
Update preprocess.py
Former-commit-id: 7f3bd35c0ead92710036064bf306740e8ee901c7
This commit is contained in:
parent
15b7182418
commit
3257df2fdb
@ -8,15 +8,26 @@ from .utils import Role
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from PIL import Image
|
||||||
from transformers.tokenization_utils import AutoProcessor, PreTrainedTokenizer
|
from transformers import ProcessorMixin, Seq2SeqTrainingArguments
|
||||||
|
from transformers.image_processing_utils import BaseImageProcessor
|
||||||
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
from ..hparams import DataArguments
|
from ..hparams import DataArguments
|
||||||
from .template import Template
|
from .template import Template
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _preprocess_visual_inputs(model_inputs: Dict[str, Any], processor: "ProcessorMixin", image: "Image") -> None:
|
||||||
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||||
|
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"][0]
|
||||||
|
if "pixel_values" not in model_inputs:
|
||||||
|
model_inputs["pixel_values"] = []
|
||||||
|
model_inputs["pixel_values"].append(pixel_values)
|
||||||
|
|
||||||
|
|
||||||
def preprocess_pretrain_dataset(
|
def preprocess_pretrain_dataset(
|
||||||
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
@ -47,10 +58,10 @@ def preprocess_pretrain_dataset(
|
|||||||
|
|
||||||
def preprocess_supervised_dataset(
|
def preprocess_supervised_dataset(
|
||||||
examples: Dict[str, List[Any]],
|
examples: Dict[str, List[Any]],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
|
||||||
template: "Template",
|
template: "Template",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
processor: "AutoProcessor" = None,
|
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||||
@ -90,17 +101,15 @@ def preprocess_supervised_dataset(
|
|||||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||||
model_inputs["labels"].append(labels)
|
model_inputs["labels"].append(labels)
|
||||||
if processor is not None and "images" in examples:
|
if processor is not None and "images" in examples:
|
||||||
pixel_values = processor.image_processor(examples["images"][0], return_tensors="pt")["pixel_values"][0]
|
_preprocess_visual_inputs(model_inputs, processor, examples["images"][i][0])
|
||||||
if "pixel_values" not in model_inputs:
|
|
||||||
model_inputs["pixel_values"] = []
|
|
||||||
model_inputs["pixel_values"].append(pixel_values)
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
def preprocess_packed_supervised_dataset(
|
def preprocess_packed_supervised_dataset(
|
||||||
examples: Dict[str, List[Any]],
|
examples: Dict[str, List[Any]],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
|
||||||
template: "Template",
|
template: "Template",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||||
@ -145,8 +154,9 @@ def preprocess_packed_supervised_dataset(
|
|||||||
|
|
||||||
def preprocess_unsupervised_dataset(
|
def preprocess_unsupervised_dataset(
|
||||||
examples: Dict[str, List[Any]],
|
examples: Dict[str, List[Any]],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
|
||||||
template: "Template",
|
template: "Template",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||||
@ -176,14 +186,17 @@ def preprocess_unsupervised_dataset(
|
|||||||
model_inputs["input_ids"].append(input_ids)
|
model_inputs["input_ids"].append(input_ids)
|
||||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||||
model_inputs["labels"].append(labels)
|
model_inputs["labels"].append(labels)
|
||||||
|
if processor is not None and "images" in examples:
|
||||||
|
_preprocess_visual_inputs(model_inputs, processor, examples["images"][i][0])
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
|
|
||||||
def preprocess_pairwise_dataset(
|
def preprocess_pairwise_dataset(
|
||||||
examples: Dict[str, List[Any]],
|
examples: Dict[str, List[Any]],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
|
||||||
template: "Template",
|
template: "Template",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||||
@ -218,6 +231,8 @@ def preprocess_pairwise_dataset(
|
|||||||
model_inputs["prompt_ids"].append(prompt_ids)
|
model_inputs["prompt_ids"].append(prompt_ids)
|
||||||
model_inputs["chosen_ids"].append(chosen_ids)
|
model_inputs["chosen_ids"].append(chosen_ids)
|
||||||
model_inputs["rejected_ids"].append(rejected_ids)
|
model_inputs["rejected_ids"].append(rejected_ids)
|
||||||
|
if processor is not None and "images" in examples:
|
||||||
|
_preprocess_visual_inputs(model_inputs, processor, examples["images"][i][0])
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
@ -248,12 +263,12 @@ def print_unsupervised_dataset_example(example: Dict[str, List[int]], tokenizer:
|
|||||||
|
|
||||||
|
|
||||||
def get_preprocess_and_print_func(
|
def get_preprocess_and_print_func(
|
||||||
tokenizer: "PreTrainedTokenizer",
|
|
||||||
template: "Template",
|
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
stage: Literal["pt", "sft", "rm", "ppo"],
|
stage: Literal["pt", "sft", "rm", "ppo"],
|
||||||
processor: Optional["AutoProcessor"] = None,
|
template: "Template",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> Tuple[Callable, Callable]:
|
) -> Tuple[Callable, Callable]:
|
||||||
if stage == "pt":
|
if stage == "pt":
|
||||||
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
|
preprocess_func = partial(preprocess_pretrain_dataset, tokenizer=tokenizer, data_args=data_args)
|
||||||
@ -261,25 +276,38 @@ def get_preprocess_and_print_func(
|
|||||||
elif stage == "sft" and not training_args.predict_with_generate:
|
elif stage == "sft" and not training_args.predict_with_generate:
|
||||||
if data_args.packing:
|
if data_args.packing:
|
||||||
preprocess_func = partial(
|
preprocess_func = partial(
|
||||||
preprocess_packed_supervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
preprocess_packed_supervised_dataset,
|
||||||
|
template=template,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
data_args=data_args,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
preprocess_func = partial(
|
preprocess_func = partial(
|
||||||
preprocess_supervised_dataset,
|
preprocess_supervised_dataset,
|
||||||
tokenizer=tokenizer,
|
|
||||||
template=template,
|
template=template,
|
||||||
data_args=data_args,
|
tokenizer=tokenizer,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
|
data_args=data_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
||||||
elif stage == "rm":
|
elif stage == "rm":
|
||||||
preprocess_func = partial(
|
preprocess_func = partial(
|
||||||
preprocess_pairwise_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
preprocess_pairwise_dataset,
|
||||||
|
template=template,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
processor=processor,
|
||||||
|
data_args=data_args,
|
||||||
)
|
)
|
||||||
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
|
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
|
||||||
else:
|
else:
|
||||||
preprocess_func = partial(
|
preprocess_func = partial(
|
||||||
preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
|
preprocess_unsupervised_dataset,
|
||||||
|
template=template,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
processor=processor,
|
||||||
|
data_args=data_args,
|
||||||
)
|
)
|
||||||
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)
|
||||||
|
|
||||||
return preprocess_func, print_function
|
return preprocess_func, print_function
|
||||||
|
Loading…
x
Reference in New Issue
Block a user