diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 32d6247a..47c72de1 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -102,11 +102,11 @@ def load_model_and_tokenizer( ) current_max_length = getattr(config, "max_position_embeddings", None) - if current_max_length and model_args.model_max_length <= current_max_length: + if current_max_length and model_args.model_max_length > current_max_length: + scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length)) + else: logger.warning("Input length is smaller than max length. Consider increase input length.") scaling_factor = 1.0 - else: - scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length)) else: scaling_factor = 2.0