Kernel Programming

The tutorial covers the numba-dpex kernel programming API (kapi) and introduces the concepts needed to write data-parallel kernels in numba-dpex.

Core concepts

Writing a range kernel

A range kernel represents the simplest form of parallelism that can be expressed in numba-dpex using kapi. Such a kernel represents a data-parallel execution over a set of work-items with each work-item representing a logical thread of execution. Example: Vector addition using a range kernel shows an example of a range kernel written in numba-dpex.

Example: Vector addition using a range kernel
 1import dpnp
 2import numba_dpex as dpex
 3from numba_dpex import kernel_api as kapi
 4
 5
 6# Data parallel kernel implementing vector sum
 7@dpex.kernel
 8def vecadd(item: kapi.Item, a, b, c):
 9    i = item.get_id(0)
10    c[i] = a[i] + b[i]
11
12
13N = 1024
14a = dpnp.ones(N)
15b = dpnp.ones_like(a)
16c = dpnp.zeros_like(a)
17dpex.call_kernel(vecadd, kapi.Range(N), a, b, c)

The highlighted lines in the example demonstrate the definition of the execution range on line 17 and extraction of every work-items’ id or index position via the item.get_id call on line 10. An execution range comprising of 1024 work-items is defined when calling the kernel and each work-item then executes a single addition.

There are a few semantic rules that have to be adhered to when writing a range kernel:

  • Analogous to the API of SYCL a range kernel can execute only over a 1-, 2-, or a 3-dimensional set of work-items.

  • Every range kernel requires its first argument to be an instance of the numba_dpex.kernel_api.Item class. The Item object is an abstraction encapsulating the index position (id) of a single work-item in the global execution range. The id will be a 1-, 2-, or a 3-tuple depending the dimensionality of the execution range.

  • A range kernel cannot return any value.

    Note the rule is enforced only in the compiled mode and not in the pure Python execution on a kapi kernel.

  • A kernel can accept both array and scalar arguments. Array arguments currently can either be a dpnp.ndarray or a dpctl.tensor.usm_ndarray. Scalar values can be of any Python numeric type. Array arguments are passed by reference, i.e., changes to an array in a kernel are visible outside the kernel. Scalar values are always passed by value.

  • At least one argument of a kernel should be an array. The requirement is so that the kernel launcher (numba_dpex.core.kernel_launcher.call_kernel()) can determine the execution queue on which to launch the kernel. Refer to the Launching a kernel section for more details.

A range kernel has to be executed via the numba_dpex.core.kernel_launcher.call_kernel() function by passing in an instance of the numba_dpex.kernel_api.Range class. Refer to the Launching a kernel section for more details on how to launch a range kernel.

A range kernel is meant to express a basic parallel-for calculation that is ideally suited for embarrassingly parallel kernels such as element-wise computations over n-dimensional arrays (ndarrays). The API for expressing a range kernel does not allow advanced features such as synchronization of work-items and fine-grained control over memory allocation on a device. For such advanced features, an nd-range kernel should be used.

Writing an nd-range kernel

In a range kernel, the kernel execution is scheduled over a set of work-items without any explicit grouping of the work-items. The basic form of parallelism that can be expressed using a range kernel does not allow expressing any notion of locality within the kernel. To get around that limitation, kapi provides a second form of expressing a parallel kernel that is called an nd-range kernel. An nd-range kernel represents a data-parallel execution of the kernel by a set of explicitly defined groups of work-items. An individual group of work-items is called a work-group. Example: Sliding window matrix multiplication as an nd-range kernel demonstrates an nd-range kernel and some of the advanced features programmers can use in this type of kernel.

Example: Sliding window matrix multiplication as an nd-range kernel
 1from numba_dpex import kernel_api as kapi
 2import numba_dpex as dpex
 3import numpy as np
 4import dpctl.tensor as dpt
 5
 6square_block_side = 2
 7work_group_size = (square_block_side, square_block_side)
 8dtype = np.float32
 9
