mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-24 06:42:52 +08:00
fix some
Former-commit-id: 15bbcdf8d3265f4154d3937719da5e54a5963355
This commit is contained in:
parent
cc6a6f698f
commit
994049380d
@ -281,7 +281,7 @@ class CpmOPlugin(BasePlugin):
|
|||||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
|
||||||
pattern = "(<image>./</image>)"
|
pattern = "(<image>./</image>)"
|
||||||
images, image_sizes, _ = mm_inputs["pixel_values"], mm_inputs["image_sizes"], mm_inputs["tgt_sizes"]
|
images, image_sizes = mm_inputs["pixel_values"], mm_inputs["image_sizes"]
|
||||||
|
|
||||||
image_index = 0
|
image_index = 0
|
||||||
for index, message in enumerate(messages):
|
for index, message in enumerate(messages):
|
||||||
@ -334,6 +334,7 @@ class CpmOPlugin(BasePlugin):
|
|||||||
new_images.append(images[idx : idx + valid_image_nums])
|
new_images.append(images[idx : idx + valid_image_nums])
|
||||||
idx += valid_image_nums
|
idx += valid_image_nums
|
||||||
images = new_images
|
images = new_images
|
||||||
|
|
||||||
image_inputs = image_processor(
|
image_inputs = image_processor(
|
||||||
images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
|
images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
|
||||||
)
|
)
|
||||||
|
@ -100,7 +100,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
|||||||
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||||
patch_processor(processor, config, tokenizer, model_args)
|
patch_processor(processor, config, tokenizer, model_args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"Processor was not found: {e}.")
|
logger.debug(f"Processor was not found: {e}.")
|
||||||
processor = None
|
processor = None
|
||||||
|
|
||||||
# Avoid load tokenizer, see:
|
# Avoid load tokenizer, see:
|
||||||
|
@ -138,13 +138,12 @@ def patch_model(
|
|||||||
add_valuehead: bool,
|
add_valuehead: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
gen_config = model.generation_config # check and fix generation config
|
gen_config = model.generation_config # check and fix generation config
|
||||||
if gen_config is not None:
|
if not gen_config.do_sample and (
|
||||||
if not gen_config.do_sample and (
|
(gen_config.temperature is not None and gen_config.temperature != 1.0)
|
||||||
(gen_config.temperature is not None and gen_config.temperature != 1.0)
|
or (gen_config.top_p is not None and gen_config.top_p != 1.0)
|
||||||
or (gen_config.top_p is not None and gen_config.top_p != 1.0)
|
or (gen_config.typical_p is not None and gen_config.typical_p != 1.0)
|
||||||
or (gen_config.typical_p is not None and gen_config.typical_p != 1.0)
|
):
|
||||||
):
|
gen_config.do_sample = True
|
||||||
gen_config.do_sample = True
|
|
||||||
|
|
||||||
if "GenerationMixin" not in str(model.generate.__func__):
|
if "GenerationMixin" not in str(model.generate.__func__):
|
||||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||||
|
@ -122,6 +122,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
labels = inputs.pop("labels", None)
|
labels = inputs.pop("labels", None)
|
||||||
else:
|
else:
|
||||||
labels = inputs.get("labels")
|
labels = inputs.get("labels")
|
||||||
|
|
||||||
loss, generated_tokens, _ = super().prediction_step(
|
loss, generated_tokens, _ = super().prediction_step(
|
||||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs
|
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user