patch modelscope

This commit is contained in:
hiyouga
2023-12-01 22:53:15 +08:00
parent 3a64506031
commit bd42c229b0
7 changed files with 312 additions and 222 deletions

View File

@@ -23,6 +23,7 @@ except ImportError:
if TYPE_CHECKING:
from transformers import HfArgumentParser
from llmtuner.hparams import ModelArguments
class AverageMeter:
@@ -117,3 +118,23 @@ def torch_gc() -> None:
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def try_download_model_from_ms(model_args: "ModelArguments") -> None:
if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
return
try:
from modelscope import snapshot_download # type: ignore
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
model_args.model_name_or_path = snapshot_download(
model_args.model_name_or_path,
revision=revision,
cache_dir=model_args.cache_dir
)
except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`")
def use_modelscope() -> bool:
return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0")))