diff --git a/src/llamafactory/data/preprocess.py b/src/llamafactory/data/preprocess.py index a22c7c11..2ea2fa1d 100644 --- a/src/llamafactory/data/preprocess.py +++ b/src/llamafactory/data/preprocess.py @@ -15,8 +15,6 @@ from functools import partial from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple -from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence - from .processors.feedback import preprocess_feedback_dataset from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example from .processors.pretrain import preprocess_pretrain_dataset @@ -53,6 +51,7 @@ def get_preprocess_and_print_func( elif stage == "sft" and not training_args.predict_with_generate: if data_args.packing: if data_args.neat_packing: + from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence def __init__(self, data, **kwargs): return TypedSequence.__init__(