mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
support prefixes, loading multiple local files
Former-commit-id: cec9760eb890d37b733d8da73d0f3dbf924ca4ef
This commit is contained in:
parent
e4a869dc42
commit
8f1d99c926
@ -34,6 +34,21 @@ async def lifespan(app: FastAPI): # collects GPU memory
|
|||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCard(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str = "model"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
owned_by: str = "owner"
|
||||||
|
root: Optional[str] = None
|
||||||
|
parent: Optional[str] = None
|
||||||
|
permission: Optional[list] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ModelList(BaseModel):
|
||||||
|
object: str = "list"
|
||||||
|
data: List[ModelCard] = []
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
role: Literal["user", "assistant", "system"]
|
role: Literal["user", "assistant", "system"]
|
||||||
content: str
|
content: str
|
||||||
@ -73,6 +88,13 @@ class ChatCompletionResponse(BaseModel):
|
|||||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/v1/models", response_model=ModelList)
|
||||||
|
async def list_models():
|
||||||
|
global model_args
|
||||||
|
model_card = ModelCard(id="gpt-3.5-turbo")
|
||||||
|
return ModelList(data=[model_card])
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||||||
async def create_chat_completion(request: ChatCompletionRequest):
|
async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
global model, tokenizer, source_prefix, generating_args
|
global model, tokenizer, source_prefix, generating_args
|
||||||
|
@ -344,6 +344,13 @@ def prepare_data(
|
|||||||
if sha1 != hash:
|
if sha1 != hash:
|
||||||
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
|
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
|
||||||
|
|
||||||
|
ext2type = {
|
||||||
|
"csv": "csv",
|
||||||
|
"json": "json",
|
||||||
|
"jsonl": "json",
|
||||||
|
"txt": "text"
|
||||||
|
}
|
||||||
|
|
||||||
max_samples = data_args.max_samples
|
max_samples = data_args.max_samples
|
||||||
all_datasets: List[Dataset] = [] # support multiple datasets
|
all_datasets: List[Dataset] = [] # support multiple datasets
|
||||||
|
|
||||||
@ -352,37 +359,44 @@ def prepare_data(
|
|||||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||||
|
|
||||||
if dataset_attr.load_from == "hf_hub":
|
if dataset_attr.load_from == "hf_hub":
|
||||||
raw_datasets = load_dataset(dataset_attr.dataset_name, cache_dir=model_args.cache_dir)
|
data_path = dataset_attr.dataset_name
|
||||||
|
data_files = None
|
||||||
elif dataset_attr.load_from == "script":
|
elif dataset_attr.load_from == "script":
|
||||||
raw_datasets = load_dataset(
|
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||||
os.path.join(data_args.dataset_dir, dataset_attr.dataset_name),
|
data_files = None
|
||||||
cache_dir=model_args.cache_dir
|
|
||||||
)
|
|
||||||
elif dataset_attr.load_from == "file":
|
elif dataset_attr.load_from == "file":
|
||||||
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name)
|
data_path = None
|
||||||
|
data_files: List[str] = []
|
||||||
|
|
||||||
extension = dataset_attr.file_name.split(".")[-1]
|
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||||
if extension == "csv":
|
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||||
file_type = "csv"
|
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
|
||||||
elif extension == "json" or extension == "jsonl":
|
|
||||||
file_type = "json"
|
if data_path is None:
|
||||||
|
data_path = ext2type.get(data_files[0].split(".")[-1], None)
|
||||||
|
else:
|
||||||
|
assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match."
|
||||||
|
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||||
|
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
|
||||||
|
data_path = ext2type.get(data_files[0].split(".")[-1], None)
|
||||||
else:
|
else:
|
||||||
file_type = "text"
|
raise ValueError("File not found.")
|
||||||
|
|
||||||
if dataset_attr.file_sha1 is not None:
|
assert data_path, "File extension must be txt, csv, json or jsonl."
|
||||||
checksum(data_file, dataset_attr.file_sha1)
|
|
||||||
|
if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
|
||||||
|
checksum(data_files[0], dataset_attr.dataset_sha1)
|
||||||
else:
|
else:
|
||||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
|
||||||
|
|
||||||
raw_datasets = load_dataset(
|
|
||||||
file_type,
|
|
||||||
data_files=data_file,
|
|
||||||
cache_dir=model_args.cache_dir,
|
|
||||||
use_auth_token=True if model_args.use_auth_token else None
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
raw_datasets = load_dataset(
|
||||||
|
data_path,
|
||||||
|
data_files=data_files,
|
||||||
|
cache_dir=model_args.cache_dir,
|
||||||
|
use_auth_token=True if model_args.use_auth_token else None
|
||||||
|
)
|
||||||
dataset = raw_datasets[data_args.split]
|
dataset = raw_datasets[data_args.split]
|
||||||
|
|
||||||
if max_samples is not None:
|
if max_samples is not None:
|
||||||
@ -390,6 +404,7 @@ def prepare_data(
|
|||||||
dataset = dataset.select(range(max_samples_temp))
|
dataset = dataset.select(range(max_samples_temp))
|
||||||
|
|
||||||
dummy_data = [None] * len(dataset)
|
dummy_data = [None] * len(dataset)
|
||||||
|
prefix_data = [dataset_attr.source_prefix] * len(dataset)
|
||||||
for column_name, target_name in [
|
for column_name, target_name in [
|
||||||
("prompt_column", "prompt"),
|
("prompt_column", "prompt"),
|
||||||
("query_column", "query"),
|
("query_column", "query"),
|
||||||
@ -401,6 +416,7 @@ def prepare_data(
|
|||||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
|
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
|
||||||
else: # None or empty string
|
else: # None or empty string
|
||||||
dataset = dataset.add_column(target_name, dummy_data)
|
dataset = dataset.add_column(target_name, dummy_data)
|
||||||
|
dataset = dataset.add_column("prefix", prefix_data)
|
||||||
all_datasets.append(dataset)
|
all_datasets.append(dataset)
|
||||||
|
|
||||||
if len(data_args.dataset_list) == 1:
|
if len(data_args.dataset_list) == 1:
|
||||||
@ -420,7 +436,6 @@ def preprocess_data(
|
|||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
|
|
||||||
column_names = list(dataset.column_names)
|
column_names = list(dataset.column_names)
|
||||||
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
|
||||||
prompt_template = Template(data_args.prompt_template)
|
prompt_template = Template(data_args.prompt_template)
|
||||||
|
|
||||||
# support question with a single answer or multiple answers
|
# support question with a single answer or multiple answers
|
||||||
@ -429,6 +444,7 @@ def preprocess_data(
|
|||||||
if examples["prompt"][i] and examples["response"][i]:
|
if examples["prompt"][i] and examples["response"][i]:
|
||||||
query, answer = examples["prompt"][i], examples["response"][i]
|
query, answer = examples["prompt"][i], examples["response"][i]
|
||||||
query = query + "\n" + examples["query"][i] if examples["query"][i] else query
|
query = query + "\n" + examples["query"][i] if examples["query"][i] else query
|
||||||
|
prefix = examples["prefix"][i] if examples["prefix"][i] else ""
|
||||||
dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix)
|
dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix)
|
||||||
yield dialog
|
yield dialog
|
||||||
|
|
||||||
@ -513,21 +529,22 @@ def preprocess_data(
|
|||||||
|
|
||||||
def print_supervised_dataset_example(example):
|
def print_supervised_dataset_example(example):
|
||||||
print("input_ids:\n{}".format(example["input_ids"]))
|
print("input_ids:\n{}".format(example["input_ids"]))
|
||||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
|
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||||
print("label_ids:\n{}".format(example["labels"]))
|
print("label_ids:\n{}".format(example["labels"]))
|
||||||
print("labels:\n{}".format(
|
print("labels:\n{}".format(
|
||||||
tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]]))
|
tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]],
|
||||||
)
|
skip_special_tokens=False)
|
||||||
|
))
|
||||||
|
|
||||||
def print_pairwise_dataset_example(example):
|
def print_pairwise_dataset_example(example):
|
||||||
print("accept_ids:\n{}".format(example["accept_ids"]))
|
print("accept_ids:\n{}".format(example["accept_ids"]))
|
||||||
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"])))
|
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
|
||||||
print("reject_ids:\n{}".format(example["reject_ids"]))
|
print("reject_ids:\n{}".format(example["reject_ids"]))
|
||||||
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"])))
|
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))
|
||||||
|
|
||||||
def print_unsupervised_dataset_example(example):
|
def print_unsupervised_dataset_example(example):
|
||||||
print("input_ids:\n{}".format(example["input_ids"]))
|
print("input_ids:\n{}".format(example["input_ids"]))
|
||||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
|
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||||
|
|
||||||
if stage == "pt":
|
if stage == "pt":
|
||||||
preprocess_function = preprocess_pretrain_dataset
|
preprocess_function = preprocess_pretrain_dataset
|
||||||
|
@ -10,14 +10,11 @@ class DatasetAttr:
|
|||||||
|
|
||||||
load_from: str
|
load_from: str
|
||||||
dataset_name: Optional[str] = None
|
dataset_name: Optional[str] = None
|
||||||
file_name: Optional[str] = None
|
dataset_sha1: Optional[str] = None
|
||||||
file_sha1: Optional[str] = None
|
source_prefix: Optional[str] = None
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
if self.dataset_name is not None:
|
return self.dataset_name
|
||||||
return self.dataset_name
|
|
||||||
else:
|
|
||||||
return self.file_name
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.prompt_column = "instruction"
|
self.prompt_column = "instruction"
|
||||||
@ -137,7 +134,7 @@ class DataTrainingArguments:
|
|||||||
)
|
)
|
||||||
source_prefix: Optional[str] = field(
|
source_prefix: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes."}
|
||||||
)
|
)
|
||||||
dev_ratio: Optional[float] = field(
|
dev_ratio: Optional[float] = field(
|
||||||
default=0,
|
default=0,
|
||||||
@ -153,8 +150,15 @@ class DataTrainingArguments:
|
|||||||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
||||||
dataset_info = json.load(f)
|
dataset_info = json.load(f)
|
||||||
|
|
||||||
|
if self.source_prefix is not None:
|
||||||
|
prefix_list = self.source_prefix.split("|")
|
||||||
|
prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list
|
||||||
|
assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1."
|
||||||
|
else:
|
||||||
|
prefix_list = [None] * len(dataset_names)
|
||||||
|
|
||||||
self.dataset_list: List[DatasetAttr] = []
|
self.dataset_list: List[DatasetAttr] = []
|
||||||
for name in dataset_names:
|
for i, name in enumerate(dataset_names):
|
||||||
if name not in dataset_info:
|
if name not in dataset_info:
|
||||||
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
|
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
|
||||||
|
|
||||||
@ -165,10 +169,12 @@ class DataTrainingArguments:
|
|||||||
else:
|
else:
|
||||||
dataset_attr = DatasetAttr(
|
dataset_attr = DatasetAttr(
|
||||||
"file",
|
"file",
|
||||||
file_name=dataset_info[name]["file_name"],
|
dataset_name=dataset_info[name]["file_name"],
|
||||||
file_sha1=dataset_info[name].get("file_sha1", None)
|
dataset_sha1=dataset_info[name].get("file_sha1", None)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
dataset_attr.source_prefix = prefix_list[i]
|
||||||
|
|
||||||
if "columns" in dataset_info[name]:
|
if "columns" in dataset_info[name]:
|
||||||
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
|
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
|
||||||
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
|
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user