diff --git a/src/llmtuner/model/utils/visual.py b/src/llmtuner/model/utils/visual.py index d1556bb3..79a6570e 100644 --- a/src/llmtuner/model/utils/visual.py +++ b/src/llmtuner/model/utils/visual.py @@ -43,10 +43,13 @@ class LlavaMultiModalProjectorYiVL(nn.Module): self.linear_3 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) self.linear_4 = nn.LayerNorm(config.text_config.hidden_size, bias=True) self.act = nn.GELU() - self.proj = nn.Sequential(*[self.linear_1, self.linear_2, self.act, self.linear_3, self.linear_4]) def forward(self, image_features): - hidden_states = self.proj(image_features) + hidden_states = self.linear_1(image_features) + hidden_states = self.linear_2(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_3(hidden_states) + hidden_states = self.linear_4(hidden_states) return hidden_states