mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
remove some unnecessary if conditions
Former-commit-id: de06e2678e2168586614242f65939c5772e78774
This commit is contained in:
parent
b76116bb6c
commit
66e473d519
@ -168,7 +168,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
for key, value in mm_inputs.items():
|
||||
value = value if isinstance(value, torch.Tensor) else torch.tensor(value)
|
||||
gen_kwargs[key] = value.to(model.device)
|
||||
|
||||
|
||||
return gen_kwargs, prompt_length
|
||||
|
||||
@staticmethod
|
||||
|
@ -325,6 +325,14 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
return mm_inputs
|
||||
|
||||
class PixtralPlugin(BasePlugin):
|
||||
# @override
|
||||
# def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
||||
# image = super()._preprocess_image(image, **kwargs)
|
||||
# UP_SIZE = (512,512)
|
||||
# image = image.resize(UP_SIZE, resample=Image.NEAREST)
|
||||
|
||||
# return image
|
||||
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
@ -340,15 +348,22 @@ class PixtralPlugin(BasePlugin):
|
||||
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens = 0
|
||||
image_input_sizes = self._get_mm_inputs(images, videos, processor)["image_sizes"]
|
||||
img_kwargs = self._get_mm_inputs(images, videos, processor)
|
||||
image_input_sizes = None
|
||||
|
||||
if img_kwargs.get("pixel_values") is not None:
|
||||
image_input_sizes = img_kwargs["image_sizes"]
|
||||
|
||||
messages = deepcopy(messages)
|
||||
print(image_input_sizes[0], messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
img_id = 0
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
# only support one image for one time?
|
||||
image_size = image_input_sizes[0][0]
|
||||
|
||||
if image_input_sizes is None:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
|
||||
image_size = image_input_sizes[0][img_id]
|
||||
height, width = image_size
|
||||
num_height_tokens = height // patch_size
|
||||
num_width_tokens = width // patch_size
|
||||
@ -359,7 +374,7 @@ class PixtralPlugin(BasePlugin):
|
||||
replace_tokens = [item for sublist in replace_tokens for item in sublist]
|
||||
replace_tokens[-1] = image_end_token
|
||||
replace_str = "".join(replace_tokens)
|
||||
content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
|
||||
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
|
||||
|
||||
img_id += 1
|
||||
num_image_tokens += 1
|
||||
@ -383,7 +398,16 @@ class PixtralPlugin(BasePlugin):
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
|
||||
self._validate_input(images, videos)
|
||||
return self._get_mm_inputs(images, videos, processor)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
if mm_inputs.get('image_sizes'):
|
||||
del mm_inputs['image_sizes']
|
||||
# TODO fix this type error
|
||||
# if isinstance(mm_inputs.get("pixel_values"), list): #List[List[torch.tensor]] -> [B C W H]
|
||||
# recommend for batch==1 for one gpu or it will rise the error of BatchEncoding.
|
||||
mm_inputs["pixel_values"] = mm_inputs.get("pixel_values")[0][0].unsqueeze(0)
|
||||
# mm_inputs["pixel_values"] = mm_inputs.get("pixel_values")
|
||||
|
||||
return mm_inputs
|
||||
|
||||
class Qwen2vlPlugin(BasePlugin):
|
||||
@override
|
||||
|
@ -917,16 +917,6 @@ register_model_group(
|
||||
template="mistral",
|
||||
)
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Pixtral-12B-2409": {
|
||||
DownloadSource.DEFAULT: "mistral-community/pixtral-12b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b",
|
||||
}
|
||||
},
|
||||
template="mistral"
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
@ -1067,6 +1057,18 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Pixtral-12B-2409": {
|
||||
DownloadSource.DEFAULT: "mistral-community/pixtral-12b",
|
||||
DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b",
|
||||
}
|
||||
},
|
||||
template="mistral",
|
||||
vision=True
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen-1.8B": {
|
||||
|
Loading…
x
Reference in New Issue
Block a user