diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 404da42a..fe821224 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -67,10 +67,10 @@ def load_model_and_tokenizer( model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, config=config, - torch_dtype=model_args.compute_dtype, low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), **config_kwargs ) + model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model patcher.patch_model(model) register_autoclass(config, model, tokenizer) if not is_deepspeed_zero3_enabled(): @@ -95,7 +95,6 @@ def load_model_and_tokenizer( if not is_trainable: model.requires_grad_(False) # fix all model params - model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model model.eval() else: model.train()