remove visual_inputs, fix qlora

This commit is contained in:
hiyouga
2024-08-31 00:24:51 +08:00
parent a244f143f4
commit a025c3df61
22 changed files with 112 additions and 106 deletions

View File

@@ -61,7 +61,6 @@ def run_sft(
# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
training_args.remove_unused_columns = False if model_args.visual_inputs else training_args.remove_unused_columns
# Metric utils
metric_module = {}

View File

@@ -132,7 +132,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
if model_args.export_hub_model_id is not None:
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
if model_args.visual_inputs and processor is not None:
if processor is not None:
getattr(processor, "image_processor").save_pretrained(model_args.export_dir)
if model_args.export_hub_model_id is not None:
getattr(processor, "image_processor").push_to_hub(