diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 1f859be6..f9eeb66a 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -310,14 +310,15 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}" jinja_template += ( - "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}" + "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}" + "{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}" ) system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message") if not isinstance(template, Llama2Template): jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}" - jinja_template += "{% for message in messages %}" + jinja_template += "{% for message in loop_messages %}" jinja_template += "{% set content = message['content'] %}" if isinstance(template, Llama2Template): jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"