mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[misc] fix grad ckpt func (#6916)
Former-commit-id: 35e069a52b3d7cfd9b0107574b09265eb2290f0b
This commit is contained in:
		
							parent
							
								
									0c0cdc26bc
								
							
						
					
					
						commit
						3a3f4072e5
					
				@ -142,21 +142,23 @@ def calculate_mfu(
 | 
			
		||||
        args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json"
 | 
			
		||||
 | 
			
		||||
    run_exp(args)
 | 
			
		||||
    with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
 | 
			
		||||
        result = json.load(f)
 | 
			
		||||
 | 
			
		||||
    if dist.is_initialized():
 | 
			
		||||
        dist.barrier()
 | 
			
		||||
        world_size = dist.get_world_size()
 | 
			
		||||
    else:
 | 
			
		||||
        world_size = 1
 | 
			
		||||
 | 
			
		||||
    total_batch_size = batch_size * world_size
 | 
			
		||||
    mfu_value = (
 | 
			
		||||
        result["train_steps_per_second"]
 | 
			
		||||
        * compute_model_flops(model_name_or_path, total_batch_size, seq_length)
 | 
			
		||||
        / compute_device_flops(world_size)
 | 
			
		||||
    )
 | 
			
		||||
    print(f"MFU: {mfu_value * 100:.2f}%")
 | 
			
		||||
    if int(os.getenv("LOCAL_RANK", "0")) == 0:
 | 
			
		||||
        with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
 | 
			
		||||
            result = json.load(f)
 | 
			
		||||
 | 
			
		||||
        total_batch_size = batch_size * world_size
 | 
			
		||||
        mfu_value = (
 | 
			
		||||
            result["train_steps_per_second"]
 | 
			
		||||
            * compute_model_flops(model_name_or_path, total_batch_size, seq_length)
 | 
			
		||||
            / compute_device_flops(world_size)
 | 
			
		||||
        )
 | 
			
		||||
        print(f"MFU: {mfu_value * 100:.2f}%")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
@ -87,9 +87,10 @@ def _parse_args(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _set_transformers_logging() -> None:
 | 
			
		||||
    transformers.utils.logging.set_verbosity_info()
 | 
			
		||||
    transformers.utils.logging.enable_default_handler()
 | 
			
		||||
    transformers.utils.logging.enable_explicit_format()
 | 
			
		||||
    if os.getenv("LLAMAFACTORY_VERBOSITY", "INFO") in ["DEBUG", "INFO"]:
 | 
			
		||||
        transformers.utils.logging.set_verbosity_info()
 | 
			
		||||
        transformers.utils.logging.enable_default_handler()
 | 
			
		||||
        transformers.utils.logging.enable_explicit_format()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _verify_model_args(
 | 
			
		||||
 | 
			
		||||
@ -89,6 +89,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
 | 
			
		||||
            for arg in args:
 | 
			
		||||
                if torch.is_tensor(arg) and torch.is_floating_point(arg):
 | 
			
		||||
                    arg.requires_grad_(True)
 | 
			
		||||
                    break  # assume the first tensor is always the hidden states
 | 
			
		||||
 | 
			
		||||
        return gradient_checkpointing_func(func, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user