diff --git a/src/llamafactory/data/aligner.py b/src/llamafactory/data/aligner.py index 0e497ebc..87d3729e 100644 --- a/src/llamafactory/data/aligner.py +++ b/src/llamafactory/data/aligner.py @@ -33,41 +33,47 @@ logger = get_logger(__name__) def _convert_images( - images: Sequence["ImageInput"], + images: Union["ImageInput", Sequence["ImageInput"]], dataset_attr: "DatasetAttr", data_args: "DataArguments", ) -> Optional[List["ImageInput"]]: r""" Optionally concatenates image path to dataset dir when loading from local disk. """ - if len(images) == 0: + if not isinstance(images, list): + images = [images] + elif len(images) == 0: return None + else: + images = images[:] - images = images[:] if dataset_attr.load_from in ["script", "file"]: for i in range(len(images)): - if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, images[i])): - images[i] = os.path.join(data_args.dataset_dir, images[i]) + if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.image_dir, images[i])): + images[i] = os.path.join(data_args.image_dir, images[i]) return images def _convert_videos( - videos: Sequence["VideoInput"], + videos: Union["VideoInput", Sequence["VideoInput"]], dataset_attr: "DatasetAttr", data_args: "DataArguments", ) -> Optional[List["VideoInput"]]: r""" Optionally concatenates video path to dataset dir when loading from local disk. """ - if len(videos) == 0: + if not isinstance(videos, list): + videos = [videos] + elif len(videos) == 0: return None + else: + videos = videos[:] - videos = videos[:] if dataset_attr.load_from in ["script", "file"]: for i in range(len(videos)): - if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.dataset_dir, videos[i])): - videos[i] = os.path.join(data_args.dataset_dir, videos[i]) + if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.image_dir, videos[i])): + videos[i] = os.path.join(data_args.image_dir, videos[i]) return videos diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 1adcf2d0..7c89c016 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -41,6 +41,10 @@ class DataArguments: default="data", metadata={"help": "Path to the folder containing the datasets."}, ) + image_dir: Optional[str] = field( + default=None, + metadata={"help": "Path to the folder containing the images or videos. Defaults to `dataset_dir`."}, + ) cutoff_len: int = field( default=1024, metadata={"help": "The cutoff length of the tokenized inputs in the dataset."}, @@ -111,7 +115,13 @@ class DataArguments: ) tokenized_path: Optional[str] = field( default=None, - metadata={"help": "Path to save or load the tokenized datasets."}, + metadata={ + "help": ( + "Path to save or load the tokenized datasets. " + "If tokenized_path not exists, it will save the tokenized datasets. " + "If tokenized_path exists, it will load the tokenized datasets." + ) + }, ) def __post_init__(self): @@ -123,6 +133,9 @@ class DataArguments: self.dataset = split_arg(self.dataset) self.eval_dataset = split_arg(self.eval_dataset) + if self.image_dir is None: + self.image_dir = self.dataset_dir + if self.dataset is None and self.val_size > 1e-6: raise ValueError("Cannot specify `val_size` if `dataset` is None.")