mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
support prefixes, loading multiple local files
Former-commit-id: cec9760eb890d37b733d8da73d0f3dbf924ca4ef
This commit is contained in:
parent
e4a869dc42
commit
8f1d99c926
@ -34,6 +34,21 @@ async def lifespan(app: FastAPI): # collects GPU memory
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
id: str
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: str = "owner"
|
||||
root: Optional[str] = None
|
||||
parent: Optional[str] = None
|
||||
permission: Optional[list] = None
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
object: str = "list"
|
||||
data: List[ModelCard] = []
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Literal["user", "assistant", "system"]
|
||||
content: str
|
||||
@ -73,6 +88,13 @@ class ChatCompletionResponse(BaseModel):
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
|
||||
|
||||
@app.get("/v1/models", response_model=ModelList)
|
||||
async def list_models():
|
||||
global model_args
|
||||
model_card = ModelCard(id="gpt-3.5-turbo")
|
||||
return ModelList(data=[model_card])
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
global model, tokenizer, source_prefix, generating_args
|
||||
|
@ -344,6 +344,13 @@ def prepare_data(
|
||||
if sha1 != hash:
|
||||
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
|
||||
|
||||
ext2type = {
|
||||
"csv": "csv",
|
||||
"json": "json",
|
||||
"jsonl": "json",
|
||||
"txt": "text"
|
||||
}
|
||||
|
||||
max_samples = data_args.max_samples
|
||||
all_datasets: List[Dataset] = [] # support multiple datasets
|
||||
|
||||
@ -352,37 +359,44 @@ def prepare_data(
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
|
||||
if dataset_attr.load_from == "hf_hub":
|
||||
raw_datasets = load_dataset(dataset_attr.dataset_name, cache_dir=model_args.cache_dir)
|
||||
data_path = dataset_attr.dataset_name
|
||||
data_files = None
|
||||
elif dataset_attr.load_from == "script":
|
||||
raw_datasets = load_dataset(
|
||||
os.path.join(data_args.dataset_dir, dataset_attr.dataset_name),
|
||||
cache_dir=model_args.cache_dir
|
||||
)
|
||||
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||
data_files = None
|
||||
elif dataset_attr.load_from == "file":
|
||||
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name)
|
||||
data_path = None
|
||||
data_files: List[str] = []
|
||||
|
||||
extension = dataset_attr.file_name.split(".")[-1]
|
||||
if extension == "csv":
|
||||
file_type = "csv"
|
||||
elif extension == "json" or extension == "jsonl":
|
||||
file_type = "json"
|
||||
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
|
||||
|
||||
if data_path is None:
|
||||
data_path = ext2type.get(data_files[0].split(".")[-1], None)
|
||||
else:
|
||||
assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match."
|
||||
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
|
||||
data_path = ext2type.get(data_files[0].split(".")[-1], None)
|
||||
else:
|
||||
file_type = "text"
|
||||
raise ValueError("File not found.")
|
||||
|
||||
if dataset_attr.file_sha1 is not None:
|
||||
checksum(data_file, dataset_attr.file_sha1)
|
||||
assert data_path, "File extension must be txt, csv, json or jsonl."
|
||||
|
||||
if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
|
||||
checksum(data_files[0], dataset_attr.dataset_sha1)
|
||||
else:
|
||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
||||
|
||||
raw_datasets = load_dataset(
|
||||
file_type,
|
||||
data_files=data_file,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_auth_token=True if model_args.use_auth_token else None
|
||||
)
|
||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
raw_datasets = load_dataset(
|
||||
data_path,
|
||||
data_files=data_files,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_auth_token=True if model_args.use_auth_token else None
|
||||
)
|
||||
dataset = raw_datasets[data_args.split]
|
||||
|
||||
if max_samples is not None:
|
||||
@ -390,6 +404,7 @@ def prepare_data(
|
||||
dataset = dataset.select(range(max_samples_temp))
|
||||
|
||||
dummy_data = [None] * len(dataset)
|
||||
prefix_data = [dataset_attr.source_prefix] * len(dataset)
|
||||
for column_name, target_name in [
|
||||
("prompt_column", "prompt"),
|
||||
("query_column", "query"),
|
||||
@ -401,6 +416,7 @@ def prepare_data(
|
||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
|
||||
else: # None or empty string
|
||||
dataset = dataset.add_column(target_name, dummy_data)
|
||||
dataset = dataset.add_column("prefix", prefix_data)
|
||||
all_datasets.append(dataset)
|
||||
|
||||
if len(data_args.dataset_list) == 1:
|
||||
@ -420,7 +436,6 @@ def preprocess_data(
|
||||
) -> Dataset:
|
||||
|
||||
column_names = list(dataset.column_names)
|
||||
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
||||
prompt_template = Template(data_args.prompt_template)
|
||||
|
||||
# support question with a single answer or multiple answers
|
||||
@ -429,6 +444,7 @@ def preprocess_data(
|
||||
if examples["prompt"][i] and examples["response"][i]:
|
||||
query, answer = examples["prompt"][i], examples["response"][i]
|
||||
query = query + "\n" + examples["query"][i] if examples["query"][i] else query
|
||||
prefix = examples["prefix"][i] if examples["prefix"][i] else ""
|
||||
dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix)
|
||||
yield dialog
|
||||
|
||||
@ -513,21 +529,22 @@ def preprocess_data(
|
||||
|
||||
def print_supervised_dataset_example(example):
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print("labels:\n{}".format(
|
||||
tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]]))
|
||||
)
|
||||
tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]],
|
||||
skip_special_tokens=False)
|
||||
))
|
||||
|
||||
def print_pairwise_dataset_example(example):
|
||||
print("accept_ids:\n{}".format(example["accept_ids"]))
|
||||
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"])))
|
||||
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
|
||||
print("reject_ids:\n{}".format(example["reject_ids"]))
|
||||
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"])))
|
||||
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))
|
||||
|
||||
def print_unsupervised_dataset_example(example):
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
|
||||
if stage == "pt":
|
||||
preprocess_function = preprocess_pretrain_dataset
|
||||
|
@ -10,14 +10,11 @@ class DatasetAttr:
|
||||
|
||||
load_from: str
|
||||
dataset_name: Optional[str] = None
|
||||
file_name: Optional[str] = None
|
||||
file_sha1: Optional[str] = None
|
||||
dataset_sha1: Optional[str] = None
|
||||
source_prefix: Optional[str] = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self.dataset_name is not None:
|
||||
return self.dataset_name
|
||||
else:
|
||||
return self.file_name
|
||||
return self.dataset_name
|
||||
|
||||
def __post_init__(self):
|
||||
self.prompt_column = "instruction"
|
||||
@ -137,7 +134,7 @@ class DataTrainingArguments:
|
||||
)
|
||||
source_prefix: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
||||
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes."}
|
||||
)
|
||||
dev_ratio: Optional[float] = field(
|
||||
default=0,
|
||||
@ -153,8 +150,15 @@ class DataTrainingArguments:
|
||||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
||||
dataset_info = json.load(f)
|
||||
|
||||
if self.source_prefix is not None:
|
||||
prefix_list = self.source_prefix.split("|")
|
||||
prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list
|
||||
assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1."
|
||||
else:
|
||||
prefix_list = [None] * len(dataset_names)
|
||||
|
||||
self.dataset_list: List[DatasetAttr] = []
|
||||
for name in dataset_names:
|
||||
for i, name in enumerate(dataset_names):
|
||||
if name not in dataset_info:
|
||||
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
|
||||
|
||||
@ -165,10 +169,12 @@ class DataTrainingArguments:
|
||||
else:
|
||||
dataset_attr = DatasetAttr(
|
||||
"file",
|
||||
file_name=dataset_info[name]["file_name"],
|
||||
file_sha1=dataset_info[name].get("file_sha1", None)
|
||||
dataset_name=dataset_info[name]["file_name"],
|
||||
dataset_sha1=dataset_info[name].get("file_sha1", None)
|
||||
)
|
||||
|
||||
dataset_attr.source_prefix = prefix_list[i]
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
|
||||
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
|
||||
|
Loading…
x
Reference in New Issue
Block a user