DPNP C++ backend kernel library 0.18.0dev1
Data Parallel Extension for NumPy*
Loading...
Searching...
No Matches
common_internal.hpp
1//*****************************************************************************
2// Copyright (c) 2024-2025, Intel Corporation
3// All rights reserved.
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are met:
7// - Redistributions of source code must retain the above copyright notice,
8// this list of conditions and the following disclaimer.
9// - Redistributions in binary form must reproduce the above copyright notice,
10// this list of conditions and the following disclaimer in the documentation
11// and/or other materials provided with the distribution.
12//
13// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23// THE POSSIBILITY OF SUCH DAMAGE.
24//*****************************************************************************
25
26#include "ext/common.hpp"
27#include "utils/type_dispatch.hpp"
28#include <pybind11/pybind11.h>
29
30namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
31
32namespace ext::common
33{
34inline size_t get_max_local_size(const sycl::device &device)
35{
36 constexpr const int default_max_cpu_local_size = 256;
37 constexpr const int default_max_gpu_local_size = 0;
38
39 return get_max_local_size(device, default_max_cpu_local_size,
40 default_max_gpu_local_size);
41}
42
43inline size_t get_max_local_size(const sycl::device &device,
44 int cpu_local_size_limit,
45 int gpu_local_size_limit)
46{
47 int max_work_group_size =
48 device.get_info<sycl::info::device::max_work_group_size>();
49 if (device.is_cpu() && cpu_local_size_limit > 0) {
50 return std::min(cpu_local_size_limit, max_work_group_size);
51 }
52 else if (device.is_gpu() && gpu_local_size_limit > 0) {
53 return std::min(gpu_local_size_limit, max_work_group_size);
54 }
55
56 return max_work_group_size;
57}
58
59inline sycl::nd_range<1>
60 make_ndrange(size_t global_size, size_t local_range, size_t work_per_item)
61{
62 return make_ndrange(sycl::range<1>(global_size),
63 sycl::range<1>(local_range),
64 sycl::range<1>(work_per_item));
65}
66
67inline size_t get_local_mem_size_in_bytes(const sycl::device &device)
68{
69 // Reserving 1kb for runtime needs
70 constexpr const size_t reserve = 1024;
71
72 return get_local_mem_size_in_bytes(device, reserve);
73}
74
75inline size_t get_local_mem_size_in_bytes(const sycl::device &device,
76 size_t reserve)
77{
78 size_t local_mem_size =
79 device.get_info<sycl::info::device::local_mem_size>();
80 return local_mem_size - reserve;
81}
82
83inline pybind11::dtype dtype_from_typenum(int dst_typenum)
84{
85 dpctl_td_ns::typenum_t dst_typenum_t =
86 static_cast<dpctl_td_ns::typenum_t>(dst_typenum);
87 switch (dst_typenum_t) {
88 case dpctl_td_ns::typenum_t::BOOL:
89 return py::dtype("?");
90 case dpctl_td_ns::typenum_t::INT8:
91 return py::dtype("i1");
92 case dpctl_td_ns::typenum_t::UINT8:
93 return py::dtype("u1");
94 case dpctl_td_ns::typenum_t::INT16:
95 return py::dtype("i2");
96 case dpctl_td_ns::typenum_t::UINT16:
97 return py::dtype("u2");
98 case dpctl_td_ns::typenum_t::INT32:
99 return py::dtype("i4");
100 case dpctl_td_ns::typenum_t::UINT32:
101 return py::dtype("u4");
102 case dpctl_td_ns::typenum_t::INT64:
103 return py::dtype("i8");
104 case dpctl_td_ns::typenum_t::UINT64:
105 return py::dtype("u8");
106 case dpctl_td_ns::typenum_t::HALF:
107 return py::dtype("f2");
108 case dpctl_td_ns::typenum_t::FLOAT:
109 return py::dtype("f4");
110 case dpctl_td_ns::typenum_t::DOUBLE:
111 return py::dtype("f8");
112 case dpctl_td_ns::typenum_t::CFLOAT:
113 return py::dtype("c8");
114 case dpctl_td_ns::typenum_t::CDOUBLE:
115 return py::dtype("c16");
116 default:
117 throw py::value_error("Unrecognized dst_typeid");
118 }
119}
120
121} // namespace ext::common