From 78bb6d17faf33d947999ef8e3642843a5cd25bb6 Mon Sep 17 00:00:00 2001 From: Krzysztof Chalupka Date: Fri, 22 Jul 2022 09:43:05 -0700 Subject: [PATCH] Add EGLContext and DeviceContextManager Summary: EGLContext is a utility to render with OpenGL without an attached display (that is, without a monitor). DeviceContextManager allows us to avoid unnecessary context creations and releases. See docstrings for more info. Reviewed By: jcjohnson Differential Revision: D36562551 fbshipit-source-id: eb0d2a2f85555ee110e203d435a44ad243281d2c --- pytorch3d/renderer/__init__.py | 6 + pytorch3d/renderer/opengl/__init__.py | 36 ++ pytorch3d/renderer/opengl/opengl_utils.py | 422 ++++++++++++++++++++++ tests/implicitron/test_build.py | 2 +- tests/test_opengl_utils.py | 387 ++++++++++++++++++++ 5 files changed, 852 insertions(+), 1 deletion(-) create mode 100644 pytorch3d/renderer/opengl/__init__.py create mode 100755 pytorch3d/renderer/opengl/opengl_utils.py create mode 100644 tests/test_opengl_utils.py diff --git a/pytorch3d/renderer/__init__.py b/pytorch3d/renderer/__init__.py index 437cbad4..6f31d2dd 100644 --- a/pytorch3d/renderer/__init__.py +++ b/pytorch3d/renderer/__init__.py @@ -64,6 +64,12 @@ from .mesh import ( TexturesUV, TexturesVertex, ) + +try: + from .opengl import EGLContext, global_device_context_store +except (ImportError, ModuleNotFoundError): + pass # opengl or pycuda.gl not available, or pytorch3_opengl not in TARGETS. + from .points import ( AlphaCompositor, NormWeightedCompositor, diff --git a/pytorch3d/renderer/opengl/__init__.py b/pytorch3d/renderer/opengl/__init__.py new file mode 100644 index 00000000..e4c01645 --- /dev/null +++ b/pytorch3d/renderer/opengl/__init__.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# If we can access EGL, import MeshRasterizerOpenGL. +def _can_import_egl_and_pycuda(): + import os + import warnings + + try: + os.environ["PYOPENGL_PLATFORM"] = "egl" + import OpenGL.EGL + except (AttributeError, ImportError, ModuleNotFoundError): + warnings.warn( + "Can't import EGL, not importing MeshRasterizerOpenGL. This might happen if" + " your Python application imported OpenGL with a non-EGL backend before" + " importing PyTorch3D, or if you don't have pyopengl installed as part" + " of your Python distribution." + ) + return False + + try: + import pycuda.gl + except (ImportError, ImportError, ModuleNotFoundError): + warnings.warn("Can't import pucuda.gl, not importing MeshRasterizerOpenGL.") + return False + + return True + + +if _can_import_egl_and_pycuda(): + from .opengl_utils import EGLContext, global_device_context_store + +__all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/renderer/opengl/opengl_utils.py b/pytorch3d/renderer/opengl/opengl_utils.py new file mode 100755 index 00000000..7bdd4e66 --- /dev/null +++ b/pytorch3d/renderer/opengl/opengl_utils.py @@ -0,0 +1,422 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Utilities useful for OpenGL rendering. +# +# NOTE: This module MUST be imported before any other OpenGL modules in this Python +# session, unless you set PYOPENGL_PLATFORM to egl *before* importing other modules. +# Otherwise, the imports below will throw an error. +# +# This module (as well as rasterizer_opengl) will not be imported into pytorch3d if +# you do not have pycuda.gl and pyopengl installed. + +import contextlib +import ctypes +import os +import threading +from typing import Any, Dict + + +os.environ["PYOPENGL_PLATFORM"] = "egl" +import OpenGL.EGL as egl # noqa + +import pycuda.driver as cuda # noqa +from OpenGL._opaque import opaque_pointer_cls # noqa +from OpenGL.raw.EGL._errors import EGLError # noqa + +# A few constants necessary to use EGL extensions, see links for details. + +# https://www.khronos.org/registry/EGL/extensions/EXT/EGL_EXT_platform_device.txt +EGL_PLATFORM_DEVICE_EXT = 0x313F +# https://www.khronos.org/registry/EGL/extensions/NV/EGL_NV_device_cuda.txt +EGL_CUDA_DEVICE_NV = 0x323A + + +# To use EGL extensions, we need to tell OpenGL about them. For details, see +# https://developer.nvidia.com/blog/egl-eye-opengl-visualization-without-x-server/. +# To avoid garbage collection of the protos, we'll store them in a module-global list. +def _define_egl_extension(name: str, type): + if hasattr(egl, name): + return + addr = egl.eglGetProcAddress(name) + if addr is None: + raise RuntimeError(f"Cannot find EGL extension {name}.") + else: + proto = ctypes.CFUNCTYPE(type) + func = proto(addr) + setattr(egl, name, func) + return proto + + +_protos = [] +_protos.append(_define_egl_extension("eglGetPlatformDisplayEXT", egl.EGLDisplay)) +_protos.append(_define_egl_extension("eglQueryDevicesEXT", egl.EGLBoolean)) +_protos.append(_define_egl_extension("eglQueryDeviceAttribEXT", egl.EGLBoolean)) +_protos.append(_define_egl_extension("eglQueryDisplayAttribEXT", egl.EGLBoolean)) +_protos.append(_define_egl_extension("eglQueryDeviceStringEXT", ctypes.c_char_p)) + +if not hasattr(egl, "EGLDeviceEXT"): + egl.EGLDeviceEXT = opaque_pointer_cls("EGLDeviceEXT") + + +def _egl_convert_to_int_array(egl_attributes): + """ + Convert a Python dict of EGL attributes into an array of ints (some of which are + special EGL ints. + + Args: + egl_attributes: A dict where keys are EGL attributes, and values are their vals. + + Returns: + A c-list of length 2 * len(egl_attributes) + 1, of the form [key1, val1, ..., + keyN, valN, EGL_NONE] + """ + attributes_list = sum(([k, v] for k, v in egl_attributes.items()), []) + [ + egl.EGL_NONE + ] + return (egl.EGLint * len(attributes_list))(*attributes_list) + + +def _get_cuda_device(requested_device_id: int): + """ + Find an EGL device with a given CUDA device ID. + + Args: + requested_device_id: The desired CUDA device ID, e.g. "1" for "cuda:1". + + Returns: + EGL device with the desired CUDA ID. + """ + num_devices = egl.EGLint() + if ( + # pyre-ignore Undefined attribute [16] + not egl.eglQueryDevicesEXT(0, None, ctypes.pointer(num_devices)) + or num_devices.value < 1 + ): + raise RuntimeError("EGL requires a system that supports at least one device.") + devices = (egl.EGLDeviceEXT * num_devices.value)() # array of size num_devices + if ( + # pyre-ignore Undefined attribute [16] + not egl.eglQueryDevicesEXT( + num_devices.value, devices, ctypes.pointer(num_devices) + ) + or num_devices.value < 1 + ): + raise RuntimeError("EGL sees no available devices.") + if len(devices) < requested_device_id + 1: + raise ValueError( + f"Device {requested_device_id} not available. Found only {len(devices)} devices." + ) + + # Iterate over all the EGL devices, and check if their CUDA ID matches the request. + for device in devices: + available_device_id = egl.EGLAttrib(ctypes.c_int(-1)) + # pyre-ignore Undefined attribute [16] + egl.eglQueryDeviceAttribEXT(device, EGL_CUDA_DEVICE_NV, available_device_id) + if available_device_id.contents.value == requested_device_id: + return device + raise ValueError( + f"Found {len(devices)} CUDA devices, but none with CUDA id {requested_device_id}." + ) + + +def _get_egl_config(egl_dpy, surface_type): + """ + Get an EGL config with reasonable settings (for use with MeshRasterizerOpenGL). + + Args: + egl_dpy: An EGL display constant (int). + surface_type: An EGL surface_type int. + + Returns: + An EGL config object. + + Throws: + ValueError if the desired config is not available or invalid. + """ + egl_config_dict = { + egl.EGL_RED_SIZE: 8, + egl.EGL_GREEN_SIZE: 8, + egl.EGL_BLUE_SIZE: 8, + egl.EGL_ALPHA_SIZE: 8, + egl.EGL_DEPTH_SIZE: 24, + egl.EGL_STENCIL_SIZE: egl.EGL_DONT_CARE, + egl.EGL_RENDERABLE_TYPE: egl.EGL_OPENGL_BIT, + egl.EGL_SURFACE_TYPE: surface_type, + } + egl_config_array = _egl_convert_to_int_array(egl_config_dict) + egl_config = egl.EGLConfig() + num_configs = egl.EGLint() + if ( + not egl.eglChooseConfig( + egl_dpy, + egl_config_array, + ctypes.pointer(egl_config), + 1, + ctypes.pointer(num_configs), + ) + or num_configs.value == 0 + ): + raise ValueError("Invalid EGL config.") + return egl_config + + +class EGLContext: + """ + A class representing an EGL context. In short, EGL allows us to render OpenGL con- + tent in a headless mode, that is without an actual display to render to. This capa- + bility enables MeshRasterizerOpenGL to render on the GPU and then transfer the re- + sults to PyTorch3D. + """ + + def __init__(self, width: int, height: int, cuda_device_id: int = 0) -> None: + """ + Args: + width: Width of the "display" to render to. + height: Height of the "display" to render to. + cuda_device_id: Device ID to render to, in the CUDA convention (note that + this might be different than EGL's device numbering). + """ + # Lock used to prevent multiple threads from rendering on the same device + # at the same time, creating/destroying contexts at the same time, etc. + self.lock = threading.Lock() + self.cuda_device_id = cuda_device_id + self.device = _get_cuda_device(self.cuda_device_id) + self.width = width + self.height = height + self.dpy = egl.eglGetPlatformDisplayEXT( + EGL_PLATFORM_DEVICE_EXT, self.device, None + ) + major, minor = egl.EGLint(), egl.EGLint() + + # Initialize EGL components: the display, surface, and context + egl.eglInitialize(self.dpy, ctypes.pointer(major), ctypes.pointer(minor)) + + config = _get_egl_config(self.dpy, egl.EGL_PBUFFER_BIT) + pb_surf_attribs = _egl_convert_to_int_array( + { + egl.EGL_WIDTH: width, + egl.EGL_HEIGHT: height, + } + ) + self.surface = egl.eglCreatePbufferSurface(self.dpy, config, pb_surf_attribs) + if self.surface == egl.EGL_NO_SURFACE: + raise RuntimeError("Failed to create an EGL surface.") + + if not egl.eglBindAPI(egl.EGL_OPENGL_API): + raise RuntimeError("Failed to bind EGL to the OpenGL API.") + self.context = egl.eglCreateContext(self.dpy, config, egl.EGL_NO_CONTEXT, None) + if self.context == egl.EGL_NO_CONTEXT: + raise RuntimeError("Failed to create an EGL context.") + + @contextlib.contextmanager + def active_and_locked(self): + """ + A context manager used to make sure a given EGL context is only current in + a single thread at a single time. It is recommended to ALWAYS use EGL within + a `with context.active_and_locked():` context. + + Throws: + EGLError when the context cannot be made current or make non-current. + """ + self.lock.acquire() + egl.eglMakeCurrent(self.dpy, self.surface, self.surface, self.context) + yield + egl.eglMakeCurrent( + self.dpy, egl.EGL_NO_SURFACE, egl.EGL_NO_SURFACE, egl.EGL_NO_CONTEXT + ) + self.lock.release() + + def get_context_info(self) -> Dict[str, Any]: + """ + Return context info. Useful for debugging. + + Returns: + A dict of keys and ints, representing the context's display, surface, + the context itself, and the current thread. + """ + return { + "dpy": self.dpy, + "surface": self.surface, + "context": self.context, + "thread": threading.get_ident(), + } + + def release(self): + """ + Release the context's resources. + """ + self.lock.acquire() + try: + if self.surface: + egl.eglDestroySurface(self.dpy, self.surface) + if self.context and self.dpy: + egl.eglDestroyContext(self.dpy, self.context) + egl.eglMakeCurrent( + self.dpy, egl.EGL_NO_SURFACE, egl.EGL_NO_SURFACE, egl.EGL_NO_CONTEXT + ) + if self.dpy: + egl.eglTerminate(self.dpy) + except EGLError as err: + print( + f"EGL could not release context on device cuda:{self.cuda_device_id}." + " This can happen if you created two contexts on the same device." + " Instead, you can use DeviceContextStore to use a single context" + " per device, and EGLContext.make_(in)active_in_current_thread to" + " (in)activate the context as needed." + ) + raise err + + egl.eglReleaseThread() + self.lock.release() + + +class _DeviceContextStore: + """ + DeviceContextStore provides thread-safe storage for EGL and pycuda contexts. It + should not be used directly. opengl_utils instantiates a module-global variable + called opengl_utils.global_device_context_store. MeshRasterizerOpenGL uses this + store to avoid unnecessary context creation and destruction. + + The EGL/CUDA contexts are not meant to be created and destroyed all the time, + and having multiple on a single device can be troublesome. Intended use is entirely + transparent to the user: + + ``` + rasterizer1 = MeshRasterizerOpenGL(...some args...) + mesh1 = load_mesh_on_cuda_0() + + # Now rasterizer1 will request EGL/CUDA contexts from global_device_context_store + # on cuda:0, and since there aren't any, the store will create new ones. + rasterizer1.rasterize(mesh1) + + # rasterizer2 also needs EGL & CUDA contexts. But global_context_store already has + # them for cuda:0. Instead of creating new contexts, the store will tell rasterizer2 + # to use them. + rasterizer2 = MeshRasterizerOpenGL(dcs) + rasterize2.rasterize(mesh1) + + # When rasterizer1 needs to render on cuda:1, the store will create new contexts. + mesh2 = load_mesh_on_cuda_1() + rasterizer1.rasterize(mesh2) + ``` + """ + + def __init__(self): + cuda.init() + # pycuda contexts, at most one per device. + self._cuda_contexts = {} + # EGL contexts, at most one per device. + self._egl_contexts = {} + # Any extra per-device data (e.g. precompiled GL objects). + self._context_data = {} + # Lock for DeviceContextStore used in multithreaded multidevice scenarios. + self._lock = threading.Lock() + # All EGL contexts created by this store will have this resolution. + self.max_egl_width = 2048 + self.max_egl_height = 2048 + + def get_cuda_context(self, device): + """ + Return a pycuda's CUDA context on a given CUDA device. If we have not created + such a context yet, create a new one and store it in a dict. The context is + popped (you need to call context.push() to start using it). This function + is thread-safe. + + Args: + device: A torch.device. + + Returns: A pycuda context corresponding to the given device. + """ + cuda_device_id = device.index + with self._lock: + if cuda_device_id not in self._cuda_contexts: + self._cuda_contexts[cuda_device_id] = _init_cuda_context(cuda_device_id) + self._cuda_contexts[cuda_device_id].pop() + return self._cuda_contexts[cuda_device_id] + + def get_egl_context(self, device): + """ + Return an EGL context on a given CUDA device. If we have not created such a + context yet, create a new one and store it in a dict. The context if not current + (you should use the `with egl_context.active_and_locked:` context manager when + you need it to be current). This function is thread-safe. + + Args: + device: A torch.device. + + Returns: An EGLContext on the requested device. The context will have size + self.max_egl_width and self.max_egl_height. + """ + cuda_device_id = device.index + with self._lock: + egl_context = self._egl_contexts.get(cuda_device_id, None) + if egl_context is None: + self._egl_contexts[cuda_device_id] = EGLContext( + self.max_egl_width, self.max_egl_height, cuda_device_id + ) + return self._egl_contexts[cuda_device_id] + + def set_context_data(self, device, value): + """ + Set arbitrary data in a per-device dict. + + This function is intended for storing precompiled OpenGL objects separately for + EGL contexts on different devices. Each such context needs a separate compiled + OpenGL program, but (in case e.g. of MeshRasterizerOpenGL) there's no need to + re-compile it each time we move the rasterizer to the same device repeatedly, + as it happens when using DataParallel. + + Args: + device: A torch.device + value: An arbitrary Python object. + """ + + cuda_device_id = device.index + self._context_data[cuda_device_id] = value + + def get_context_data(self, device): + """ + Get arbitrary data in a per-device dict. See set_context_data for more detail. + + Args: + device: A torch.device + + Returns: + The most recent object stored using set_context_data. + """ + cuda_device_id = device.index + return self._context_data.get(cuda_device_id, None) + + def release(self): + """ + Release all CUDA and EGL contexts. + """ + for context in self._cuda_contexts.values(): + context.detach() + + for context in self._egl_contexts.values(): + context.release() + + +def _init_cuda_context(device_id: int = 0): + """ + Initialize a pycuda context on a chosen device. + + Args: + device_id: int, specifies which GPU to use. + + Returns: + A pycuda Context. + """ + # pyre-ignore Undefined attribute [16] + device = cuda.Device(device_id) + cuda_context = device.make_context() + return cuda_context + + +# Initialize a global _DeviceContextStore. Almost always we will only need a single one. +global_device_context_store = _DeviceContextStore() diff --git a/tests/implicitron/test_build.py b/tests/implicitron/test_build.py index f554c0f4..3a8579ac 100644 --- a/tests/implicitron/test_build.py +++ b/tests/implicitron/test_build.py @@ -27,7 +27,7 @@ class TestBuild(unittest.TestCase): root_dir = get_pytorch3d_dir() / "pytorch3d" for module_file in root_dir.glob("**/*.py"): - if module_file.stem in ("__init__", "plotly_vis"): + if module_file.stem in ("__init__", "plotly_vis", "opengl_utils"): continue relative_module = str(module_file.relative_to(root_dir))[:-3] module = "pytorch3d." + relative_module.replace("/", ".") diff --git a/tests/test_opengl_utils.py b/tests/test_opengl_utils.py new file mode 100644 index 00000000..31d9c2bb --- /dev/null +++ b/tests/test_opengl_utils.py @@ -0,0 +1,387 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import ctypes +import os +import sys +import threading +import unittest + +import torch + +os.environ["PYOPENGL_PLATFORM"] = "egl" +import pycuda._driver # noqa +from OpenGL import GL as gl # noqa +from OpenGL.raw.EGL._errors import EGLError # noqa +from pytorch3d.renderer.opengl import _can_import_egl_and_pycuda # noqa +from pytorch3d.renderer.opengl.opengl_utils import ( # noqa + _define_egl_extension, + _egl_convert_to_int_array, + _get_cuda_device, + egl, + EGLContext, + global_device_context_store, +) + +from .common_testing import TestCaseMixin # noqa + +MAX_EGL_HEIGHT = global_device_context_store.max_egl_height +MAX_EGL_WIDTH = global_device_context_store.max_egl_width + + +def _draw_square(r=1.0, g=0.0, b=1.0, **kwargs) -> torch.Tensor: + gl.glClear(gl.GL_COLOR_BUFFER_BIT) + gl.glColor3f(r, g, b) + x1, x2 = -0.5, 0.5 + y1, y2 = -0.5, 0.5 + gl.glRectf(x1, y1, x2, y2) + out_buffer = gl.glReadPixels( + 0, 0, MAX_EGL_WIDTH, MAX_EGL_HEIGHT, gl.GL_RGB, gl.GL_UNSIGNED_BYTE + ) + image = torch.frombuffer(out_buffer, dtype=torch.uint8).reshape( + MAX_EGL_HEIGHT, MAX_EGL_WIDTH, 3 + ) + return image + + +def _draw_squares_with_context( + cuda_device_id=0, result=None, thread_id=None, **kwargs +) -> None: + context = EGLContext(MAX_EGL_WIDTH, MAX_EGL_HEIGHT, cuda_device_id) + with context.active_and_locked(): + images = [] + for _ in range(3): + images.append(_draw_square(**kwargs).float()) + if result is not None and thread_id is not None: + egl_info = context.get_context_info() + data = {"egl": egl_info, "images": images} + result[thread_id] = data + + +def _draw_squares_with_context_store( + cuda_device_id=0, + result=None, + thread_id=None, + verbose=False, + **kwargs, +) -> None: + device = torch.device(f"cuda:{cuda_device_id}") + context = global_device_context_store.get_egl_context(device) + if verbose: + print(f"In thread {thread_id}, device {cuda_device_id}.") + with context.active_and_locked(): + images = [] + for _ in range(3): + images.append(_draw_square(**kwargs).float()) + if result is not None and thread_id is not None: + egl_info = context.get_context_info() + data = {"egl": egl_info, "images": images} + result[thread_id] = data + + +class TestDeviceContextStore(TestCaseMixin, unittest.TestCase): + def test_cuda_context(self): + cuda_context_1 = global_device_context_store.get_cuda_context( + device=torch.device("cuda:0") + ) + cuda_context_2 = global_device_context_store.get_cuda_context( + device=torch.device("cuda:0") + ) + cuda_context_3 = global_device_context_store.get_cuda_context( + device=torch.device("cuda:1") + ) + cuda_context_4 = global_device_context_store.get_cuda_context( + device=torch.device("cuda:1") + ) + self.assertIs(cuda_context_1, cuda_context_2) + self.assertIs(cuda_context_3, cuda_context_4) + self.assertIsNot(cuda_context_1, cuda_context_3) + + def test_egl_context(self): + egl_context_1 = global_device_context_store.get_egl_context( + torch.device("cuda:0") + ) + egl_context_2 = global_device_context_store.get_egl_context( + torch.device("cuda:0") + ) + egl_context_3 = global_device_context_store.get_egl_context( + torch.device("cuda:1") + ) + egl_context_4 = global_device_context_store.get_egl_context( + torch.device("cuda:1") + ) + self.assertIs(egl_context_1, egl_context_2) + self.assertIs(egl_context_3, egl_context_4) + self.assertIsNot(egl_context_1, egl_context_3) + + +class TestUtils(TestCaseMixin, unittest.TestCase): + def test_load_extensions(self): + # This should work + _define_egl_extension("eglGetPlatformDisplayEXT", egl.EGLDisplay) + + # And this shouldn't (wrong extension) + with self.assertRaisesRegex(RuntimeError, "Cannot find EGL extension"): + _define_egl_extension("eglFakeExtensionEXT", egl.EGLBoolean) + + def test_get_cuda_device(self): + # This should work + device = _get_cuda_device(0) + self.assertIsNotNone(device) + + with self.assertRaisesRegex(ValueError, "Device 10000 not available"): + _get_cuda_device(10000) + + def test_egl_convert_to_int_array(self): + egl_attributes = {egl.EGL_RED_SIZE: 8} + attribute_array = _egl_convert_to_int_array(egl_attributes) + self.assertEqual(attribute_array._type_, ctypes.c_int) + self.assertEqual(attribute_array._length_, 3) + self.assertEqual(attribute_array[0], egl.EGL_RED_SIZE) + self.assertEqual(attribute_array[1], 8) + self.assertEqual(attribute_array[2], egl.EGL_NONE) + + +class TestOpenGLSingleThreaded(TestCaseMixin, unittest.TestCase): + def test_draw_square(self): + context = EGLContext(width=MAX_EGL_WIDTH, height=MAX_EGL_HEIGHT) + with context.active_and_locked(): + rendering_result = _draw_square().float() + expected_result = torch.zeros( + (MAX_EGL_WIDTH, MAX_EGL_HEIGHT, 3), dtype=torch.float + ) + start_px = int(MAX_EGL_WIDTH / 4) + end_px = int(MAX_EGL_WIDTH * 3 / 4) + expected_result[start_px:end_px, start_px:end_px, 0] = 255.0 + expected_result[start_px:end_px, start_px:end_px, 2] = 255.0 + + self.assertTrue(torch.all(expected_result == rendering_result)) + + def test_render_two_squares(self): + # Check that drawing twice doesn't overwrite the initial buffer. + context = EGLContext(width=MAX_EGL_WIDTH, height=MAX_EGL_HEIGHT) + with context.active_and_locked(): + red_square = _draw_square(r=1.0, g=0.0, b=0.0) + blue_square = _draw_square(r=0.0, g=0.0, b=1.0) + + start_px = int(MAX_EGL_WIDTH / 4) + end_px = int(MAX_EGL_WIDTH * 3 / 4) + + self.assertTrue( + torch.all( + red_square[start_px:end_px, start_px:end_px] + == torch.tensor([255, 0, 0]) + ) + ) + self.assertTrue( + torch.all( + blue_square[start_px:end_px, start_px:end_px] + == torch.tensor([0, 0, 255]) + ) + ) + + +class TestOpenGLMultiThreaded(TestCaseMixin, unittest.TestCase): + def test_multiple_renders_single_gpu_single_context(self): + _draw_squares_with_context() + + def test_multiple_renders_single_gpu_context_store(self): + _draw_squares_with_context_store() + + def test_render_two_threads_single_gpu(self): + self._render_two_threads_single_gpu(_draw_squares_with_context) + + def test_render_two_threads_single_gpu_context_store(self): + self._render_two_threads_single_gpu(_draw_squares_with_context_store) + + def test_render_two_threads_two_gpus(self): + self._render_two_threads_two_gpus(_draw_squares_with_context) + + def test_render_two_threads_two_gpus_context_store(self): + self._render_two_threads_two_gpus(_draw_squares_with_context_store) + + def _render_two_threads_single_gpu(self, draw_fn): + result = [None] * 2 + thread1 = threading.Thread( + target=draw_fn, + kwargs={ + "cuda_device_id": 0, + "result": result, + "thread_id": 0, + "r": 1.0, + "g": 0.0, + "b": 0.0, + }, + ) + thread2 = threading.Thread( + target=draw_fn, + kwargs={ + "cuda_device_id": 0, + "result": result, + "thread_id": 1, + "r": 0.0, + "g": 1.0, + "b": 0.0, + }, + ) + + thread1.start() + thread2.start() + thread1.join() + thread2.join() + + start_px = int(MAX_EGL_WIDTH / 4) + end_px = int(MAX_EGL_WIDTH * 3 / 4) + red_squares = torch.stack(result[0]["images"], dim=0)[ + :, start_px:end_px, start_px:end_px + ] + green_squares = torch.stack(result[1]["images"], dim=0)[ + :, start_px:end_px, start_px:end_px + ] + self.assertTrue(torch.all(red_squares == torch.tensor([255.0, 0.0, 0.0]))) + self.assertTrue(torch.all(green_squares == torch.tensor([0.0, 255.0, 0.0]))) + + def _render_two_threads_two_gpus(self, draw_fn): + # Contrary to _render_two_threads_two_gpus, this renders in two separate threads + # but on a different GPU each. This means using different EGL contexts and is a + # much less risky endeavour. + result = [None] * 2 + thread1 = threading.Thread( + target=draw_fn, + kwargs={ + "cuda_device_id": 0, + "result": result, + "thread_id": 0, + "r": 1.0, + "g": 0.0, + "b": 0.0, + }, + ) + thread2 = threading.Thread( + target=draw_fn, + kwargs={ + "cuda_device_id": 1, + "result": result, + "thread_id": 1, + "r": 0.0, + "g": 1.0, + "b": 0.0, + }, + ) + thread1.start() + thread2.start() + thread1.join() + thread2.join() + self.assertNotEqual( + result[0]["egl"]["context"].address, result[1]["egl"]["context"].address + ) + + start_px = int(MAX_EGL_WIDTH / 4) + end_px = int(MAX_EGL_WIDTH * 3 / 4) + red_squares = torch.stack(result[0]["images"], dim=0)[ + :, start_px:end_px, start_px:end_px + ] + green_squares = torch.stack(result[1]["images"], dim=0)[ + :, start_px:end_px, start_px:end_px + ] + self.assertTrue(torch.all(red_squares == torch.tensor([255.0, 0.0, 0.0]))) + self.assertTrue(torch.all(green_squares == torch.tensor([0.0, 255.0, 0.0]))) + + def test_render_multi_thread_multi_gpu(self): + # Multiple threads using up multiple GPUs; more threads than GPUs. + # This is certainly not encouraged in practice, but shouldn't fail. Note that + # the context store will only allow one rendering at a time to occur on a + # single GPU, even across threads. + n_gpus = torch.cuda.device_count() + n_threads = 10 + kwargs = { + "r": 1.0, + "g": 0.0, + "b": 0.0, + "verbose": True, + } + + threads = [] + for thread_id in range(n_threads): + kwargs.update( + {"cuda_device_id": thread_id % n_gpus, "thread_id": thread_id} + ) + threads.append( + threading.Thread( + target=_draw_squares_with_context_store, kwargs=dict(kwargs) + ) + ) + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + +class TestOpenGLUtils(TestCaseMixin, unittest.TestCase): + def test_device_context_store(self): + # Most of DCS's functionality is tested in the tests above, test the remainder. + device = torch.device("cuda:0") + global_device_context_store.set_context_data(device, 123) + + self.assertEqual(global_device_context_store.get_context_data(device), 123) + + self.assertEqual( + global_device_context_store.get_context_data(torch.device("cuda:1")), None + ) + + # Check that contexts in store can be manually released (although that's a very + # bad idea! Don't do it manually!) + egl_ctx = global_device_context_store.get_egl_context(device) + cuda_ctx = global_device_context_store.get_cuda_context(device) + egl_ctx.release() + cuda_ctx.detach() + + # Reset the contexts (just for testing! never do this manually!). Then, check + # that first running DeviceContextStore.release() will cause subsequent releases + # to fail (because we already released all the contexts). + global_device_context_store._cuda_contexts = {} + global_device_context_store._egl_contexts = {} + + egl_ctx = global_device_context_store.get_egl_context(device) + cuda_ctx = global_device_context_store.get_cuda_context(device) + global_device_context_store.release() + with self.assertRaisesRegex(EGLError, "EGL_NOT_INITIALIZED"): + egl_ctx.release() + with self.assertRaisesRegex(pycuda._driver.LogicError, "cannot detach"): + cuda_ctx.detach() + + def test_no_egl_error(self): + # Remove EGL, import OpenGL with the wrong backend. This should make it + # impossible to import OpenGL.EGL. + del os.environ["PYOPENGL_PLATFORM"] + modules = list(sys.modules) + for m in modules: + if "OpenGL" in m: + del sys.modules[m] + import OpenGL.GL # noqa + + self.assertFalse(_can_import_egl_and_pycuda()) + + # Import OpenGL back with the right backend. This should get things on track. + modules = list(sys.modules) + for m in modules: + if "OpenGL" in m: + del sys.modules[m] + + os.environ["PYOPENGL_PLATFORM"] = "egl" + self.assertTrue(_can_import_egl_and_pycuda()) + + def test_egl_release_error(self): + # Creating two contexts on the same device will lead to trouble (that's one of + # the reasons behind DeviceContextStore). You can release one of them, + # but you cannot release the same EGL resources twice! + ctx1 = EGLContext(width=100, height=100) + ctx2 = EGLContext(width=100, height=100) + + ctx1.release() + with self.assertRaisesRegex(EGLError, "EGL_NOT_INITIALIZED"): + ctx2.release()