Source code for dpctl.program.utils._utils
# Data Parallel Control (dpctl)
#
# Copyright 2026 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements various utilities for the dpctl.program module."""
from dataclasses import dataclass
from enum import IntEnum
import numpy as np
# these constants come from the SPIR-V spec:
# https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html
class SpirvOpCode(IntEnum):
OpName = 5
OpTypeBool = 20
OpTypeInt = 21
OpTypeFloat = 22
OpSpecConstantTrue = 48
OpSpecConstantFalse = 49
OpSpecConstant = 50
OpFunction = 54
OpDecorate = 71
class SpirvDecoration(IntEnum):
SpecId = 1
[docs]@dataclass(frozen=True)
class SpecializationConstantInfo:
"""Data class representing specialization constant information."""
spec_id: int
dtype: str
name: str
itemsize: int
default_value: int | float | bool | None
[docs]def parse_spirv_specializations(
spv_bytes: bytes | bytearray | memoryview,
) -> tuple[SpecializationConstantInfo]:
"""
Parses SPIR-V byte stream to extract information about specializations,
including the specialization IDs, types, names, and default values.
Note that the dtype information may be imprecise, as the compiler may
choose to, for example, represent a bool as char, or may represent both
signed and unsigned integers as unsigned integer bit buckets of the same
length.
Args:
spv_bytes (bytes | bytearray | memoryview):
the SPIR-V byte stream.
Returns:
tuple[SpecializationConstantInfo]:
a tuple of parsed constants and their information represented by
`SpecializationConstantInfo` objects, sorted by their
specialization IDs. The length of the tuple is equal to the number
of specialization constants found. Each
`SpecializationConstantInfo` object contains the following
attributes:
- `spec_id` (int): The specialization ID.
- `dtype` (str): A NumPy style string representing the data type.
- `itemsize` (int): The size of the specialization constant in
bytes.
- `name` (str): The variable name. If not preserved in the binary,
a default name in the format `unnamed_spec_const_{spec_id}` is
used.
- `default_value` (int | float | bool | None): The default value of
the specialization constant. If not specified, `None` is used.
"""
words = np.frombuffer(spv_bytes, dtype=np.uint32)
# verify magic number
if len(words) < 5 or words[0] != 0x07230203:
raise ValueError("Invalid SPIR-V binary")
types = {}
ids = {}
names = {}
constants = {}
defaults = {}
i = 5 # skip 5 word header
while i < len(words):
word = words[i]
opcode = word & 0xFFFF
word_count = word >> 16
if word_count == 0:
raise ValueError(f"Invalid SPIR-V instruction at word index {i}")
if i + word_count > len(words):
raise ValueError(
f"Invalid SPIR-V instruction at offset {i} (extends beyond "
"buffer)"
)
if opcode == SpirvOpCode.OpFunction:
# everything following is not relevant to specialization constant
# parsing, so we can stop parsing at this point
break
elif opcode == SpirvOpCode.OpTypeBool:
result_id = int(words[i + 1])
types[result_id] = {"dtype": "?", "itemsize": 1}
elif opcode == SpirvOpCode.OpTypeInt:
result_id = int(words[i + 1])
width = int(words[i + 2])
signed = int(words[i + 3])
prefix = "i" if signed else "u"
types[result_id] = {
"dtype": f"{prefix}{width // 8}",
"itemsize": width // 8,
}
elif opcode == SpirvOpCode.OpTypeFloat:
result_id = int(words[i + 1])
width = int(words[i + 2])
types[result_id] = {
"dtype": f"f{width // 8}",
"itemsize": width // 8,
}
elif opcode == SpirvOpCode.OpSpecConstant:
type_id = int(words[i + 1])
result_id = int(words[i + 2])
constants[result_id] = type_id
literal_words = words[i + 3 : i + word_count]
defaults[result_id] = literal_words.tobytes()
elif opcode == SpirvOpCode.OpSpecConstantTrue:
type_id = int(words[i + 1])
result_id = int(words[i + 2])
constants[result_id] = type_id
defaults[result_id] = True
elif opcode == SpirvOpCode.OpSpecConstantFalse:
type_id = int(words[i + 1])
result_id = int(words[i + 2])
constants[result_id] = type_id
defaults[result_id] = False
elif opcode == SpirvOpCode.OpDecorate:
target_id = int(words[i + 1])
decoration = int(words[i + 2])
if decoration == SpirvDecoration.SpecId:
ids[target_id] = int(words[i + 3])
elif opcode == SpirvOpCode.OpName:
target_id = int(words[i + 1])
name_bytes = words[i + 2 : i + word_count].tobytes()
names[target_id] = name_bytes.split(b"\x00", 1)[0].decode("utf-8")
i += word_count
# a spec ID may appear multiple times in the same binary with different
# target IDs. We only need to keep one, so skip duplicates
unique_ids = set()
result = []
for target_id, spec_id in ids.items():
if spec_id in unique_ids:
continue
unique_ids.add(spec_id)
type_id = constants.get(target_id)
type_info = types.get(type_id, {"dtype": "unknown_type", "itemsize": 0})
name = names.get(target_id, f"unnamed_spec_const_{spec_id}")
dtype_str = type_info["dtype"]
raw_default = defaults.get(target_id)
default_value = None
if isinstance(raw_default, bool):
default_value = raw_default
elif isinstance(raw_default, bytes) and dtype_str != "unknown_type":
try:
default_value = np.frombuffer(raw_default, dtype=dtype_str)[
0
].item()
except (ValueError, TypeError):
default_value = None
result.append(
SpecializationConstantInfo(
spec_id=spec_id,
dtype=dtype_str,
name=name,
itemsize=type_info["itemsize"],
default_value=default_value,
)
)
return tuple(sorted(result, key=lambda x: x.spec_id))