diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 7003cbee..a3b5b2a6 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -101,6 +101,10 @@ class ModelArguments: default=False, metadata={"help": "Whether or not to upcast the output of lm_head in fp32."}, ) + train_from_scratch: bool = field( + default=False, + metadata={"help": "Whether or not to randomly initialize the model weights."}, + ) infer_backend: Literal["huggingface", "vllm"] = field( default="huggingface", metadata={"help": "Backend engine used at inference."}, diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 49b347d5..8f3309b3 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -131,6 +131,8 @@ def load_model( model = load_mod_pretrained_model(**init_kwargs) elif model_args.visual_inputs: model = AutoModelForVision2Seq.from_pretrained(**init_kwargs) + elif model_args.train_from_scratch: + model = AutoModelForCausalLM.from_config(config) else: model = AutoModelForCausalLM.from_pretrained(**init_kwargs)