From 28a807472b22d84efca3f7a25dd107d97eb222a7 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 9 Aug 2023 16:23:31 +0800 Subject: [PATCH] fix rm #420, fix template #426, fix #423 Former-commit-id: 39cd8b6989c9190d213e65467ec41f34ea04c5bc --- src/llmtuner/dsets/preprocess.py | 8 ++++---- src/llmtuner/extras/template.py | 7 ++----- src/llmtuner/webui/runner.py | 2 +- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index dc01a77d..d2150dbc 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -103,13 +103,13 @@ def preprocess_dataset( if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] - if len(accept_ids) > data_args.max_target_length - 1: # eos token + if len(accept_ids) > data_args.max_target_length: accept_ids = accept_ids[:data_args.max_target_length - 1] - if len(reject_ids) > data_args.max_target_length - 1: # eos token + if len(reject_ids) > data_args.max_target_length: reject_ids = reject_ids[:data_args.max_target_length - 1] - accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id] - reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id] + accept_ids = source_ids + accept_ids + reject_ids = source_ids + reject_ids model_inputs["accept_ids"].append(accept_ids) model_inputs["reject_ids"].append(reject_ids) diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 5f120253..5b00af03 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -388,12 +388,9 @@ register_template( name="intern", prefix=[], prompt=[ - {"token": "<|User|>"}, - ":{{query}}", + "<|User|>:{{query}}", {"token": ""}, - "\n", - {"token": "<|Bot|>"}, - ":" + "\n<|Bot|>:" ], sep=[ "\n" diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 53c18828..763ff614 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -156,7 +156,7 @@ class Runner: else: yield format_info(logger_handler.log, trainer_callback) - if os.path.exists(os.path.join(output_dir), TRAINING_ARGS_NAME): + if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)): finish_info = ALERTS["info_finished"][lang] else: finish_info = ALERTS["err_failed"][lang]