mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-28 00:32:48 +08:00
Merge branch 'main' into feat/support_ms
Former-commit-id: 00f5c9ee1608b98ab8f40bcafdc3edc71833257f
This commit is contained in:
commit
9a26819a58
7
.gitignore
vendored
7
.gitignore
vendored
@ -157,4 +157,9 @@ cython_debug/
|
|||||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
#.idea/
|
.idea/
|
||||||
|
|
||||||
|
# custom .gitignore
|
||||||
|
user.config
|
||||||
|
saves/
|
||||||
|
cache/
|
||||||
|
@ -94,7 +94,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||||
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
||||||
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
||||||
| [Qwen](https://github.com/QwenLM/Qwen) | 7B/14B | c_attn | qwen |
|
| [Qwen](https://github.com/QwenLM/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||||
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
@ -158,6 +158,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
|||||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||||
|
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||||
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
||||||
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
||||||
@ -173,6 +174,7 @@ Please refer to [constants.py](src/llmtuner/extras/constants.py) for a full list
|
|||||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||||
|
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@ -201,8 +203,8 @@ huggingface-cli login
|
|||||||
| Full | 16 | 140GB | 240GB | 520GB | 1200GB |
|
| Full | 16 | 140GB | 240GB | 520GB | 1200GB |
|
||||||
| Freeze | 16 | 20GB | 40GB | 120GB | 240GB |
|
| Freeze | 16 | 20GB | 40GB | 120GB | 240GB |
|
||||||
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB |
|
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB |
|
||||||
| LoRA | 8 | 10GB | 16GB | 40GB | 80GB |
|
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB |
|
||||||
| LoRA | 4 | 6GB | 12GB | 24GB | 48GB |
|
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB |
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|
||||||
|
@ -94,7 +94,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||||
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
||||||
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
||||||
| [Qwen](https://github.com/QwenLM/Qwen) | 7B/14B | c_attn | qwen |
|
| [Qwen](https://github.com/QwenLM/Qwen) | 1.8B/7B/14B/72B | c_attn | qwen |
|
||||||
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
@ -158,6 +158,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||||
|
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||||
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
||||||
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
||||||
@ -173,6 +174,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846
|
|||||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||||
|
- [Nectar (en)](https://huggingface.co/datasets/berkeley-nest/Nectar)
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@ -201,8 +203,8 @@ huggingface-cli login
|
|||||||
| 全参数 | 16 | 140GB | 240GB | 520GB | 1200GB |
|
| 全参数 | 16 | 140GB | 240GB | 520GB | 1200GB |
|
||||||
| 部分参数 | 16 | 20GB | 40GB | 120GB | 240GB |
|
| 部分参数 | 16 | 20GB | 40GB | 120GB | 240GB |
|
||||||
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB |
|
| LoRA | 16 | 16GB | 32GB | 80GB | 160GB |
|
||||||
| LoRA | 8 | 10GB | 16GB | 40GB | 80GB |
|
| QLoRA | 8 | 10GB | 16GB | 40GB | 80GB |
|
||||||
| LoRA | 4 | 6GB | 12GB | 24GB | 48GB |
|
| QLoRA | 4 | 6GB | 12GB | 24GB | 48GB |
|
||||||
|
|
||||||
## 如何使用
|
## 如何使用
|
||||||
|
|
||||||
|
@ -134,6 +134,9 @@
|
|||||||
"webnovel": {
|
"webnovel": {
|
||||||
"hf_hub_url": "zxbsmk/webnovel_cn"
|
"hf_hub_url": "zxbsmk/webnovel_cn"
|
||||||
},
|
},
|
||||||
|
"nectar_sft": {
|
||||||
|
"hf_hub_url": "mlinmg/SFT-Nectar"
|
||||||
|
},
|
||||||
"adgen": {
|
"adgen": {
|
||||||
"hf_hub_url": "HasturOfficial/adgen",
|
"hf_hub_url": "HasturOfficial/adgen",
|
||||||
"columns": {
|
"columns": {
|
||||||
@ -216,6 +219,10 @@
|
|||||||
"file_sha1": "515b18ed497199131ddcc1af950345c11dc5c7fd",
|
"file_sha1": "515b18ed497199131ddcc1af950345c11dc5c7fd",
|
||||||
"ranking": true
|
"ranking": true
|
||||||
},
|
},
|
||||||
|
"nectar_rm": {
|
||||||
|
"hf_hub_url": "mlinmg/RLAIF-Nectar",
|
||||||
|
"ranking": true
|
||||||
|
},
|
||||||
"wiki_demo": {
|
"wiki_demo": {
|
||||||
"file_name": "wiki_demo.txt",
|
"file_name": "wiki_demo.txt",
|
||||||
"file_sha1": "e70375e28eda542a90c68213640cc371898ce181",
|
"file_sha1": "e70375e28eda542a90c68213640cc371898ce181",
|
||||||
|
@ -408,18 +408,31 @@ register_template(
|
|||||||
"{{system}}"
|
"{{system}}"
|
||||||
],
|
],
|
||||||
prompt=[
|
prompt=[
|
||||||
"### Instruction:\n{{query}}\n\n### Response:\n"
|
"User: {{query}}\n\nAssistant:"
|
||||||
|
],
|
||||||
|
system="",
|
||||||
|
sep=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_template(
|
||||||
|
name="deepseekcoder",
|
||||||
|
prefix=[
|
||||||
|
"{{system}}"
|
||||||
|
],
|
||||||
|
prompt=[
|
||||||
|
"### Instruction:\n{{query}}\n### Response:\n"
|
||||||
],
|
],
|
||||||
system=(
|
system=(
|
||||||
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
||||||
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
||||||
"For politically sensitive questions, security and privacy issues, "
|
"For politically sensitive questions, security and privacy issues, "
|
||||||
"and other non-computer science questions, you will refuse to answer."
|
"and other non-computer science questions, you will refuse to answer\n"
|
||||||
),
|
),
|
||||||
sep=[
|
sep=[
|
||||||
"\n",
|
"\n",
|
||||||
{"token": "<|EOT|>"},
|
{"token": "<|EOT|>"},
|
||||||
"\n\n"
|
"\n"
|
||||||
],
|
],
|
||||||
stop_words=[
|
stop_words=[
|
||||||
"<|EOT|>"
|
"<|EOT|>"
|
||||||
|
@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
|
from transformers.modeling_utils import custom_object_save, unwrap_model
|
||||||
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
||||||
|
|
||||||
from llmtuner.extras.constants import LOG_FILE_NAME
|
from llmtuner.extras.constants import LOG_FILE_NAME
|
||||||
@ -18,6 +19,16 @@ if TYPE_CHECKING:
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _save_model_with_valuehead(model: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None:
|
||||||
|
model.pretrained_model.config.save_pretrained(output_dir)
|
||||||
|
if model.pretrained_model.can_generate():
|
||||||
|
model.pretrained_model.generation_config.save_pretrained(output_dir)
|
||||||
|
if getattr(model, "is_peft_model", False):
|
||||||
|
model.pretrained_model.save_pretrained(output_dir)
|
||||||
|
elif getattr(model.pretrained_model, "_auto_class", None): # must not a peft model
|
||||||
|
custom_object_save(model.pretrained_model, output_dir, config=model.pretrained_model.config)
|
||||||
|
|
||||||
|
|
||||||
class SavePeftModelCallback(TrainerCallback):
|
class SavePeftModelCallback(TrainerCallback):
|
||||||
|
|
||||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
@ -25,25 +36,17 @@ class SavePeftModelCallback(TrainerCallback):
|
|||||||
Event called after a checkpoint save.
|
Event called after a checkpoint save.
|
||||||
"""
|
"""
|
||||||
if args.should_save:
|
if args.should_save:
|
||||||
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
_save_model_with_valuehead(
|
||||||
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
|
model=unwrap_model(kwargs.pop("model")),
|
||||||
model.pretrained_model.config.save_pretrained(output_dir)
|
output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
||||||
if model.pretrained_model.can_generate():
|
)
|
||||||
model.pretrained_model.generation_config.save_pretrained(output_dir)
|
|
||||||
if getattr(model, "is_peft_model", False):
|
|
||||||
model.pretrained_model.save_pretrained(output_dir)
|
|
||||||
|
|
||||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called at the end of training.
|
Event called at the end of training.
|
||||||
"""
|
"""
|
||||||
if args.should_save:
|
if args.should_save:
|
||||||
model: "AutoModelForCausalLMWithValueHead" = kwargs.pop("model")
|
_save_model_with_valuehead(model=unwrap_model(kwargs.pop("model")), output_dir=args.output_dir)
|
||||||
model.pretrained_model.config.save_pretrained(args.output_dir)
|
|
||||||
if model.pretrained_model.can_generate():
|
|
||||||
model.pretrained_model.generation_config.save_pretrained(args.output_dir)
|
|
||||||
if getattr(model, "is_peft_model", False):
|
|
||||||
model.pretrained_model.save_pretrained(args.output_dir)
|
|
||||||
|
|
||||||
|
|
||||||
class LogCallback(TrainerCallback):
|
class LogCallback(TrainerCallback):
|
||||||
|
@ -69,11 +69,12 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
|||||||
|
|
||||||
def get_current_device() -> str:
|
def get_current_device() -> str:
|
||||||
import accelerate
|
import accelerate
|
||||||
dummy_accelerator = accelerate.Accelerator()
|
|
||||||
if accelerate.utils.is_xpu_available():
|
if accelerate.utils.is_xpu_available():
|
||||||
return "xpu:{}".format(dummy_accelerator.local_process_index)
|
return "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
||||||
|
elif accelerate.utils.is_npu_available() or torch.cuda.is_available():
|
||||||
|
return os.environ.get("LOCAL_RANK", "0")
|
||||||
else:
|
else:
|
||||||
return dummy_accelerator.local_process_index if torch.cuda.is_available() else "cpu"
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
def get_logits_processor() -> "LogitsProcessorList":
|
def get_logits_processor() -> "LogitsProcessorList":
|
||||||
|
@ -4,6 +4,9 @@ from typing import List, Literal, Optional
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
DATA_CONFIG = "dataset_info.json"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DatasetAttr:
|
class DatasetAttr:
|
||||||
|
|
||||||
@ -130,11 +133,11 @@ class DataArguments:
|
|||||||
self.seed = seed
|
self.seed = seed
|
||||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
|
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
|
||||||
try:
|
try:
|
||||||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
with open(os.path.join(self.dataset_dir, DATA_CONFIG), "r") as f:
|
||||||
dataset_info = json.load(f)
|
dataset_info = json.load(f)
|
||||||
except Exception:
|
except Exception as err:
|
||||||
if self.dataset is not None:
|
if self.dataset is not None:
|
||||||
raise ValueError("Cannot find dataset_info.json in `dataset_dir`.")
|
raise ValueError("Cannot open {} due to {}.".format(os.path.join(self.dataset_dir, DATA_CONFIG), str(err)))
|
||||||
dataset_info = None
|
dataset_info = None
|
||||||
|
|
||||||
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
|
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
|
||||||
@ -147,7 +150,7 @@ class DataArguments:
|
|||||||
self.dataset_list: List[DatasetAttr] = []
|
self.dataset_list: List[DatasetAttr] = []
|
||||||
for i, name in enumerate(dataset_names):
|
for i, name in enumerate(dataset_names):
|
||||||
if name not in dataset_info:
|
if name not in dataset_info:
|
||||||
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
|
raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
|
||||||
|
|
||||||
if "hf_hub_url" in dataset_info[name]:
|
if "hf_hub_url" in dataset_info[name]:
|
||||||
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
||||||
|
@ -6,7 +6,9 @@ from tqdm import tqdm
|
|||||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
|
|
||||||
from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
|
from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||||
|
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||||
|
from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
||||||
|
|
||||||
from trl import PPOTrainer
|
from trl import PPOTrainer
|
||||||
from trl.core import PPODecorators, logprobs_from_logits
|
from trl.core import PPODecorators, logprobs_from_logits
|
||||||
@ -55,6 +57,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
|
|
||||||
self.state = TrainerState()
|
self.state = TrainerState()
|
||||||
self.control = TrainerControl()
|
self.control = TrainerControl()
|
||||||
|
self.is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
|
||||||
|
self.accelerator.state, "deepspeed_plugin"
|
||||||
|
)
|
||||||
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
||||||
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
|
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
|
||||||
|
|
||||||
@ -62,10 +67,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||||
|
|
||||||
if reward_model is not None:
|
if reward_model is not None:
|
||||||
is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
|
if self.is_deepspeed_enabled:
|
||||||
self.accelerator.state, "deepspeed_plugin"
|
|
||||||
)
|
|
||||||
if is_deepspeed_enabled:
|
|
||||||
if not (
|
if not (
|
||||||
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
|
getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False)
|
||||||
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
|
or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False)
|
||||||
@ -298,7 +300,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||||
logits, _, values = model(**input_kwargs)
|
logits, _, values = model(**input_kwargs)
|
||||||
|
|
||||||
if getattr(model.config, "model_type", None) == "chatglm":
|
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||||
|
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
|
||||||
values = torch.transpose(values, 0, 1)
|
values = torch.transpose(values, 0, 1)
|
||||||
|
|
||||||
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
|
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
|
||||||
@ -344,4 +347,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
Subclass and override to inject custom behavior.
|
Subclass and override to inject custom behavior.
|
||||||
"""
|
"""
|
||||||
if self.args.should_save:
|
if self.args.should_save:
|
||||||
self._save(output_dir)
|
try:
|
||||||
|
self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model))
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(
|
||||||
|
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
|
||||||
|
" zero_to_fp32.py to recover weights"
|
||||||
|
)
|
||||||
|
self._save(output_dir, state_dict={})
|
||||||
|
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
|
||||||
|
self.model.save_checkpoint(output_dir) # wrapped model
|
||||||
|
@ -40,7 +40,8 @@ class PairwiseTrainer(Trainer):
|
|||||||
# Compute rewards
|
# Compute rewards
|
||||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
||||||
|
|
||||||
if getattr(model.config, "model_type", None) == "chatglm":
|
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
|
||||||
|
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
|
||||||
values = torch.transpose(values, 0, 1)
|
values = torch.transpose(values, 0, 1)
|
||||||
|
|
||||||
# Split the inputs and rewards into two parts, chosen and rejected
|
# Split the inputs and rewards into two parts, chosen and rejected
|
||||||
|
@ -11,14 +11,22 @@ from transformers.utils import (
|
|||||||
ADAPTER_SAFE_WEIGHTS_NAME
|
ADAPTER_SAFE_WEIGHTS_NAME
|
||||||
)
|
)
|
||||||
|
|
||||||
from llmtuner.extras.constants import DEFAULT_MODULE, DEFAULT_TEMPLATE, SUPPORTED_MODELS, ALL_OFFICIAL_MODELS, TRAINING_STAGES
|
|
||||||
|
from llmtuner.extras.constants import (
|
||||||
|
DEFAULT_MODULE,
|
||||||
|
DEFAULT_TEMPLATE,
|
||||||
|
SUPPORTED_MODELS,
|
||||||
|
ALL_OFFICIAL_MODELS,
|
||||||
|
TRAINING_STAGES
|
||||||
|
)
|
||||||
|
from llmtuner.hparams.data_args import DATA_CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CACHE_DIR = "cache"
|
DEFAULT_CACHE_DIR = "cache"
|
||||||
DEFAULT_DATA_DIR = "data"
|
DEFAULT_DATA_DIR = "data"
|
||||||
DEFAULT_SAVE_DIR = "saves"
|
DEFAULT_SAVE_DIR = "saves"
|
||||||
USER_CONFIG = "user.config"
|
USER_CONFIG = "user.config"
|
||||||
DATA_CONFIG = "dataset_info.json"
|
|
||||||
CKPT_NAMES = [
|
CKPT_NAMES = [
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
WEIGHTS_INDEX_NAME,
|
WEIGHTS_INDEX_NAME,
|
||||||
@ -92,12 +100,12 @@ def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
|
|||||||
return gr.update(value=[], choices=checkpoints)
|
return gr.update(value=[], choices=checkpoints)
|
||||||
|
|
||||||
|
|
||||||
def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
|
def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
||||||
try:
|
try:
|
||||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
except:
|
except Exception as err:
|
||||||
print("Cannot find {} in {}.".format(DATA_CONFIG, dataset_dir))
|
print("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err)))
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,10 +38,11 @@ def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
|||||||
max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
|
max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
|
||||||
top_p = gr.Slider(0.01, 1, value=0.7, step=0.01)
|
top_p = gr.Slider(0.01, 1, value=0.7, step=0.01)
|
||||||
temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
|
temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
|
||||||
|
output_dir = gr.Textbox()
|
||||||
|
|
||||||
input_elems.update({max_new_tokens, top_p, temperature})
|
input_elems.update({max_new_tokens, top_p, temperature, output_dir})
|
||||||
elem_dict.update(dict(
|
elem_dict.update(dict(
|
||||||
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, output_dir=output_dir
|
||||||
))
|
))
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -49,7 +49,10 @@ class Engine:
|
|||||||
else:
|
else:
|
||||||
yield self._form_dict({"eval.resume_btn": {"value": True}})
|
yield self._form_dict({"eval.resume_btn": {"value": True}})
|
||||||
else:
|
else:
|
||||||
yield self._form_dict({"train.output_dir": {"value": get_time()}})
|
yield self._form_dict({
|
||||||
|
"train.output_dir": {"value": "train_" + get_time()},
|
||||||
|
"eval.output_dir": {"value": "eval_" + get_time()},
|
||||||
|
})
|
||||||
|
|
||||||
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
|
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
|
||||||
return {
|
return {
|
||||||
|
@ -132,7 +132,7 @@ LOCALES = {
|
|||||||
"dataset_dir": {
|
"dataset_dir": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Data dir",
|
"label": "Data dir",
|
||||||
"info": "Path of the data directory."
|
"info": "Path to the data directory."
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {
|
||||||
"label": "数据路径",
|
"label": "数据路径",
|
||||||
@ -475,12 +475,12 @@ LOCALES = {
|
|||||||
},
|
},
|
||||||
"output_dir": {
|
"output_dir": {
|
||||||
"en": {
|
"en": {
|
||||||
"label": "Checkpoint name",
|
"label": "Output dir",
|
||||||
"info": "Directory to save checkpoint."
|
"info": "Directory for saving results."
|
||||||
},
|
},
|
||||||
"zh": {
|
"zh": {
|
||||||
"label": "断点名称",
|
"label": "输出目录",
|
||||||
"info": "保存模型断点的文件夹名称。"
|
"info": "保存结果的路径。"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"output_box": {
|
"output_box": {
|
||||||
|
@ -87,9 +87,9 @@ class Runner:
|
|||||||
user_config = load_config()
|
user_config = load_config()
|
||||||
|
|
||||||
if get("top.checkpoints"):
|
if get("top.checkpoints"):
|
||||||
checkpoint_dir = ",".join([get_save_dir(
|
checkpoint_dir = ",".join([
|
||||||
get("top.model_name"), get("top.finetuning_type"), ckpt
|
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
||||||
) for ckpt in get("top.checkpoints")])
|
])
|
||||||
else:
|
else:
|
||||||
checkpoint_dir = None
|
checkpoint_dir = None
|
||||||
|
|
||||||
@ -160,15 +160,11 @@ class Runner:
|
|||||||
user_config = load_config()
|
user_config = load_config()
|
||||||
|
|
||||||
if get("top.checkpoints"):
|
if get("top.checkpoints"):
|
||||||
checkpoint_dir = ",".join([get_save_dir(
|
checkpoint_dir = ",".join([
|
||||||
get("top.model_name"), get("top.finetuning_type"), ckpt
|
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
||||||
) for ckpt in get("top.checkpoints")])
|
])
|
||||||
output_dir = get_save_dir(
|
|
||||||
get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints"))
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
checkpoint_dir = None
|
checkpoint_dir = None
|
||||||
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), "eval_base")
|
|
||||||
|
|
||||||
args = dict(
|
args = dict(
|
||||||
stage="sft",
|
stage="sft",
|
||||||
@ -192,7 +188,7 @@ class Runner:
|
|||||||
max_new_tokens=get("eval.max_new_tokens"),
|
max_new_tokens=get("eval.max_new_tokens"),
|
||||||
top_p=get("eval.top_p"),
|
top_p=get("eval.top_p"),
|
||||||
temperature=get("eval.temperature"),
|
temperature=get("eval.temperature"),
|
||||||
output_dir=output_dir
|
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("eval.output_dir"))
|
||||||
)
|
)
|
||||||
|
|
||||||
if get("eval.predict"):
|
if get("eval.predict"):
|
||||||
@ -242,6 +238,7 @@ class Runner:
|
|||||||
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get(
|
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get(
|
||||||
"{}.output_dir".format("train" if self.do_train else "eval")
|
"{}.output_dir".format("train" if self.do_train else "eval")
|
||||||
))
|
))
|
||||||
|
|
||||||
while self.thread.is_alive():
|
while self.thread.is_alive():
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
if self.aborted:
|
if self.aborted:
|
||||||
|
@ -44,7 +44,8 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
|||||||
def gen_cmd(args: Dict[str, Any]) -> str:
|
def gen_cmd(args: Dict[str, Any]) -> str:
|
||||||
args.pop("disable_tqdm", None)
|
args.pop("disable_tqdm", None)
|
||||||
args["plot_loss"] = args.get("do_train", None)
|
args["plot_loss"] = args.get("do_train", None)
|
||||||
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python src/train_bash.py "]
|
current_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
|
||||||
|
cmd_lines = ["CUDA_VISIBLE_DEVICES={} python src/train_bash.py ".format(current_devices)]
|
||||||
for k, v in args.items():
|
for k, v in args.items():
|
||||||
if v is not None and v != "":
|
if v is not None and v != "":
|
||||||
cmd_lines.append(" --{} {} ".format(k, str(v)))
|
cmd_lines.append(" --{} {} ".format(k, str(v)))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user