mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-15 00:25:59 +08:00
Compare commits
3 Commits
b4e051bea4
...
5cfd804b59
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5cfd804b59 | ||
|
|
4c1eb922e2 | ||
|
|
958fb523a2 |
@@ -298,6 +298,7 @@ Read technical notes:
|
|||||||
| [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
|
| [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
|
||||||
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
|
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
|
||||||
| [Ling 2.0 (mini/flash)](https://huggingface.co/inclusionAI) | 16B/100B | bailing_v2 |
|
| [Ling 2.0 (mini/flash)](https://huggingface.co/inclusionAI) | 16B/100B | bailing_v2 |
|
||||||
|
| [LFM 2.5 (VL)](https://huggingface.co/LiquidAI) | 1.2B/1.6B | lfm2/lfm2_vl |
|
||||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||||
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||||
|
|||||||
@@ -300,6 +300,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
|||||||
| [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
|
| [InternLM/Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
|
||||||
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
|
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
|
||||||
| [Ling 2.0 (mini/flash)](https://huggingface.co/inclusionAI) | 16B/100B | bailing_v2 |
|
| [Ling 2.0 (mini/flash)](https://huggingface.co/inclusionAI) | 16B/100B | bailing_v2 |
|
||||||
|
| [LFM 2.5 (VL)](https://huggingface.co/LiquidAI) | 1.2B/1.6B | lfm2/lfm2_vl |
|
||||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||||
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||||
|
|||||||
@@ -36,5 +36,3 @@ lr_scheduler_type: cosine
|
|||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
bf16: true
|
bf16: true
|
||||||
ddp_timeout: 180000000
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2092,6 +2092,73 @@ class VideoLlavaPlugin(BasePlugin):
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LFMVLPlugin(BasePlugin):
|
||||||
|
r"""Plugin for LFM2.5-VL vision-language models.
|
||||||
|
|
||||||
|
LFM2.5-VL uses dynamic image token counts based on image resolution.
|
||||||
|
The image processor returns spatial_shapes tensor with [height, width] grid dimensions.
|
||||||
|
Token count per image = (spatial_h * spatial_w) / (downsample_factor^2)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: list["ImageInput"],
|
||||||
|
videos: list["VideoInput"],
|
||||||
|
audios: list["AudioInput"],
|
||||||
|
processor: "MMProcessor",
|
||||||
|
) -> dict[str, "torch.Tensor"]:
|
||||||
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||||
|
mm_inputs = {}
|
||||||
|
if len(images) != 0:
|
||||||
|
images = self._regularize_images(
|
||||||
|
images,
|
||||||
|
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
||||||
|
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
||||||
|
)["images"]
|
||||||
|
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
images: list["ImageInput"],
|
||||||
|
videos: list["VideoInput"],
|
||||||
|
audios: list["AudioInput"],
|
||||||
|
processor: Optional["MMProcessor"],
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
self._validate_messages(messages, images, videos, audios)
|
||||||
|
num_image_tokens = 0
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||||
|
downsample_factor: int = getattr(image_processor, "downsample_factor", 2)
|
||||||
|
|
||||||
|
if self.expand_mm_tokens and len(images) > 0:
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
|
spatial_shapes = mm_inputs.get("spatial_shapes", [])
|
||||||
|
else:
|
||||||
|
spatial_shapes = []
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while IMAGE_PLACEHOLDER in content:
|
||||||
|
if self.expand_mm_tokens and len(spatial_shapes) > num_image_tokens:
|
||||||
|
h, w = spatial_shapes[num_image_tokens].tolist()
|
||||||
|
image_seqlen = (h * w) // (downsample_factor * downsample_factor)
|
||||||
|
else:
|
||||||
|
image_seqlen = 1
|
||||||
|
|
||||||
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||||
|
num_image_tokens += 1
|
||||||
|
|
||||||
|
message["content"] = content.replace("{{image}}", self.image_token)
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
PLUGINS = {
|
PLUGINS = {
|
||||||
"base": BasePlugin,
|
"base": BasePlugin,
|
||||||
"ernie_vl": ErnieVLPlugin,
|
"ernie_vl": ErnieVLPlugin,
|
||||||
@@ -2104,6 +2171,7 @@ PLUGINS = {
|
|||||||
"llava": LlavaPlugin,
|
"llava": LlavaPlugin,
|
||||||
"llava_next": LlavaNextPlugin,
|
"llava_next": LlavaNextPlugin,
|
||||||
"llava_next_video": LlavaNextVideoPlugin,
|
"llava_next_video": LlavaNextVideoPlugin,
|
||||||
|
"lfm2_vl": LFMVLPlugin,
|
||||||
"minicpm_v": MiniCPMVPlugin,
|
"minicpm_v": MiniCPMVPlugin,
|
||||||
"mllama": MllamaPlugin,
|
"mllama": MllamaPlugin,
|
||||||
"paligemma": PaliGemmaPlugin,
|
"paligemma": PaliGemmaPlugin,
|
||||||
|
|||||||
@@ -1331,18 +1331,18 @@ register_template(
|
|||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="lfm",
|
name="lfm2",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm"),
|
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm2"),
|
||||||
format_observation=StringFormatter(
|
format_observation=StringFormatter(
|
||||||
slots=[
|
slots=[
|
||||||
"<|im_start|>tool\n<|tool_response_start|>{{content}}<|tool_response_end|><|im_end|>\n"
|
"<|im_start|>tool\n<|tool_response_start|>{{content}}<|tool_response_end|><|im_end|>\n"
|
||||||
"<|im_start|>assistant\n"
|
"<|im_start|>assistant\n"
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
format_tools=ToolFormatter(tool_format="lfm"),
|
format_tools=ToolFormatter(tool_format="lfm2"),
|
||||||
default_system="You are a helpful AI assistant.",
|
default_system="You are a helpful AI assistant.",
|
||||||
stop_words=["<|im_end|>"],
|
stop_words=["<|im_end|>"],
|
||||||
tool_call_words=("<|tool_call_start|>", "<|tool_call_end|>"),
|
tool_call_words=("<|tool_call_start|>", "<|tool_call_end|>"),
|
||||||
@@ -1350,6 +1350,27 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_template(
|
||||||
|
name="lfm2_vl",
|
||||||
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||||
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
|
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm2"),
|
||||||
|
format_observation=StringFormatter(
|
||||||
|
slots=[
|
||||||
|
"<|im_start|>tool\n<|tool_response_start|>{{content}}<|tool_response_end|><|im_end|>\n"
|
||||||
|
"<|im_start|>assistant\n"
|
||||||
|
]
|
||||||
|
),
|
||||||
|
format_tools=ToolFormatter(tool_format="lfm2"),
|
||||||
|
default_system="You are a helpful multimodal assistant by Liquid AI.",
|
||||||
|
stop_words=["<|im_end|>"],
|
||||||
|
tool_call_words=("<|tool_call_start|>", "<|tool_call_end|>"),
|
||||||
|
replace_eos=True,
|
||||||
|
mm_plugin=get_mm_plugin(name="lfm2_vl", image_token="<image>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="llama2",
|
name="llama2",
|
||||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ LING_TOOL_PROMPT = (
|
|||||||
""""arguments": <args-json-object>}}\n</tool_call>"""
|
""""arguments": <args-json-object>}}\n</tool_call>"""
|
||||||
)
|
)
|
||||||
|
|
||||||
LFM_TOOL_PROMPT = "List of tools: <|tool_list_start|>{tool_text}<|tool_list_end|>"
|
LFM2_TOOL_PROMPT = "List of tools: <|tool_list_start|>{tool_text}<|tool_list_end|>"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -549,8 +549,8 @@ class LingToolUtils(QwenToolUtils):
|
|||||||
return LING_TOOL_PROMPT.format(tool_text=tool_text) + "\n" + "detailed thinking off"
|
return LING_TOOL_PROMPT.format(tool_text=tool_text) + "\n" + "detailed thinking off"
|
||||||
|
|
||||||
|
|
||||||
class LFMToolUtils(ToolUtils):
|
class LFM2ToolUtils(ToolUtils):
|
||||||
r"""LFM 2.5 tool using template with Pythonic function call syntax."""
|
r"""LFM2.5 tool using template with Pythonic function call syntax."""
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -560,7 +560,7 @@ class LFMToolUtils(ToolUtils):
|
|||||||
tool = tool.get("function", tool) if tool.get("type") == "function" else tool
|
tool = tool.get("function", tool) if tool.get("type") == "function" else tool
|
||||||
tool_list.append(tool)
|
tool_list.append(tool)
|
||||||
|
|
||||||
return LFM_TOOL_PROMPT.format(tool_text=json.dumps(tool_list, ensure_ascii=False))
|
return LFM2_TOOL_PROMPT.format(tool_text=json.dumps(tool_list, ensure_ascii=False))
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -643,7 +643,7 @@ class LFMToolUtils(ToolUtils):
|
|||||||
for keyword in node.keywords:
|
for keyword in node.keywords:
|
||||||
key = keyword.arg
|
key = keyword.arg
|
||||||
try:
|
try:
|
||||||
value = LFMToolUtils._ast_to_value(keyword.value)
|
value = LFM2ToolUtils._ast_to_value(keyword.value)
|
||||||
except (ValueError, SyntaxError):
|
except (ValueError, SyntaxError):
|
||||||
return content
|
return content
|
||||||
args_dict[key] = value
|
args_dict[key] = value
|
||||||
@@ -657,7 +657,7 @@ TOOLS = {
|
|||||||
"default": DefaultToolUtils(),
|
"default": DefaultToolUtils(),
|
||||||
"glm4": GLM4ToolUtils(),
|
"glm4": GLM4ToolUtils(),
|
||||||
"llama3": Llama3ToolUtils(),
|
"llama3": Llama3ToolUtils(),
|
||||||
"lfm": LFMToolUtils(),
|
"lfm2": LFM2ToolUtils(),
|
||||||
"minimax1": MiniMaxM1ToolUtils(),
|
"minimax1": MiniMaxM1ToolUtils(),
|
||||||
"minimax2": MiniMaxM2ToolUtils(),
|
"minimax2": MiniMaxM2ToolUtils(),
|
||||||
"mistral": MistralToolUtils(),
|
"mistral": MistralToolUtils(),
|
||||||
|
|||||||
@@ -1502,7 +1502,18 @@ register_model_group(
|
|||||||
DownloadSource.DEFAULT: "LiquidAI/LFM2.5-1.2B-Instruct",
|
DownloadSource.DEFAULT: "LiquidAI/LFM2.5-1.2B-Instruct",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
template="lfm",
|
template="lfm2",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"LFM2.5-VL-1.6B": {
|
||||||
|
DownloadSource.DEFAULT: "LiquidAI/LFM2.5-VL-1.6B",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="lfm2_vl",
|
||||||
|
multimodal=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -70,13 +71,13 @@ def read_args(args: dict[str, Any] | list[str] | None = None) -> dict[str, Any]
|
|||||||
if args is not None:
|
if args is not None:
|
||||||
return args
|
return args
|
||||||
|
|
||||||
if sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml"):
|
if len(sys.argv) > 1 and (sys.argv[1].endswith(".yaml") or sys.argv[1].endswith(".yml")):
|
||||||
override_config = OmegaConf.from_cli(sys.argv[2:])
|
override_config = OmegaConf.from_cli(sys.argv[2:])
|
||||||
dict_config = OmegaConf.load(Path(sys.argv[1]).absolute())
|
dict_config = OmegaConf.load(Path(sys.argv[1]).absolute())
|
||||||
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
|
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
|
||||||
elif sys.argv[1].endswith(".json"):
|
elif len(sys.argv) > 1 and sys.argv[1].endswith(".json"):
|
||||||
override_config = OmegaConf.from_cli(sys.argv[2:])
|
override_config = OmegaConf.from_cli(sys.argv[2:])
|
||||||
dict_config = OmegaConf.load(Path(sys.argv[1]).absolute())
|
dict_config = OmegaConf.create(json.load(Path(sys.argv[1]).absolute()))
|
||||||
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
|
return OmegaConf.to_container(OmegaConf.merge(dict_config, override_config))
|
||||||
else:
|
else:
|
||||||
return sys.argv[1:]
|
return sys.argv[1:]
|
||||||
|
|||||||
@@ -151,6 +151,12 @@ def patch_config(
|
|||||||
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
|
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
|
||||||
raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.")
|
raise RuntimeError("InternLM3 model requires transformers>=4.47.1, please upgrade it.")
|
||||||
|
|
||||||
|
if getattr(config, "model_type", None) == "lfm2_vl" and not is_transformers_version_greater_than("4.58.0"):
|
||||||
|
raise RuntimeError(
|
||||||
|
"LFM2.5-VL model requires transformers>=4.58.0 or install from commit: "
|
||||||
|
"pip install git+https://github.com/huggingface/transformers.git@3c2517727ce28a30f5044e01663ee204deb1cdbe"
|
||||||
|
)
|
||||||
|
|
||||||
if getattr(config, "model_type", None) == "qwen3_omni_moe":
|
if getattr(config, "model_type", None) == "qwen3_omni_moe":
|
||||||
patch_qwen3_omni_moe_thinker_text_sparse_moe_block()
|
patch_qwen3_omni_moe_thinker_text_sparse_moe_block()
|
||||||
|
|
||||||
|
|||||||
@@ -30,21 +30,6 @@ from .training_args import TrainingArguments
|
|||||||
InputArgument = dict[str, Any] | list[str] | None
|
InputArgument = dict[str, Any] | list[str] | None
|
||||||
|
|
||||||
|
|
||||||
def validate_args(
|
|
||||||
data_args: DataArguments,
|
|
||||||
model_args: ModelArguments,
|
|
||||||
training_args: TrainingArguments,
|
|
||||||
sample_args: SampleArguments,
|
|
||||||
):
|
|
||||||
"""Validate arguments."""
|
|
||||||
if (
|
|
||||||
model_args.quant_config is not None
|
|
||||||
and training_args.dist_config is not None
|
|
||||||
and training_args.dist_config.name == "deepspeed"
|
|
||||||
):
|
|
||||||
raise ValueError("Quantization is not supported with deepspeed backend.")
|
|
||||||
|
|
||||||
|
|
||||||
def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
|
def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments, TrainingArguments, SampleArguments]:
|
||||||
"""Parse arguments from command line or config file."""
|
"""Parse arguments from command line or config file."""
|
||||||
parser = HfArgumentParser([DataArguments, ModelArguments, TrainingArguments, SampleArguments])
|
parser = HfArgumentParser([DataArguments, ModelArguments, TrainingArguments, SampleArguments])
|
||||||
@@ -71,8 +56,6 @@ def get_args(args: InputArgument = None) -> tuple[DataArguments, ModelArguments,
|
|||||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||||
|
|
||||||
validate_args(*parsed_args)
|
|
||||||
|
|
||||||
return tuple(parsed_args)
|
return tuple(parsed_args)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -295,8 +295,8 @@ def test_qwen_multi_tool_extractor():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
def test_lfm_function_formatter():
|
def test_lfm2_function_formatter():
|
||||||
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm")
|
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm2")
|
||||||
tool_calls = json.dumps(FUNCTION)
|
tool_calls = json.dumps(FUNCTION)
|
||||||
assert formatter.apply(content=tool_calls) == [
|
assert formatter.apply(content=tool_calls) == [
|
||||||
"""<|tool_call_start|>[tool_name(foo="bar", size=10)]<|tool_call_end|><|im_end|>\n"""
|
"""<|tool_call_start|>[tool_name(foo="bar", size=10)]<|tool_call_end|><|im_end|>\n"""
|
||||||
@@ -304,8 +304,8 @@ def test_lfm_function_formatter():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
def test_lfm_multi_function_formatter():
|
def test_lfm2_multi_function_formatter():
|
||||||
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm")
|
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm2")
|
||||||
tool_calls = json.dumps([FUNCTION] * 2)
|
tool_calls = json.dumps([FUNCTION] * 2)
|
||||||
assert formatter.apply(content=tool_calls) == [
|
assert formatter.apply(content=tool_calls) == [
|
||||||
"""<|tool_call_start|>[tool_name(foo="bar", size=10), tool_name(foo="bar", size=10)]<|tool_call_end|>"""
|
"""<|tool_call_start|>[tool_name(foo="bar", size=10), tool_name(foo="bar", size=10)]<|tool_call_end|>"""
|
||||||
@@ -314,23 +314,23 @@ def test_lfm_multi_function_formatter():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
def test_lfm_tool_formatter():
|
def test_lfm2_tool_formatter():
|
||||||
formatter = ToolFormatter(tool_format="lfm")
|
formatter = ToolFormatter(tool_format="lfm2")
|
||||||
assert formatter.apply(content=json.dumps(TOOLS)) == [
|
assert formatter.apply(content=json.dumps(TOOLS)) == [
|
||||||
"List of tools: <|tool_list_start|>" + json.dumps(TOOLS, ensure_ascii=False) + "<|tool_list_end|>"
|
"List of tools: <|tool_list_start|>" + json.dumps(TOOLS, ensure_ascii=False) + "<|tool_list_end|>"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
def test_lfm_tool_extractor():
|
def test_lfm2_tool_extractor():
|
||||||
formatter = ToolFormatter(tool_format="lfm")
|
formatter = ToolFormatter(tool_format="lfm2")
|
||||||
result = """<|tool_call_start|>[test_tool(foo="bar", size=10)]<|tool_call_end|>"""
|
result = """<|tool_call_start|>[test_tool(foo="bar", size=10)]<|tool_call_end|>"""
|
||||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
def test_lfm_multi_tool_extractor():
|
def test_lfm2_multi_tool_extractor():
|
||||||
formatter = ToolFormatter(tool_format="lfm")
|
formatter = ToolFormatter(tool_format="lfm2")
|
||||||
result = """<|tool_call_start|>[test_tool(foo="bar", size=10), another_tool(foo="job", size=2)]<|tool_call_end|>"""
|
result = """<|tool_call_start|>[test_tool(foo="bar", size=10), another_tool(foo="job", size=2)]<|tool_call_end|>"""
|
||||||
assert formatter.extract(result) == [
|
assert formatter.extract(result) == [
|
||||||
("test_tool", """{"foo": "bar", "size": 10}"""),
|
("test_tool", """{"foo": "bar", "size": 10}"""),
|
||||||
@@ -339,8 +339,8 @@ def test_lfm_multi_tool_extractor():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
def test_lfm_tool_extractor_with_nested_dict():
|
def test_lfm2_tool_extractor_with_nested_dict():
|
||||||
formatter = ToolFormatter(tool_format="lfm")
|
formatter = ToolFormatter(tool_format="lfm2")
|
||||||
result = """<|tool_call_start|>[search(query="test", options={"limit": 10, "offset": 0})]<|tool_call_end|>"""
|
result = """<|tool_call_start|>[search(query="test", options={"limit": 10, "offset": 0})]<|tool_call_end|>"""
|
||||||
extracted = formatter.extract(result)
|
extracted = formatter.extract(result)
|
||||||
assert len(extracted) == 1
|
assert len(extracted) == 1
|
||||||
@@ -351,8 +351,8 @@ def test_lfm_tool_extractor_with_nested_dict():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
def test_lfm_tool_extractor_with_list_arg():
|
def test_lfm2_tool_extractor_with_list_arg():
|
||||||
formatter = ToolFormatter(tool_format="lfm")
|
formatter = ToolFormatter(tool_format="lfm2")
|
||||||
result = """<|tool_call_start|>[batch_process(items=[1, 2, 3], enabled=True)]<|tool_call_end|>"""
|
result = """<|tool_call_start|>[batch_process(items=[1, 2, 3], enabled=True)]<|tool_call_end|>"""
|
||||||
extracted = formatter.extract(result)
|
extracted = formatter.extract(result)
|
||||||
assert len(extracted) == 1
|
assert len(extracted) == 1
|
||||||
@@ -363,17 +363,17 @@ def test_lfm_tool_extractor_with_list_arg():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
def test_lfm_tool_extractor_no_match():
|
def test_lfm2_tool_extractor_no_match():
|
||||||
formatter = ToolFormatter(tool_format="lfm")
|
formatter = ToolFormatter(tool_format="lfm2")
|
||||||
result = "This is a regular response without tool calls."
|
result = "This is a regular response without tool calls."
|
||||||
extracted = formatter.extract(result)
|
extracted = formatter.extract(result)
|
||||||
assert extracted == result
|
assert extracted == result
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
def test_lfm_tool_round_trip():
|
def test_lfm2_tool_round_trip():
|
||||||
formatter = FunctionFormatter(slots=["{{content}}"], tool_format="lfm")
|
formatter = FunctionFormatter(slots=["{{content}}"], tool_format="lfm2")
|
||||||
tool_formatter = ToolFormatter(tool_format="lfm")
|
tool_formatter = ToolFormatter(tool_format="lfm2")
|
||||||
original = {"name": "my_func", "arguments": {"arg1": "hello", "arg2": 42, "arg3": True}}
|
original = {"name": "my_func", "arguments": {"arg1": "hello", "arg2": 42, "arg3": True}}
|
||||||
formatted = formatter.apply(content=json.dumps(original))
|
formatted = formatter.apply(content=json.dumps(original))
|
||||||
extracted = tool_formatter.extract(formatted[0])
|
extracted = tool_formatter.extract(formatted[0])
|
||||||
|
|||||||
@@ -419,3 +419,15 @@ def test_video_llava_plugin():
|
|||||||
]
|
]
|
||||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
|
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
|
||||||
_check_plugin(**check_inputs)
|
_check_plugin(**check_inputs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
|
def test_lfm2_vl_plugin():
|
||||||
|
"""Test LFM2.5-VL plugin instantiation."""
|
||||||
|
# Test plugin can be instantiated with correct tokens
|
||||||
|
lfm2_vl_plugin = get_mm_plugin(name="lfm2_vl", image_token="<image>")
|
||||||
|
assert lfm2_vl_plugin is not None
|
||||||
|
assert lfm2_vl_plugin.image_token == "<image>"
|
||||||
|
assert lfm2_vl_plugin.video_token is None
|
||||||
|
assert lfm2_vl_plugin.audio_token is None
|
||||||
|
assert lfm2_vl_plugin.__class__.__name__ == "LFMVLPlugin"
|
||||||
|
|||||||
Reference in New Issue
Block a user