diff --git a/src/api_demo.py b/src/api_demo.py index 85c41eae..5f8d99d6 100644 --- a/src/api_demo.py +++ b/src/api_demo.py @@ -34,6 +34,21 @@ async def lifespan(app: FastAPI): # collects GPU memory app = FastAPI(lifespan=lifespan) +class ModelCard(BaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "owner" + root: Optional[str] = None + parent: Optional[str] = None + permission: Optional[list] = None + + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] = [] + + class ChatMessage(BaseModel): role: Literal["user", "assistant", "system"] content: str @@ -73,6 +88,13 @@ class ChatCompletionResponse(BaseModel): created: Optional[int] = Field(default_factory=lambda: int(time.time())) +@app.get("/v1/models", response_model=ModelList) +async def list_models(): + global model_args + model_card = ModelCard(id="gpt-3.5-turbo") + return ModelList(data=[model_card]) + + @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def create_chat_completion(request: ChatCompletionRequest): global model, tokenizer, source_prefix, generating_args diff --git a/src/utils/common.py b/src/utils/common.py index 93c9a2ce..2bb82d6e 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -344,6 +344,13 @@ def prepare_data( if sha1 != hash: logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path)) + ext2type = { + "csv": "csv", + "json": "json", + "jsonl": "json", + "txt": "text" + } + max_samples = data_args.max_samples all_datasets: List[Dataset] = [] # support multiple datasets @@ -352,37 +359,44 @@ def prepare_data( logger.info("Loading dataset {}...".format(dataset_attr)) if dataset_attr.load_from == "hf_hub": - raw_datasets = load_dataset(dataset_attr.dataset_name, cache_dir=model_args.cache_dir) + data_path = dataset_attr.dataset_name + data_files = None elif dataset_attr.load_from == "script": - raw_datasets = load_dataset( - os.path.join(data_args.dataset_dir, dataset_attr.dataset_name), - cache_dir=model_args.cache_dir - ) + data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) + data_files = None elif dataset_attr.load_from == "file": - data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name) + data_path = None + data_files: List[str] = [] - extension = dataset_attr.file_name.split(".")[-1] - if extension == "csv": - file_type = "csv" - elif extension == "json" or extension == "jsonl": - file_type = "json" + if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): + for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): + data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name)) + + if data_path is None: + data_path = ext2type.get(data_files[0].split(".")[-1], None) + else: + assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match." + elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): + data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)) + data_path = ext2type.get(data_files[0].split(".")[-1], None) else: - file_type = "text" + raise ValueError("File not found.") - if dataset_attr.file_sha1 is not None: - checksum(data_file, dataset_attr.file_sha1) + assert data_path, "File extension must be txt, csv, json or jsonl." + + if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None: + checksum(data_files[0], dataset_attr.dataset_sha1) else: - logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.") - - raw_datasets = load_dataset( - file_type, - data_files=data_file, - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None - ) + logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.") else: raise NotImplementedError + raw_datasets = load_dataset( + data_path, + data_files=data_files, + cache_dir=model_args.cache_dir, + use_auth_token=True if model_args.use_auth_token else None + ) dataset = raw_datasets[data_args.split] if max_samples is not None: @@ -390,6 +404,7 @@ def prepare_data( dataset = dataset.select(range(max_samples_temp)) dummy_data = [None] * len(dataset) + prefix_data = [dataset_attr.source_prefix] * len(dataset) for column_name, target_name in [ ("prompt_column", "prompt"), ("query_column", "query"), @@ -401,6 +416,7 @@ def prepare_data( dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name) else: # None or empty string dataset = dataset.add_column(target_name, dummy_data) + dataset = dataset.add_column("prefix", prefix_data) all_datasets.append(dataset) if len(data_args.dataset_list) == 1: @@ -420,7 +436,6 @@ def preprocess_data( ) -> Dataset: column_names = list(dataset.column_names) - prefix = data_args.source_prefix if data_args.source_prefix is not None else "" prompt_template = Template(data_args.prompt_template) # support question with a single answer or multiple answers @@ -429,6 +444,7 @@ def preprocess_data( if examples["prompt"][i] and examples["response"][i]: query, answer = examples["prompt"][i], examples["response"][i] query = query + "\n" + examples["query"][i] if examples["query"][i] else query + prefix = examples["prefix"][i] if examples["prefix"][i] else "" dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix) yield dialog @@ -513,21 +529,22 @@ def preprocess_data( def print_supervised_dataset_example(example): print("input_ids:\n{}".format(example["input_ids"])) - print("inputs:\n{}".format(tokenizer.decode(example["input_ids"]))) + print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("label_ids:\n{}".format(example["labels"])) print("labels:\n{}".format( - tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]])) - ) + tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]], + skip_special_tokens=False) + )) def print_pairwise_dataset_example(example): print("accept_ids:\n{}".format(example["accept_ids"])) - print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"]))) + print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False))) print("reject_ids:\n{}".format(example["reject_ids"])) - print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"]))) + print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False))) def print_unsupervised_dataset_example(example): print("input_ids:\n{}".format(example["input_ids"])) - print("inputs:\n{}".format(tokenizer.decode(example["input_ids"]))) + print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) if stage == "pt": preprocess_function = preprocess_pretrain_dataset diff --git a/src/utils/config.py b/src/utils/config.py index ffd4ee7a..ff300424 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -10,14 +10,11 @@ class DatasetAttr: load_from: str dataset_name: Optional[str] = None - file_name: Optional[str] = None - file_sha1: Optional[str] = None + dataset_sha1: Optional[str] = None + source_prefix: Optional[str] = None def __repr__(self) -> str: - if self.dataset_name is not None: - return self.dataset_name - else: - return self.file_name + return self.dataset_name def __post_init__(self): self.prompt_column = "instruction" @@ -137,7 +134,7 @@ class DataTrainingArguments: ) source_prefix: Optional[str] = field( default=None, - metadata={"help": "A prefix to add before every source text (useful for T5 models)."} + metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes."} ) dev_ratio: Optional[float] = field( default=0, @@ -153,8 +150,15 @@ class DataTrainingArguments: with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: dataset_info = json.load(f) + if self.source_prefix is not None: + prefix_list = self.source_prefix.split("|") + prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list + assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1." + else: + prefix_list = [None] * len(dataset_names) + self.dataset_list: List[DatasetAttr] = [] - for name in dataset_names: + for i, name in enumerate(dataset_names): if name not in dataset_info: raise ValueError("Undefined dataset {} in dataset_info.json.".format(name)) @@ -165,10 +169,12 @@ class DataTrainingArguments: else: dataset_attr = DatasetAttr( "file", - file_name=dataset_info[name]["file_name"], - file_sha1=dataset_info[name].get("file_sha1", None) + dataset_name=dataset_info[name]["file_name"], + dataset_sha1=dataset_info[name].get("file_sha1", None) ) + dataset_attr.source_prefix = prefix_list[i] + if "columns" in dataset_info[name]: dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None) dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)