mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 10:10:35 +08:00
[v1] add models & accelerator (#9579)
This commit is contained in:
@@ -19,7 +19,6 @@ from typing import Any, Literal, Optional, Union
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from ...config.data_args import DataArguments
|
||||
from ...extras.types import DatasetInfo, HFDataset
|
||||
|
||||
|
||||
@@ -27,9 +26,6 @@ from ...extras.types import DatasetInfo, HFDataset
|
||||
class DataLoaderPlugin:
|
||||
"""Plugin for loading dataset."""
|
||||
|
||||
args: DataArguments
|
||||
"""Data arguments."""
|
||||
|
||||
def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
|
||||
"""Get dataset builder name.
|
||||
|
||||
@@ -42,7 +38,7 @@ class DataLoaderPlugin:
|
||||
return os.path.splitext(path)[-1][1:].replace("jsonl", "json").replace("txt", "text")
|
||||
|
||||
def auto_load_data(self, dataset_info: DatasetInfo) -> HFDataset:
|
||||
dataset_dir = dataset_info.get("dataset_dir", self.args.dataset_dir)
|
||||
dataset_dir = dataset_info.get("dataset_dir", ".")
|
||||
split = dataset_info.get("split", "train")
|
||||
streaming = dataset_info.get("streaming", False)
|
||||
if "file_name" in dataset_info:
|
||||
|
||||
Reference in New Issue
Block a user