diff --git a/src/llmtuner/extras/patches/llama_patch.py b/src/llmtuner/extras/patches/llama_patch.py index cc22041d..42691194 100644 --- a/src/llmtuner/extras/patches/llama_patch.py +++ b/src/llmtuner/extras/patches/llama_patch.py @@ -1,7 +1,3 @@ -# coding=utf-8 -# Modified from: -# [1] https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py - import math import torch import torch.nn as nn @@ -19,6 +15,7 @@ except ImportError: logger = logging.get_logger(__name__) +# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py class LlamaShiftShortAttention(LlamaAttention): def forward(