Former-commit-id: 15bbcdf8d3265f4154d3937719da5e54a5963355
This commit is contained in:
fzc8578 2025-01-10 20:55:52 +08:00
parent cc6a6f698f
commit 994049380d
4 changed files with 10 additions and 9 deletions

View File

@ -281,7 +281,7 @@ class CpmOPlugin(BasePlugin):
mm_inputs = self._get_mm_inputs(images, videos, processor) mm_inputs = self._get_mm_inputs(images, videos, processor)
pattern = "(<image>./</image>)" pattern = "(<image>./</image>)"
images, image_sizes, _ = mm_inputs["pixel_values"], mm_inputs["image_sizes"], mm_inputs["tgt_sizes"] images, image_sizes = mm_inputs["pixel_values"], mm_inputs["image_sizes"]
image_index = 0 image_index = 0
for index, message in enumerate(messages): for index, message in enumerate(messages):
@ -334,6 +334,7 @@ class CpmOPlugin(BasePlugin):
new_images.append(images[idx : idx + valid_image_nums]) new_images.append(images[idx : idx + valid_image_nums])
idx += valid_image_nums idx += valid_image_nums
images = new_images images = new_images
image_inputs = image_processor( image_inputs = image_processor(
images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt" images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
) )

View File

@ -100,7 +100,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_processor(processor, config, tokenizer, model_args) patch_processor(processor, config, tokenizer, model_args)
except Exception as e: except Exception as e:
logger.info(f"Processor was not found: {e}.") logger.debug(f"Processor was not found: {e}.")
processor = None processor = None
# Avoid load tokenizer, see: # Avoid load tokenizer, see:

View File

@ -138,13 +138,12 @@ def patch_model(
add_valuehead: bool, add_valuehead: bool,
) -> None: ) -> None:
gen_config = model.generation_config # check and fix generation config gen_config = model.generation_config # check and fix generation config
if gen_config is not None: if not gen_config.do_sample and (
if not gen_config.do_sample and ( (gen_config.temperature is not None and gen_config.temperature != 1.0)
(gen_config.temperature is not None and gen_config.temperature != 1.0) or (gen_config.top_p is not None and gen_config.top_p != 1.0)
or (gen_config.top_p is not None and gen_config.top_p != 1.0) or (gen_config.typical_p is not None and gen_config.typical_p != 1.0)
or (gen_config.typical_p is not None and gen_config.typical_p != 1.0) ):
): gen_config.do_sample = True
gen_config.do_sample = True
if "GenerationMixin" not in str(model.generate.__func__): if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model) model.generate = MethodType(PreTrainedModel.generate, model)

View File

@ -122,6 +122,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
labels = inputs.pop("labels", None) labels = inputs.pop("labels", None)
else: else:
labels = inputs.get("labels") labels = inputs.get("labels")
loss, generated_tokens, _ = super().prediction_step( loss, generated_tokens, _ = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs
) )