mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-14 19:06:26 +08:00
@@ -1,7 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Modified from:
|
||||
# [1] https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -19,6 +15,7 @@ except ImportError:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
class LlamaShiftShortAttention(LlamaAttention):
|
||||
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user