mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 03:10:35 +08:00
refactor data preprocessing, fix mllm rlhf
Former-commit-id: 3a023bca2a
This commit is contained in:
@@ -61,7 +61,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
and image is not None
|
||||
and not hasattr(processor, "image_seq_length")
|
||||
and IMAGE_TOKEN not in messages[0]["content"]
|
||||
): # llava case
|
||||
): # llava-like models
|
||||
messages[0]["content"] = IMAGE_TOKEN + messages[0]["content"]
|
||||
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
@@ -74,7 +74,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
batch_feature = image_processor(image, return_tensors="pt")
|
||||
pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W)
|
||||
if hasattr(processor, "image_seq_length"): # paligemma case
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
||||
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ class VllmEngine(BaseEngine):
|
||||
and image is not None
|
||||
and not hasattr(self.processor, "image_seq_length")
|
||||
and IMAGE_TOKEN not in messages[0]["content"]
|
||||
): # llava case
|
||||
): # llava-like models
|
||||
messages[0]["content"] = IMAGE_TOKEN * self.image_feature_size + messages[0]["content"]
|
||||
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
|
||||
Reference in New Issue
Block a user