diff --git a/src/llmtuner/train/utils.py b/src/llmtuner/train/utils.py index 144af244..09572ff7 100644 --- a/src/llmtuner/train/utils.py +++ b/src/llmtuner/train/utils.py @@ -164,6 +164,8 @@ def _create_galore_optimizer( if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all": galore_targets = find_all_linear_modules(model) + else: + galore_targets = finetuning_args.galore_target galore_params: List["torch.nn.Parameter"] = [] for name, module in model.named_modules():