/******************************************************************************* * * MIT License * * Copyright 2019-2025 AMD ROCm(TM) Software * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * *******************************************************************************/ #pragma once #include "DataTypes.hpp" namespace rocRoller { template std::string friendlyTypeName() { if constexpr(CHasTypeInfo) { return TypeInfo::Name(); } else { return typeName(); } } inline constexpr DataType getIntegerType(bool isSigned, int sizeBytes) { if(isSigned) { switch(sizeBytes) { case 1: return DataType::Int8; case 2: return DataType::Int16; case 4: return DataType::Int32; case 8: return DataType::Int64; } } else { switch(sizeBytes) { case 1: return DataType::UInt8; case 2: return DataType::UInt16; case 4: return DataType::UInt32; case 8: return DataType::UInt64; } } auto prefix = isSigned ? "signed" : "unsigned"; Throw( "No enumeration for ", prefix, " integer with size ", sizeBytes, " bytes."); // cppcheck doesn't seem to notice that Throw<>() is marked [[noreturn]] so it will // complain if this isn't here. return DataType::None; } template <> inline DataType fromString(std::string const& str) { using myInt = std::underlying_type_t; auto maxValue = static_cast(DataType::Count); for(myInt i = 0; i < maxValue; ++i) { auto val = static_cast(i); std::string testStr = toString(val); if(std::equal( str.begin(), str.end(), testStr.begin(), testStr.end(), [](auto a, auto b) { return std::tolower(a) == std::tolower(b); })) return val; } // Special cases std::string strCopy = str; std::transform(strCopy.begin(), strCopy.end(), strCopy.begin(), ::tolower); if(strCopy == "fp16") { return DataType::Half; } if(strCopy == "bf16") { return DataType::BFloat16; } Throw( "Invalid fromString: type name: ", typeName(), ", string input: ", str); // Unreachable code return DataType::None; } inline constexpr VariableType::VariableType() : dataType() { } inline constexpr VariableType::VariableType(VariableType const& v) : dataType(v.dataType) , pointerType(v.pointerType) { } inline constexpr VariableType::VariableType(DataType d) : dataType(d) , pointerType(PointerType::Value) { } inline constexpr VariableType::VariableType(DataType d, PointerType p) : dataType(d) , pointerType(p) { } inline constexpr VariableType::VariableType(PointerType p) : dataType() , pointerType(p) { } inline bool VariableType::isPointer() const { return pointerType != PointerType::Value; } inline bool VariableType::isGlobalPointer() const { return pointerType == PointerType::PointerGlobal; } inline VariableType VariableType::getDereferencedType() const { return VariableType(dataType); } inline VariableType VariableType::getPointer() const { AssertFatal(pointerType == PointerType::Value, ShowValue(pointerType)); return VariableType(dataType, PointerType::PointerGlobal); } inline DataType VariableType::getArithmeticType() const { if(pointerType == PointerType::Value) return dataType; return getIntegerType(false, getElementSize()); } inline bool CompareVariableTypesPointersEqual::operator()(VariableType const& lhs, VariableType const& rhs) const { if(lhs.pointerType < rhs.pointerType) return true; if(lhs.pointerType == rhs.pointerType && lhs.pointerType == PointerType::Value) return lhs.dataType < rhs.dataType; return false; } template constexpr VariableType BaseTypeInfo::Var; template constexpr VariableType BaseTypeInfo::SegmentVariableType; template constexpr size_t BaseTypeInfo::ElementBytes; template constexpr size_t BaseTypeInfo::ElementBits; template constexpr size_t BaseTypeInfo::Packing; template constexpr size_t BaseTypeInfo::RegisterCount; template constexpr bool BaseTypeInfo::IsComplex; template constexpr bool BaseTypeInfo::IsIntegral; #define DeclareDefaultValueTypeInfo(dtype, enumVal) \ template <> \ struct TypeInfo : public BaseTypeInfo, \ std::is_signed_v> \ { \ } DeclareDefaultValueTypeInfo(float, Float); DeclareDefaultValueTypeInfo(int8_t, Int8); DeclareDefaultValueTypeInfo(int16_t, Int16); DeclareDefaultValueTypeInfo(int32_t, Int32); DeclareDefaultValueTypeInfo(uint8_t, UInt8); DeclareDefaultValueTypeInfo(uint16_t, UInt16); DeclareDefaultValueTypeInfo(uint32_t, UInt32); #undef DeclareDefaultValueTypeInfo template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo> : public BaseTypeInfo, DataType::ComplexFloat, DataType::ComplexFloat, PointerType::Value, 1, 2, 64, true, false, true> { }; template <> struct TypeInfo> : public BaseTypeInfo, DataType::ComplexDouble, DataType::ComplexDouble, PointerType::Value, 1, 4, 128, true, false, true> { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; template <> struct TypeInfo : public BaseTypeInfo { }; #define DeclareEnumTypeInfo(typeEnum, dtype) \ template <> \ struct EnumTypeInfo : public TypeInfo \ { \ } DeclareEnumTypeInfo(Float, float); DeclareEnumTypeInfo(Double, double); DeclareEnumTypeInfo(ComplexFloat, std::complex); DeclareEnumTypeInfo(ComplexDouble, std::complex); DeclareEnumTypeInfo(Half, Half); DeclareEnumTypeInfo(Halfx2, Halfx2); DeclareEnumTypeInfo(FP8, FP8); DeclareEnumTypeInfo(FP8x4, FP8x4); DeclareEnumTypeInfo(BF8, BF8); DeclareEnumTypeInfo(BF8x4, BF8x4); DeclareEnumTypeInfo(FP6, FP6); DeclareEnumTypeInfo(FP6x16, FP6x16); DeclareEnumTypeInfo(BF6, BF6); DeclareEnumTypeInfo(BF6x16, BF6x16); DeclareEnumTypeInfo(FP4, FP4); DeclareEnumTypeInfo(FP4x8, FP4x8); DeclareEnumTypeInfo(Int8x4, Int8x4); DeclareEnumTypeInfo(Int8, int8_t); DeclareEnumTypeInfo(Int16, int16_t); DeclareEnumTypeInfo(Int32, int32_t); DeclareEnumTypeInfo(Int64, int64_t); DeclareEnumTypeInfo(BFloat16, BFloat16); DeclareEnumTypeInfo(BFloat16x2, BFloat16x2); DeclareEnumTypeInfo(Raw32, Raw32); DeclareEnumTypeInfo(UInt8x4, UInt8x4); DeclareEnumTypeInfo(UInt8, uint8_t); DeclareEnumTypeInfo(UInt16, uint16_t); DeclareEnumTypeInfo(UInt32, uint32_t); DeclareEnumTypeInfo(UInt64, uint64_t); DeclareEnumTypeInfo(Bool, bool); DeclareEnumTypeInfo(Bool32, Bool32); DeclareEnumTypeInfo(Bool64, Bool64); DeclareEnumTypeInfo(E8M0, E8M0); DeclareEnumTypeInfo(E8M0x4, E8M0x4); #undef DeclareEnumTypeInfo }