mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	update scripts
Former-commit-id: 05aa52adde8905ca892f1ed5847d6f90b1992848
This commit is contained in:
		
							parent
							
								
									d1a8cd67d2
								
							
						
					
					
						commit
						8516054e4d
					
				@ -24,7 +24,7 @@ import fire
 | 
			
		||||
import torch
 | 
			
		||||
from safetensors.torch import save_file
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
 | 
			
		||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
 | 
			
		||||
from transformers.modeling_utils import (
 | 
			
		||||
    SAFE_WEIGHTS_INDEX_NAME,
 | 
			
		||||
    SAFE_WEIGHTS_NAME,
 | 
			
		||||
@ -35,7 +35,7 @@ from transformers.modeling_utils import (
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from transformers import PretrainedConfig, PreTrainedModel
 | 
			
		||||
    from transformers import PretrainedConfig
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def change_name(name: str, old_index: int, new_index: int) -> str:
 | 
			
		||||
@ -61,17 +61,18 @@ def block_expansion(
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
 | 
			
		||||
    tokenizer.save_pretrained(output_dir)
 | 
			
		||||
 | 
			
		||||
    config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path)  # load the original one
 | 
			
		||||
    config = AutoConfig.from_pretrained(model_name_or_path)  # load the original one
 | 
			
		||||
    if save_safetensors:
 | 
			
		||||
        setattr(config, "tie_word_embeddings", False)  # safetensors does not allow shared weights
 | 
			
		||||
 | 
			
		||||
    model: "PreTrainedModel" = AutoModelForCausalLM.from_pretrained(
 | 
			
		||||
    model = AutoModelForCausalLM.from_pretrained(
 | 
			
		||||
        model_name_or_path,
 | 
			
		||||
        config=config,
 | 
			
		||||
        torch_dtype="auto",
 | 
			
		||||
        trust_remote_code=True,
 | 
			
		||||
        low_cpu_mem_usage=True,
 | 
			
		||||
    )
 | 
			
		||||
    assert isinstance(model, PreTrainedModel)  # type hint
 | 
			
		||||
    state_dict = model.state_dict()
 | 
			
		||||
 | 
			
		||||
    if num_layers % num_expand != 0:
 | 
			
		||||
@ -85,7 +86,7 @@ def block_expansion(
 | 
			
		||||
            if f".{i:d}." in key:
 | 
			
		||||
                output_state_dict[change_name(key, i, layer_cnt)] = value
 | 
			
		||||
 | 
			
		||||
        print(f"Add layer {layer_cnt} copied from layer {i}")
 | 
			
		||||
        print(f"Add layer {layer_cnt} copied from layer {i}.")
 | 
			
		||||
        layer_cnt += 1
 | 
			
		||||
        if (i + 1) % split == 0:
 | 
			
		||||
            for key, value in state_dict.items():
 | 
			
		||||
@ -95,7 +96,7 @@ def block_expansion(
 | 
			
		||||
                    else:
 | 
			
		||||
                        output_state_dict[change_name(key, i, layer_cnt)] = torch.clone(value)
 | 
			
		||||
 | 
			
		||||
            print(f"Add layer {layer_cnt} expanded from layer {i}")
 | 
			
		||||
            print(f"Add layer {layer_cnt} expanded from layer {i}.")
 | 
			
		||||
            layer_cnt += 1
 | 
			
		||||
 | 
			
		||||
    for key, value in state_dict.items():
 | 
			
		||||
@ -112,12 +113,13 @@ def block_expansion(
 | 
			
		||||
            torch.save(shard, os.path.join(output_dir, shard_file))
 | 
			
		||||
 | 
			
		||||
    if index is None:
 | 
			
		||||
        print(f"Model weights saved in {os.path.join(output_dir, weights_name)}")
 | 
			
		||||
        print(f"Model weights saved in {os.path.join(output_dir, weights_name)}.")
 | 
			
		||||
    else:
 | 
			
		||||
        index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
 | 
			
		||||
        with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
 | 
			
		||||
            json.dump(index, f, indent=2, sort_keys=True)
 | 
			
		||||
        print(f"Model weights saved in {output_dir}")
 | 
			
		||||
 | 
			
		||||
        print(f"Model weights saved in {output_dir}.")
 | 
			
		||||
 | 
			
		||||
    print("- Fine-tune this model with:")
 | 
			
		||||
    print(f"model_name_or_path: {output_dir}")
 | 
			
		||||
 | 
			
		||||
@ -41,7 +41,7 @@ def calculate_lr(
 | 
			
		||||
    dataset: str = "alpaca_en_demo",
 | 
			
		||||
    dataset_dir: str = "data",
 | 
			
		||||
    template: str = "default",
 | 
			
		||||
    cutoff_len: int = 1024,  # i.e. maximum input length during training
 | 
			
		||||
    cutoff_len: int = 2048,  # i.e. maximum input length during training
 | 
			
		||||
    is_mistral_or_gemma: bool = False,  # mistral and gemma models opt for a smaller learning rate,
 | 
			
		||||
    packing: bool = False,
 | 
			
		||||
):
 | 
			
		||||
@ -59,6 +59,7 @@ def calculate_lr(
 | 
			
		||||
            template=template,
 | 
			
		||||
            cutoff_len=cutoff_len,
 | 
			
		||||
            packing=packing,
 | 
			
		||||
            preprocessing_num_workers=16,
 | 
			
		||||
            output_dir="dummy_dir",
 | 
			
		||||
            overwrite_cache=True,
 | 
			
		||||
            do_train=True,
 | 
			
		||||
@ -79,7 +80,7 @@ def calculate_lr(
 | 
			
		||||
 | 
			
		||||
    dataloader = DataLoader(trainset, batch_size, shuffle=False, collate_fn=data_collator, pin_memory=True)
 | 
			
		||||
    valid_tokens, total_tokens = 0, 0
 | 
			
		||||
    for batch in tqdm(dataloader):
 | 
			
		||||
    for batch in tqdm(dataloader, desc="Collecting valid tokens"):
 | 
			
		||||
        valid_tokens += torch.sum(batch["labels"] != IGNORE_INDEX).item()
 | 
			
		||||
        total_tokens += torch.numel(batch["labels"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -63,7 +63,7 @@ def calculate_ppl(
 | 
			
		||||
    dataset: str = "alpaca_en_demo",
 | 
			
		||||
    dataset_dir: str = "data",
 | 
			
		||||
    template: str = "default",
 | 
			
		||||
    cutoff_len: int = 1024,
 | 
			
		||||
    cutoff_len: int = 2048,
 | 
			
		||||
    max_samples: Optional[int] = None,
 | 
			
		||||
    train_on_prompt: bool = False,
 | 
			
		||||
):
 | 
			
		||||
@ -82,6 +82,7 @@ def calculate_ppl(
 | 
			
		||||
            cutoff_len=cutoff_len,
 | 
			
		||||
            max_samples=max_samples,
 | 
			
		||||
            train_on_prompt=train_on_prompt,
 | 
			
		||||
            preprocessing_num_workers=16,
 | 
			
		||||
            output_dir="dummy_dir",
 | 
			
		||||
            overwrite_cache=True,
 | 
			
		||||
            do_train=True,
 | 
			
		||||
@ -111,7 +112,7 @@ def calculate_ppl(
 | 
			
		||||
    perplexities = []
 | 
			
		||||
    batch: Dict[str, "torch.Tensor"]
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        for batch in tqdm(dataloader):
 | 
			
		||||
        for batch in tqdm(dataloader, desc="Computing perplexities"):
 | 
			
		||||
            batch = batch.to(model.device)
 | 
			
		||||
            outputs = model(**batch)
 | 
			
		||||
            shift_logits: "torch.Tensor" = outputs["logits"][..., :-1, :]
 | 
			
		||||
 | 
			
		||||
@ -42,6 +42,7 @@ def length_cdf(
 | 
			
		||||
            dataset_dir=dataset_dir,
 | 
			
		||||
            template=template,
 | 
			
		||||
            cutoff_len=1_000_000,
 | 
			
		||||
            preprocessing_num_workers=16,
 | 
			
		||||
            output_dir="dummy_dir",
 | 
			
		||||
            overwrite_cache=True,
 | 
			
		||||
            do_train=True,
 | 
			
		||||
@ -52,7 +53,7 @@ def length_cdf(
 | 
			
		||||
    trainset = get_dataset(template, model_args, data_args, training_args, "sft", **tokenizer_module)["train_dataset"]
 | 
			
		||||
    total_num = len(trainset)
 | 
			
		||||
    length_dict = defaultdict(int)
 | 
			
		||||
    for sample in tqdm(trainset["input_ids"]):
 | 
			
		||||
    for sample in tqdm(trainset["input_ids"], desc="Collecting lengths"):
 | 
			
		||||
        length_dict[len(sample) // interval * interval] += 1
 | 
			
		||||
 | 
			
		||||
    length_tuples = list(length_dict.items())
 | 
			
		||||
 | 
			
		||||
@ -64,6 +64,7 @@ def vllm_infer(
 | 
			
		||||
            template=template,
 | 
			
		||||
            cutoff_len=cutoff_len,
 | 
			
		||||
            max_samples=max_samples,
 | 
			
		||||
            preprocessing_num_workers=16,
 | 
			
		||||
            vllm_config=vllm_config,
 | 
			
		||||
            temperature=temperature,
 | 
			
		||||
            top_p=top_p,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user