refactor dataset_attr, add eos in pt, fix #757

Former-commit-id: a9d1fb72f7
This commit is contained in:
hiyouga
2023-09-01 19:00:45 +08:00
parent 173022473d
commit a4fd976048
20 changed files with 160 additions and 198 deletions

View File

@@ -6,7 +6,7 @@ import gradio as gr
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, DATASET_STAGE_MAP
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES
DEFAULT_CACHE_DIR = "cache"
@@ -78,11 +78,10 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
return {}
def list_dataset(dataset_dir: Optional[str] = None, stage: Optional[str] = None) -> Dict[str, Any]:
def list_dataset(
dataset_dir: Optional[str] = None, training_stage: Optional[str] = list(TRAINING_STAGES.keys())[0]
) -> Dict[str, Any]:
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
if stage:
dataset_stage = DATASET_STAGE_MAP[stage]
dataset_info = {key: value for key, value in dataset_info.items()
if ("stage" not in value) or value["stage"] == dataset_stage}
return gr.update(value=[], choices=list(dataset_info.keys()))
ranking = TRAINING_STAGES[training_stage] in ["rm", "dpo"]
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
return gr.update(value=[], choices=datasets)

View File

@@ -3,7 +3,7 @@ from transformers.trainer_utils import SchedulerType
import gradio as gr
from llmtuner.extras.constants import STAGES
from llmtuner.extras.constants import TRAINING_STAGES
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
@@ -15,7 +15,9 @@ if TYPE_CHECKING:
def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
with gr.Row():
training_stage = gr.Dropdown(choices=STAGES, value=STAGES[0], scale=2)
training_stage = gr.Dropdown(
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=2
)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4)
data_preview_btn = gr.Button(interactive=False, scale=1)

View File

@@ -8,7 +8,7 @@ from transformers.trainer import TRAINING_ARGS_NAME
from typing import Any, Dict, Generator, List, Tuple
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE
from llmtuner.extras.constants import DEFAULT_MODULE, TRAINING_STAGES
from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import run_exp
@@ -106,7 +106,7 @@ class Runner:
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
args = dict(
stage="sft",
stage=TRAINING_STAGES[training_stage],
model_name_or_path=get_model_path(model_name),
do_train=True,
overwrite_cache=True,
@@ -133,26 +133,20 @@ class Runner:
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
resume_lora_training=resume_lora_training,
resume_lora_training=(
False if TRAINING_STAGES[training_stage] in ["rm", "ppo", "dpo"] else resume_lora_training
),
output_dir=output_dir
)
args[compute_type] = True
if training_stage == "Reward Modeling":
args["stage"] = "rm"
args["resume_lora_training"] = False
elif training_stage == "PPO":
args["stage"] = "ppo"
args["resume_lora_training"] = False
if args["stage"] == "ppo":
args["reward_model"] = reward_model
args["padding_side"] = "left"
val_size = 0
elif training_stage == "DPO":
args["stage"] = "dpo"
args["resume_lora_training"] = False
if args["stage"] == "dpo":
args["dpo_beta"] = dpo_beta
elif training_stage == "Pre-Training":
args["stage"] = "pt"
if val_size > 1e-6:
args["val_size"] = val_size

View File

@@ -3,10 +3,9 @@ import json
import gradio as gr
import matplotlib.figure
import matplotlib.pyplot as plt
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from datetime import datetime
from llmtuner.dsets.utils import EXT2TYPE
from llmtuner.extras.ploting import smooth
from llmtuner.tuner import export_model
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
@@ -37,6 +36,7 @@ def get_time() -> str:
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
if (
len(dataset) > 0
and "file_name" in dataset_info[dataset[0]]
@@ -47,25 +47,26 @@ def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
return gr.update(interactive=False)
def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, Dict[str, Any]]:
def get_preview(
dataset_dir: str, dataset: list, start: Optional[int] = 0, end: Optional[int] = 2
) -> Tuple[int, list, Dict[str, Any]]:
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
dataset_info = json.load(f)
data_file = dataset_info[dataset[0]]["file_name"]
data = []
data_format = EXT2TYPE.get(data_file.split(".")[-1], None)
if data_format == "text":
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
for line in f:
data.append(line.strip())
elif data_format == "json":
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
data_file: str = dataset_info[dataset[0]]["file_name"]
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
if data_file.endswith(".json"):
data = json.load(f)
return len(data), data[:2], gr.update(visible=True)
elif data_file.endswith(".jsonl"):
data = [json.load(line) for line in f]
else:
data = [line for line in f]
return len(data), data[start:end], gr.update(visible=True)
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
if finetuning_type != "lora":
return gr.update(value="", interactive=False)
return gr.update(value="None", interactive=False)
else:
return gr.update(interactive=True)
@@ -73,7 +74,7 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
def gen_cmd(args: Dict[str, Any]) -> str:
if args.get("do_train", None):
args["plot_loss"] = True
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python "]
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python src/train_bash.py"]
for k, v in args.items():
if v is not None and v != "":
cmd_lines.append(" --{} {} ".format(k, str(v)))