From 4c4e319ecfa46b7bc24adc3da957d6479dfde4b9 Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Wed, 2 Nov 2022 15:00:09 +0000 Subject: [PATCH] Fix overloads with `target="generic"` for CUDA As identified in #8271, the CUDA target needs to be set as the target at the bottom of the call stack, otherwise overloads for the generic target cannot be resolved. This is required so that the fix applied in #8562 (using the generic target for `ol_compatible_view` from #8537) actually works. --- numba/cuda/compiler.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/numba/cuda/compiler.py b/numba/cuda/compiler.py index 29dfcc76d40..b72a8c625b7 100644 --- a/numba/cuda/compiler.py +++ b/numba/cuda/compiler.py @@ -207,14 +207,16 @@ def compile_cuda(pyfunc, return_type, args, debug=False, lineinfo=False, flags.nvvm_options = nvvm_options # Run compilation pipeline - cres = compiler.compile_extra(typingctx=typingctx, - targetctx=targetctx, - func=pyfunc, - args=args, - return_type=return_type, - flags=flags, - locals={}, - pipeline_class=CUDACompiler) + from numba.core.target_extension import target_override + with target_override('cuda'): + cres = compiler.compile_extra(typingctx=typingctx, + targetctx=targetctx, + func=pyfunc, + args=args, + return_type=return_type, + flags=flags, + locals={}, + pipeline_class=CUDACompiler) library = cres.library library.finalize()