From 45f0437a14f10cbbf8f0d61877e741f3c9888ca5 Mon Sep 17 00:00:00 2001 From: Yinlei Sun Date: Tue, 18 Nov 2025 13:44:08 +0800 Subject: [PATCH] [v1] Add support for ShareGPT format. (#9486) --- .../v1/plugins/data_plugins/converter.py | 92 +++++++++++++++++- .../plugins/data_plugins/test_converter.py | 93 +++++++++++++++++++ 2 files changed, 184 insertions(+), 1 deletion(-) diff --git a/src/llamafactory/v1/plugins/data_plugins/converter.py b/src/llamafactory/v1/plugins/data_plugins/converter.py index 37125926..777b710e 100644 --- a/src/llamafactory/v1/plugins/data_plugins/converter.py +++ b/src/llamafactory/v1/plugins/data_plugins/converter.py @@ -15,11 +15,15 @@ from typing import Callable, TypedDict -from typing_extensions import NotRequired +from typing_extensions import NotRequired, Required +from ....extras import logging from ...extras.types import DPOSample, Sample, SFTSample +logger = logging.get_logger(__name__) + + class AlpacaSample(TypedDict, total=False): system: NotRequired[str] instruction: NotRequired[str] @@ -27,6 +31,21 @@ class AlpacaSample(TypedDict, total=False): output: NotRequired[str] +ShareGPTMessage = TypedDict( + "ShareGPTMessage", + { + "from": Required[str], # Role of the message sender (e.g., "human", "gpt", "system") + "value": Required[str], # Content of the message + }, +) + + +class ShareGPTSample(TypedDict, total=False): + """Type definition for raw ShareGPT sample.""" + + conversations: Required[list[ShareGPTMessage]] + + class PairSample(TypedDict, total=False): prompt: NotRequired[str] chosen: NotRequired[list[dict]] @@ -48,6 +67,20 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample: {"role": "system", "content": [{"type": "text", "value": raw_sample["system"]}], "loss_weight": 0.0} ) + if "history" in raw_sample: + for idx, item in enumerate(raw_sample["history"]): + if len(item) != 2: + logger.warning_rank0( + f"Warning: History item at index {idx} has invalid length (expected 2, got {len(item)}). Skipping." + ) + continue + + old_prompt, old_response = item + messages.append({"role": "user", "content": [{"type": "text", "value": old_prompt}], "loss_weight": 0.0}) + messages.append( + {"role": "assistant", "content": [{"type": "text", "value": old_response}], "loss_weight": 1.0} + ) + if "instruction" in raw_sample or "input" in raw_sample: messages.append( { @@ -67,6 +100,62 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample: return {"messages": messages} +def sharegpt_converter(raw_sample: ShareGPTSample) -> SFTSample: + """Converts a raw ShareGPT sample into a formatted SFT (Supervised Fine-Tuning) sample. + + Retains only SFT-relevant scenarios and removes parity checks. + + Args: + raw_sample (ShareGPTSample): A raw sample in ShareGPT format. + + Returns: + dict: A dictionary containing the formatted 'messages' list for SFT training. + Returns an empty list if the input data is invalid. + """ + tag_mapping = { + "human": "user", + "gpt": "assistant", + "observation": "observation", + "function_call": "function", + } + messages = raw_sample.get("conversations", []) + aligned_messages = [] + system_content = "" + + # Extract system message if present (typically the first message) + if messages and messages[0]["from"] == "system": + system_content = messages[0]["value"] + messages = messages[1:] + + if system_content: + aligned_messages.append( + {"role": "system", "content": [{"type": "text", "value": system_content}], "loss_weight": 0.0} + ) + + has_invalid_role = False + for message in messages: + sender = message["from"] + # validate sender is in supported tags + if sender not in tag_mapping: + logger.warning_rank0(f"Unsupported role tag '{sender}' in message: {message}") + has_invalid_role = True + break + + aligned_messages.append( + { + "role": tag_mapping[sender], + "content": [{"type": "text", "value": message["value"]}], + "loss_weight": 0.0 if sender in ("human", "observation") else 1.0, + } + ) + + if has_invalid_role: + logger.warning_rank0("Skipping invalid example due to unsupported role tags.") + return {"messages": []} + + return {"messages": aligned_messages} + + def pair_converter(raw_sample: PairSample) -> DPOSample: """Convert Pair sample to standard DPO sample. @@ -148,6 +237,7 @@ def pair_converter(raw_sample: PairSample) -> DPOSample: CONVERTERS = { "alpaca": alpaca_converter, "pair": pair_converter, + "sharegpt": sharegpt_converter, } diff --git a/tests_v1/plugins/data_plugins/test_converter.py b/tests_v1/plugins/data_plugins/test_converter.py index 03f8e17d..e64b8c25 100644 --- a/tests_v1/plugins/data_plugins/test_converter.py +++ b/tests_v1/plugins/data_plugins/test_converter.py @@ -19,6 +19,7 @@ from datasets import load_dataset from llamafactory.v1.config.data_args import DataArguments from llamafactory.v1.core.data_engine import DataEngine +from llamafactory.v1.plugins.data_plugins.converter import get_converter @pytest.mark.parametrize("num_samples", [16]) @@ -48,6 +49,96 @@ def test_alpaca_converter(num_samples: int): assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data} +def test_sharegpt_converter_invalid(): + example = { + "conversations": [ + { + "from": "system", + "value": "Processes historical market data to generate trading signals " + "based on specified technical indicators.", + }, + { + "from": "human", + "value": "I possess a detailed dataset, 'Historical_Market_Data.csv'. " + "Could you proceed with these function calls to assist me with the task?", + }, + { + "from": "gpt", + "value": "```tool_call\n{'arguments': '{\"data_file\": \"Historical_Market_Data.csv\"]}', " + "'name': 'backtest_trading_signals'}```\n", + }, + { + "from": "tool", + "value": '\n{"analysis": {"RSI_signals": [{"date": "2025-01-10", ' + '"symbol": "AAPL", "signal": "Buy"}]}}}\n\n', + }, + ] + } + dataset_converter = get_converter("sharegpt") + assert dataset_converter(example) == {"messages": []} + + +def test_sharegpt_converter_valid(): + example = { + "conversations": [ + { + "from": "system", + "value": "Processes historical market data to generate trading signals based on " + "specified technical indicators.", + }, + { + "from": "human", + "value": "I possess a detailed dataset, 'Historical_Market_Data.csv'. " + "Could you proceed with these function calls to assist me with the task?", + }, + { + "from": "gpt", + "value": "```tool_call\n{'arguments': '{\"data_file\": \"Historical_Market_Data.csv\"]}', " + "'name': 'backtest_trading_signals'}```\n", + }, + ] + } + dataset_converter = get_converter("sharegpt") + expected_data = { + "messages": [ + { + "content": [ + { + "type": "text", + "value": "Processes historical market data to generate trading signals based on " + "specified technical indicators.", + } + ], + "loss_weight": 0.0, + "role": "system", + }, + { + "content": [ + { + "type": "text", + "value": "I possess a detailed dataset, 'Historical_Market_Data.csv'. " + "Could you proceed with these function calls to assist me with the task?", + } + ], + "loss_weight": 0.0, + "role": "user", + }, + { + "content": [ + { + "type": "text", + "value": "```tool_call\n{'arguments': '{\"data_file\": \"Historical_Market_Data.csv\"]}', " + "'name': 'backtest_trading_signals'}```\n", + } + ], + "loss_weight": 1.0, + "role": "assistant", + }, + ] + } + assert dataset_converter(example) == expected_data + + @pytest.mark.parametrize("num_samples", [16]) def test_pair_converter(num_samples: int): data_args = DataArguments(dataset="frozenleaves/tiny-dpo/dataset_info.yaml") @@ -98,4 +189,6 @@ def test_pair_converter(num_samples: int): if __name__ == "__main__": test_alpaca_converter(1) + test_sharegpt_converter_invalid() + test_sharegpt_converter_valid() test_pair_converter(1)