10
11@dpex.kernel
12def matmul(
13    nditem: kapi.NdItem,
14    X,  # IN READ-ONLY    (X_n_rows, n_cols)
15    y,  # IN READ-ONLY    (n_cols, y_n_rows),
16    X_slm,  # SLM to store a sliding window over X
17    Y_slm,  # SLM to store a sliding window over Y
18    result,  # OUT        (X_n_rows, y_n_rows)
19):
20    X_n_rows = X.shape[0]
21    Y_n_cols = y.shape[1]
22    n_cols = X.shape[1]
23
24    result_row_idx = nditem.get_global_id(0)
25    result_col_idx = nditem.get_global_id(1)
26
27    local_row_idx = nditem.get_local_id(0)
28    local_col_idx = nditem.get_local_id(1)
29
30    n_blocks_for_cols = n_cols // square_block_side
31    if (n_cols % square_block_side) > 0:
32        n_blocks_for_cols += 1
33
34    output = dtype(0)
35
36    gr = nditem.get_group()
37
38    for block_idx in range(n_blocks_for_cols):
39        X_slm[local_row_idx, local_col_idx] = dtype(0)
40        Y_slm[local_row_idx, local_col_idx] = dtype(0)
41        if (result_row_idx < X_n_rows) and (
42            (local_col_idx + (square_block_side * block_idx)) < n_cols
43        ):
44            X_slm[local_row_idx, local_col_idx] = X[
45                result_row_idx, local_col_idx + (square_block_side * block_idx)
46            ]
47
48        if (result_col_idx < Y_n_cols) and (
49            (local_row_idx + (square_block_side * block_idx)) < n_cols
50        ):
51            Y_slm[local_row_idx, local_col_idx] = y[
52                local_row_idx + (square_block_side * block_idx), result_col_idx
53            ]
54
55        kapi.group_barrier(gr)
56
57        for idx in range(square_block_side):
58            output += X_slm[local_row_idx, idx] * Y_slm[idx, local_col_idx]
59
60        kapi.group_barrier(gr)
61
62    if (result_row_idx < X_n_rows) and (result_col_idx < Y_n_cols):
63        result[result_row_idx, result_col_idx] = output
64
65
66def _arange_reshaped(shape, dtype):
67    n_items = shape[0] * shape[1]
68    return np.arange(n_items, dtype=dtype).reshape(shape)
69
70
71X = _arange_reshaped((5, 5), dtype)
72Y = _arange_reshaped((5, 5), dtype)
73X = dpt.asarray(X)
74Y = dpt.asarray(Y)
75device = X.device.sycl_device
76result = dpt.zeros((5, 5), dtype, device=device)
77X_slm = kapi.LocalAccessor(shape=work_group_size, dtype=dtype)
78Y_slm = kapi.LocalAccessor(shape=work_group_size, dtype=dtype)
79
80dpex.call_kernel(matmul, kapi.NdRange((6, 6), (2, 2)), X, Y, X_slm, Y_slm, result)

When writing an nd-range kernel, a programmer defines a set of groups of work-items instead of a flat execution range.There are several semantic rules associated both with a work-group and the work-items in a work-group:

  • Each work-group gets executed in an arbitrary order by the underlying runtime and programmers should not assume any implicit ordering.

  • Work-items in different wok-groups cannot communicate with each other except via atomic operations on global memory.

  • Work-items within a work-group share a common memory region called “shared local memory” (SLM). Depending on the device the SLM maybe mapped to a dedicated fast memory.

  • Work-items in a work-group can synchronize using a numba_dpex.kernel_api.group_barrier() operation that can additionally guarantee memory consistency using a work-group memory fence.

Note

The SYCL language provides additional features for work-items in a work-group such as group functions that specify communication routines across work-items and also implement patterns such as reduction and scan. These features are not yet available in numba-dpex.

