mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 04:02:49 +08:00
remove useless codes
Former-commit-id: 9b2642a2b53d3392e95061ed0f2c8dc10580c9e8
This commit is contained in:
parent
9c4941a1ea
commit
a24f94a36c
@ -68,42 +68,6 @@ def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
return image_processor(images=IMAGES, return_tensors="pt")
|
||||
|
||||
def _is_nested_tensor_list(element):
|
||||
if not isinstance(element, list):
|
||||
return False
|
||||
|
||||
for item in element:
|
||||
if isinstance(item, list):
|
||||
if not _is_nested_tensor_list(item):
|
||||
return False
|
||||
|
||||
elif not isinstance(item, torch.Tensor):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _equal_nested_tensor_list(a: List[List[torch.Tensor]], b: List[List[torch.Tensor]]) -> bool:
|
||||
if type(a) != type(b):
|
||||
return False
|
||||
|
||||
if isinstance(a, list) and isinstance(b, list):
|
||||
if len(a) != len(b):
|
||||
return False
|
||||
|
||||
for sub_a, sub_b in zip(a, b):
|
||||
if isinstance(sub_a, list) and isinstance(sub_b, list):
|
||||
if not _equal_nested_tensor_list(sub_a, sub_b):
|
||||
return False
|
||||
elif isinstance(sub_a, torch.Tensor) and isinstance(sub_b, torch.Tensor):
|
||||
if not torch.equal(sub_a, sub_b):
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
|
||||
assert batch_a.keys() == batch_b.keys()
|
||||
|
Loading…
x
Reference in New Issue
Block a user