mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	improve aligner
Former-commit-id: cc7296b92e10c24967fc753393275b71d300683f
This commit is contained in:
		
							parent
							
								
									a41fa6e730
								
							
						
					
					
						commit
						1955a8ea5a
					
				@ -174,6 +174,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
 | 
			
		||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
 | 
			
		||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
 | 
			
		||||
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
 | 
			
		||||
- [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca)
 | 
			
		||||
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
 | 
			
		||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
 | 
			
		||||
- [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa)
 | 
			
		||||
 | 
			
		||||
@ -174,6 +174,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
 | 
			
		||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
 | 
			
		||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
 | 
			
		||||
- [OpenOrca (en)](https://huggingface.co/datasets/Open-Orca/OpenOrca)
 | 
			
		||||
- [SlimOrca (en)](https://huggingface.co/datasets/Open-Orca/SlimOrca)
 | 
			
		||||
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
 | 
			
		||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
 | 
			
		||||
- [Wiki QA (en)](https://huggingface.co/datasets/wiki_qa)
 | 
			
		||||
 | 
			
		||||
@ -11,7 +11,7 @@ If you are using a custom dataset, please provide your dataset definition in the
 | 
			
		||||
  "folder": "the name of the folder of the dataset repository on the Hugging Face hub. (optional, default: None)",
 | 
			
		||||
  "ranking": "whether the dataset is a preference dataset or not. (default: false)",
 | 
			
		||||
  "formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
 | 
			
		||||
  "columns": {
 | 
			
		||||
  "columns (optional)": {
 | 
			
		||||
    "prompt": "the column name in the dataset containing the prompts. (default: instruction)",
 | 
			
		||||
    "query": "the column name in the dataset containing the queries. (default: input)",
 | 
			
		||||
    "response": "the column name in the dataset containing the responses. (default: output)",
 | 
			
		||||
@ -20,14 +20,14 @@ If you are using a custom dataset, please provide your dataset definition in the
 | 
			
		||||
    "system": "the column name in the dataset containing the system prompts. (default: None)",
 | 
			
		||||
    "tools": "the column name in the dataset containing the tool description. (default: None)"
 | 
			
		||||
  },
 | 
			
		||||
  "tags": {
 | 
			
		||||
  "tags (optional, used for the sharegpt format)": {
 | 
			
		||||
    "role_tag": "the key in the message represents the identity. (default: from)",
 | 
			
		||||
    "content_tag": "the key in the message represents the content. (default: value)",
 | 
			
		||||
    "user_tag": "the value of the role_tag represents the user. (default: human)",
 | 
			
		||||
    "assistant_tag": "the value of the role_tag represents the assistant. (default: gpt)",
 | 
			
		||||
    "observation_tag": "the value of the role_tag represents the tool results. (default: observation)",
 | 
			
		||||
    "function_tag": "the value of the role_tag represents the function call. (default: function_call)",
 | 
			
		||||
    "system_tag": "the value of the role_tag represents the system prompt. (default: None) incompatible with system column"
 | 
			
		||||
    "system_tag": "the value of the role_tag represents the system prompt. (default: system, can override system column)"
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
@ -11,7 +11,7 @@
 | 
			
		||||
  "folder": "Hugging Face 仓库的文件夹名称(可选,默认:None)",
 | 
			
		||||
  "ranking": "是否为偏好数据集(可选,默认:False)",
 | 
			
		||||
  "formatting": "数据集格式(可选,默认:alpaca,可以为 alpaca 或 sharegpt)",
 | 
			
		||||
  "columns": {
 | 
			
		||||
  "columns(可选)": {
 | 
			
		||||
    "prompt": "数据集代表提示词的表头名称(默认:instruction)",
 | 
			
		||||
    "query": "数据集代表请求的表头名称(默认:input)",
 | 
			
		||||
    "response": "数据集代表回答的表头名称(默认:output)",
 | 
			
		||||
@ -20,13 +20,14 @@
 | 
			
		||||
    "system": "数据集代表系统提示的表头名称(默认:None)",
 | 
			
		||||
    "tools": "数据集代表工具描述的表头名称(默认:None)"
 | 
			
		||||
  },
 | 
			
		||||
  "tags": {
 | 
			
		||||
  "tags(可选,用于 sharegpt 格式)": {
 | 
			
		||||
    "role_tag": "消息中代表发送者身份的键名(默认:from)",
 | 
			
		||||
    "content_tag": "消息中代表文本内容的键名(默认:value)",
 | 
			
		||||
    "user_tag": "消息中代表用户的 role_tag(默认:human)",
 | 
			
		||||
    "assistant_tag": "消息中代表助手的 role_tag(默认:gpt)",
 | 
			
		||||
    "observation_tag": "消息中代表工具返回结果的 role_tag(默认:observation)",
 | 
			
		||||
    "function_tag": "消息中代表工具调用的 role_tag(默认:function_call)"
 | 
			
		||||
    "function_tag": "消息中代表工具调用的 role_tag(默认:function_call)",
 | 
			
		||||
    "system_tag": "消息中代表系统提示的 role_tag(默认:system,会覆盖 system 列)"
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
@ -49,40 +49,32 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
 | 
			
		||||
        dataset_attr.function_tag: Role.FUNCTION,
 | 
			
		||||
        dataset_attr.system_tag: Role.SYSTEM,
 | 
			
		||||
    }
 | 
			
		||||
    odd_tags = (dataset_attr.user_tag, dataset_attr.observation_tag)
 | 
			
		||||
    even_tags = (dataset_attr.assistant_tag, dataset_attr.function_tag)
 | 
			
		||||
    accept_tags = (odd_tags, even_tags)
 | 
			
		||||
    for i, messages in enumerate(examples[dataset_attr.messages]):
 | 
			
		||||
        if len(messages) <= 1:
 | 
			
		||||
        if dataset_attr.system_tag and messages[0][dataset_attr.role_tag] == dataset_attr.system_tag:
 | 
			
		||||
            system = messages[0][dataset_attr.content_tag]
 | 
			
		||||
            messages = messages[1:]
 | 
			
		||||
        else:
 | 
			
		||||
            system = examples[dataset_attr.system][i] if dataset_attr.system else ""
 | 
			
		||||
 | 
			
		||||
        messages = messages[: len(messages) // 2 * 2]  # should be multiples of 2
 | 
			
		||||
        if len(messages) == 0:
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        prompt = []
 | 
			
		||||
        response = []
 | 
			
		||||
        n_sys = 0
 | 
			
		||||
        aligned_messages = []
 | 
			
		||||
        for turn_idx, message in enumerate(messages):
 | 
			
		||||
            if dataset_attr.system_tag and message[dataset_attr.role_tag] == dataset_attr.system_tag:
 | 
			
		||||
                outputs["system"].append(message[dataset_attr.content_tag])
 | 
			
		||||
                n_sys = 1
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            if (turn_idx - n_sys) % 2 == 0:
 | 
			
		||||
                accept_tags = [dataset_attr.user_tag, dataset_attr.observation_tag]
 | 
			
		||||
            else:
 | 
			
		||||
                accept_tags = [dataset_attr.assistant_tag, dataset_attr.function_tag]
 | 
			
		||||
 | 
			
		||||
            if message[dataset_attr.role_tag] not in accept_tags:
 | 
			
		||||
            if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
 | 
			
		||||
                raise ValueError("Invalid role tag in {}.".format(messages))
 | 
			
		||||
 | 
			
		||||
            prompt.append(
 | 
			
		||||
            aligned_messages.append(
 | 
			
		||||
                {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if len(prompt) % 2 == 1:
 | 
			
		||||
            # Last message was neither from assistant nor function
 | 
			
		||||
            prompt.pop(-1)
 | 
			
		||||
        last_message = prompt.pop(-1)
 | 
			
		||||
        response.append(last_message)
 | 
			
		||||
        outputs["prompt"].append(prompt)
 | 
			
		||||
        outputs["response"].append(response)
 | 
			
		||||
        if n_sys == 0:
 | 
			
		||||
            outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
 | 
			
		||||
        outputs["prompt"].append(aligned_messages[:-1])
 | 
			
		||||
        outputs["response"].append(aligned_messages[-1:])
 | 
			
		||||
        outputs["system"].append(system)
 | 
			
		||||
        outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
 | 
			
		||||
 | 
			
		||||
    return outputs
 | 
			
		||||
@ -93,8 +85,8 @@ def align_dataset(
 | 
			
		||||
) -> Union["Dataset", "IterableDataset"]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Aligned dataset:
 | 
			
		||||
        prompt: [{"role": "user", "content": "..."}]
 | 
			
		||||
        response: [{"role": "assistant", "content": "..."}]
 | 
			
		||||
        prompt: [{"role": "user", "content": "..."}] * (2T - 1)
 | 
			
		||||
        response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
 | 
			
		||||
        system: "..."
 | 
			
		||||
        tools: "..."
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
@ -30,6 +30,7 @@ def load_single_dataset(
 | 
			
		||||
    model_args: "ModelArguments",
 | 
			
		||||
    data_args: "DataArguments",
 | 
			
		||||
):
 | 
			
		||||
    logger.info("Loading dataset {}...".format(dataset_attr))
 | 
			
		||||
    data_path, data_name, data_dir, data_files = None, None, None, None
 | 
			
		||||
    if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
 | 
			
		||||
        data_path = dataset_attr.dataset_name
 | 
			
		||||
@ -60,7 +61,7 @@ def load_single_dataset(
 | 
			
		||||
        if data_path is None:
 | 
			
		||||
            raise ValueError("File extension must be txt, csv, json or jsonl.")
 | 
			
		||||
 | 
			
		||||
        checksum(data_files, dataset_attr.dataset_sha1)
 | 
			
		||||
        checksum(data_files, dataset_attr.file_sha1)
 | 
			
		||||
    else:
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
@ -157,7 +158,7 @@ def get_dataset(
 | 
			
		||||
 | 
			
		||||
    with training_args.main_process_first(desc="load dataset"):
 | 
			
		||||
        all_datasets = []
 | 
			
		||||
        for dataset_attr in get_dataset_list(data_args):  # TODO: add split
 | 
			
		||||
        for dataset_attr in get_dataset_list(data_args):
 | 
			
		||||
            all_datasets.append(load_single_dataset(dataset_attr, model_args, data_args))
 | 
			
		||||
        dataset = merge_dataset(all_datasets, data_args, training_args)
 | 
			
		||||
 | 
			
		||||
@ -185,6 +186,6 @@ def get_dataset(
 | 
			
		||||
            try:
 | 
			
		||||
                print_function(next(iter(dataset)))
 | 
			
		||||
            except StopIteration:
 | 
			
		||||
                raise RuntimeError("Empty dataset!")
 | 
			
		||||
                raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.")
 | 
			
		||||
 | 
			
		||||
        return dataset
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,7 @@
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import TYPE_CHECKING, List, Literal, Optional
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
 | 
			
		||||
 | 
			
		||||
from ..extras.constants import DATA_CONFIG
 | 
			
		||||
from ..extras.misc import use_modelscope
 | 
			
		||||
@ -13,38 +13,44 @@ if TYPE_CHECKING:
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class DatasetAttr:
 | 
			
		||||
    r"""
 | 
			
		||||
    Dataset attributes.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    """ basic configs """
 | 
			
		||||
    load_from: Literal["hf_hub", "ms_hub", "script", "file"]
 | 
			
		||||
    dataset_name: Optional[str] = None
 | 
			
		||||
    dataset_sha1: Optional[str] = None
 | 
			
		||||
    """ extra configs """
 | 
			
		||||
    file_sha1: Optional[str] = None
 | 
			
		||||
    subset: Optional[str] = None
 | 
			
		||||
    folder: Optional[str] = None
 | 
			
		||||
    ranking: Optional[bool] = False
 | 
			
		||||
    formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
 | 
			
		||||
 | 
			
		||||
    """ columns """
 | 
			
		||||
    system: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
    """ columns for the alpaca format """
 | 
			
		||||
    prompt: Optional[str] = "instruction"
 | 
			
		||||
    query: Optional[str] = "input"
 | 
			
		||||
    response: Optional[str] = "output"
 | 
			
		||||
    history: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
    """ columns for the sharegpt format """
 | 
			
		||||
    messages: Optional[str] = "conversations"
 | 
			
		||||
    tools: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
    """ tags for the sharegpt format """
 | 
			
		||||
    role_tag: Optional[str] = "from"
 | 
			
		||||
    content_tag: Optional[str] = "value"
 | 
			
		||||
    user_tag: Optional[str] = "human"
 | 
			
		||||
    assistant_tag: Optional[str] = "gpt"
 | 
			
		||||
    observation_tag: Optional[str] = "observation"
 | 
			
		||||
    function_tag: Optional[str] = "function_call"
 | 
			
		||||
    system_tag: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
    assert system_tag is None or system is None, f"Can not provide both system message (system_tag={system_tag}) and system column(system={system})"
 | 
			
		||||
 | 
			
		||||
    system_tag: Optional[str] = "system"
 | 
			
		||||
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        return self.dataset_name
 | 
			
		||||
 | 
			
		||||
    def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
 | 
			
		||||
        setattr(self, key, obj.get(key, default))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
 | 
			
		||||
    dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else []
 | 
			
		||||
@ -77,30 +83,36 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
 | 
			
		||||
        elif "script_url" in dataset_info[name]:
 | 
			
		||||
            dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
 | 
			
		||||
        else:
 | 
			
		||||
            dataset_attr = DatasetAttr(
 | 
			
		||||
                "file",
 | 
			
		||||
                dataset_name=dataset_info[name]["file_name"],
 | 
			
		||||
                dataset_sha1=dataset_info[name].get("file_sha1", None),
 | 
			
		||||
            )
 | 
			
		||||
            dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
 | 
			
		||||
 | 
			
		||||
        dataset_attr.subset = dataset_info[name].get("subset", None)
 | 
			
		||||
        dataset_attr.folder = dataset_info[name].get("folder", None)
 | 
			
		||||
        dataset_attr.ranking = dataset_info[name].get("ranking", False)
 | 
			
		||||
        dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
 | 
			
		||||
        dataset_attr.set_attr("file_sha1", dataset_info[name])
 | 
			
		||||
        dataset_attr.set_attr("subset", dataset_info[name])
 | 
			
		||||
        dataset_attr.set_attr("folder", dataset_info[name])
 | 
			
		||||
        dataset_attr.set_attr("ranking", dataset_info[name], default=False)
 | 
			
		||||
        dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
 | 
			
		||||
 | 
			
		||||
        if "columns" in dataset_info[name]:
 | 
			
		||||
            column_names = ["system"]
 | 
			
		||||
            if dataset_attr.formatting == "alpaca":
 | 
			
		||||
                column_names = ["prompt", "query", "response", "history"]
 | 
			
		||||
                column_names.extend(["prompt", "query", "response", "history"])
 | 
			
		||||
            else:
 | 
			
		||||
                column_names = ["messages", "tools"]
 | 
			
		||||
                column_names.extend(["messages", "tools"])
 | 
			
		||||
 | 
			
		||||
            column_names += ["system"]
 | 
			
		||||
            for column_name in column_names:
 | 
			
		||||
                setattr(dataset_attr, column_name, dataset_info[name]["columns"].get(column_name, None))
 | 
			
		||||
                dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
 | 
			
		||||
 | 
			
		||||
        if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
 | 
			
		||||
            for tag in ["role_tag", "content_tag", "user_tag", "assistant_tag", "observation_tag", "function_tag", "system_tag"]:
 | 
			
		||||
                setattr(dataset_attr, tag, dataset_info[name]["tags"].get(tag, None))
 | 
			
		||||
            tag_names = (
 | 
			
		||||
                "role_tag",
 | 
			
		||||
                "content_tag",
 | 
			
		||||
                "user_tag",
 | 
			
		||||
                "assistant_tag",
 | 
			
		||||
                "observation_tag",
 | 
			
		||||
                "function_tag",
 | 
			
		||||
                "system_tag",
 | 
			
		||||
            )
 | 
			
		||||
            for tag in tag_names:
 | 
			
		||||
                dataset_attr.set_attr(tag, dataset_info[name]["tags"])
 | 
			
		||||
 | 
			
		||||
        dataset_list.append(dataset_attr)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -247,7 +247,7 @@ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str)
 | 
			
		||||
        logger.info("Replace eos token: {}".format(tokenizer.eos_token))
 | 
			
		||||
 | 
			
		||||
    if is_oov:
 | 
			
		||||
        logger.warning("New token is added, you must enable `resize_vocab` to activate it.")
 | 
			
		||||
        logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_template_and_fix_tokenizer(
 | 
			
		||||
 | 
			
		||||
@ -19,9 +19,9 @@ logger = get_logger(__name__)
 | 
			
		||||
class Role(str, Enum):
 | 
			
		||||
    USER = "user"
 | 
			
		||||
    ASSISTANT = "assistant"
 | 
			
		||||
    SYSTEM = "system"
 | 
			
		||||
    OBSERVATION = "observation"
 | 
			
		||||
    FUNCTION = "function"
 | 
			
		||||
    SYSTEM = "system"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
 | 
			
		||||
 | 
			
		||||
@ -67,7 +67,7 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
 | 
			
		||||
            raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
 | 
			
		||||
 | 
			
		||||
    if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
 | 
			
		||||
        raise ValueError("Only LoRA method has adapters.")
 | 
			
		||||
        raise ValueError("Adapter is only valid for the LoRA method.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
 | 
			
		||||
@ -125,6 +125,14 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
 | 
			
		||||
 | 
			
		||||
    _verify_model_args(model_args, finetuning_args)
 | 
			
		||||
 | 
			
		||||
    if (
 | 
			
		||||
        training_args.do_train
 | 
			
		||||
        and finetuning_args.finetuning_type == "lora"
 | 
			
		||||
        and model_args.resize_vocab
 | 
			
		||||
        and finetuning_args.additional_target is None
 | 
			
		||||
    ):
 | 
			
		||||
        logger.warning("Add token embeddings to `additional_target` to make the added tokens trainable.")
 | 
			
		||||
 | 
			
		||||
    if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
 | 
			
		||||
        logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user