mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	Supports custom data set sampling quantity
Former-commit-id: fa8325401df27595de4611a89dfcc14644956abd
This commit is contained in:
		
							parent
							
								
									dc540dfaa8
								
							
						
					
					
						commit
						7cdc16abdf
					
				@ -27,8 +27,9 @@ If you are using a custom dataset, please provide your dataset definition in the
 | 
			
		||||
    "assistant_tag": "the value of the role_tag represents the assistant. (default: gpt)",
 | 
			
		||||
    "observation_tag": "the value of the role_tag represents the tool results. (default: observation)",
 | 
			
		||||
    "function_tag": "the value of the role_tag represents the function call. (default: function_call)",
 | 
			
		||||
    "system_tag": "the value of the role_tag represents the system prompt. (default: system, can override system column)"
 | 
			
		||||
  }
 | 
			
		||||
    "system_tag": "the value of the role_tag represents the system prompt. (default: system, can override system column)",
 | 
			
		||||
  },
 | 
			
		||||
  "sample_num": "the number of samples from this dataset can be greater than the total amount of the dataset. (default: None)"
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -28,7 +28,8 @@
 | 
			
		||||
    "observation_tag": "消息中代表工具返回结果的 role_tag(默认:observation)",
 | 
			
		||||
    "function_tag": "消息中代表工具调用的 role_tag(默认:function_call)",
 | 
			
		||||
    "system_tag": "消息中代表系统提示的 role_tag(默认:system,会覆盖 system 列)"
 | 
			
		||||
  }
 | 
			
		||||
  },
 | 
			
		||||
  "sample_num": "从该数据集采样的数量,可大于该数据集总量(默认:None)"
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,7 @@
 | 
			
		||||
import inspect
 | 
			
		||||
import os
 | 
			
		||||
import numpy as np
 | 
			
		||||
from numpy.random import RandomState
 | 
			
		||||
from typing import TYPE_CHECKING, Literal, Union
 | 
			
		||||
 | 
			
		||||
from datasets import load_dataset, load_from_disk
 | 
			
		||||
@ -108,6 +110,17 @@ def load_single_dataset(
 | 
			
		||||
        num_samples = min(data_args.max_samples, len(dataset))
 | 
			
		||||
        dataset = dataset.select(range(num_samples))
 | 
			
		||||
 | 
			
		||||
    if dataset_attr.sample_num:
 | 
			
		||||
        dataset_sample_num = dataset_attr.sample_num
 | 
			
		||||
        logger.info(f"从 {dataset_attr.dataset_name} 采样 {dataset_sample_num} 条训练样本")
 | 
			
		||||
        random_state = RandomState(42)
 | 
			
		||||
        idx = random_state.permutation(len(dataset))[:dataset_sample_num]
 | 
			
		||||
        dataset_sample_num -= len(idx)
 | 
			
		||||
        if dataset_sample_num > 0:
 | 
			
		||||
            idx2 = random_state.choice(len(dataset), dataset_sample_num)
 | 
			
		||||
            idx = np.concatenate([idx, idx2], axis=0)
 | 
			
		||||
        dataset = dataset.select(idx)
 | 
			
		||||
 | 
			
		||||
    return align_dataset(dataset, dataset_attr, data_args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -44,6 +44,7 @@ class DatasetAttr:
 | 
			
		||||
    observation_tag: Optional[str] = "observation"
 | 
			
		||||
    function_tag: Optional[str] = "function_call"
 | 
			
		||||
    system_tag: Optional[str] = "system"
 | 
			
		||||
    sample_num: Optional[int] = None
 | 
			
		||||
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        return self.dataset_name
 | 
			
		||||
@ -90,7 +91,8 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
 | 
			
		||||
        dataset_attr.set_attr("folder", dataset_info[name])
 | 
			
		||||
        dataset_attr.set_attr("ranking", dataset_info[name], default=False)
 | 
			
		||||
        dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
 | 
			
		||||
 | 
			
		||||
        dataset_attr.set_attr("sample_num", dataset_info[name])
 | 
			
		||||
        
 | 
			
		||||
        if "columns" in dataset_info[name]:
 | 
			
		||||
            column_names = ["system"]
 | 
			
		||||
            if dataset_attr.formatting == "alpaca":
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user