mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
Former-commit-id: b0ed0dec5e6788a0344c09a6cc58d1116265fd68
This commit is contained in:
parent
327e14d3ea
commit
a46f277477
@ -1,48 +1,25 @@
|
|||||||
import os
|
import os
|
||||||
import hashlib
|
from typing import TYPE_CHECKING, List, Union
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
|
||||||
|
|
||||||
from datasets import Value, concatenate_datasets, interleave_datasets, load_dataset
|
from datasets import concatenate_datasets, interleave_datasets, load_dataset
|
||||||
|
|
||||||
|
from llmtuner.dsets.utils import checksum, EXT2TYPE
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset
|
from datasets import Dataset, IterableDataset
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments
|
from llmtuner.hparams import ModelArguments, DataArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
EXT2TYPE = {
|
|
||||||
"csv": "csv",
|
|
||||||
"json": "json",
|
|
||||||
"jsonl": "json",
|
|
||||||
"txt": "text"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
|
||||||
if file_sha1 is None:
|
|
||||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
|
||||||
return
|
|
||||||
|
|
||||||
if len(data_files) != 1:
|
|
||||||
logger.warning("Checksum failed: too many files.")
|
|
||||||
return
|
|
||||||
|
|
||||||
with open(data_files[0], "rb") as f:
|
|
||||||
sha1 = hashlib.sha1(f.read()).hexdigest()
|
|
||||||
if sha1 != file_sha1:
|
|
||||||
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
|
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(
|
def get_dataset(
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments"
|
data_args: "DataArguments"
|
||||||
) -> "Dataset":
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
max_samples = data_args.max_samples
|
max_samples = data_args.max_samples
|
||||||
all_datasets: List["Dataset"] = [] # support multiple datasets
|
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
|
||||||
|
|
||||||
for dataset_attr in data_args.dataset_list:
|
for dataset_attr in data_args.dataset_list:
|
||||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||||
@ -94,9 +71,7 @@ def get_dataset(
|
|||||||
|
|
||||||
if dataset_attr.system_prompt: # add system prompt
|
if dataset_attr.system_prompt: # add system prompt
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
features = dataset.features
|
dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt})
|
||||||
features["system"] = Value(dtype="string", id=None)
|
|
||||||
dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt}, features=features)
|
|
||||||
else:
|
else:
|
||||||
dataset = dataset.add_column("system", [dataset_attr.system_prompt] * len(dataset))
|
dataset = dataset.add_column("system", [dataset_attr.system_prompt] * len(dataset))
|
||||||
|
|
||||||
|
@ -1,25 +1,25 @@
|
|||||||
import tiktoken
|
import tiktoken
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal
|
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset
|
from datasets import Dataset, IterableDataset
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
from llmtuner.hparams import DataArguments
|
from llmtuner.hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
def preprocess_dataset(
|
def preprocess_dataset(
|
||||||
dataset: "Dataset",
|
dataset: Union["Dataset", "IterableDataset"],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||||
) -> "Dataset":
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
column_names = list(dataset.column_names)
|
column_names = list(next(iter(dataset)).keys())
|
||||||
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
||||||
|
|
||||||
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
from typing import TYPE_CHECKING, Dict, Union
|
import hashlib
|
||||||
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from llmtuner.extras.logging import get_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
@ -6,6 +9,32 @@ if TYPE_CHECKING:
|
|||||||
from llmtuner.hparams import DataArguments
|
from llmtuner.hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
EXT2TYPE = {
|
||||||
|
"csv": "csv",
|
||||||
|
"json": "json",
|
||||||
|
"jsonl": "json",
|
||||||
|
"txt": "text"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
||||||
|
if file_sha1 is None:
|
||||||
|
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(data_files) != 1:
|
||||||
|
logger.warning("Checksum failed: too many files.")
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(data_files[0], "rb") as f:
|
||||||
|
sha1 = hashlib.sha1(f.read()).hexdigest()
|
||||||
|
if sha1 != file_sha1:
|
||||||
|
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
|
||||||
|
|
||||||
|
|
||||||
def split_dataset(
|
def split_dataset(
|
||||||
dataset: Union["Dataset", "IterableDataset"],
|
dataset: Union["Dataset", "IterableDataset"],
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
|
@ -25,7 +25,7 @@ class ComputeMetrics:
|
|||||||
Uses the model predictions to compute metrics.
|
Uses the model predictions to compute metrics.
|
||||||
"""
|
"""
|
||||||
preds, labels = eval_preds
|
preds, labels = eval_preds
|
||||||
score_dict = {"accuracy": [], "rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||||
|
|
||||||
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
||||||
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
|
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
|
||||||
@ -49,6 +49,5 @@ class ComputeMetrics:
|
|||||||
|
|
||||||
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
|
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
|
||||||
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
|
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
|
||||||
score_dict["accuracy"].append(float(len(label) != 0 and pred[:len(label)] == label))
|
|
||||||
|
|
||||||
return {k: float(np.mean(v)) for k, v in score_dict.items()}
|
return {k: float(np.mean(v)) for k, v in score_dict.items()}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user