mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
refactor dataset_attr, add eos in pt, fix #757
This commit is contained in:
@@ -3,10 +3,9 @@ import json
|
||||
import gradio as gr
|
||||
import matplotlib.figure
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from llmtuner.dsets.utils import EXT2TYPE
|
||||
from llmtuner.extras.ploting import smooth
|
||||
from llmtuner.tuner import export_model
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
|
||||
@@ -37,6 +36,7 @@ def get_time() -> str:
|
||||
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
|
||||
if (
|
||||
len(dataset) > 0
|
||||
and "file_name" in dataset_info[dataset[0]]
|
||||
@@ -47,25 +47,26 @@ def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
|
||||
return gr.update(interactive=False)
|
||||
|
||||
|
||||
def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, Dict[str, Any]]:
|
||||
def get_preview(
|
||||
dataset_dir: str, dataset: list, start: Optional[int] = 0, end: Optional[int] = 2
|
||||
) -> Tuple[int, list, Dict[str, Any]]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
data_file = dataset_info[dataset[0]]["file_name"]
|
||||
data = []
|
||||
data_format = EXT2TYPE.get(data_file.split(".")[-1], None)
|
||||
if data_format == "text":
|
||||
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
data.append(line.strip())
|
||||
elif data_format == "json":
|
||||
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
|
||||
|
||||
data_file: str = dataset_info[dataset[0]]["file_name"]
|
||||
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
|
||||
if data_file.endswith(".json"):
|
||||
data = json.load(f)
|
||||
return len(data), data[:2], gr.update(visible=True)
|
||||
elif data_file.endswith(".jsonl"):
|
||||
data = [json.load(line) for line in f]
|
||||
else:
|
||||
data = [line for line in f]
|
||||
return len(data), data[start:end], gr.update(visible=True)
|
||||
|
||||
|
||||
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
||||
if finetuning_type != "lora":
|
||||
return gr.update(value="", interactive=False)
|
||||
return gr.update(value="None", interactive=False)
|
||||
else:
|
||||
return gr.update(interactive=True)
|
||||
|
||||
@@ -73,7 +74,7 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
||||
def gen_cmd(args: Dict[str, Any]) -> str:
|
||||
if args.get("do_train", None):
|
||||
args["plot_loss"] = True
|
||||
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python "]
|
||||
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python src/train_bash.py"]
|
||||
for k, v in args.items():
|
||||
if v is not None and v != "":
|
||||
cmd_lines.append(" --{} {} ".format(k, str(v)))
|
||||
|
||||
Reference in New Issue
Block a user