support prefixes, loading multiple local files

Former-commit-id: cec9760eb890d37b733d8da73d0f3dbf924ca4ef
This commit is contained in:
hiyouga 2023-06-26 15:32:40 +08:00
parent e4a869dc42
commit 8f1d99c926
3 changed files with 84 additions and 39 deletions

View File

@ -34,6 +34,21 @@ async def lifespan(app: FastAPI): # collects GPU memory
app = FastAPI(lifespan=lifespan)
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = []
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system"]
content: str
@ -73,6 +88,13 @@ class ChatCompletionResponse(BaseModel):
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
@app.get("/v1/models", response_model=ModelList)
async def list_models():
global model_args
model_card = ModelCard(id="gpt-3.5-turbo")
return ModelList(data=[model_card])
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer, source_prefix, generating_args

View File

@ -344,6 +344,13 @@ def prepare_data(
if sha1 != hash:
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
ext2type = {
"csv": "csv",
"json": "json",
"jsonl": "json",
"txt": "text"
}
max_samples = data_args.max_samples
all_datasets: List[Dataset] = [] # support multiple datasets
@ -352,37 +359,44 @@ def prepare_data(
logger.info("Loading dataset {}...".format(dataset_attr))
if dataset_attr.load_from == "hf_hub":
raw_datasets = load_dataset(dataset_attr.dataset_name, cache_dir=model_args.cache_dir)
data_path = dataset_attr.dataset_name
data_files = None
elif dataset_attr.load_from == "script":
raw_datasets = load_dataset(
os.path.join(data_args.dataset_dir, dataset_attr.dataset_name),
cache_dir=model_args.cache_dir
)
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
data_files = None
elif dataset_attr.load_from == "file":
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name)
data_path = None
data_files: List[str] = []
extension = dataset_attr.file_name.split(".")[-1]
if extension == "csv":
file_type = "csv"
elif extension == "json" or extension == "jsonl":
file_type = "json"
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
if data_path is None:
data_path = ext2type.get(data_files[0].split(".")[-1], None)
else:
file_type = "text"
if dataset_attr.file_sha1 is not None:
checksum(data_file, dataset_attr.file_sha1)
assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match."
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
data_path = ext2type.get(data_files[0].split(".")[-1], None)
else:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
raise ValueError("File not found.")
raw_datasets = load_dataset(
file_type,
data_files=data_file,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None
)
assert data_path, "File extension must be txt, csv, json or jsonl."
if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
checksum(data_files[0], dataset_attr.dataset_sha1)
else:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
else:
raise NotImplementedError
raw_datasets = load_dataset(
data_path,
data_files=data_files,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None
)
dataset = raw_datasets[data_args.split]
if max_samples is not None:
@ -390,6 +404,7 @@ def prepare_data(
dataset = dataset.select(range(max_samples_temp))
dummy_data = [None] * len(dataset)
prefix_data = [dataset_attr.source_prefix] * len(dataset)
for column_name, target_name in [
("prompt_column", "prompt"),
("query_column", "query"),
@ -401,6 +416,7 @@ def prepare_data(
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
else: # None or empty string
dataset = dataset.add_column(target_name, dummy_data)
dataset = dataset.add_column("prefix", prefix_data)
all_datasets.append(dataset)
if len(data_args.dataset_list) == 1:
@ -420,7 +436,6 @@ def preprocess_data(
) -> Dataset:
column_names = list(dataset.column_names)
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
prompt_template = Template(data_args.prompt_template)
# support question with a single answer or multiple answers
@ -429,6 +444,7 @@ def preprocess_data(
if examples["prompt"][i] and examples["response"][i]:
query, answer = examples["prompt"][i], examples["response"][i]
query = query + "\n" + examples["query"][i] if examples["query"][i] else query
prefix = examples["prefix"][i] if examples["prefix"][i] else ""
dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix)
yield dialog
@ -513,21 +529,22 @@ def preprocess_data(
def print_supervised_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format(
tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]]))
)
tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]],
skip_special_tokens=False)
))
def print_pairwise_dataset_example(example):
print("accept_ids:\n{}".format(example["accept_ids"]))
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"])))
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
print("reject_ids:\n{}".format(example["reject_ids"]))
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"])))
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))
def print_unsupervised_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
if stage == "pt":
preprocess_function = preprocess_pretrain_dataset

View File

@ -10,14 +10,11 @@ class DatasetAttr:
load_from: str
dataset_name: Optional[str] = None
file_name: Optional[str] = None
file_sha1: Optional[str] = None
dataset_sha1: Optional[str] = None
source_prefix: Optional[str] = None
def __repr__(self) -> str:
if self.dataset_name is not None:
return self.dataset_name
else:
return self.file_name
def __post_init__(self):
self.prompt_column = "instruction"
@ -137,7 +134,7 @@ class DataTrainingArguments:
)
source_prefix: Optional[str] = field(
default=None,
metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes."}
)
dev_ratio: Optional[float] = field(
default=0,
@ -153,8 +150,15 @@ class DataTrainingArguments:
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
dataset_info = json.load(f)
if self.source_prefix is not None:
prefix_list = self.source_prefix.split("|")
prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list
assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1."
else:
prefix_list = [None] * len(dataset_names)
self.dataset_list: List[DatasetAttr] = []
for name in dataset_names:
for i, name in enumerate(dataset_names):
if name not in dataset_info:
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
@ -165,10 +169,12 @@ class DataTrainingArguments:
else:
dataset_attr = DatasetAttr(
"file",
file_name=dataset_info[name]["file_name"],
file_sha1=dataset_info[name].get("file_sha1", None)
dataset_name=dataset_info[name]["file_name"],
dataset_sha1=dataset_info[name].get("file_sha1", None)
)
dataset_attr.source_prefix = prefix_list[i]
if "columns" in dataset_info[name]:
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)