From ef6c5ae18a581e82d1023b3aa067c8ffb5047d97 Mon Sep 17 00:00:00 2001 From: BUAADreamer <1428195643@qq.com> Date: Sat, 10 Jun 2023 15:53:47 +0800 Subject: [PATCH] add code for reading from multi files in one directory Former-commit-id: b7ebb83a96619e5111b0faa9da9d0feb8d9cdff0 --- src/utils/common.py | 70 +++++++++++++++++++++++---------------------- src/utils/config.py | 60 +++++++++++++++++++++++++------------- 2 files changed, 76 insertions(+), 54 deletions(-) diff --git a/src/utils/common.py b/src/utils/common.py index 2f77cbac..5af5e8cc 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -56,7 +56,6 @@ require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0") require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0") require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1") - logger = get_logger(__name__) @@ -92,10 +91,12 @@ def _init_adapter( if model_args.checkpoint_dir is not None: if finetuning_args.finetuning_type != "lora": - assert is_mergeable and len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." - load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods + assert is_mergeable and len( + model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." + load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods else: - assert is_mergeable or len(model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint." + assert is_mergeable or len( + model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint." if finetuning_args.finetuning_type == "lora": logger.info("Fine-tuning method: LoRA") @@ -105,7 +106,8 @@ def _init_adapter( assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \ "The given checkpoint is not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead." - if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights + if (is_trainable and model_args.resume_lora_training) or ( + not is_mergeable): # continually train on the lora weights checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] else: checkpoints_to_merge = model_args.checkpoint_dir @@ -117,10 +119,10 @@ def _init_adapter( if len(checkpoints_to_merge) > 0: logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge))) - if lastest_checkpoint is not None: # resume lora training or quantized inference + if lastest_checkpoint is not None: # resume lora training or quantized inference model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=is_trainable) - if is_trainable and lastest_checkpoint is None: # create new lora weights while training + if is_trainable and lastest_checkpoint is None: # create new lora weights while training lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, @@ -168,7 +170,7 @@ def load_pretrained( padding_side="left", **config_kwargs ) - tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the token + tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the token config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) is_mergeable = True @@ -184,9 +186,11 @@ def load_pretrained( ) elif model_args.quantization_bit == 4: require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") - require_version("transformers>=4.30.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git") + require_version("transformers>=4.30.0.dev0", + "To fix: pip install git+https://github.com/huggingface/transformers.git") require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git") - require_version("accelerate>=0.20.0.dev0", "To fix: pip install git+https://github.com/huggingface/accelerate.git") + require_version("accelerate>=0.20.0.dev0", + "To fix: pip install git+https://github.com/huggingface/accelerate.git") config_kwargs["load_in_4bit"] = True config_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, @@ -214,10 +218,10 @@ def load_pretrained( model = prepare_model_for_training(model) if is_trainable else model model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable) - if stage == "rm" or stage == "ppo": # add value head + if stage == "rm" or stage == "ppo": # add value head model = AutoModelForCausalLMWithValueHead.from_pretrained(model) - if stage == "ppo": # load reward model + if stage == "ppo": # load reward model assert is_trainable, "PPO stage cannot be performed at evaluation." assert model_args.reward_model is not None, "Reward model is necessary for PPO training." logger.info("Load reward model from {}".format(model_args.reward_model)) @@ -230,8 +234,8 @@ def load_pretrained( model._is_int8_training_enabled = True if not is_trainable: - model.requires_grad_(False) # fix all model params - model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16 + model.requires_grad_(False) # fix all model params + model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16 print_trainable_params(model) @@ -241,11 +245,11 @@ def load_pretrained( def prepare_args( stage: Literal["pt", "sft", "rm", "ppo"] ) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]: - parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments)) - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. - model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. + model_args, data_args, training_args, finetuning_args = parser.parse_json_file( + json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses() @@ -286,7 +290,7 @@ def prepare_args( logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.") training_args.ddp_find_unused_parameters = False - training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning + training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning if model_args.quantization_bit is not None: if training_args.fp16: @@ -310,10 +314,9 @@ def prepare_args( def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments]: - parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments)) - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. model_args, data_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, finetuning_args = parser.parse_args_into_dataclasses() @@ -331,7 +334,6 @@ def prepare_data( model_args: ModelArguments, data_args: DataTrainingArguments ) -> Dataset: - def checksum(file_path, hash): with open(file_path, "rb") as datafile: binary_data = datafile.read() @@ -340,7 +342,7 @@ def prepare_data( logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path)) max_samples = data_args.max_samples - all_datasets: List[Dataset] = [] # support multiple datasets + all_datasets: List[Dataset] = [] # support multiple datasets for dataset_attr in data_args.dataset_list: @@ -361,7 +363,7 @@ def prepare_data( checksum(data_file, dataset_attr.file_sha1) else: logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.") - + print(extension) raw_datasets = load_dataset( extension if extension in ["csv", "json"] else "text", data_files=data_file, @@ -383,11 +385,11 @@ def prepare_data( ("query_column", "query"), ("response_column", "response"), ("history_column", "history") - ]: # every dataset will have 4 columns same as each other + ]: # every dataset will have 4 columns same as each other if getattr(dataset_attr, column_name) != target_name: if getattr(dataset_attr, column_name): dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name) - else: # None or empty string + else: # None or empty string dataset = dataset.add_column(target_name, dummy_data) all_datasets.append(dataset) @@ -406,7 +408,6 @@ def preprocess_data( training_args: Seq2SeqTrainingArguments, stage: Literal["pt", "sft", "rm", "ppo"] ) -> Dataset: - column_names = list(dataset.column_names) prefix = data_args.source_prefix if data_args.source_prefix is not None else "" prompt_template = Template(data_args.prompt_template) @@ -429,7 +430,8 @@ def preprocess_data( # we drop the small remainder, and if the total_length < block_size, we exclude this batch total_length = (total_length // data_args.max_source_length) * data_args.max_source_length # split by chunks of max_source_length - result = [concatenated_ids[i: i+data_args.max_source_length] for i in range(0, total_length, data_args.max_source_length)] + result = [concatenated_ids[i: i + data_args.max_source_length] for i in + range(0, total_length, data_args.max_source_length)] return { "input_ids": result, "labels": result.copy() @@ -442,9 +444,9 @@ def preprocess_data( source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) target_ids = tokenizer.encode(text=answer, add_special_tokens=False) - if len(source_ids) > data_args.max_source_length - 1: # bos token + if len(source_ids) > data_args.max_source_length - 1: # bos token source_ids = source_ids[:data_args.max_source_length - 1] - if len(target_ids) > data_args.max_target_length - 1: # eos token + if len(target_ids) > data_args.max_target_length - 1: # eos token target_ids = target_ids[:data_args.max_target_length - 1] input_ids = source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id] @@ -461,9 +463,9 @@ def preprocess_data( source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) target_ids = tokenizer.encode(text=answer, add_special_tokens=False) - if len(source_ids) > data_args.max_source_length - 1: # bos token + if len(source_ids) > data_args.max_source_length - 1: # bos token source_ids = source_ids[:data_args.max_source_length - 1] - if len(target_ids) > data_args.max_target_length - 1: # bos token + if len(target_ids) > data_args.max_target_length - 1: # bos token target_ids = target_ids[:data_args.max_target_length - 1] input_ids = source_ids + [tokenizer.bos_token_id] @@ -481,11 +483,11 @@ def preprocess_data( accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False) reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False) - if len(source_ids) > data_args.max_source_length - 1: # bos token + if len(source_ids) > data_args.max_source_length - 1: # bos token source_ids = source_ids[:data_args.max_source_length - 1] - if len(accept_ids) > data_args.max_target_length - 1: # eos token + if len(accept_ids) > data_args.max_target_length - 1: # eos token accept_ids = accept_ids[:data_args.max_target_length - 1] - if len(reject_ids) > data_args.max_target_length - 1: # eos token + if len(reject_ids) > data_args.max_target_length - 1: # eos token reject_ids = reject_ids[:data_args.max_target_length - 1] accept_ids = source_ids + [tokenizer.bos_token_id] + accept_ids + [tokenizer.eos_token_id] diff --git a/src/utils/config.py b/src/utils/config.py index ccf3c736..7a75aa4f 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -7,7 +7,6 @@ from dataclasses import asdict, dataclass, field @dataclass class DatasetAttr: - load_from: str dataset_name: Optional[str] = None file_name: Optional[str] = None @@ -68,7 +67,8 @@ class ModelArguments: ) checkpoint_dir: Optional[str] = field( default=None, - metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."} + metadata={ + "help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."} ) reward_model: Optional[str] = field( default=None, @@ -76,7 +76,8 @@ class ModelArguments: ) resume_lora_training: Optional[bool] = field( default=True, - metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} + metadata={ + "help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} ) plot_loss: Optional[bool] = field( default=False, @@ -84,7 +85,7 @@ class ModelArguments: ) def __post_init__(self): - if self.checkpoint_dir is not None: # support merging multiple lora weights + if self.checkpoint_dir is not None: # support merging multiple lora weights self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] @@ -146,7 +147,7 @@ class DataTrainingArguments: metadata={"help": "Which template to use for constructing prompts in training and inference."} ) - def __post_init__(self): # support mixing multiple datasets + def __post_init__(self): # support mixing multiple datasets dataset_names = [ds.strip() for ds in self.dataset.split(",")] with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: dataset_info = json.load(f) @@ -155,25 +156,42 @@ class DataTrainingArguments: for name in dataset_names: if name not in dataset_info: raise ValueError("Undefined dataset {} in dataset_info.json.".format(name)) - + dataset_attrs = [] + dataset_attr = None if "hf_hub_url" in dataset_info[name]: dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) elif "script_url" in dataset_info[name]: dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) - else: + elif os.path.isfile(os.path.join(self.dataset_dir, dataset_info[name]["file_name"])): dataset_attr = DatasetAttr( "file", file_name=dataset_info[name]["file_name"], file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None ) - - if "columns" in dataset_info[name]: - dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None) - dataset_attr.query_column = dataset_info[name]["columns"].get("query", None) - dataset_attr.response_column = dataset_info[name]["columns"].get("response", None) - dataset_attr.history_column = dataset_info[name]["columns"].get("history", None) - - self.dataset_list.append(dataset_attr) + else: + # Support Directory + for file_name in os.listdir(os.path.join(self.dataset_dir, dataset_info[name]["file_name"])): + path = os.path.join(dataset_info[name]["file_name"], file_name) + dataset_attrs.append(DatasetAttr( + "file", + file_name=path, + file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None + )) + if dataset_attr is not None: + if "columns" in dataset_info[name]: + dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None) + dataset_attr.query_column = dataset_info[name]["columns"].get("query", None) + dataset_attr.response_column = dataset_info[name]["columns"].get("response", None) + dataset_attr.history_column = dataset_info[name]["columns"].get("history", None) + self.dataset_list.append(dataset_attr) + else: + for i, dataset_attr in enumerate(dataset_attrs): + if "columns" in dataset_info[name]: + dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None) + dataset_attr.query_column = dataset_info[name]["columns"].get("query", None) + dataset_attr.response_column = dataset_info[name]["columns"].get("response", None) + dataset_attr.history_column = dataset_info[name]["columns"].get("history", None) + self.dataset_list.append(dataset_attr) @dataclass @@ -216,14 +234,16 @@ class FinetuningArguments: def __post_init__(self): if isinstance(self.lora_target, str): - self.lora_target = [target.strip() for target in self.lora_target.split(",")] # support custom target modules of LoRA + self.lora_target = [target.strip() for target in + self.lora_target.split(",")] # support custom target modules of LoRA - if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 - trainable_layer_ids = [27-k for k in range(self.num_layer_trainable)] - else: # fine-tuning the first n layers if num_layer_trainable < 0 + if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 + trainable_layer_ids = [27 - k for k in range(self.num_layer_trainable)] + else: # fine-tuning the first n layers if num_layer_trainable < 0 trainable_layer_ids = [k for k in range(-self.num_layer_trainable)] - self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids] + self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in + trainable_layer_ids] assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."