update scripts

This commit is contained in:
hiyouga
2024-08-09 19:16:23 +08:00
parent c87023d539
commit 86f7099fa3
8 changed files with 29 additions and 16 deletions

View File

@@ -19,7 +19,7 @@
import json
import os
from collections import OrderedDict
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
import fire
import torch
@@ -47,8 +47,8 @@ def block_expansion(
model_name_or_path: str,
output_dir: str,
num_expand: int,
shard_size: Optional[str] = "2GB",
save_safetensors: Optional[bool] = False,
shard_size: str = "2GB",
save_safetensors: bool = True,
):
r"""
Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.