From 2ba88c6b08c1b03cdea8600a97533518d306a66c Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 14 Nov 2023 21:14:42 +0800 Subject: [PATCH] Update cal_lr.py Former-commit-id: 829e879e040cae3b49ace981b3d4a8eaf1e0c4ae --- tests/cal_lr.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/cal_lr.py b/tests/cal_lr.py index a036b414..317520dc 100644 --- a/tests/cal_lr.py +++ b/tests/cal_lr.py @@ -32,19 +32,14 @@ def calculate_lr( dataset=dataset, template="default", cutoff_len=cutoff_len, - output_dir="dummy_dir", - fp16=True + output_dir="dummy_dir" )) trainset = get_dataset(model_args, data_args) _, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft") trainset = preprocess_dataset(trainset, tokenizer, data_args, training_args, stage="sft") data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX) dataloader = DataLoader( - dataset=trainset, - batch_size=batch_size, - shuffle=True, - collate_fn=data_collator, - pin_memory=True + dataset=trainset, batch_size=batch_size, shuffle=True, collate_fn=data_collator, pin_memory=True ) valid_tokens, total_tokens = 0, 0 for batch in tqdm(dataloader):