fix streaming in pt stage #548 #549

Former-commit-id: b0ed0dec5e6788a0344c09a6cc58d1116265fd68
This commit is contained in:
hiyouga 2023-08-17 17:59:26 +08:00
parent 327e14d3ea
commit a46f277477
4 changed files with 43 additions and 40 deletions

View File

@ -1,48 +1,25 @@
import os
import hashlib
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, List, Union
from datasets import Value, concatenate_datasets, interleave_datasets, load_dataset
from datasets import concatenate_datasets, interleave_datasets, load_dataset
from llmtuner.dsets.utils import checksum, EXT2TYPE
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from datasets import Dataset
from datasets import Dataset, IterableDataset
from llmtuner.hparams import ModelArguments, DataArguments
logger = get_logger(__name__)
EXT2TYPE = {
"csv": "csv",
"json": "json",
"jsonl": "json",
"txt": "text"
}
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
if file_sha1 is None:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
return
if len(data_files) != 1:
logger.warning("Checksum failed: too many files.")
return
with open(data_files[0], "rb") as f:
sha1 = hashlib.sha1(f.read()).hexdigest()
if sha1 != file_sha1:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def get_dataset(
model_args: "ModelArguments",
data_args: "DataArguments"
) -> "Dataset":
) -> Union["Dataset", "IterableDataset"]:
max_samples = data_args.max_samples
all_datasets: List["Dataset"] = [] # support multiple datasets
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr))
@ -94,9 +71,7 @@ def get_dataset(
if dataset_attr.system_prompt: # add system prompt
if data_args.streaming:
features = dataset.features
features["system"] = Value(dtype="string", id=None)
dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt}, features=features)
dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt})
else:
dataset = dataset.add_column("system", [dataset_attr.system_prompt] * len(dataset))

View File

@ -1,25 +1,25 @@
import tiktoken
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union
from itertools import chain
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.template import get_template_and_fix_tokenizer
if TYPE_CHECKING:
from datasets import Dataset
from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer
from llmtuner.hparams import DataArguments
def preprocess_dataset(
dataset: "Dataset",
dataset: Union["Dataset", "IterableDataset"],
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"]
) -> "Dataset":
column_names = list(dataset.column_names)
) -> Union["Dataset", "IterableDataset"]:
column_names = list(next(iter(dataset)).keys())
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:

View File

@ -1,4 +1,7 @@
from typing import TYPE_CHECKING, Dict, Union
import hashlib
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from datasets import Dataset, IterableDataset
@ -6,6 +9,32 @@ if TYPE_CHECKING:
from llmtuner.hparams import DataArguments
logger = get_logger(__name__)
EXT2TYPE = {
"csv": "csv",
"json": "json",
"jsonl": "json",
"txt": "text"
}
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
if file_sha1 is None:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
return
if len(data_files) != 1:
logger.warning("Checksum failed: too many files.")
return
with open(data_files[0], "rb") as f:
sha1 = hashlib.sha1(f.read()).hexdigest()
if sha1 != file_sha1:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def split_dataset(
dataset: Union["Dataset", "IterableDataset"],
data_args: "DataArguments",

View File

@ -25,7 +25,7 @@ class ComputeMetrics:
Uses the model predictions to compute metrics.
"""
preds, labels = eval_preds
score_dict = {"accuracy": [], "rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
@ -49,6 +49,5 @@ class ComputeMetrics:
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
score_dict["accuracy"].append(float(len(label) != 0 and pred[:len(label)] == label))
return {k: float(np.mean(v)) for k, v in score_dict.items()}