mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 20:30:36 +08:00
[misc] support export ollama modelfile (#6899)
* support export ollama modelfile * update config * add system and num ctx
This commit is contained in:
@@ -239,11 +239,9 @@ class Template:
|
||||
Returns the jinja template.
|
||||
"""
|
||||
prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
|
||||
system_message = self._convert_slots_to_jinja(
|
||||
self.format_system.apply(), tokenizer, placeholder="system_message"
|
||||
)
|
||||
user_message = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
|
||||
assistant_message = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
|
||||
system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message")
|
||||
user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
|
||||
assistant = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
|
||||
jinja_template = ""
|
||||
if prefix:
|
||||
jinja_template += "{{ " + prefix + " }}"
|
||||
@@ -254,13 +252,13 @@ class Template:
|
||||
jinja_template += (
|
||||
"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
|
||||
"{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
|
||||
"{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
|
||||
"{% if system_message is defined %}{{ " + system + " }}{% endif %}"
|
||||
"{% for message in loop_messages %}"
|
||||
"{% set content = message['content'] %}"
|
||||
"{% if message['role'] == 'user' %}"
|
||||
"{{ " + user_message + " }}"
|
||||
"{{ " + user + " }}"
|
||||
"{% elif message['role'] == 'assistant' %}"
|
||||
"{{ " + assistant_message + " }}"
|
||||
"{{ " + assistant + " }}"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
)
|
||||
@@ -276,6 +274,64 @@ class Template:
|
||||
except ValueError as e:
|
||||
logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
|
||||
|
||||
@staticmethod
|
||||
def _convert_slots_to_ollama(
|
||||
slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content"
|
||||
) -> str:
|
||||
r"""
|
||||
Converts slots to ollama template.
|
||||
"""
|
||||
slot_items = []
|
||||
for slot in slots:
|
||||
if isinstance(slot, str):
|
||||
slot_pieces = slot.split("{{content}}")
|
||||
if slot_pieces[0]:
|
||||
slot_items.append(slot_pieces[0])
|
||||
if len(slot_pieces) > 1:
|
||||
slot_items.append("{{ " + placeholder + " }}")
|
||||
if slot_pieces[1]:
|
||||
slot_items.append(slot_pieces[1])
|
||||
elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
|
||||
if "bos_token" in slot and tokenizer.bos_token_id is not None:
|
||||
slot_items.append(tokenizer.bos_token)
|
||||
elif "eos_token" in slot and tokenizer.eos_token_id is not None:
|
||||
slot_items.append(tokenizer.eos_token)
|
||||
elif isinstance(slot, dict):
|
||||
raise ValueError("Dict is not supported.")
|
||||
|
||||
return "".join(slot_items)
|
||||
|
||||
def _get_ollama_template(self, tokenizer: "PreTrainedTokenizer") -> str:
|
||||
r"""
|
||||
Returns the ollama template.
|
||||
"""
|
||||
prefix = self._convert_slots_to_ollama(self.format_prefix.apply(), tokenizer)
|
||||
system = self._convert_slots_to_ollama(self.format_system.apply(), tokenizer, placeholder=".System")
|
||||
user = self._convert_slots_to_ollama(self.format_user.apply(), tokenizer, placeholder=".Content")
|
||||
assistant = self._convert_slots_to_ollama(self.format_assistant.apply(), tokenizer, placeholder=".Content")
|
||||
return (
|
||||
f"{prefix}{{{{ if .System }}}}{system}{{{{ end }}}}"
|
||||
f"""{{{{ range .Messages }}}}{{{{ if eq .Role "user" }}}}{user}"""
|
||||
f"""{{{{ else if eq .Role "assistant" }}}}{assistant}{{{{ end }}}}{{{{ end }}}}"""
|
||||
)
|
||||
|
||||
def get_ollama_modelfile(self, tokenizer: "PreTrainedTokenizer") -> str:
|
||||
r"""
|
||||
Returns the ollama modelfile.
|
||||
|
||||
TODO: support function calling.
|
||||
"""
|
||||
modelfile = f'FROM .\n\nTEMPLATE """{self._get_ollama_template(tokenizer)}"""\n\n'
|
||||
|
||||
if self.default_system:
|
||||
modelfile += f'SYSTEM system "{self.default_system}"\n\n'
|
||||
|
||||
for stop_token_id in self.get_stop_token_ids(tokenizer):
|
||||
modelfile += f'PARAMETER stop "{tokenizer.convert_ids_to_tokens(stop_token_id)}"\n'
|
||||
|
||||
modelfile += "PARAMETER num_ctx 4096\n"
|
||||
return modelfile
|
||||
|
||||
|
||||
@dataclass
|
||||
class Llama2Template(Template):
|
||||
@@ -1020,7 +1076,7 @@ _register_template(
|
||||
)
|
||||
|
||||
|
||||
# copied from chatml template
|
||||
# copied from minicpm_v template
|
||||
_register_template(
|
||||
name="minicpm_o",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
|
||||
@@ -87,7 +87,7 @@ class ExportArguments:
|
||||
metadata={"help": "Path to the directory to save the exported model."},
|
||||
)
|
||||
export_size: int = field(
|
||||
default=1,
|
||||
default=5,
|
||||
metadata={"help": "The file shard size (in GB) of the exported model."},
|
||||
)
|
||||
export_device: Literal["cpu", "auto"] = field(
|
||||
|
||||
@@ -104,7 +104,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
processor = tokenizer_module["processor"]
|
||||
get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||
model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab
|
||||
|
||||
if getattr(model, "quantization_method", None) is not None and model_args.adapter_name_or_path is not None:
|
||||
@@ -171,3 +171,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
||||
|
||||
except Exception as e:
|
||||
logger.warning_rank0(f"Cannot save tokenizer, please copy the files manually: {e}.")
|
||||
|
||||
with open(os.path.join(model_args.export_dir, "Modelfile"), "w", encoding="utf-8") as f:
|
||||
f.write(template.get_ollama_modelfile(tokenizer))
|
||||
logger.info_rank0(f"Saved ollama modelfile to {model_args.export_dir}.")
|
||||
|
||||
Reference in New Issue
Block a user