From c87910ada348dce0b9e27b9b0838f0c8b4d59edc Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 5 Jul 2023 15:00:06 +0800 Subject: [PATCH] support falcon model #72 Former-commit-id: c136f362c1aa75d3374b151188ba4a55d9313a59 --- README.md | 5 ++++ src/export_model.py | 1 - src/utils/common.py | 62 +++++++++++++++++++++++++++------------------ src/utils/config.py | 5 ++-- src/utils/other.py | 2 +- 5 files changed, 46 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 0c3fcdcf..f45a312b 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,8 @@ ## Changelog +[23/07/05] Now we support training the Falcon-7B/40B models in this repo. Try `--model_name_or_path tiiuae/falcon-7b` and `--lora_target query_key_value` arguments to use the Falcon model. + [23/06/29] We provide a reproducible example of training a chat model using instruction-following datasets, see this [HuggingFace Repo](https://huggingface.co/hiyouga/baichuan-7b-sft) for details. [23/06/22] Now we align the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in arbitrary ChatGPT-based applications. @@ -23,6 +25,7 @@ - [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B) - [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B) +- [Falcon](https://huggingface.co/tiiuae/falcon-7b) (7B/40B) - [baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B) ## Supported Training Approaches @@ -283,6 +286,8 @@ Please follow the [Model Card](https://github.com/facebookresearch/llama/blob/ma Please follow the [RAIL License](https://huggingface.co/spaces/bigscience/license) to use the BLOOM & BLOOMZ models. +Please follow the [Apache-2.0 License](LICENSE) to use the Falcon models. + Please follow the [baichuan-7B License](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) to use the baichuan-7B model. ## Citation diff --git a/src/export_model.py b/src/export_model.py index e36d5c82..71985180 100644 --- a/src/export_model.py +++ b/src/export_model.py @@ -13,7 +13,6 @@ def main(): model.save_pretrained(training_args.output_dir, max_shard_size="10GB") tokenizer.save_pretrained(training_args.output_dir) print("model and tokenizer have been saved at:", training_args.output_dir) - print("Remember to copy the *.py files from the original directory.") if __name__ == "__main__": diff --git a/src/utils/common.py b/src/utils/common.py index b2448150..1086be35 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -183,6 +183,7 @@ def load_pretrained( load_in_8bit=True, llm_int8_threshold=6.0 ) + elif model_args.quantization_bit == 4: require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1") @@ -195,6 +196,7 @@ def load_pretrained( bnb_4bit_use_double_quant=model_args.double_quantization, bnb_4bit_quant_type=model_args.quantization_type ) + is_mergeable = False config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) @@ -211,10 +213,20 @@ def load_pretrained( model = AutoModelForCausalLM.from_pretrained( model_to_load, config=config, - torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16, + torch_dtype=model_args.compute_dtype, low_cpu_mem_usage=True, **config_kwargs ) + + # Register auto class to save the custom code files. + if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map: + config.__class__.register_for_auto_class() + if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map: + tokenizer.__class__.register_for_auto_class() + if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map: + model.__class__.register_for_auto_class() + + # Initialize adapters model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable) @@ -487,49 +499,49 @@ def preprocess_data( # for input with history, we build multiple input-label pairs just like: # https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112 model_inputs = {"input_ids": [], "labels": []} + max_length = data_args.max_source_length + data_args.max_target_length + for dialog in get_dialog(examples): input_ids, labels = [], [] for i in range(len(dialog) // 2): - source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=False) + source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=True) target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False) - if len(source_ids) > data_args.max_source_length - 1: # bos token - source_ids = source_ids[:data_args.max_source_length - 1] + if len(source_ids) > data_args.max_source_length: + source_ids = source_ids[:data_args.max_source_length] if len(target_ids) > data_args.max_target_length - 1: # eos token target_ids = target_ids[:data_args.max_target_length - 1] - input_ids += [tokenizer.bos_token_id] + source_ids + target_ids + [tokenizer.eos_token_id] - labels += [IGNORE_INDEX] * (len(source_ids) + 1) + target_ids + [tokenizer.eos_token_id] + if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length: + break - if len(input_ids) > data_args.max_source_length + data_args.max_target_length: - input_ids = input_ids[:data_args.max_source_length + data_args.max_target_length] - if len(labels) > data_args.max_source_length + data_args.max_target_length: - labels = labels[:data_args.max_source_length + data_args.max_target_length] + input_ids += source_ids + target_ids + [tokenizer.eos_token_id] + labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id] model_inputs["input_ids"].append(input_ids) model_inputs["labels"].append(labels) + return model_inputs def preprocess_unsupervised_dataset(examples): # build inputs with format ` X` and labels with format ` Y` model_inputs = {"input_ids": [], "labels": []} + for dialog in get_dialog(examples): prompt, answer = "".join(dialog[:-1]), dialog[-1] - source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) - target_ids = tokenizer.encode(text=answer, add_special_tokens=False) + source_ids = tokenizer.encode(text=prompt, add_special_tokens=True) + target_ids = tokenizer.encode(text=answer, add_special_tokens=True) - if len(source_ids) > data_args.max_source_length - 1: # bos token - source_ids = source_ids[:data_args.max_source_length - 1] - if len(target_ids) > data_args.max_target_length - 1: # bos token - target_ids = target_ids[:data_args.max_target_length - 1] + if len(source_ids) > data_args.max_source_length: + source_ids = source_ids[:data_args.max_source_length] + if len(target_ids) > data_args.max_target_length: + target_ids = target_ids[:data_args.max_target_length] - input_ids = [tokenizer.bos_token_id] + source_ids - labels = [tokenizer.bos_token_id] + target_ids + model_inputs["input_ids"].append(source_ids) + model_inputs["labels"].append(target_ids) - model_inputs["input_ids"].append(input_ids) - model_inputs["labels"].append(labels) return model_inputs def preprocess_pairwise_dataset(examples): @@ -538,19 +550,19 @@ def preprocess_data( for dialog in get_dialog(examples): prompt, answer = "".join(dialog[:-1]), dialog[-1] - source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) + source_ids = tokenizer.encode(text=prompt, add_special_tokens=True) accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False) reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False) - if len(source_ids) > data_args.max_source_length - 1: # bos token - source_ids = source_ids[:data_args.max_source_length - 1] + if len(source_ids) > data_args.max_source_length: + source_ids = source_ids[:data_args.max_source_length] if len(accept_ids) > data_args.max_target_length - 1: # eos token accept_ids = accept_ids[:data_args.max_target_length - 1] if len(reject_ids) > data_args.max_target_length - 1: # eos token reject_ids = reject_ids[:data_args.max_target_length - 1] - accept_ids = [tokenizer.bos_token_id] + source_ids + accept_ids + [tokenizer.eos_token_id] - reject_ids = [tokenizer.bos_token_id] + source_ids + reject_ids + [tokenizer.eos_token_id] + accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id] + reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id] model_inputs["accept_ids"].append(accept_ids) model_inputs["reject_ids"].append(reject_ids) diff --git a/src/utils/config.py b/src/utils/config.py index c07066c1..f0e63d8e 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -198,6 +198,7 @@ class FinetuningArguments: metadata={"help": "Number of decoder blocks in the model. \ LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \ BLOOM choices: [\"24\", \"30\", \"70\"], \ + Falcon choices: [\"32\", \"60\"], \ Baichuan choices: [\"32\"]"} ) num_layer_trainable: Optional[int] = field( @@ -208,7 +209,7 @@ class FinetuningArguments: default="mlp", metadata={"help": "Name of trainable modules for Freeze fine-tuning. \ LLaMA choices: [\"mlp\", \"self_attn\"], \ - BLOOM choices: [\"mlp\", \"self_attention\"], \ + BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \ Baichuan choices: [\"mlp\", \"self_attn\"]"} ) lora_rank: Optional[int] = field( @@ -227,7 +228,7 @@ class FinetuningArguments: default="q_proj,v_proj", metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \ LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ - BLOOM choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \ + BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \ Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"} ) diff --git a/src/utils/other.py b/src/utils/other.py index 21a56ea2..ce780ab9 100644 --- a/src/utils/other.py +++ b/src/utils/other.py @@ -74,7 +74,7 @@ def prepare_model_for_training( finetuning_type: str, output_embedding_layer_name: Optional[str] = "lm_head", use_gradient_checkpointing: Optional[bool] = True, - layer_norm_names: Optional[List[str]] = ["norm", "ln_f"] # for LLaMA and BLOOM setting + layer_norm_names: Optional[List[str]] = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings ) -> PreTrainedModel: for name, param in model.named_parameters():