mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
parent
d7130ec635
commit
a9ce54d143
@ -75,18 +75,23 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
|||||||
get_template_and_fix_tokenizer(tokenizer, data_args.template)
|
get_template_and_fix_tokenizer(tokenizer, data_args.template)
|
||||||
model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab
|
model = load_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab
|
||||||
|
|
||||||
if getattr(model, "quantization_method", None) and model_args.adapter_name_or_path is not None:
|
if getattr(model, "quantization_method", None) is not None and model_args.adapter_name_or_path is not None:
|
||||||
raise ValueError("Cannot merge adapters to a quantized model.")
|
raise ValueError("Cannot merge adapters to a quantized model.")
|
||||||
|
|
||||||
if not isinstance(model, PreTrainedModel):
|
if not isinstance(model, PreTrainedModel):
|
||||||
raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
|
raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
|
||||||
|
|
||||||
if getattr(model, "quantization_method", None) is None: # cannot convert dtype of a quantized model
|
if getattr(model, "quantization_method", None) is not None: # quantized model adopts float16 type
|
||||||
output_dtype = getattr(model.config, "torch_dtype", torch.float16)
|
setattr(model.config, "torch_dtype", torch.float16)
|
||||||
|
else:
|
||||||
|
if model_args.infer_dtype == "auto":
|
||||||
|
output_dtype = getattr(model.config, "torch_dtype", torch.float16)
|
||||||
|
else:
|
||||||
|
output_dtype = getattr(torch, model_args.infer_dtype)
|
||||||
|
|
||||||
setattr(model.config, "torch_dtype", output_dtype)
|
setattr(model.config, "torch_dtype", output_dtype)
|
||||||
model = model.to(output_dtype)
|
model = model.to(output_dtype)
|
||||||
else:
|
logger.info("Convert model dtype to: {}.".format(output_dtype))
|
||||||
setattr(model.config, "torch_dtype", torch.float16)
|
|
||||||
|
|
||||||
model.save_pretrained(
|
model.save_pretrained(
|
||||||
save_directory=model_args.export_dir,
|
save_directory=model_args.export_dir,
|
||||||
|
@ -48,7 +48,7 @@ def save_model(
|
|||||||
template: str,
|
template: str,
|
||||||
visual_inputs: bool,
|
visual_inputs: bool,
|
||||||
export_size: int,
|
export_size: int,
|
||||||
export_quantization_bit: int,
|
export_quantization_bit: str,
|
||||||
export_quantization_dataset: str,
|
export_quantization_dataset: str,
|
||||||
export_device: str,
|
export_device: str,
|
||||||
export_legacy_format: bool,
|
export_legacy_format: bool,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user