mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
fix unittest
Former-commit-id: e80006795fe6344ea98b61f9a8db16356498c7cb
This commit is contained in:
parent
e1e01d7efd
commit
d0891f05fa
@ -33,6 +33,8 @@ TRAIN_ARGS = {
|
||||
"stage": "sft",
|
||||
"do_predict": True,
|
||||
"finetuning_type": "full",
|
||||
"eval_dataset": "system_chat",
|
||||
"dataset_dir": "REMOTE:" + DEMO_DATA,
|
||||
"template": "llama3",
|
||||
"cutoff_len": 8192,
|
||||
"overwrite_cache": True,
|
||||
@ -45,7 +47,7 @@ TRAIN_ARGS = {
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [16])
|
||||
def test_unsupervised_data(num_samples: int):
|
||||
train_dataset = load_train_dataset(dataset_dir="REMOTE:" + DEMO_DATA, dataset="system_chat", **TRAIN_ARGS)
|
||||
train_dataset = load_train_dataset(**TRAIN_ARGS)
|
||||
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||
original_data = load_dataset(DEMO_DATA, name="system_chat", split="train")
|
||||
indexes = random.choices(range(len(original_data)), k=num_samples)
|
||||
|
Loading…
x
Reference in New Issue
Block a user