From 3eaabe12aac7d6c750d1cdf74b7a1c4ecb82fc07 Mon Sep 17 00:00:00 2001 From: ShaneTian <42370681+ShaneTian@users.noreply.github.com> Date: Thu, 21 Dec 2023 21:25:20 +0800 Subject: [PATCH] Fix slow model initialization in bfloat16 dtype. Former-commit-id: d032daa4bd598dd0d71b43eb68a614de77a699a6 --- src/llmtuner/model/loader.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 404da42a..fe821224 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -67,10 +67,10 @@ def load_model_and_tokenizer( model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, config=config, - torch_dtype=model_args.compute_dtype, low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), **config_kwargs ) + model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model patcher.patch_model(model) register_autoclass(config, model, tokenizer) if not is_deepspeed_zero3_enabled(): @@ -95,7 +95,6 @@ def load_model_and_tokenizer( if not is_trainable: model.requires_grad_(False) # fix all model params - model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model model.eval() else: model.train()