mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-15 16:18:10 +08:00
support system column #1765
Former-commit-id: f425584a511c5e42bae8b3ba090eaa898b28adad
This commit is contained in:
parent
c27675f70d
commit
934d00ea1e
@ -57,7 +57,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
|
|
||||||
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
[23/12/12] We supported fine-tuning the latest MoE model **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)** in our framework. See hardware requirement [here](#hardware-requirement).
|
||||||
|
|
||||||
[23/12/01] We supported downloading pre-trained models from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-models-optional) for usage.
|
[23/12/01] We supported downloading pre-trained models and datasets from the **[ModelScope Hub](https://modelscope.cn/models)** for Chinese mainland users. See [this tutorial](#use-modelscope-hub-optional) for usage.
|
||||||
|
|
||||||
<details><summary>Full Changelog</summary>
|
<details><summary>Full Changelog</summary>
|
||||||
|
|
||||||
@ -242,9 +242,9 @@ If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you wi
|
|||||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
||||||
```
|
```
|
||||||
|
|
||||||
### Use ModelScope Models (optional)
|
### Use ModelScope Hub (optional)
|
||||||
|
|
||||||
If you have trouble with downloading models from Hugging Face, you can use LLaMA-Factory together with ModelScope in the following manner.
|
If you have trouble with downloading models and datasets from Hugging Face, you can use LLaMA-Factory together with ModelScope in the following manner.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
|
export USE_MODELSCOPE_HUB=1 # `set USE_MODELSCOPE_HUB=1` for Windows
|
||||||
@ -258,7 +258,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
... # arguments (same as above)
|
... # arguments (same as above)
|
||||||
```
|
```
|
||||||
|
|
||||||
LLaMA Board also supports using the models on the ModelScope Hub.
|
LLaMA Board also supports using the models and datasets on the ModelScope Hub.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
|
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
|
||||||
|
@ -57,7 +57,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
|
|
||||||
[23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。
|
[23/12/12] 我们支持了微调最新的混合专家模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**。硬件需求请查阅[此处](#硬件依赖)。
|
||||||
|
|
||||||
[23/12/01] 我们支持了从 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型。详细用法请参照 [此教程](#使用魔搭社区可跳过)。
|
[23/12/01] 我们支持了从 **[魔搭社区](https://modelscope.cn/models)** 下载预训练模型和数据集。详细用法请参照 [此教程](#使用魔搭社区可跳过)。
|
||||||
|
|
||||||
<details><summary>展开日志</summary>
|
<details><summary>展开日志</summary>
|
||||||
|
|
||||||
@ -244,7 +244,7 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
|
|||||||
|
|
||||||
### 使用魔搭社区(可跳过)
|
### 使用魔搭社区(可跳过)
|
||||||
|
|
||||||
如果您在 Hugging Face 模型的下载中遇到了问题,可以通过下述方法使用魔搭社区。
|
如果您在 Hugging Face 模型和数据集的下载中遇到了问题,可以通过下述方法使用魔搭社区。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
|
||||||
@ -258,7 +258,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
|||||||
... # 参数同上
|
... # 参数同上
|
||||||
```
|
```
|
||||||
|
|
||||||
LLaMA Board 同样支持魔搭社区的模型下载。
|
LLaMA Board 同样支持魔搭社区的模型和数据集下载。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
|
CUDA_VISIBLE_DEVICES=0 USE_MODELSCOPE_HUB=1 python src/train_web.py
|
||||||
|
@ -17,7 +17,8 @@ If you are using a custom dataset, please provide your dataset definition in the
|
|||||||
"history": "the column name in the dataset containing the histories. (default: None, for alpaca)",
|
"history": "the column name in the dataset containing the histories. (default: None, for alpaca)",
|
||||||
"messages": "the column name in the dataset containing the messages. (default: conversations, for sharegpt)",
|
"messages": "the column name in the dataset containing the messages. (default: conversations, for sharegpt)",
|
||||||
"role": "the key in the message represents the identity. (default: from, for sharegpt)",
|
"role": "the key in the message represents the identity. (default: from, for sharegpt)",
|
||||||
"content": "the key in the message represents the content. (default: value, for sharegpt)"
|
"content": "the key in the message represents the content. (default: value, for sharegpt)",
|
||||||
|
"system": "the column name in the dataset containing the system prompts. (default: None, for both)"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
@ -32,6 +33,7 @@ Currently we support dataset in **alpaca** or **sharegpt** format, the dataset i
|
|||||||
"instruction": "user instruction (required)",
|
"instruction": "user instruction (required)",
|
||||||
"input": "user input (optional)",
|
"input": "user input (optional)",
|
||||||
"output": "model response (required)",
|
"output": "model response (required)",
|
||||||
|
"system": "system prompt (optional)",
|
||||||
"history": [
|
"history": [
|
||||||
["user instruction in the first round (optional)", "model response in the first round (optional)"],
|
["user instruction in the first round (optional)", "model response in the first round (optional)"],
|
||||||
["user instruction in the second round (optional)", "model response in the second round (optional)"]
|
["user instruction in the second round (optional)", "model response in the second round (optional)"]
|
||||||
@ -48,6 +50,7 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
|||||||
"prompt": "instruction",
|
"prompt": "instruction",
|
||||||
"query": "input",
|
"query": "input",
|
||||||
"response": "output",
|
"response": "output",
|
||||||
|
"system": "system",
|
||||||
"history": "history"
|
"history": "history"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -55,7 +58,7 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
|||||||
|
|
||||||
where the `prompt` and `response` columns should contain non-empty values, represent instruction and response respectively. The `query` column will be concatenated with the `prompt` column and used as input for the model.
|
where the `prompt` and `response` columns should contain non-empty values, represent instruction and response respectively. The `query` column will be concatenated with the `prompt` column and used as input for the model.
|
||||||
|
|
||||||
The `history` column is a list consisting string tuples representing query-response pairs in history. Note that the responses **in each round will be used for training**.
|
The `system` column will be used as the system prompt in the template. The `history` column is a list consisting string tuples representing query-response pairs in history. Note that the responses **in each round will be used for training**.
|
||||||
|
|
||||||
For the pre-training datasets, only the `prompt` column will be used for training.
|
For the pre-training datasets, only the `prompt` column will be used for training.
|
||||||
|
|
||||||
@ -86,7 +89,8 @@ The dataset in sharegpt format should follow the below format:
|
|||||||
"from": "gpt",
|
"from": "gpt",
|
||||||
"value": "model response"
|
"value": "model response"
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"system": "system prompt (optional)"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
```
|
```
|
||||||
@ -98,7 +102,8 @@ Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
|||||||
"columns": {
|
"columns": {
|
||||||
"messages": "conversations",
|
"messages": "conversations",
|
||||||
"role": "from",
|
"role": "from",
|
||||||
"content": "value"
|
"content": "value",
|
||||||
|
"system": "system"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
@ -17,7 +17,8 @@
|
|||||||
"history": "数据集代表历史对话的表头名称(默认:None,用于 alpaca 格式)",
|
"history": "数据集代表历史对话的表头名称(默认:None,用于 alpaca 格式)",
|
||||||
"messages": "数据集代表消息列表的表头名称(默认:conversations,用于 sharegpt 格式)",
|
"messages": "数据集代表消息列表的表头名称(默认:conversations,用于 sharegpt 格式)",
|
||||||
"role": "消息中代表发送者身份的键名(默认:from,用于 sharegpt 格式)",
|
"role": "消息中代表发送者身份的键名(默认:from,用于 sharegpt 格式)",
|
||||||
"content": "消息中代表文本内容的键名(默认:value,用于 sharegpt 格式)"
|
"content": "消息中代表文本内容的键名(默认:value,用于 sharegpt 格式)",
|
||||||
|
"system": "数据集代表系统提示的表头名称(默认:None,用于两种格式)"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
@ -32,6 +33,7 @@
|
|||||||
"instruction": "用户指令(必填)",
|
"instruction": "用户指令(必填)",
|
||||||
"input": "用户输入(选填)",
|
"input": "用户输入(选填)",
|
||||||
"output": "模型回答(必填)",
|
"output": "模型回答(必填)",
|
||||||
|
"system": "系统提示词(选填)",
|
||||||
"history": [
|
"history": [
|
||||||
["第一轮指令(选填)", "第一轮回答(选填)"],
|
["第一轮指令(选填)", "第一轮回答(选填)"],
|
||||||
["第二轮指令(选填)", "第二轮回答(选填)"]
|
["第二轮指令(选填)", "第二轮回答(选填)"]
|
||||||
@ -48,6 +50,7 @@
|
|||||||
"prompt": "instruction",
|
"prompt": "instruction",
|
||||||
"query": "input",
|
"query": "input",
|
||||||
"response": "output",
|
"response": "output",
|
||||||
|
"system": "system",
|
||||||
"history": "history"
|
"history": "history"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -55,7 +58,7 @@
|
|||||||
|
|
||||||
其中 `prompt` 和 `response` 列应当是非空的字符串,分别代表用户指令和模型回答。`query` 列的内容将会和 `prompt` 列拼接作为模型输入。
|
其中 `prompt` 和 `response` 列应当是非空的字符串,分别代表用户指令和模型回答。`query` 列的内容将会和 `prompt` 列拼接作为模型输入。
|
||||||
|
|
||||||
`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意每轮的模型回答**均会被用于训练**。
|
`system` 为模板中的系统提示词。`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意每轮的模型回答**均会被用于训练**。
|
||||||
|
|
||||||
对于预训练数据集,仅 `prompt` 列中的内容会用于模型训练。
|
对于预训练数据集,仅 `prompt` 列中的内容会用于模型训练。
|
||||||
|
|
||||||
@ -86,7 +89,8 @@
|
|||||||
"from": "gpt",
|
"from": "gpt",
|
||||||
"value": "模型回答"
|
"value": "模型回答"
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"system": "系统提示词(选填)"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
```
|
```
|
||||||
@ -98,7 +102,8 @@
|
|||||||
"columns": {
|
"columns": {
|
||||||
"messages": "conversations",
|
"messages": "conversations",
|
||||||
"role": "from",
|
"role": "from",
|
||||||
"content": "value"
|
"content": "value",
|
||||||
|
"system": "system"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
@ -1 +0,0 @@
|
|||||||
38c89869c6aeca2a3af9ea1e09afe460f9b46810
|
|
@ -30,7 +30,6 @@ class ChatModel:
|
|||||||
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
||||||
self.model = dispatch_model(self.model)
|
self.model = dispatch_model(self.model)
|
||||||
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
||||||
self.system_prompt = data_args.system_prompt
|
|
||||||
|
|
||||||
def _process_args(
|
def _process_args(
|
||||||
self,
|
self,
|
||||||
@ -39,7 +38,6 @@ class ChatModel:
|
|||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
**input_kwargs
|
**input_kwargs
|
||||||
) -> Tuple[Dict[str, Any], int]:
|
) -> Tuple[Dict[str, Any], int]:
|
||||||
system = system or self.system_prompt
|
|
||||||
prompt, _ = self.template.encode_oneturn(
|
prompt, _ = self.template.encode_oneturn(
|
||||||
tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
|
tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
|
||||||
)
|
)
|
||||||
|
@ -83,7 +83,7 @@ def get_dataset(
|
|||||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
|
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
|
||||||
)
|
)
|
||||||
|
|
||||||
if data_args.streaming and (dataset_attr.load_from == "file"):
|
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||||
|
|
||||||
if max_samples is not None: # truncate dataset
|
if max_samples is not None: # truncate dataset
|
||||||
@ -91,8 +91,8 @@ def get_dataset(
|
|||||||
|
|
||||||
def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||||
# convert dataset from sharegpt format to alpaca format
|
# convert dataset from sharegpt format to alpaca format
|
||||||
outputs = {"prompt": [], "query": [], "response": [], "history": []}
|
outputs = {"prompt": [], "query": [], "response": [], "history": [], "system": []}
|
||||||
for msg_list in examples[dataset_attr.messages]:
|
for i, msg_list in enumerate(examples[dataset_attr.messages]):
|
||||||
msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2
|
msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2
|
||||||
if len(msg_list) == 0:
|
if len(msg_list) == 0:
|
||||||
continue
|
continue
|
||||||
@ -116,6 +116,7 @@ def get_dataset(
|
|||||||
outputs["query"].append("")
|
outputs["query"].append("")
|
||||||
outputs["response"].append(msg_pairs[-1][1])
|
outputs["response"].append(msg_pairs[-1][1])
|
||||||
outputs["history"].append(msg_pairs[:-1])
|
outputs["history"].append(msg_pairs[:-1])
|
||||||
|
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@ -136,17 +137,10 @@ def get_dataset(
|
|||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
for column_name in ["prompt", "query", "response", "history"]: # align dataset
|
for column_name in ["prompt", "query", "response", "history", "system"]: # align dataset
|
||||||
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
|
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
|
||||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
||||||
|
|
||||||
if dataset_attr.system_prompt: # add system prompt
|
|
||||||
system_prompt = dataset_attr.system_prompt
|
|
||||||
if data_args.streaming:
|
|
||||||
dataset = dataset.map(lambda _: {"system": system_prompt})
|
|
||||||
else:
|
|
||||||
dataset = dataset.add_column("system", [system_prompt] * len(dataset))
|
|
||||||
|
|
||||||
all_datasets.append(dataset)
|
all_datasets.append(dataset)
|
||||||
|
|
||||||
if len(data_args.dataset_list) == 1:
|
if len(data_args.dataset_list) == 1:
|
||||||
|
@ -17,7 +17,6 @@ class DatasetAttr:
|
|||||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
||||||
dataset_name: Optional[str] = None
|
dataset_name: Optional[str] = None
|
||||||
dataset_sha1: Optional[str] = None
|
dataset_sha1: Optional[str] = None
|
||||||
system_prompt: Optional[str] = None
|
|
||||||
subset: Optional[str] = None
|
subset: Optional[str] = None
|
||||||
folder: Optional[str] = None
|
folder: Optional[str] = None
|
||||||
ranking: Optional[bool] = False
|
ranking: Optional[bool] = False
|
||||||
@ -30,6 +29,7 @@ class DatasetAttr:
|
|||||||
messages: Optional[str] = "conversations"
|
messages: Optional[str] = "conversations"
|
||||||
role: Optional[str] = "from"
|
role: Optional[str] = "from"
|
||||||
content: Optional[str] = "value"
|
content: Optional[str] = "value"
|
||||||
|
system: Optional[str] = None
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return self.dataset_name
|
return self.dataset_name
|
||||||
@ -104,10 +104,6 @@ class DataArguments:
|
|||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
|
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
|
||||||
)
|
)
|
||||||
system_prompt: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."}
|
|
||||||
)
|
|
||||||
val_size: Optional[float] = field(
|
val_size: Optional[float] = field(
|
||||||
default=0,
|
default=0,
|
||||||
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
|
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
|
||||||
@ -145,15 +141,11 @@ class DataArguments:
|
|||||||
raise ValueError("Cannot open {} due to {}.".format(os.path.join(self.dataset_dir, DATA_CONFIG), str(err)))
|
raise ValueError("Cannot open {} due to {}.".format(os.path.join(self.dataset_dir, DATA_CONFIG), str(err)))
|
||||||
dataset_info = None
|
dataset_info = None
|
||||||
|
|
||||||
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
|
|
||||||
prompt_list = prompt_list * (len(dataset_names) // len(prompt_list))
|
|
||||||
assert len(prompt_list) == len(dataset_names), "Number of system prompts should be equal to datasets or 1."
|
|
||||||
|
|
||||||
if self.interleave_probs is not None:
|
if self.interleave_probs is not None:
|
||||||
self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
|
self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
|
||||||
|
|
||||||
self.dataset_list: List[DatasetAttr] = []
|
self.dataset_list: List[DatasetAttr] = []
|
||||||
for i, name in enumerate(dataset_names):
|
for name in dataset_names:
|
||||||
if name not in dataset_info:
|
if name not in dataset_info:
|
||||||
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
||||||
|
|
||||||
@ -191,10 +183,10 @@ class DataArguments:
|
|||||||
dataset_attr.messages = dataset_info[name]["columns"].get("messages", None)
|
dataset_attr.messages = dataset_info[name]["columns"].get("messages", None)
|
||||||
dataset_attr.role = dataset_info[name]["columns"].get("role", None)
|
dataset_attr.role = dataset_info[name]["columns"].get("role", None)
|
||||||
dataset_attr.content = dataset_info[name]["columns"].get("content", None)
|
dataset_attr.content = dataset_info[name]["columns"].get("content", None)
|
||||||
|
dataset_attr.system = dataset_info[name]["columns"].get("system", None)
|
||||||
|
|
||||||
dataset_attr.subset = dataset_info[name].get("subset", None)
|
dataset_attr.subset = dataset_info[name].get("subset", None)
|
||||||
dataset_attr.folder = dataset_info[name].get("folder", None)
|
dataset_attr.folder = dataset_info[name].get("folder", None)
|
||||||
dataset_attr.ranking = dataset_info[name].get("ranking", False)
|
dataset_attr.ranking = dataset_info[name].get("ranking", False)
|
||||||
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
|
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
|
||||||
dataset_attr.system_prompt = prompt_list[i]
|
|
||||||
self.dataset_list.append(dataset_attr)
|
self.dataset_list.append(dataset_attr)
|
||||||
|
@ -217,7 +217,7 @@ def load_model_and_tokenizer(
|
|||||||
# Prepare model for inference
|
# Prepare model for inference
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
model.requires_grad_(False) # fix all model params
|
model.requires_grad_(False) # fix all model params
|
||||||
model = model.to(model_args.compute_dtype) if model_args.quantization_bit is None else model
|
model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model
|
||||||
model.eval()
|
model.eval()
|
||||||
else:
|
else:
|
||||||
model.train()
|
model.train()
|
||||||
|
@ -77,7 +77,6 @@ class WebChatModel(ChatModel):
|
|||||||
finetuning_type=get("top.finetuning_type"),
|
finetuning_type=get("top.finetuning_type"),
|
||||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||||
template=get("top.template"),
|
template=get("top.template"),
|
||||||
system_prompt=get("top.system_prompt"),
|
|
||||||
flash_attn=get("top.flash_attn"),
|
flash_attn=get("top.flash_attn"),
|
||||||
shift_attn=get("top.shift_attn"),
|
shift_attn=get("top.shift_attn"),
|
||||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None
|
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None
|
||||||
|
@ -25,16 +25,13 @@ def create_top() -> Dict[str, "Component"]:
|
|||||||
|
|
||||||
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=1)
|
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none")
|
||||||
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1)
|
template = gr.Dropdown(choices=list(templates.keys()), value="default")
|
||||||
system_prompt = gr.Textbox(scale=2)
|
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
|
||||||
|
|
||||||
with gr.Accordion(label="Model config (LLaMA only)", open=False) as llama_tab:
|
|
||||||
with gr.Row():
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
flash_attn = gr.Checkbox(value=False)
|
flash_attn = gr.Checkbox(value=False)
|
||||||
shift_attn = gr.Checkbox(value=False)
|
shift_attn = gr.Checkbox(value=False)
|
||||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
|
|
||||||
|
|
||||||
model_name.change(
|
model_name.change(
|
||||||
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
|
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
|
||||||
@ -66,9 +63,7 @@ def create_top() -> Dict[str, "Component"]:
|
|||||||
advanced_tab=advanced_tab,
|
advanced_tab=advanced_tab,
|
||||||
quantization_bit=quantization_bit,
|
quantization_bit=quantization_bit,
|
||||||
template=template,
|
template=template,
|
||||||
system_prompt=system_prompt,
|
rope_scaling=rope_scaling,
|
||||||
llama_tab=llama_tab,
|
|
||||||
flash_attn=flash_attn,
|
flash_attn=flash_attn,
|
||||||
shift_attn=shift_attn,
|
shift_attn=shift_attn
|
||||||
rope_scaling=rope_scaling
|
|
||||||
)
|
)
|
||||||
|
@ -77,22 +77,12 @@ LOCALES = {
|
|||||||
"info": "构建提示词时使用的模板"
|
"info": "构建提示词时使用的模板"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"system_prompt": {
|
"rope_scaling": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "System prompt (optional)",
|
"label": "RoPE scaling"
|
||||||
"info": "A sequence used as the default system prompt."
|
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {
|
||||||
"label": "系统提示词(非必填)",
|
"label": "RoPE 插值方法"
|
||||||
"info": "默认使用的系统提示词"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"llama_tab": {
|
|
||||||
"en": {
|
|
||||||
"label": "Model configurations (LLaMA only)"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "模型设置(仅LLaMA)"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"flash_attn": {
|
"flash_attn": {
|
||||||
@ -111,14 +101,6 @@ LOCALES = {
|
|||||||
"label": "使用 shift short attention (S^2-Attn)"
|
"label": "使用 shift short attention (S^2-Attn)"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"rope_scaling": {
|
|
||||||
"en": {
|
|
||||||
"label": "RoPE scaling"
|
|
||||||
},
|
|
||||||
"zh": {
|
|
||||||
"label": "RoPE 插值方法"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"training_stage": {
|
"training_stage": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Stage",
|
"label": "Stage",
|
||||||
|
@ -25,7 +25,6 @@ class Manager:
|
|||||||
self.all_elems["top"]["finetuning_type"],
|
self.all_elems["top"]["finetuning_type"],
|
||||||
self.all_elems["top"]["quantization_bit"],
|
self.all_elems["top"]["quantization_bit"],
|
||||||
self.all_elems["top"]["template"],
|
self.all_elems["top"]["template"],
|
||||||
self.all_elems["top"]["system_prompt"],
|
|
||||||
self.all_elems["top"]["flash_attn"],
|
self.all_elems["top"]["flash_attn"],
|
||||||
self.all_elems["top"]["shift_attn"],
|
self.all_elems["top"]["shift_attn"],
|
||||||
self.all_elems["top"]["rope_scaling"]
|
self.all_elems["top"]["rope_scaling"]
|
||||||
|
@ -102,7 +102,6 @@ class Runner:
|
|||||||
finetuning_type=get("top.finetuning_type"),
|
finetuning_type=get("top.finetuning_type"),
|
||||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||||
template=get("top.template"),
|
template=get("top.template"),
|
||||||
system_prompt=get("top.system_prompt"),
|
|
||||||
flash_attn=get("top.flash_attn"),
|
flash_attn=get("top.flash_attn"),
|
||||||
shift_attn=get("top.shift_attn"),
|
shift_attn=get("top.shift_attn"),
|
||||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||||
@ -176,7 +175,6 @@ class Runner:
|
|||||||
finetuning_type=get("top.finetuning_type"),
|
finetuning_type=get("top.finetuning_type"),
|
||||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||||
template=get("top.template"),
|
template=get("top.template"),
|
||||||
system_prompt=get("top.system_prompt"),
|
|
||||||
flash_attn=get("top.flash_attn"),
|
flash_attn=get("top.flash_attn"),
|
||||||
shift_attn=get("top.shift_attn"),
|
shift_attn=get("top.shift_attn"),
|
||||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user