diff --git a/src/llamafactory/data/processor/supervised.py b/src/llamafactory/data/processor/supervised.py index 26e50d93..b5aba11b 100644 --- a/src/llamafactory/data/processor/supervised.py +++ b/src/llamafactory/data/processor/supervised.py @@ -62,7 +62,7 @@ class SupervisedDatasetProcessor(DatasetProcessor): if self.data_args.train_on_prompt: source_label = source_ids - elif self.template.efficient_eos: + elif self.template.efficient_eos and turn_idx != 0: source_label = [self.tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1) else: source_label = [IGNORE_INDEX] * source_len diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 70f5e435..e28d462b 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -1069,7 +1069,9 @@ register_template( format_assistant=StringFormatter(slots=["{{content}}<|end|>"]), format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]), default_system="You are ChatGPT, a large language model trained by OpenAI.", + thought_words=("<|channel|>analysis<|message|>", "<|end|><|start|>assistant<|channel|>final<|message|>"), efficient_eos=True, + template_class=ReasoningTemplate, ) diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 077a3497..14ae8c8c 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -126,6 +126,7 @@ class QuantizationMethod(str, Enum): QUANTO = "quanto" EETQ = "eetq" HQQ = "hqq" + MXFP4 = "mxfp4" class RopeScaling(str, Enum): diff --git a/src/llamafactory/model/model_utils/quantization.py b/src/llamafactory/model/model_utils/quantization.py index d3160f69..8b227b7f 100644 --- a/src/llamafactory/model/model_utils/quantization.py +++ b/src/llamafactory/model/model_utils/quantization.py @@ -90,12 +90,13 @@ def configure_quantization( if model_args.quantization_bit is not None: logger.warning_rank0("`quantization_bit` will not affect on the PTQ-quantized models.") - if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): - raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.") - quantization_config: dict[str, Any] = getattr(config, "quantization_config", None) quant_method = quantization_config.get("quant_method", "") + if quant_method != QuantizationMethod.MXFP4 and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()): + # mxfp4 will dequant the model weights + raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.") + if quant_method == QuantizationMethod.GPTQ: check_version("gptqmodel>=2.0.0", mandatory=True) quantization_config.pop("disable_exllama", None) # remove deprecated args