mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-19 12:12:48 +08:00
[data] fix template (#8827)
This commit is contained in:
parent
706b3e5ee7
commit
bc54ed8efb
@ -62,7 +62,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
|
|||||||
|
|
||||||
if self.data_args.train_on_prompt:
|
if self.data_args.train_on_prompt:
|
||||||
source_label = source_ids
|
source_label = source_ids
|
||||||
elif self.template.efficient_eos:
|
elif self.template.efficient_eos and turn_idx != 0:
|
||||||
source_label = [self.tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
|
source_label = [self.tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
|
||||||
else:
|
else:
|
||||||
source_label = [IGNORE_INDEX] * source_len
|
source_label = [IGNORE_INDEX] * source_len
|
||||||
|
@ -1069,7 +1069,9 @@ register_template(
|
|||||||
format_assistant=StringFormatter(slots=["{{content}}<|end|>"]),
|
format_assistant=StringFormatter(slots=["{{content}}<|end|>"]),
|
||||||
format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]),
|
format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]),
|
||||||
default_system="You are ChatGPT, a large language model trained by OpenAI.",
|
default_system="You are ChatGPT, a large language model trained by OpenAI.",
|
||||||
|
thought_words=("<|channel|>analysis<|message|>", "<|end|><|start|>assistant<|channel|>final<|message|>"),
|
||||||
efficient_eos=True,
|
efficient_eos=True,
|
||||||
|
template_class=ReasoningTemplate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -126,6 +126,7 @@ class QuantizationMethod(str, Enum):
|
|||||||
QUANTO = "quanto"
|
QUANTO = "quanto"
|
||||||
EETQ = "eetq"
|
EETQ = "eetq"
|
||||||
HQQ = "hqq"
|
HQQ = "hqq"
|
||||||
|
MXFP4 = "mxfp4"
|
||||||
|
|
||||||
|
|
||||||
class RopeScaling(str, Enum):
|
class RopeScaling(str, Enum):
|
||||||
|
@ -90,12 +90,13 @@ def configure_quantization(
|
|||||||
if model_args.quantization_bit is not None:
|
if model_args.quantization_bit is not None:
|
||||||
logger.warning_rank0("`quantization_bit` will not affect on the PTQ-quantized models.")
|
logger.warning_rank0("`quantization_bit` will not affect on the PTQ-quantized models.")
|
||||||
|
|
||||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
|
|
||||||
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
|
|
||||||
|
|
||||||
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
|
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
|
||||||
quant_method = quantization_config.get("quant_method", "")
|
quant_method = quantization_config.get("quant_method", "")
|
||||||
|
|
||||||
|
if quant_method != QuantizationMethod.MXFP4 and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
|
||||||
|
# mxfp4 will dequant the model weights
|
||||||
|
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
|
||||||
|
|
||||||
if quant_method == QuantizationMethod.GPTQ:
|
if quant_method == QuantizationMethod.GPTQ:
|
||||||
check_version("gptqmodel>=2.0.0", mandatory=True)
|
check_version("gptqmodel>=2.0.0", mandatory=True)
|
||||||
quantization_config.pop("disable_exllama", None) # remove deprecated args
|
quantization_config.pop("disable_exllama", None) # remove deprecated args
|
||||||
|
Loading…
x
Reference in New Issue
Block a user