mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
fix inputs
This commit is contained in:
@@ -164,7 +164,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
logits_processor=get_logits_processor(),
|
||||
)
|
||||
|
||||
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
|
||||
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
|
||||
for key, value in mm_inputs.items():
|
||||
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
|
||||
value = torch.stack(value) # assume they have same sizes
|
||||
|
||||
Reference in New Issue
Block a user