mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-30 02:30:35 +08:00
[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, List
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras import logging
|
||||
from .visual import COMPOSITE_MODELS
|
||||
@@ -25,10 +25,8 @@ if TYPE_CHECKING:
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
|
||||
r"""
|
||||
Finds all available modules to apply LoRA, GaLore or APOLLO.
|
||||
"""
|
||||
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> list[str]:
|
||||
r"""Find all available modules to apply LoRA, GaLore or APOLLO."""
|
||||
model_type = getattr(model.config, "model_type", None)
|
||||
forbidden_modules = {"lm_head"}
|
||||
if model_type == "chatglm":
|
||||
@@ -54,10 +52,8 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
|
||||
return list(module_names)
|
||||
|
||||
|
||||
def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]:
|
||||
r"""
|
||||
Finds the modules in the expanded blocks to apply lora.
|
||||
"""
|
||||
def find_expanded_modules(model: "PreTrainedModel", target_modules: list[str], num_layer_trainable: int) -> list[str]:
|
||||
r"""Find the modules in the expanded blocks to apply lora."""
|
||||
num_layers = getattr(model.config, "num_hidden_layers", None)
|
||||
if not num_layers:
|
||||
raise ValueError("Model was not supported.")
|
||||
|
||||
Reference in New Issue
Block a user