add code for reading from multi files in one directory

Former-commit-id: b7ebb83a96619e5111b0faa9da9d0feb8d9cdff0
This commit is contained in:
BUAADreamer 2023-06-10 15:53:47 +08:00
parent 03c92c79ff
commit ef6c5ae18a
2 changed files with 76 additions and 54 deletions

View File

@ -56,7 +56,6 @@ require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0") require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1") require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1")
logger = get_logger(__name__) logger = get_logger(__name__)
@ -92,10 +91,12 @@ def _init_adapter(
if model_args.checkpoint_dir is not None: if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora": if finetuning_args.finetuning_type != "lora":
assert is_mergeable and len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." assert is_mergeable and len(
model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods
else: else:
assert is_mergeable or len(model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint." assert is_mergeable or len(
model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint."
if finetuning_args.finetuning_type == "lora": if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA") logger.info("Fine-tuning method: LoRA")
@ -105,7 +106,8 @@ def _init_adapter(
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \ assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
"The given checkpoint is not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead." "The given checkpoint is not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights if (is_trainable and model_args.resume_lora_training) or (
not is_mergeable): # continually train on the lora weights
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
else: else:
checkpoints_to_merge = model_args.checkpoint_dir checkpoints_to_merge = model_args.checkpoint_dir
@ -184,9 +186,11 @@ def load_pretrained(
) )
elif model_args.quantization_bit == 4: elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
require_version("transformers>=4.30.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git") require_version("transformers>=4.30.0.dev0",
"To fix: pip install git+https://github.com/huggingface/transformers.git")
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git") require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
require_version("accelerate>=0.20.0.dev0", "To fix: pip install git+https://github.com/huggingface/accelerate.git") require_version("accelerate>=0.20.0.dev0",
"To fix: pip install git+https://github.com/huggingface/accelerate.git")
config_kwargs["load_in_4bit"] = True config_kwargs["load_in_4bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig( config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True, load_in_4bit=True,
@ -241,11 +245,11 @@ def load_pretrained(
def prepare_args( def prepare_args(
stage: Literal["pt", "sft", "rm", "ppo"] stage: Literal["pt", "sft", "rm", "ppo"]
) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]: ) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments)) parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) model_args, data_args, training_args, finetuning_args = parser.parse_json_file(
json_file=os.path.abspath(sys.argv[1]))
else: else:
model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses() model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()
@ -310,7 +314,6 @@ def prepare_args(
def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments]: def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments]:
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments)) parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
@ -331,7 +334,6 @@ def prepare_data(
model_args: ModelArguments, model_args: ModelArguments,
data_args: DataTrainingArguments data_args: DataTrainingArguments
) -> Dataset: ) -> Dataset:
def checksum(file_path, hash): def checksum(file_path, hash):
with open(file_path, "rb") as datafile: with open(file_path, "rb") as datafile:
binary_data = datafile.read() binary_data = datafile.read()
@ -361,7 +363,7 @@ def prepare_data(
checksum(data_file, dataset_attr.file_sha1) checksum(data_file, dataset_attr.file_sha1)
else: else:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.") logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
print(extension)
raw_datasets = load_dataset( raw_datasets = load_dataset(
extension if extension in ["csv", "json"] else "text", extension if extension in ["csv", "json"] else "text",
data_files=data_file, data_files=data_file,
@ -406,7 +408,6 @@ def preprocess_data(
training_args: Seq2SeqTrainingArguments, training_args: Seq2SeqTrainingArguments,
stage: Literal["pt", "sft", "rm", "ppo"] stage: Literal["pt", "sft", "rm", "ppo"]
) -> Dataset: ) -> Dataset:
column_names = list(dataset.column_names) column_names = list(dataset.column_names)
prefix = data_args.source_prefix if data_args.source_prefix is not None else "" prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
prompt_template = Template(data_args.prompt_template) prompt_template = Template(data_args.prompt_template)
@ -429,7 +430,8 @@ def preprocess_data(
# we drop the small remainder, and if the total_length < block_size, we exclude this batch # we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // data_args.max_source_length) * data_args.max_source_length total_length = (total_length // data_args.max_source_length) * data_args.max_source_length
# split by chunks of max_source_length # split by chunks of max_source_length
result = [concatenated_ids[i: i+data_args.max_source_length] for i in range(0, total_length, data_args.max_source_length)] result = [concatenated_ids[i: i + data_args.max_source_length] for i in
range(0, total_length, data_args.max_source_length)]
return { return {
"input_ids": result, "input_ids": result,
"labels": result.copy() "labels": result.copy()

View File

@ -7,7 +7,6 @@ from dataclasses import asdict, dataclass, field
@dataclass @dataclass
class DatasetAttr: class DatasetAttr:
load_from: str load_from: str
dataset_name: Optional[str] = None dataset_name: Optional[str] = None
file_name: Optional[str] = None file_name: Optional[str] = None
@ -68,7 +67,8 @@ class ModelArguments:
) )
checkpoint_dir: Optional[str] = field( checkpoint_dir: Optional[str] = field(
default=None, default=None,
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."} metadata={
"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
) )
reward_model: Optional[str] = field( reward_model: Optional[str] = field(
default=None, default=None,
@ -76,7 +76,8 @@ class ModelArguments:
) )
resume_lora_training: Optional[bool] = field( resume_lora_training: Optional[bool] = field(
default=True, default=True,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} metadata={
"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
) )
plot_loss: Optional[bool] = field( plot_loss: Optional[bool] = field(
default=False, default=False,
@ -155,24 +156,41 @@ class DataTrainingArguments:
for name in dataset_names: for name in dataset_names:
if name not in dataset_info: if name not in dataset_info:
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name)) raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
dataset_attrs = []
dataset_attr = None
if "hf_hub_url" in dataset_info[name]: if "hf_hub_url" in dataset_info[name]:
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
elif "script_url" in dataset_info[name]: elif "script_url" in dataset_info[name]:
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
else: elif os.path.isfile(os.path.join(self.dataset_dir, dataset_info[name]["file_name"])):
dataset_attr = DatasetAttr( dataset_attr = DatasetAttr(
"file", "file",
file_name=dataset_info[name]["file_name"], file_name=dataset_info[name]["file_name"],
file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None
) )
else:
# Support Directory
for file_name in os.listdir(os.path.join(self.dataset_dir, dataset_info[name]["file_name"])):
path = os.path.join(dataset_info[name]["file_name"], file_name)
dataset_attrs.append(DatasetAttr(
"file",
file_name=path,
file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None
))
if dataset_attr is not None:
if "columns" in dataset_info[name]:
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
self.dataset_list.append(dataset_attr)
else:
for i, dataset_attr in enumerate(dataset_attrs):
if "columns" in dataset_info[name]: if "columns" in dataset_info[name]:
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None) dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None) dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None) dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None) dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
self.dataset_list.append(dataset_attr) self.dataset_list.append(dataset_attr)
@ -216,14 +234,16 @@ class FinetuningArguments:
def __post_init__(self): def __post_init__(self):
if isinstance(self.lora_target, str): if isinstance(self.lora_target, str):
self.lora_target = [target.strip() for target in self.lora_target.split(",")] # support custom target modules of LoRA self.lora_target = [target.strip() for target in
self.lora_target.split(",")] # support custom target modules of LoRA
if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
trainable_layer_ids = [27 - k for k in range(self.num_layer_trainable)] trainable_layer_ids = [27 - k for k in range(self.num_layer_trainable)]
else: # fine-tuning the first n layers if num_layer_trainable < 0 else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)] trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids] self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in
trainable_layer_ids]
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method." assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."