From a3a7465f008ff951581647b86d5b57c08345328f Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 9 Aug 2023 16:23:31 +0800 Subject: [PATCH] fix rm #420, fix template #426, fix #423 Former-commit-id: 70ea3caaa7a7695c77179cd1bb18707a80a373d7 --- src/llmtuner/dsets/preprocess.py | 8 ++++---- src/llmtuner/extras/template.py | 7 ++----- src/llmtuner/webui/runner.py | 2 +- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index dc01a77d..d2150dbc 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -103,13 +103,13 @@ def preprocess_dataset( if len(source_ids) > data_args.max_source_length: source_ids = source_ids[:data_args.max_source_length] - if len(accept_ids) > data_args.max_target_length - 1: # eos token + if len(accept_ids) > data_args.max_target_length: accept_ids = accept_ids[:data_args.max_target_length - 1] - if len(reject_ids) > data_args.max_target_length - 1: # eos token + if len(reject_ids) > data_args.max_target_length: reject_ids = reject_ids[:data_args.max_target_length - 1] - accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id] - reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id] + accept_ids = source_ids + accept_ids + reject_ids = source_ids + reject_ids model_inputs["accept_ids"].append(accept_ids) model_inputs["reject_ids"].append(reject_ids) diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 5f120253..5b00af03 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -388,12 +388,9 @@ register_template( name="intern", prefix=[], prompt=[ - {"token": "<|User|>"}, - ":{{query}}", + "<|User|>:{{query}}", {"token": ""}, - "\n", - {"token": "<|Bot|>"}, - ":" + "\n<|Bot|>:" ], sep=[ "\n" diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 53c18828..763ff614 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -156,7 +156,7 @@ class Runner: else: yield format_info(logger_handler.log, trainer_callback) - if os.path.exists(os.path.join(output_dir), TRAINING_ARGS_NAME): + if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)): finish_info = ALERTS["info_finished"][lang] else: finish_info = ALERTS["err_failed"][lang]