From 3fcb678d00fce818d719fa2f5e02f2d7022a01e6 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Thu, 6 Jun 2024 02:43:19 +0800 Subject: [PATCH] support train from scratch #4033 #4075 Former-commit-id: a12a506c3d2ba85975a5990c46d2e055cdfe0f2e --- src/llamafactory/hparams/model_args.py | 4 ++++ src/llamafactory/model/loader.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 7003cbee..a3b5b2a6 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -101,6 +101,10 @@ class ModelArguments: default=False, metadata={"help": "Whether or not to upcast the output of lm_head in fp32."}, ) + train_from_scratch: bool = field( + default=False, + metadata={"help": "Whether or not to randomly initialize the model weights."}, + ) infer_backend: Literal["huggingface", "vllm"] = field( default="huggingface", metadata={"help": "Backend engine used at inference."}, diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 49b347d5..8f3309b3 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -131,6 +131,8 @@ def load_model( model = load_mod_pretrained_model(**init_kwargs) elif model_args.visual_inputs: model = AutoModelForVision2Seq.from_pretrained(**init_kwargs) + elif model_args.train_from_scratch: + model = AutoModelForCausalLM.from_config(config) else: model = AutoModelForCausalLM.from_pretrained(**init_kwargs)