An nd-range kernel needs to be launched with an instance of the numba_dpex.kernel_api.NdRange class and the first argument to an nd-range kernel has to be an instance of numba_dpex.kernel_api.NdItem. Apart from the need to provide an `NdItem parameter, the rest of the semantic rules that apply to a range kernel also apply to an nd-range kernel.

Launching a kernel

A kernel decorated kapi function produces a KernelDispatcher object that is a type of a Numba* Dispatcher object. However, unlike regular Numba* Dispatcher objects a KernelDispatcher object cannot be directly invoked from either CPython or another compiled Numba* jit function. To invoke a kernel decorated function, a programmer has to use the numba_dpex.core.kernel_launcher.call_kernel() function.

To invoke a KernelDispatcher the call_kernel function requires three things: the KernelDispatcher object, the Range or NdRange object over which the kernel is to be executed, and the list of arguments to be passed to the compiled kernel. Once called with the necessary arguments, the call_kernel function does the following main things:

  • Compiles the KernelDispatcher object specializing it for the provided argument types.

  • Unboxes the kernel arguments by converting CPython objects into Numba* or

    numba-dpex objects.

  • Infer the execution queue on which to submit the kernel from the provided kernel arguments. (TODO: Refer compute follows data.)

  • Submits the kernel to the execution queue.

  • Waits for the execution completion, before returning control back to the caller.

Important

Programmers should note the following two things when defining the global or local range to launch a kernel.

  • Numba-dpex currently limits the maximum allowed global range size to 2^31-1. It is due to the capabilities of current OpenCL GPU backends that generally do not support more than 32-bit global range sizes. A kernel requesting a larger global range than that will not execute and a dpctl._sycl_queue.SyclKernelSubmitError will get raised.

    The Intel dpcpp SYCL compiler does handle greater than 32-bit global ranges for GPU backends by wrapping the kernel in a new kernel that has each work-item perform multiple invocations of the original kernel in a 32-bit global range. Such a feature is not yet available in numba-dpex.

  • When launching an nd-range kernel, if the number of work-items for a particular dimension of a work-group exceeds the maximum device capability, it can result in undefined behavior.

The maximum allowed work-items for a device can be queried programmatically as shown in Example: Query maximum number of work-items for a device.

Example: Query maximum number of work-items for a device
 1import dpctl
 2import math
 3
 4d = dpctl.SyclDevice("gpu")
 5d.print_device_info()
 6
 7max_num_work_items = (
 8    d.max_work_group_size
 9    * d.max_work_item_sizes1d[0]
10    * d.max_work_item_sizes2d[0]
11    * d.max_work_item_sizes3d[0]
12)
13print(max_num_work_items, f"(2^{int(math.log(max_num_work_items, 2))})")
14
15cpud = dpctl.SyclDevice("cpu")
16cpud.print_device_info()
17
18max_num_work_items_cpu = (
19    cpud.max_work_group_size
20    * cpud.max_work_item_sizes1d[0]
21    * cpud.max_work_item_sizes2d[0]
22    * cpud.max_work_item_sizes3d[0]
23)
24print(max_num_work_items_cpu, f"(2^{int(math.log(max_num_work_items_cpu, 2))})")

The output for Example: Query maximum number of work-items for a device on a system with an Intel Gen9 integrated graphics processor and a 9th Generation Coffee Lake CPU is shown in OUTPUT: Query maximum number of work-items for a device.

OUTPUT: Query maximum number of work-items for a device
    Name            Intel(R) UHD Graphics 630 [0x3e98]
    Driver version  1.3.24595
    Vendor          Intel(R) Corporation
    Filter string   level_zero:gpu:0

4294967296 (2^32)
    Name            Intel(R) Core(TM) i7-9700 CPU @ 3.00GHz
    Driver version  2023.16.12.0.12_195853.xmain-hotfix
    Vendor          Intel(R) Corporation
    Filter string   opencl:cpu:0

4503599627370496 (2^52)

The call_kernel function can be invoked both from CPython and from another Numba* compiled function. Note that the call_kernel function supports only synchronous execution of kernel and the call_kernel_async function should be used for asynchronous mode of kernel execution (refer Async kernel execution).

See also

Refer the API documentation for numba_dpex.core.kernel_launcher.call_kernel() for more details.

The device_func decorator

Numba-dpex provides a decorator to express auxiliary device-only functions that can be called from a kernel or another device function, but are not callable from the host. This decorator numba_dpex.core.decorators.device_func() has no direct analogue in SYCL and primarily is provided to help programmers make their kapi applications modular. Example: Basic usage of device_func shows a simple usage of the device_func decorator.

Example: Basic usage of device_func
 1import dpnp
 2
 3import numba_dpex as dpex
 4from numba_dpex import kernel_api as kapi
 5
 6# Array size
 7N = 10
 8
 9
10@dpex.device_func
11def a_device_function(a):
12    """A device callable function that can be invoked from a kernel or
13    another device function.
14    """
15    return a + 1
16
17
18@dpex.kernel
19def a_kernel_function(item: kapi.Item, a, b):
20    """Demonstrates calling a device function from a kernel."""
21    i = item.get_id(0)
22    b[i] = a_device_function(a[i])
23
24
25N = 16
26a = dpnp.ones(N, dtype=dpnp.int32)
27b = dpnp.zeros(N, dtype=dpnp.int32)
28
29dpex.call_kernel(a_kernel_function, dpex.Range(N), a, b)
Example: Using kapi functionalities in a device_func
 1import dpnp
 2
 3import numba_dpex as dpex
 4from numba_dpex import kernel_api as kapi
 5
 6
 7@dpex.device_func
 8def increment_value(nd_item: kapi.NdItem, a):
 9    """Demonstrates the usage of group_barrier and NdItem usage in a
10    device_func.
11    """
12    i = nd_item.get_global_id(0)
13
14    a[i] += 1
15    kapi.group_barrier(nd_item.get_group(), kapi.MemoryScope.DEVICE)
16
17    if i == 0:
18        for idx in range(1, a.size):
19            a[0] += a[idx]
20
21
22@dpex.kernel
23def another_kernel(nd_item: kapi.NdItem, a):
24    """The kernel does everything by calling a device_func."""
25    increment_value(nd_item, a)
26
27
28N = 16
29b = dpnp.ones(N, dtype=dpnp.int32)
30
31dpex.call_kernel(another_kernel, dpex.NdRange((N,), (N,)), b)

A device function does not require the first argument to be an index space id class, and unlike a kernel function a device function is allowed to return a value. All kapi functionality can be used in a device_func decorated function and at compilation stage numba-dpex will attempt to inline a device_func into the kernel where it is used.

Supported types of kernel argument

A kapi kernel function can have both array and scalar arguments. At least one of the argument to every kernel function has to be an array. The requirement is enforced so that a execution queue can be inferred at the kernel launch stage. An array type argument is passed as a reference to the kernel and all scalar arguments are passed by value.

Supported array types

Scalar types

Scalar values can be passed to a kernel function either using the default Python scalar type or as explicit NumPy or dpnp data type objects. Example: Ways of defining a scalar kernel argument shows the two possible ways of defining a scalar type. In both scenarios, numba-dpex depends on the default Numba* type inferring algorithm to determine the LLVM IR type of a Python object that represents a scalar value. At the kernel submission stage the LLVM IR type is reinterpreted as a C++11 type to interoperate with the underlying SYCL runtime.

Example: Ways of defining a scalar kernel argument
import dpnp

a = 1
b = dpnp.dtype("int32").type(1)

print(type(a))
print(type(b))
Output: Ways of defining a scalar kernel argument
<class 'int'>
<class 'numpy.int32'>

The following scalar types are currently supported as arguments of a numba-dpex kernel function:

  • int

  • float

  • complex

  • numpy.int32

  • numpy.uint32

  • numpy.int64

  • numpy.uint32

  • numpy.float32

  • numpy.float64

Important

The Numba* type inferring algorithm by default infers a native Python scalar type to be a 64-bit value. The algorithm is defined that way to be consistent with the default CPython behavior. The default inferred 64-bit type can cause compilation failures on platforms that do not have native 64-bit floating point support. Another potential fallout of the default 64-bit type inference can be when a narrower width type is required by a specific kernel. To avoid these issues, users are advised to always use a dpnp/numpy type object to explicitly define the type of a scalar value.

DLPack support

At this time direct support for the DLPack protocol is has not been added to numba-dpex. To interoperate numba_dpex with other SYCL USM based libraries, users should first convert their input tensor or ndarray object into either of the two supported array types, both of which support DLPack.

Supported Python features

Mathematical operations

Scalar mathematical functions from the Python math module and the dpnp library can be used inside a kernel function. During compilation the mathematical functions get compiled into device-specific intrinsic instructions.

Current support matrix of math module functions

Name

Supported signature

math.isnan

types.float32(types.float32); types.float64(types.float64)

math.isinf

types.float32(types.float32); types.float64(types.float64)

math.ceil

types.float32(types.float32); types.float64(types.float64)

math.floor

types.float32(types.float32); types.float64(types.float64)

math.trunc

types.float32(types.float32); types.float64(types.float64)

math.fabs

types.float32(types.float32); types.float64(types.float64)

math.sqrt

types.float32(types.float32); types.float64(types.float64)

math.exp

types.float32(types.float32); types.float64(types.float64)

math.expm1

types.float32(types.float32); types.float64(types.float64)

math.log

types.float32(types.float32); types.float64(types.float64)

math.log10

types.float32(types.float32); types.float64(types.float64)

math.log1p

types.float32(types.float32); types.float64(types.float64)

math.sin

types.float32(types.float32); types.float64(types.float64)

math.cos

types.float32(types.float32); types.float64(types.float64)

math.tan

types.float32(types.float32); types.float64(types.float64)

math.asin

types.float32(types.float32); types.float64(types.float64)

math.acos

types.float32(types.float32); types.float64(types.float64)

math.atan

types.float32(types.float32); types.float64(types.float64)

math.sinh

types.float32(types.float32); types.float64(types.float64)

math.cosh

types.float32(types.float32); types.float64(types.float64)

math.tanh

types.float32(types.float32); types.float64(types.float64)

math.asinh

types.float32(types.float32); types.float64(types.float64)

math.acosh

types.float32(types.float32); types.float64(types.float64)

math.atanh

types.float32(types.float32); types.float64(types.float64)

math.exp2

types.float32(types.float32); types.float64(types.float64)

math.log2

types.float32(types.float32); types.float64(types.float64)

math.erf

types.float32(types.float32); types.float64(types.float64)

math.erfc

types.float32(types.float32); types.float64(types.float64)

math.gamma

types.float32(types.float32); types.float64(types.float64)

math.lgamma

types.float32(types.float32); types.float64(types.float64)

math.copysign

types.float32(types.float32, types.float32); types.float64(types.float64, types.float64)

math.atan2

types.float32(types.float32, types.float32); types.float64(types.float64, types.float64)

math.pow

types.float32(types.float32, types.float32); types.float64(types.float64, types.float64)

math.fmod

types.float32(types.float32, types.float32); types.float64(types.float64, types.float64)

math.ldexp

types.float32(types.float32, types.int32); types.float32(types.float32, types.int64); types.float64(types.float64, types.int32); types.float64(types.float64, types.int64)

math.hypot

types.float32(types.float32, types.int32); types.float32(types.float32, types.int64); types.float64(types.float64, types.int32); types.float64(types.float64, types.int64)

math.frexp

Not supported

math.ldexp

Not supported

math.trunc

Not supported

math.modf

Not supported

math.factorial

Not supported

math.fsum

Not supported

Caution

The supported signature for some of the math module functions in the compiled mode differs from CPython. The divergence in behavior is a known issue. Please refer https://github.com/IntelPython/numba-dpex/issues/759 for updates.

Current support matrix of dpnp functions

Name

Supported types

Notes

dpnp.add

types.float32 types.float64 types.int32 types.int64

dpnp.arctan2

types.float32 types.float64 types.int32 types.int64

Not supported on devices that lack FP64 support

dpnp.bitwise_and

types.int32 types.int64

dpnp.bitwise_or

types.int32 types.int64

dpnp.bitwise_xor

types.int32 types.int64

dpnp.copysign

types.float32 types.float64

dpnp.divide

types.float32 types.float64

dpnp.equal

types.float32 types.float64 types.int32 types.int64

dpnp.floor_divide

types.float32 types.float64 types.int32 types.int64

dpnp.fmax

types.float32 types.float64 types.int32 types.int64

dpnp.fmin

types.float32 types.float64 types.int32 types.int64

dpnp.fmod

types.float32 types.float64

dpnp.greater

types.float32 types.float64 types.int32 types.int64

dpnp.greater_equal

types.float32 types.float64 types.int32 types.int64

dpnp.hypot

types.float32 types.float64

dpnp.left_shift

types.int32 types.int64

dpnp.less

types.float32 types.float64 types.int32 types.int64

dpnp.less_equal

types.float32 types.float64 types.int32 types.int64

dpnp.logical_and

types.float32 types.float64 types.int32 types.int64

dpnp.logical_or

types.float32 types.float64 types.int32 types.int64

dpnp.logical_xor

types.float32 types.float64 types.int32 types.int64

dpnp.maximum

types.float32 types.float64 types.int32 types.int64

dpnp.minimum

types.float32 types.float64 types.int32 types.int64

dpnp.mod

types.int32 types.int64

dpnp.multiply

types.float32 types.float64 types.int32 types.int64

dpnp.not_equal

types.float32 types.float64 types.int32 types.int64

dpnp.power

types.float32 types.float64

dpnp.remainder

types.float32 types.float64 types.int32 types.int64

dpnp.right_shift

types.float32 types.float64

dpnp.subtract

types.float32 types.float64 types.int32 types.int64

dpnp.true_divide

types.float32 types.float64

dpnp.abs

types.float32 types.float64 types.int32 types.int64

dpnp.absolute

types.float32 types.float64 types.int32 types.int64

dpnp.arccos

types.float32 types.float64

dpnp.arccosh

types.float32 types.float64

Not supported on Intel Xe (Gen12) GPUs

dpnp.arcsin

types.float32 types.float64

dpnp.arcsinh

types.float32 types.float64

dpnp.arctan

types.float32 types.float64

dpnp.arctanh

types.float32 types.float64

dpnp.bitwise_not

types.int32 types.int64

dpnp.cbrt

N/A

Not supported

dpnp.ceil

types.float32 types.float64

dpnp.conjugate

types.float32 types.float64 types.int32 types.int64

dpnp.cos

types.float32 types.float64

dpnp.cosh

types.float32 types.float64

dpnp.deg2rad

types.float32 types.float64

dpnp.degrees

types.float32 types.float64

dpnp.erf

types.float32 types.float64

dpnp.exp

types.float32 types.float64

dpnp.exp2

types.float32 types.float64

dpnp.expm1

types.float32 types.float64

Not supported on Intel Xe (Gen12) GPUs

dpnp.fabs

types.float32 types.float64

dpnp.floor

types.float32 types.float64

dpnp.frexp

N/A

Not supported

dpnp.invert

types.int32 types.int64

dpnp.isfinite

types.float32 types.float64 types.int32 types.int64

dpnp.isinf

types.float32 types.float64 types.int32 types.int64

dpnp.isnan

types.float32 types.float64 types.int32 types.int64

dpnp.log

types.float32 types.float64

Not supported on Intel Xe (Gen12) GPUs

dpnp.log10

types.float32 types.float64

Not supported on Intel Xe (Gen12) GPUs

dpnp.log1p

types.float32 types.float64

dpnp.log2

types.float32 types.float64

Not supported on Intel Xe (Gen12) GPUs

dpnp.log2

N/A

Not supported

dpnp.logical_not

types.float32 types.float64 types.int32 types.int64

dpnp.logaddexp

N/A

Not supported

dpnp.logaddexp2

N/A

Not supported

dpnp.negative

types.float32 types.float64 types.int32 types.int64

dpnp.rad2deg

types.float32 types.float64

dpnp.radians

types.float32 types.float64

dpnp.reciprocal

types.float32 types.float64

dpnp.sign

types.float32 types.float64 types.int32 types.int64

Not supported on Intel Xe (Gen12) GPUs

dpnp.sin

types.float32 types.float64

dpnp.sinh

types.float32 types.float64

dpnp.sqrt

types.float32 types.float64

dpnp.square

types.float32 types.float64 types.int32 types.int64

dpnp.tan

types.float32 types.float64

dpnp.tanh

types.float32 types.float64

dpnp.trunc

types.float32 types.float64

Operators

List of supported Python operators that can be used in a kernel or device_func decorated function.

Current support matrix of Python operators

Name

Operator

Note

Addition

+

Multiplication

*

Subtraction

-

Division

/

Floor Division

//

Modulo

%

Exponent

**

In-place Addition

+=

In-place Subtraction

-=

In-place Division

/=

In-place Floor Division

//=

In-place Modulo

%=

In-place Exponent

**=

Only supported on OpenCL CPU devices

Bitwise And

&

Bitwise Left Shift

<<

Bitwise Right Shift

>>

Bitwise Or

|

Bitwise Exclusive Or

^

In-place Bitwise And

&=

In-place Bitwise Left Shift

<<=

In-place Bitwise Right Shift

>>=

In-place Bitwise Or

|=

In-place Bitwise Exclusive Or

^=

Negation

-

Complement

~

Pos

+

Less Than

<

Less Than Equal

<=

Greater Than

>

Greater Than Equal

>=

Equal To

==

Not Equal To

!=

Matmul

@

Not supported

In-place Matmul

@=

Not supported

General Python features

A kapi function when run in the purely interpreted mode by the CPython interpreter is a regular Python function, and as such in theory any Python feature can be used in the body of the function. In practice, to be JIT compilable and executable on a device only a subset of Python language features are supported in a kapi function. The restriction stems from both limitations in the Numba compiler tooling and also from the device-specific calling convention and other restrictions applied by a device’s ABI.

This section provides a partial support matrix for Python features with respect to their usage in a kapi function.

Built-in types

Supported Types

  • int

  • float

Unsupported Types

  • complex

  • bool

  • None

  • tuple

Built-in functions

The following built-in functions are supported:

  • abs()

  • float

  • int

  • len()

  • range()

  • round()

Unsupported Constructs

The following Python constructs are not supported:

  • Exception handling (try .. except, try .. finally)

  • Context management (the with statement)

  • Comprehensions (either list, dict, set or generator comprehensions)

  • Generator (any yield statements)

  • The raise statement

  • The assert statement

Advanced concepts

Local memory allocation

Private memory allocation

Group barrier synchronization

Atomic operations

Async kernel execution

Specializing a kernel or a device_func