mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 20:22:49 +08:00
support ms
Former-commit-id: d38a2e7341100902b6c761895b1fe6191c905d06
This commit is contained in:
parent
8ec3617f19
commit
e08e0e5814
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
from collections import defaultdict, OrderedDict
|
from collections import defaultdict, OrderedDict
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
@ -20,6 +21,8 @@ SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
|||||||
|
|
||||||
SUPPORTED_MODELS = OrderedDict()
|
SUPPORTED_MODELS = OrderedDict()
|
||||||
|
|
||||||
|
MODELSCOPE_MODELS = OrderedDict()
|
||||||
|
|
||||||
TRAINING_STAGES = {
|
TRAINING_STAGES = {
|
||||||
"Supervised Fine-Tuning": "sft",
|
"Supervised Fine-Tuning": "sft",
|
||||||
"Reward Modeling": "rm",
|
"Reward Modeling": "rm",
|
||||||
@ -40,7 +43,11 @@ def register_model_group(
|
|||||||
prefix = name.split("-")[0]
|
prefix = name.split("-")[0]
|
||||||
else:
|
else:
|
||||||
assert prefix == name.split("-")[0], "prefix should be identical."
|
assert prefix == name.split("-")[0], "prefix should be identical."
|
||||||
|
|
||||||
SUPPORTED_MODELS[name] = path
|
SUPPORTED_MODELS[name] = path
|
||||||
|
if os.environ.get('USE_MODELSCOPE_HUB', False) and name in MODELSCOPE_MODELS:
|
||||||
|
# Use ModelScope modelhub
|
||||||
|
SUPPORTED_MODELS[name] = MODELSCOPE_MODELS[name]
|
||||||
if module is not None:
|
if module is not None:
|
||||||
DEFAULT_MODULE[prefix] = module
|
DEFAULT_MODULE[prefix] = module
|
||||||
if template is not None:
|
if template is not None:
|
||||||
@ -58,6 +65,13 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODELSCOPE_MODELS.update({
|
||||||
|
"Baichuan-7B-Base": "baichuan-inc/baichuan-7B",
|
||||||
|
"Baichuan-13B-Base": "baichuan-inc/Baichuan-13B-Base",
|
||||||
|
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Base"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Baichuan2-7B-Base": "baichuan-inc/Baichuan2-7B-Base",
|
"Baichuan2-7B-Base": "baichuan-inc/Baichuan2-7B-Base",
|
||||||
@ -70,6 +84,14 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODELSCOPE_MODELS.update({
|
||||||
|
"Baichuan2-7B-Base": "baichuan-inc/Baichuan2-7B-Base",
|
||||||
|
"Baichuan2-13B-Base": "baichuan-inc/Baichuan2-13B-Base",
|
||||||
|
"Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat",
|
||||||
|
"Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"BLOOM-560M": "bigscience/bloom-560m",
|
"BLOOM-560M": "bigscience/bloom-560m",
|
||||||
@ -80,6 +102,13 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODELSCOPE_MODELS.update({
|
||||||
|
"BLOOM-560M": "AI-ModelScope/bloom-560m",
|
||||||
|
"BLOOM-3B": "bigscience/bloom-3b",
|
||||||
|
"BLOOM-7B1": "bigscience/bloom-7b1"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"BLOOMZ-560M": "bigscience/bloomz-560m",
|
"BLOOMZ-560M": "bigscience/bloomz-560m",
|
||||||
@ -90,6 +119,13 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODELSCOPE_MODELS.update({
|
||||||
|
"BLOOMZ-560M": "bigscience/bloomz-560m",
|
||||||
|
"BLOOMZ-3B": "bigscience/bloomz-3b",
|
||||||
|
"BLOOMZ-7B1-mt": "AI-ModelScope/bloomz-7b1-mt"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"BlueLM-7B-Base": "vivo-ai/BlueLM-7B-Base",
|
"BlueLM-7B-Base": "vivo-ai/BlueLM-7B-Base",
|
||||||
@ -99,6 +135,12 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODELSCOPE_MODELS.update({
|
||||||
|
"BlueLM-7B-Base": "vivo-ai/BlueLM-7B-Base",
|
||||||
|
"BlueLM-7B-Chat": "vivo-ai/BlueLM-7B-Chat"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b"
|
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b"
|
||||||
@ -108,6 +150,11 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODELSCOPE_MODELS.update({
|
||||||
|
"ChatGLM2-6B-Chat": "ZhipuAI/chatglm2-6b"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"ChatGLM3-6B-Base": "THUDM/chatglm3-6b-base",
|
"ChatGLM3-6B-Base": "THUDM/chatglm3-6b-base",
|
||||||
@ -118,6 +165,12 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODELSCOPE_MODELS.update({
|
||||||
|
"ChatGLM3-6B-Base": "ZhipuAI/chatglm3-6b-base",
|
||||||
|
"ChatGLM3-6B-Chat": "ZhipuAI/chatglm3-6b"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"ChineseLLaMA2-1.3B": "hfl/chinese-llama-2-1.3b",
|
"ChineseLLaMA2-1.3B": "hfl/chinese-llama-2-1.3b",
|
||||||
@ -131,6 +184,16 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODELSCOPE_MODELS.update({
|
||||||
|
"ChineseLLaMA2-1.3B": "hfl/chinese-llama-2-1.3b",
|
||||||
|
"ChineseLLaMA2-7B": "hfl/chinese-llama-2-7b",
|
||||||
|
"ChineseLLaMA2-13B": "hfl/chinese-llama-2-13b",
|
||||||
|
"ChineseLLaMA2-1.3B-Chat": "hfl/chinese-alpaca-2-1.3b",
|
||||||
|
"ChineseLLaMA2-7B-Chat": "hfl/chinese-alpaca-2-7b",
|
||||||
|
"ChineseLLaMA2-13B-Chat": "hfl/chinese-alpaca-2-13b"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Falcon-7B": "tiiuae/falcon-7b",
|
"Falcon-7B": "tiiuae/falcon-7b",
|
||||||
@ -145,6 +208,16 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODELSCOPE_MODELS.update({
|
||||||
|
"Falcon-7B": "tiiuae/falcon-7b",
|
||||||
|
"Falcon-40B": "tiiuae/falcon-40b",
|
||||||
|
"Falcon-180B": "tiiuae/falcon-180B",
|
||||||
|
"Falcon-7B-Chat": "AI-ModelScope/falcon-7b-instruct",
|
||||||
|
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
|
||||||
|
"Falcon-180B-Chat": "tiiuae/falcon-180B-chat"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"InternLM-7B": "internlm/internlm-7b",
|
"InternLM-7B": "internlm/internlm-7b",
|
||||||
@ -156,6 +229,14 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODELSCOPE_MODELS.update({
|
||||||
|
"InternLM-7B": "Shanghai_AI_Laboratory/internlm-7b",
|
||||||
|
"InternLM-20B": "Shanghai_AI_Laboratory/internlm-20b",
|
||||||
|
"InternLM-7B-Chat": "Shanghai_AI_Laboratory/internlm-chat-7b",
|
||||||
|
"InternLM-20B-Chat": "Shanghai_AI_Laboratory/internlm-chat-20b"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"LingoWhale-8B": "deeplang-ai/LingoWhale-8B"
|
"LingoWhale-8B": "deeplang-ai/LingoWhale-8B"
|
||||||
@ -164,6 +245,11 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODELSCOPE_MODELS.update({
|
||||||
|
"LingoWhale-8B": "DeepLang/LingoWhale-8B"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"LLaMA-7B": "huggyllama/llama-7b",
|
"LLaMA-7B": "huggyllama/llama-7b",
|
||||||
@ -174,6 +260,14 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODELSCOPE_MODELS.update({
|
||||||
|
"LLaMA-7B": "skyline2006/llama-7b",
|
||||||
|
"LLaMA-13B": "skyline2006/llama-13b",
|
||||||
|
"LLaMA-30B": "skyline2006/llama-30b",
|
||||||
|
"LLaMA-65B": "skyline2006/llama-65b"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
|
"LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
|
||||||
@ -187,6 +281,16 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODELSCOPE_MODELS.update({
|
||||||
|
"LLaMA2-7B": "modelscope/Llama-2-7b-ms",
|
||||||
|
"LLaMA2-13B": "modelscope/Llama-2-13b-ms",
|
||||||
|
"LLaMA2-70B": "modelscope/Llama-2-70b-ms",
|
||||||
|
"LLaMA2-7B-Chat": "modelscope/Llama-2-7b-chat-ms",
|
||||||
|
"LLaMA2-13B-Chat": "modelscope/Llama-2-13b-chat-ms",
|
||||||
|
"LLaMA2-70B-Chat": "modelscope/Llama-2-70b-chat-ms"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Mistral-7B": "mistralai/Mistral-7B-v0.1",
|
"Mistral-7B": "mistralai/Mistral-7B-v0.1",
|
||||||
@ -196,6 +300,12 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODELSCOPE_MODELS.update({
|
||||||
|
"Mistral-7B": "AI-ModelScope/Mistral-7B-v0.1",
|
||||||
|
"Mistral-7B-Chat": "AI-ModelScope/Mistral-7B-Instruct-v0.1"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"OpenChat3.5-7B-Chat": "openchat/openchat_3.5"
|
"OpenChat3.5-7B-Chat": "openchat/openchat_3.5"
|
||||||
@ -204,6 +314,11 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODELSCOPE_MODELS.update({
|
||||||
|
"OpenChat3.5-7B-Chat": "openchat/openchat_3.5"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Phi1.5-1.3B": "microsoft/phi-1_5"
|
"Phi1.5-1.3B": "microsoft/phi-1_5"
|
||||||
@ -212,6 +327,11 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODELSCOPE_MODELS.update({
|
||||||
|
"Phi1.5-1.3B": "microsoft/phi-1_5"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Qwen-7B": "Qwen/Qwen-7B",
|
"Qwen-7B": "Qwen/Qwen-7B",
|
||||||
@ -228,6 +348,18 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODELSCOPE_MODELS.update({
|
||||||
|
"Qwen-7B": "qwen/Qwen-7B",
|
||||||
|
"Qwen-14B": "qwen/Qwen-14B",
|
||||||
|
"Qwen-7B-Chat": "qwen/Qwen-7B-Chat",
|
||||||
|
"Qwen-14B-Chat": "qwen/Qwen-14B-Chat",
|
||||||
|
"Qwen-7B-int8-Chat": "qwen/Qwen-7B-Chat-Int8",
|
||||||
|
"Qwen-7B-int4-Chat": "qwen/Qwen-7B-Chat-Int4",
|
||||||
|
"Qwen-14B-int8-Chat": "qwen/Qwen-14B-Chat-Int8",
|
||||||
|
"Qwen-14B-int4-Chat": "qwen/Qwen-14B-Chat-Int4"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Skywork-13B-Base": "Skywork/Skywork-13B-base"
|
"Skywork-13B-Base": "Skywork/Skywork-13B-base"
|
||||||
@ -235,6 +367,11 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODELSCOPE_MODELS.update({
|
||||||
|
"Skywork-13B-Base": "skywork/Skywork-13B-base"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Vicuna1.5-7B-Chat": "lmsys/vicuna-7b-v1.5",
|
"Vicuna1.5-7B-Chat": "lmsys/vicuna-7b-v1.5",
|
||||||
@ -244,6 +381,12 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODELSCOPE_MODELS.update({
|
||||||
|
"Vicuna1.5-7B-Chat": "AI-ModelScope/vicuna-7b-v1.5",
|
||||||
|
"Vicuna1.5-13B-Chat": "lmsys/vicuna-13b-v1.5"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"XVERSE-7B": "xverse/XVERSE-7B",
|
"XVERSE-7B": "xverse/XVERSE-7B",
|
||||||
@ -256,6 +399,15 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODELSCOPE_MODELS.update({
|
||||||
|
"XVERSE-7B": "xverse/XVERSE-7B",
|
||||||
|
"XVERSE-13B": "xverse/XVERSE-13B",
|
||||||
|
"XVERSE-65B": "xverse/XVERSE-65B",
|
||||||
|
"XVERSE-7B-Chat": "xverse/XVERSE-7B-Chat",
|
||||||
|
"XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Yayi-7B": "wenge-research/yayi-7b-llama2",
|
"Yayi-7B": "wenge-research/yayi-7b-llama2",
|
||||||
@ -265,6 +417,12 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODELSCOPE_MODELS.update({
|
||||||
|
"Yayi-7B": "wenge-research/yayi-7b-llama2",
|
||||||
|
"Yayi-13B": "wenge-research/yayi-13b-llama2"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Yi-6B": "01-ai/Yi-6B",
|
"Yi-6B": "01-ai/Yi-6B",
|
||||||
@ -276,6 +434,14 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODELSCOPE_MODELS.update({
|
||||||
|
"Yi-6B": "01ai/Yi-6B",
|
||||||
|
"Yi-34B": "01ai/Yi-34B",
|
||||||
|
"Yi-34B-Chat": "01ai/Yi-34B-Chat",
|
||||||
|
"Yi-34B-int8-Chat": "01ai/Yi-34B-Chat-8bits"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Zephyr-7B-Alpha-Chat": "HuggingFaceH4/zephyr-7b-alpha",
|
"Zephyr-7B-Alpha-Chat": "HuggingFaceH4/zephyr-7b-alpha",
|
||||||
@ -283,3 +449,9 @@ register_model_group(
|
|||||||
},
|
},
|
||||||
template="zephyr"
|
template="zephyr"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODELSCOPE_MODELS.update({
|
||||||
|
"Zephyr-7B-Alpha-Chat": "HuggingFaceH4/zephyr-7b-alpha",
|
||||||
|
"Zephyr-7B-Beta-Chat": "modelscope/zephyr-7b-beta"
|
||||||
|
})
|
||||||
|
@ -8,7 +8,8 @@ class ModelArguments:
|
|||||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
||||||
"""
|
"""
|
||||||
model_name_or_path: str = field(
|
model_name_or_path: str = field(
|
||||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
|
metadata={"help": "Path to pretrained model or model identifier "
|
||||||
|
"from huggingface.co/models or modelscope.cn/models."}
|
||||||
)
|
)
|
||||||
cache_dir: Optional[str] = field(
|
cache_dir: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Literal, Optional, Tuple
|
from typing import TYPE_CHECKING, Literal, Optional, Tuple
|
||||||
@ -63,6 +65,8 @@ def load_model_and_tokenizer(
|
|||||||
"token": model_args.hf_hub_token
|
"token": model_args.hf_hub_token
|
||||||
}
|
}
|
||||||
|
|
||||||
|
try_download_model_from_ms(model_args)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path,
|
||||||
use_fast=model_args.use_fast_tokenizer,
|
use_fast=model_args.use_fast_tokenizer,
|
||||||
@ -228,3 +232,13 @@ def load_model_and_tokenizer(
|
|||||||
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
|
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def try_download_model_from_ms(model_args):
|
||||||
|
if os.environ.get('USE_MODELSCOPE_HUB', False) and not os.path.exists(model_args.model_name_or_path):
|
||||||
|
try:
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
model_args.model_name_or_path = snapshot_download(model_args.model_name_or_path, model_args.model_revision)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(f'You are using `USE_MODELSCOPE_HUB=True` but you have no modelscope sdk installed. '
|
||||||
|
f'Please install it by `pip install modelscope -U`') from e
|
||||||
|
Loading…
x
Reference in New Issue
Block a user