From 20a9565e363c0830f92b6af546a0f512a9fd9be7 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 3 Jan 2025 10:50:32 +0000 Subject: [PATCH] update scripts Former-commit-id: dd44c65d7f60cb6f5d0e0d8ee5f4e7643defb89b --- scripts/llama_pro.py | 18 ++++++++++-------- scripts/stat_utils/cal_lr.py | 5 +++-- scripts/stat_utils/cal_ppl.py | 5 +++-- scripts/stat_utils/length_cdf.py | 3 ++- scripts/vllm_infer.py | 1 + 5 files changed, 19 insertions(+), 13 deletions(-) diff --git a/scripts/llama_pro.py b/scripts/llama_pro.py index b086583d..447890f4 100644 --- a/scripts/llama_pro.py +++ b/scripts/llama_pro.py @@ -24,7 +24,7 @@ import fire import torch from safetensors.torch import save_file from tqdm import tqdm -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel from transformers.modeling_utils import ( SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, @@ -35,7 +35,7 @@ from transformers.modeling_utils import ( if TYPE_CHECKING: - from transformers import PretrainedConfig, PreTrainedModel + from transformers import PretrainedConfig def change_name(name: str, old_index: int, new_index: int) -> str: @@ -61,17 +61,18 @@ def block_expansion( tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) tokenizer.save_pretrained(output_dir) - config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path) # load the original one + config = AutoConfig.from_pretrained(model_name_or_path) # load the original one if save_safetensors: setattr(config, "tie_word_embeddings", False) # safetensors does not allow shared weights - model: "PreTrainedModel" = AutoModelForCausalLM.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( model_name_or_path, config=config, torch_dtype="auto", trust_remote_code=True, low_cpu_mem_usage=True, ) + assert isinstance(model, PreTrainedModel) # type hint state_dict = model.state_dict() if num_layers % num_expand != 0: @@ -85,7 +86,7 @@ def block_expansion( if f".{i:d}." in key: output_state_dict[change_name(key, i, layer_cnt)] = value - print(f"Add layer {layer_cnt} copied from layer {i}") + print(f"Add layer {layer_cnt} copied from layer {i}.") layer_cnt += 1 if (i + 1) % split == 0: for key, value in state_dict.items(): @@ -95,7 +96,7 @@ def block_expansion( else: output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value) - print(f"Add layer {layer_cnt} expanded from layer {i}") + print(f"Add layer {layer_cnt} expanded from layer {i}.") layer_cnt += 1 for key, value in state_dict.items(): @@ -112,12 +113,13 @@ def block_expansion( torch.save(shard, os.path.join(output_dir, shard_file)) if index is None: - print(f"Model weights saved in {os.path.join(output_dir, weights_name)}") + print(f"Model weights saved in {os.path.join(output_dir, weights_name)}.") else: index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f: json.dump(index, f, indent=2, sort_keys=True) - print(f"Model weights saved in {output_dir}") + + print(f"Model weights saved in {output_dir}.") print("- Fine-tune this model with:") print(f"model_name_or_path: {output_dir}") diff --git a/scripts/stat_utils/cal_lr.py b/scripts/stat_utils/cal_lr.py index a76d5827..21206a28 100644 --- a/scripts/stat_utils/cal_lr.py +++ b/scripts/stat_utils/cal_lr.py @@ -41,7 +41,7 @@ def calculate_lr( dataset: str = "alpaca_en_demo", dataset_dir: str = "data", template: str = "default", - cutoff_len: int = 1024, # i.e. maximum input length during training + cutoff_len: int = 2048, # i.e. maximum input length during training is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate, packing: bool = False, ): @@ -59,6 +59,7 @@ def calculate_lr( template=template, cutoff_len=cutoff_len, packing=packing, + preprocessing_num_workers=16, output_dir="dummy_dir", overwrite_cache=True, do_train=True, @@ -79,7 +80,7 @@ def calculate_lr( dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True) valid_tokens, total_tokens = 0, 0 - for batch in tqdm(dataloader): + for batch in tqdm(dataloader, desc="Collecting valid tokens"): valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item() total_tokens += torch.numel(batch["labels"]) diff --git a/scripts/stat_utils/cal_ppl.py b/scripts/stat_utils/cal_ppl.py index 03b25d9b..32d50e64 100644 --- a/scripts/stat_utils/cal_ppl.py +++ b/scripts/stat_utils/cal_ppl.py @@ -63,7 +63,7 @@ def calculate_ppl( dataset: str = "alpaca_en_demo", dataset_dir: str = "data", template: str = "default", - cutoff_len: int = 1024, + cutoff_len: int = 2048, max_samples: Optional[int] = None, train_on_prompt: bool = False, ): @@ -82,6 +82,7 @@ def calculate_ppl( cutoff_len=cutoff_len, max_samples=max_samples, train_on_prompt=train_on_prompt, + preprocessing_num_workers=16, output_dir="dummy_dir", overwrite_cache=True, do_train=True, @@ -111,7 +112,7 @@ def calculate_ppl( perplexities = [] batch: Dict[str, "torch.Tensor"] with torch.no_grad(): - for batch in tqdm(dataloader): + for batch in tqdm(dataloader, desc="Computing perplexities"): batch = batch.to(model.device) outputs = model(**batch) shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :] diff --git a/scripts/stat_utils/length_cdf.py b/scripts/stat_utils/length_cdf.py index 4b2b5349..5cf25347 100644 --- a/scripts/stat_utils/length_cdf.py +++ b/scripts/stat_utils/length_cdf.py @@ -42,6 +42,7 @@ def length_cdf( dataset_dir=dataset_dir, template=template, cutoff_len=1_000_000, + preprocessing_num_workers=16, output_dir="dummy_dir", overwrite_cache=True, do_train=True, @@ -52,7 +53,7 @@ def length_cdf( trainset = get_dataset(template, model_args, data_args, training_args, "sft", **tokenizer_module)["train_dataset"] total_num = len(trainset) length_dict = defaultdict(int) - for sample in tqdm(trainset["input_ids"]): + for sample in tqdm(trainset["input_ids"], desc="Collecting lengths"): length_dict[len(sample) // interval * interval] += 1 length_tuples = list(length_dict.items()) diff --git a/scripts/vllm_infer.py b/scripts/vllm_infer.py index 063f457c..c9a8cfb6 100644 --- a/scripts/vllm_infer.py +++ b/scripts/vllm_infer.py @@ -64,6 +64,7 @@ def vllm_infer( template=template, cutoff_len=cutoff_len, max_samples=max_samples, + preprocessing_num_workers=16, vllm_config=vllm_config, temperature=temperature, top_p=top_p,