fix streaming in pt stage #548 #549

This commit is contained in:
hiyouga
2023-08-17 17:59:26 +08:00
parent ff0aa793b6
commit b0ed0dec5e
4 changed files with 43 additions and 40 deletions

View File

@@ -1,48 +1,25 @@
import os
import hashlib
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, List, Union
from datasets import Value, concatenate_datasets, interleave_datasets, load_dataset
from datasets import concatenate_datasets, interleave_datasets, load_dataset
from llmtuner.dsets.utils import checksum, EXT2TYPE
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from datasets import Dataset
from datasets import Dataset, IterableDataset
from llmtuner.hparams import ModelArguments, DataArguments
logger = get_logger(__name__)
EXT2TYPE = {
"csv": "csv",
"json": "json",
"jsonl": "json",
"txt": "text"
}
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
if file_sha1 is None:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
return
if len(data_files) != 1:
logger.warning("Checksum failed: too many files.")
return
with open(data_files[0], "rb") as f:
sha1 = hashlib.sha1(f.read()).hexdigest()
if sha1 != file_sha1:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def get_dataset(
model_args: "ModelArguments",
data_args: "DataArguments"
) -> "Dataset":
) -> Union["Dataset", "IterableDataset"]:
max_samples = data_args.max_samples
all_datasets: List["Dataset"] = [] # support multiple datasets
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr))
@@ -94,9 +71,7 @@ def get_dataset(
if dataset_attr.system_prompt: # add system prompt
if data_args.streaming:
features = dataset.features
features["system"] = Value(dtype="string", id=None)
dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt}, features=features)
dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt})
else:
dataset = dataset.add_column("system", [dataset_attr.system_prompt] * len(dataset))