diff --git a/assets/wechat.jpg b/assets/wechat.jpg index e054274e..37ca0ab9 100644 Binary files a/assets/wechat.jpg and b/assets/wechat.jpg differ diff --git a/data/comparison_gpt4_data_en.json.REMOVED.git-id b/data/comparison_gpt4_data_en.json.REMOVED.git-id index d7c6f987..884ac974 100644 --- a/data/comparison_gpt4_data_en.json.REMOVED.git-id +++ b/data/comparison_gpt4_data_en.json.REMOVED.git-id @@ -1 +1 @@ -f437d58b7791609ee91f064551c5c5734a0fd97a \ No newline at end of file +f5cb08305ff5dc9c17a09809c54c8c8834aadc70 \ No newline at end of file diff --git a/data/comparison_gpt4_data_zh.json.REMOVED.git-id b/data/comparison_gpt4_data_zh.json.REMOVED.git-id index 6007329e..dbc830e7 100644 --- a/data/comparison_gpt4_data_zh.json.REMOVED.git-id +++ b/data/comparison_gpt4_data_zh.json.REMOVED.git-id @@ -1 +1 @@ -0e346cf70e633456c7e83f68765361016005447a \ No newline at end of file +aee47b7b443496e37808d7f34ef10403ff99bcc3 \ No newline at end of file diff --git a/data/dataset_info.json b/data/dataset_info.json index 302d607e..a00d2021 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -79,11 +79,11 @@ }, "comparison_gpt4_en": { "file_name": "comparison_gpt4_data_en.json", - "file_sha1": "eeb295ce0ab011c37af52596460c8a57d07ad19f" + "file_sha1": "96fa18313544e22444fe20eead7754b17da452ae" }, "comparison_gpt4_zh": { "file_name": "comparison_gpt4_data_zh.json", - "file_sha1": "b99a41c1c864019d9b0c07dbcd5df0560cf33ce0" + "file_sha1": "515b18ed497199131ddcc1af950345c11dc5c7fd" }, "hh_rlhf_en": { "script_url": "hh_rlhf_en", @@ -103,14 +103,5 @@ "response": "", "history": "" } - }, - "pretrain_data": { - "file_name": "pretrain_data", - "columns": { - "prompt": "content", - "query": "", - "response": "", - "history": "" - } } } diff --git a/requirements.txt b/requirements.txt index 0afa59ca..4d583e8d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,11 +2,11 @@ torch>=1.13.1 protobuf cpm_kernels sentencepiece -transformers>=4.27.4 -datasets>=2.10.0 -accelerate>=0.18.0 +transformers>=4.29.1 +datasets>=2.12.0 +accelerate>=0.19.0 peft>=0.3.0 -trl>=0.4.1 +trl>=0.4.4 jieba rouge_chinese nltk diff --git a/src/api_demo.py b/src/api_demo.py index 81763cda..0797da84 100644 --- a/src/api_demo.py +++ b/src/api_demo.py @@ -42,7 +42,7 @@ app = FastAPI() @app.post("/") async def create_item(request: Request): - global model, tokenizer, prompt_template + global model, tokenizer, prompt_template, generating_args # Parse the request JSON json_post_raw = await request.json() @@ -56,16 +56,9 @@ async def create_item(request: Request): input_ids = input_ids.to(model.device) # Generation arguments - gen_kwargs = { - "input_ids": input_ids, - "do_sample": True, - "top_p": 0.7, - "temperature": 0.95, - "num_beams": 1, - "max_new_tokens": 512, - "repetition_penalty": 1.0, - "logits_processor": get_logits_processor() - } + gen_kwargs = generating_args.to_dict() + gen_kwargs["input_ids"] = input_ids + gen_kwargs["logits_processor"] = get_logits_processor() # Generate response with torch.no_grad(): @@ -95,7 +88,7 @@ async def create_item(request: Request): if __name__ == "__main__": - model_args, data_args, finetuning_args = prepare_infer_args() + model_args, data_args, finetuning_args, generating_args = prepare_infer_args() model, tokenizer = load_pretrained(model_args, finetuning_args) prompt_template = Template(data_args.prompt_template) diff --git a/src/cli_demo.py b/src/cli_demo.py index 31299c30..9bcfec53 100644 --- a/src/cli_demo.py +++ b/src/cli_demo.py @@ -15,7 +15,7 @@ from transformers import TextIteratorStreamer def main(): - model_args, data_args, finetuning_args = prepare_infer_args() + model_args, data_args, finetuning_args, generating_args = prepare_infer_args() model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA" model, tokenizer = load_pretrained(model_args, finetuning_args) @@ -25,17 +25,10 @@ def main(): def predict_and_print(query, history: list): input_ids = tokenizer([prompt_template.get_prompt(query, history)], return_tensors="pt")["input_ids"] input_ids = input_ids.to(model.device) - gen_kwargs = { - "input_ids": input_ids, - "do_sample": True, - "top_p": 0.7, - "temperature": 0.95, - "num_beams": 1, - "max_new_tokens": 512, - "repetition_penalty": 1.0, - "logits_processor": get_logits_processor(), - "streamer": streamer - } + gen_kwargs = generating_args.to_dict() + gen_kwargs["input_ids"] = input_ids + gen_kwargs["logits_processor"] = get_logits_processor() + gen_kwargs["streamer"] = streamer thread = Thread(target=model.generate, kwargs=gen_kwargs) thread.start() response = "" diff --git a/src/train_ppo.py b/src/train_ppo.py index 9dbe9c0e..e3d2d403 100644 --- a/src/train_ppo.py +++ b/src/train_ppo.py @@ -6,18 +6,17 @@ import math from torch.optim import AdamW - from transformers.optimization import get_scheduler from trl import PPOConfig from utils import ( - prepare_args, - prepare_data, - load_pretrained, - preprocess_data, DynamicDataCollatorWithPadding, PPOPeftTrainer, LogCallback, + load_pretrained, + prepare_args, + prepare_data, + preprocess_data, plot_loss ) @@ -29,7 +28,7 @@ def main(): dataset = prepare_data(model_args, data_args) model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="ppo") dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="ppo") - data_collator = DynamicDataCollatorWithPadding(tokenizer, model.pretrained_model) + data_collator = DynamicDataCollatorWithPadding(tokenizer) ppo_config = PPOConfig( model_name=model_args.model_name_or_path, diff --git a/src/train_pt.py b/src/train_pt.py index 7f5a4779..6fedf931 100644 --- a/src/train_pt.py +++ b/src/train_pt.py @@ -5,14 +5,15 @@ import math + from utils import ( + DynamicDataCollatorWithPadding, + PeftTrainer, + LogCallback, load_pretrained, prepare_args, prepare_data, preprocess_data, - DynamicDataCollatorWithPadding, - PeftTrainer, - LogCallback, plot_loss ) @@ -24,7 +25,7 @@ def main(): dataset = prepare_data(model_args, data_args) model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="pt") dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="pt") - data_collator = DynamicDataCollatorWithPadding(tokenizer, model, data_args.ignore_pad_token_for_loss) + data_collator = DynamicDataCollatorWithPadding(tokenizer, data_args.ignore_pad_token_for_loss) # Split the dataset if training_args.do_train: diff --git a/src/train_rm.py b/src/train_rm.py index 11aec993..117aa13d 100644 --- a/src/train_rm.py +++ b/src/train_rm.py @@ -6,13 +6,14 @@ from utils import ( - prepare_args, - prepare_data, - load_pretrained, - preprocess_data, PairwiseDataCollatorWithPadding, PairwisePeftTrainer, LogCallback, + load_pretrained, + prepare_args, + prepare_data, + preprocess_data, + compute_accuracy, plot_loss ) @@ -23,7 +24,7 @@ def main(): dataset = prepare_data(model_args, data_args) model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="rm") dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rm") - data_collator = PairwiseDataCollatorWithPadding(tokenizer, model.pretrained_model) + data_collator = PairwiseDataCollatorWithPadding(tokenizer) training_args.remove_unused_columns = False # important for pairwise dataset @@ -45,6 +46,7 @@ def main(): tokenizer=tokenizer, data_collator=data_collator, callbacks=[LogCallback()], + compute_metrics=compute_accuracy, **trainer_kwargs ) diff --git a/src/train_sft.py b/src/train_sft.py index bad98c8a..971fbb19 100644 --- a/src/train_sft.py +++ b/src/train_sft.py @@ -5,14 +5,14 @@ from utils import ( - load_pretrained, - prepare_args, - prepare_data, - preprocess_data, DynamicDataCollatorWithPadding, Seq2SeqPeftTrainer, ComputeMetrics, LogCallback, + load_pretrained, + prepare_args, + prepare_data, + preprocess_data, get_logits_processor, plot_loss ) @@ -25,7 +25,7 @@ def main(): dataset = prepare_data(model_args, data_args) model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft") dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft") - data_collator = DynamicDataCollatorWithPadding(tokenizer, model, data_args.ignore_pad_token_for_loss) + data_collator = DynamicDataCollatorWithPadding(tokenizer, data_args.ignore_pad_token_for_loss) # Override the decoding parameters of Seq2SeqTrainer training_args.generation_max_length = training_args.generation_max_length if \ diff --git a/src/utils/__init__.py b/src/utils/__init__.py index f2db999d..3c08e3d2 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -11,7 +11,7 @@ from .data_collator import DynamicDataCollatorWithPadding from .peft_trainer import PeftTrainer, LogCallback from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer -from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer +from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer, compute_accuracy from .ppo import PPOPeftTrainer from .template import Template diff --git a/src/utils/common.py b/src/utils/common.py index 71a7ebe7..fb515b9a 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -36,7 +36,8 @@ from trl import AutoModelForCausalLMWithValueHead from .config import ( ModelArguments, DataTrainingArguments, - FinetuningArguments + FinetuningArguments, + GeneratingArguments ) from .template import Template @@ -54,7 +55,8 @@ check_min_version("4.29.1") require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0") require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0") -require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1") +require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4") + logger = get_logger(__name__) @@ -91,12 +93,10 @@ def _init_adapter( if model_args.checkpoint_dir is not None: if finetuning_args.finetuning_type != "lora": - assert is_mergeable and len( - model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." - load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods + assert is_mergeable and len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." + load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods else: - assert is_mergeable or len( - model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint." + assert is_mergeable or len(model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint." if finetuning_args.finetuning_type == "lora": logger.info("Fine-tuning method: LoRA") @@ -106,8 +106,7 @@ def _init_adapter( assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \ "The given checkpoint is not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead." - if (is_trainable and model_args.resume_lora_training) or ( - not is_mergeable): # continually train on the lora weights + if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] else: checkpoints_to_merge = model_args.checkpoint_dir @@ -119,10 +118,10 @@ def _init_adapter( if len(checkpoints_to_merge) > 0: logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge))) - if lastest_checkpoint is not None: # resume lora training or quantized inference + if lastest_checkpoint is not None: # resume lora training or quantized inference model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=is_trainable) - if is_trainable and lastest_checkpoint is None: # create new lora weights while training + if is_trainable and lastest_checkpoint is None: # create new lora weights while training lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, @@ -170,7 +169,7 @@ def load_pretrained( padding_side="left", **config_kwargs ) - tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the token + tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the token config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) is_mergeable = True @@ -186,11 +185,9 @@ def load_pretrained( ) elif model_args.quantization_bit == 4: require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") - require_version("transformers>=4.30.0.dev0", - "To fix: pip install git+https://github.com/huggingface/transformers.git") + require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1") + require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3") require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git") - require_version("accelerate>=0.20.0.dev0", - "To fix: pip install git+https://github.com/huggingface/accelerate.git") config_kwargs["load_in_4bit"] = True config_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, @@ -201,10 +198,10 @@ def load_pretrained( else: raise NotImplementedError is_mergeable = False - config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK") or 0)} + config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) - if not is_trainable: + if not is_trainable: # `device_map=auto` should be used for inference only config_kwargs["device_map"] = "auto" # Load and prepare pretrained models (without valuehead). @@ -218,24 +215,26 @@ def load_pretrained( model = prepare_model_for_training(model) if is_trainable else model model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable) - if stage == "rm" or stage == "ppo": # add value head + if stage == "rm" or stage == "ppo": # add value head model = AutoModelForCausalLMWithValueHead.from_pretrained(model) - if stage == "ppo": # load reward model + if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model + load_valuehead_params(model, model_args.checkpoint_dir[0]) + model.v_head.load_state_dict({ + "summary.weight": getattr(model, "reward_head_weight"), + "summary.bias": getattr(model, "reward_head_bias") + }) + + if stage == "ppo": # load reward model assert is_trainable, "PPO stage cannot be performed at evaluation." assert model_args.reward_model is not None, "Reward model is necessary for PPO training." logger.info("Load reward model from {}".format(model_args.reward_model)) model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False) load_valuehead_params(model, model_args.reward_model) - # Set the parameter _is_int8_training_enabled for the AutoModelForCausalLMWithValueHead model - # To meet the compliance requirements of the transformers library - if model_args.quantization_bit is not None: - model._is_int8_training_enabled = True - if not is_trainable: - model.requires_grad_(False) # fix all model params - model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16 + model.requires_grad_(False) # fix all model params + model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16 print_trainable_params(model) @@ -245,11 +244,11 @@ def load_pretrained( def prepare_args( stage: Literal["pt", "sft", "rm", "ppo"] ) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]: + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments)) - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. - model_args, data_args, training_args, finetuning_args = parser.parse_json_file( - json_file=os.path.abspath(sys.argv[1])) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. + model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses() @@ -290,7 +289,7 @@ def prepare_args( logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.") training_args.ddp_find_unused_parameters = False - training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning + training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning if model_args.quantization_bit is not None: if training_args.fp16: @@ -313,13 +312,14 @@ def prepare_args( return model_args, data_args, training_args, finetuning_args -def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments]: - parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments)) +def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments]: - if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. - model_args, data_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments)) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. + model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: - model_args, data_args, finetuning_args = parser.parse_args_into_dataclasses() + model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses() if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": raise ValueError("Quantization is only compatible with the LoRA method.") @@ -327,13 +327,14 @@ def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, Finetun if data_args.prompt_template == "alpaca": logger.warning("Please specify `prompt_template` if you are using other pre-trained models.") - return model_args, data_args, finetuning_args + return model_args, data_args, finetuning_args, generating_args def prepare_data( model_args: ModelArguments, data_args: DataTrainingArguments ) -> Dataset: + def checksum(file_path, hash): with open(file_path, "rb") as datafile: binary_data = datafile.read() @@ -342,7 +343,7 @@ def prepare_data( logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path)) max_samples = data_args.max_samples - all_datasets: List[Dataset] = [] # support multiple datasets + all_datasets: List[Dataset] = [] # support multiple datasets for dataset_attr in data_args.dataset_list: @@ -358,10 +359,12 @@ def prepare_data( elif dataset_attr.load_from == "file": data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name) extension = dataset_attr.file_name.split(".")[-1] + if dataset_attr.file_sha1 is not None: checksum(data_file, dataset_attr.file_sha1) else: logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.") + raw_datasets = load_dataset( extension if extension in ["csv", "json"] else "text", data_files=data_file, @@ -383,11 +386,11 @@ def prepare_data( ("query_column", "query"), ("response_column", "response"), ("history_column", "history") - ]: # every dataset will have 4 columns same as each other + ]: # every dataset will have 4 columns same as each other if getattr(dataset_attr, column_name) != target_name: if getattr(dataset_attr, column_name): dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name) - else: # None or empty string + else: # None or empty string dataset = dataset.add_column(target_name, dummy_data) all_datasets.append(dataset) @@ -406,6 +409,7 @@ def preprocess_data( training_args: Seq2SeqTrainingArguments, stage: Literal["pt", "sft", "rm", "ppo"] ) -> Dataset: + column_names = list(dataset.column_names) prefix = data_args.source_prefix if data_args.source_prefix is not None else "" prompt_template = Template(data_args.prompt_template) @@ -442,9 +446,9 @@ def preprocess_data( source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) target_ids = tokenizer.encode(text=answer, add_special_tokens=False) - if len(source_ids) > data_args.max_source_length - 1: # bos token + if len(source_ids) > data_args.max_source_length - 1: # bos token source_ids = source_ids[:data_args.max_source_length - 1] - if len(target_ids) > data_args.max_target_length - 1: # eos token + if len(target_ids) > data_args.max_target_length - 1: # eos token target_ids = target_ids[:data_args.max_target_length - 1] input_ids = source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id] @@ -461,9 +465,9 @@ def preprocess_data( source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) target_ids = tokenizer.encode(text=answer, add_special_tokens=False) - if len(source_ids) > data_args.max_source_length - 1: # bos token + if len(source_ids) > data_args.max_source_length - 1: # bos token source_ids = source_ids[:data_args.max_source_length - 1] - if len(target_ids) > data_args.max_target_length - 1: # bos token + if len(target_ids) > data_args.max_target_length - 1: # bos token target_ids = target_ids[:data_args.max_target_length - 1] input_ids = source_ids + [tokenizer.bos_token_id] @@ -481,11 +485,11 @@ def preprocess_data( accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False) reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False) - if len(source_ids) > data_args.max_source_length - 1: # bos token + if len(source_ids) > data_args.max_source_length - 1: # bos token source_ids = source_ids[:data_args.max_source_length - 1] - if len(accept_ids) > data_args.max_target_length - 1: # eos token + if len(accept_ids) > data_args.max_target_length - 1: # eos token accept_ids = accept_ids[:data_args.max_target_length - 1] - if len(reject_ids) > data_args.max_target_length - 1: # eos token + if len(reject_ids) > data_args.max_target_length - 1: # eos token reject_ids = reject_ids[:data_args.max_target_length - 1] accept_ids = source_ids + [tokenizer.bos_token_id] + accept_ids + [tokenizer.eos_token_id] diff --git a/src/utils/config.py b/src/utils/config.py index 7a75aa4f..c0f89217 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -1,12 +1,13 @@ import os import json import torch -from typing import List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional from dataclasses import asdict, dataclass, field @dataclass class DatasetAttr: + load_from: str dataset_name: Optional[str] = None file_name: Optional[str] = None @@ -55,11 +56,11 @@ class ModelArguments: ) quantization_type: Optional[Literal["fp4", "nf4"]] = field( default="nf4", - metadata={"help": "Quantization data type to use."} + metadata={"help": "Quantization data type to use in int4 training."} ) double_quantization: Optional[bool] = field( default=True, - metadata={"help": "Compress the quantization statistics through double quantization."} + metadata={"help": "Whether to use double quantization in int4 training or not."} ) compute_dtype: Optional[torch.dtype] = field( default=None, @@ -67,8 +68,7 @@ class ModelArguments: ) checkpoint_dir: Optional[str] = field( default=None, - metadata={ - "help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."} + metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."} ) reward_model: Optional[str] = field( default=None, @@ -76,8 +76,7 @@ class ModelArguments: ) resume_lora_training: Optional[bool] = field( default=True, - metadata={ - "help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} + metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} ) plot_loss: Optional[bool] = field( default=False, @@ -85,7 +84,7 @@ class ModelArguments: ) def __post_init__(self): - if self.checkpoint_dir is not None: # support merging multiple lora weights + if self.checkpoint_dir is not None: # support merging multiple lora weights self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] @@ -147,7 +146,7 @@ class DataTrainingArguments: metadata={"help": "Which template to use for constructing prompts in training and inference."} ) - def __post_init__(self): # support mixing multiple datasets + def __post_init__(self): # support mixing multiple datasets dataset_names = [ds.strip() for ds in self.dataset.split(",")] with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: dataset_info = json.load(f) @@ -156,42 +155,25 @@ class DataTrainingArguments: for name in dataset_names: if name not in dataset_info: raise ValueError("Undefined dataset {} in dataset_info.json.".format(name)) - dataset_attrs = [] - dataset_attr = None + if "hf_hub_url" in dataset_info[name]: dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) elif "script_url" in dataset_info[name]: dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) - elif os.path.isfile(os.path.join(self.dataset_dir, dataset_info[name]["file_name"])): + else: dataset_attr = DatasetAttr( "file", file_name=dataset_info[name]["file_name"], file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None ) - else: - # Support Directory - for file_name in os.listdir(os.path.join(self.dataset_dir, dataset_info[name]["file_name"])): - path = os.path.join(dataset_info[name]["file_name"], file_name) - dataset_attrs.append(DatasetAttr( - "file", - file_name=path, - file_sha1=dataset_info[name]["file_sha1"] if "file_sha1" in dataset_info[name] else None - )) - if dataset_attr is not None: - if "columns" in dataset_info[name]: - dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None) - dataset_attr.query_column = dataset_info[name]["columns"].get("query", None) - dataset_attr.response_column = dataset_info[name]["columns"].get("response", None) - dataset_attr.history_column = dataset_info[name]["columns"].get("history", None) - self.dataset_list.append(dataset_attr) - else: - for i, dataset_attr in enumerate(dataset_attrs): - if "columns" in dataset_info[name]: - dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None) - dataset_attr.query_column = dataset_info[name]["columns"].get("query", None) - dataset_attr.response_column = dataset_info[name]["columns"].get("response", None) - dataset_attr.history_column = dataset_info[name]["columns"].get("history", None) - self.dataset_list.append(dataset_attr) + + if "columns" in dataset_info[name]: + dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None) + dataset_attr.query_column = dataset_info[name]["columns"].get("query", None) + dataset_attr.response_column = dataset_info[name]["columns"].get("response", None) + dataset_attr.history_column = dataset_info[name]["columns"].get("history", None) + + self.dataset_list.append(dataset_attr) @dataclass @@ -228,22 +210,20 @@ class FinetuningArguments: lora_target: Optional[str] = field( default="q_proj,v_proj", metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \ - LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"down_proj\"], \ - BLOOM choices: [\"query_key_value\", \"dense\", \"dense_\"]"} + LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"gate_proj\", \"down_proj\"], \ + BLOOM choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"]"} ) def __post_init__(self): - if isinstance(self.lora_target, str): - self.lora_target = [target.strip() for target in - self.lora_target.split(",")] # support custom target modules of LoRA + if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA + self.lora_target = [target.strip() for target in self.lora_target.split(",")] - if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 + if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 trainable_layer_ids = [27 - k for k in range(self.num_layer_trainable)] - else: # fine-tuning the first n layers if num_layer_trainable < 0 + else: # fine-tuning the first n layers if num_layer_trainable < 0 trainable_layer_ids = [k for k in range(-self.num_layer_trainable)] - self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in - trainable_layer_ids] + self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids] assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method." @@ -259,3 +239,44 @@ class FinetuningArguments: with open(json_path, "r", encoding="utf-8") as f: text = f.read() return cls(**json.loads(text)) + + +@dataclass +class GeneratingArguments: + """ + Arguments pertaining to specify the decoding parameters. + """ + do_sample: Optional[bool] = field( + default=True, + metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."} + ) + temperature: Optional[float] = field( + default=0.95, + metadata={"help": "The value used to modulate the next token probabilities."} + ) + top_p: Optional[float] = field( + default=0.7, + metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."} + ) + top_k: Optional[int] = field( + default=50, + metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."} + ) + infer_num_beams: Optional[int] = field( + default=1, + metadata={"help": "Number of beams for beam search. 1 means no beam search."} + ) + max_new_tokens: Optional[int] = field( + default=512, + metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."} + ) + repetition_penalty: Optional[float] = field( + default=1.0, + metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."} + ) + + def to_dict(self) -> Dict[str, Any]: + data_dict = asdict(self) + num_beams = data_dict.pop("infer_num_beams") + data_dict["num_beams"] = num_beams + return data_dict diff --git a/src/utils/data_collator.py b/src/utils/data_collator.py index 27d1b7ac..dbfc34f0 100644 --- a/src/utils/data_collator.py +++ b/src/utils/data_collator.py @@ -3,7 +3,6 @@ import torch from typing import Dict, Optional, Sequence, Union from transformers import DataCollatorWithPadding, BatchEncoding -from transformers.modeling_utils import PreTrainedModel from transformers.tokenization_utils import PreTrainedTokenizer from .other import IGNORE_INDEX @@ -16,11 +15,9 @@ class DynamicDataCollatorWithPadding(DataCollatorWithPadding): def __init__( self, tokenizer: PreTrainedTokenizer, - model: PreTrainedModel, ignore_pad_token_for_loss: Optional[bool] = False ): super().__init__(tokenizer, padding=True) - self.model = model self.label_pad_token_id = IGNORE_INDEX if ignore_pad_token_for_loss else tokenizer.pad_token_id def get_attention_masks(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor: diff --git a/src/utils/pairwise.py b/src/utils/pairwise.py index 317599be..57729793 100644 --- a/src/utils/pairwise.py +++ b/src/utils/pairwise.py @@ -1,5 +1,6 @@ import torch -from typing import Dict, Sequence, Union +import numpy as np +from typing import Dict, Sequence, Tuple, Union from .data_collator import DynamicDataCollatorWithPadding @@ -10,6 +11,12 @@ from .other import get_logger logger = get_logger(__name__) +def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: + preds, _ = eval_preds + preds = np.array(preds) + return {"accuracy": (preds[:, 0] > preds[:, 1]).sum() / len(preds)} + + class PairwiseDataCollatorWithPadding(DynamicDataCollatorWithPadding): r""" Data collator for pairwise data. @@ -47,5 +54,4 @@ class PairwisePeftTrainer(PeftTrainer): _, _, values = model(**inputs) r_accept, r_reject = values[:, -1].split(batch_size, dim=0) loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean() - outputs = {"r_accept": r_accept, "r_reject": r_reject} - return (loss, outputs) if return_outputs else loss + return (loss, torch.stack((r_accept, r_reject), dim=-1)) if return_outputs else loss diff --git a/src/utils/template.py b/src/utils/template.py index 5a2f49bd..64134e0c 100644 --- a/src/utils/template.py +++ b/src/utils/template.py @@ -14,27 +14,25 @@ class Template: return getattr(self, "_format_{}".format(self.name))(query, history, prefix) def _format_vanilla(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: - prompt = prefix - if history: - for old_query, response in history: - prompt += old_query + "\n" + response + "\n" - prompt += query - return prompt + r""" + Use for language model inference without histories. + """ + return query def _format_alpaca(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: r""" Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff + https://github.com/ymcui/Chinese-LLaMA-Alpaca """ if prefix: prompt = prefix else: prompt = "Below is an instruction that describes a task. " prompt += "Write a response that appropriately completes the request.\n\n" - prompt += "Instruction:\n" if history: for old_query, response in history: - prompt += "Human:\n{}\n\nAssistant:\n{}\n\n".format(old_query, response) - prompt += "Human:\n{}\n\nAssistant:".format(query) + prompt += "### Instruction:\n{}\n\n### Response:\n{}\n\n".format(old_query, response) + prompt += "### Instruction:\n{}\n\n### Response:\n".format(query) return prompt def _format_vicuna(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: diff --git a/src/web_demo.py b/src/web_demo.py index 8a8934fb..35ffcf56 100644 --- a/src/web_demo.py +++ b/src/web_demo.py @@ -21,7 +21,7 @@ from transformers.utils.versions import require_version require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0") -model_args, data_args, finetuning_args = prepare_infer_args() +model_args, data_args, finetuning_args, generating_args = prepare_infer_args() model, tokenizer = load_pretrained(model_args, finetuning_args) prompt_template = Template(data_args.prompt_template) @@ -87,9 +87,9 @@ def predict(query, chatbot, max_length, top_p, temperature, history): "do_sample": True, "top_p": top_p, "temperature": temperature, - "num_beams": 1, + "num_beams": generating_args.infer_num_beams, "max_length": max_length, - "repetition_penalty": 1.0, + "repetition_penalty": generating_args.repetition_penalty, "logits_processor": get_logits_processor(), "streamer": streamer } @@ -133,8 +133,8 @@ with gr.Blocks() as demo: with gr.Column(scale=1): emptyBtn = gr.Button("Clear History") max_length = gr.Slider(0, 2048, value=1024, step=1.0, label="Maximum length", interactive=True) - top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) - temperature = gr.Slider(0, 1.5, value=0.95, step=0.01, label="Temperature", interactive=True) + top_p = gr.Slider(0, 1, value=generating_args.top_p, step=0.01, label="Top P", interactive=True) + temperature = gr.Slider(0, 1.5, value=generating_args.temperature, step=0.01, label="Temperature", interactive=True) history = gr.State([])