/******************************************************************************* * * MIT License * * Copyright 2021-2025 AMD ROCm(TM) Software * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * *******************************************************************************/ #pragma once #include #include #include #include #include #include #include namespace rocRoller { namespace Expression { inline std::string toString(EvaluationTime t) { switch(t) { case EvaluationTime::Translate: return "Translate"; case EvaluationTime::KernelLaunch: return "KernelLaunch"; case EvaluationTime::KernelExecute: return "KernelExecute"; default: break; } Throw("Invalid EvaluationTime"); } inline std::ostream& operator<<(std::ostream& stream, EvaluationTime const& t) { return stream << toString(t); } inline std::string toString(AlgebraicProperty p) { switch(p) { case AlgebraicProperty::Commutative: return "Commutative"; case AlgebraicProperty::Associative: return "Associative"; default: break; } Throw("Invalid AlgebraicProperty"); } inline std::ostream& operator<<(std::ostream& stream, AlgebraicProperty const& p) { return stream << toString(p); } inline std::string toString(Category c) { switch(c) { case Category::Arithmetic: return "Arithmetic"; case Category::Comparison: return "Comparison"; case Category::Logical: return "Logical"; case Category::Conversion: return "Conversion"; case Category::Value: return "Value"; default: break; } Throw("Invalid Category"); } inline std::ostream& operator<<(std::ostream& stream, Category const& c) { return stream << toString(c); } inline bool isRaw32Literal(Expression const& expr) { return std::visit( [](auto&& arg) { using T = std::decay_t; if constexpr(std::is_same_v) { return std::holds_alternative(arg); } else return false; }, expr); } inline bool isRaw32Literal(ExpressionPtr const& exprPtr) { if(not exprPtr) return false; return isRaw32Literal(*exprPtr); } inline ExpressionPtr operator+(ExpressionPtr a, ExpressionPtr b) { AssertFatal(!isRaw32Literal(a) and !isRaw32Literal(b), "Raw32 is a bit type and cannot be used in arithmetic (+) operation: ", ShowValue(a), ShowValue(b)); return std::make_shared(Add{a, b}); } inline ExpressionPtr operator-(ExpressionPtr a, ExpressionPtr b) { AssertFatal(!isRaw32Literal(a) and !isRaw32Literal(b), "Raw32 is a bit type and cannot be used in arithmetic (-) operation: ", ShowValue(a), ShowValue(b)); return std::make_shared(Subtract{a, b}); } inline ExpressionPtr operator*(ExpressionPtr a, ExpressionPtr b) { AssertFatal(!isRaw32Literal(a) and !isRaw32Literal(b), "Raw32 is a bit type and cannot be used in arithmetic (*) operation: ", ShowValue(a), ShowValue(b)); return std::make_shared(Multiply{a, b}); } inline ExpressionPtr operator/(ExpressionPtr a, ExpressionPtr b) { AssertFatal(!isRaw32Literal(a) and !isRaw32Literal(b), "Raw32 is a bit type and cannot be used in arithmetic (/) operation: ", ShowValue(a), ShowValue(b)); return std::make_shared(Divide{a, b}); } inline ExpressionPtr operator%(ExpressionPtr a, ExpressionPtr b) { AssertFatal(!isRaw32Literal(a) and !isRaw32Literal(b), "Raw32 is a bit type and cannot be used in arithmetic (%) operation: ", ShowValue(a), ShowValue(b)); return std::make_shared(Modulo{a, b}); } inline ExpressionPtr operator<<(ExpressionPtr a, ExpressionPtr b) { AssertFatal(!isRaw32Literal(b), "Raw32 is a bit type and cannot be used as RHS in <<: ", ShowValue(b)); return std::make_shared(ShiftL{a, b}); } inline ExpressionPtr operator>>(ExpressionPtr a, ExpressionPtr b) { AssertFatal(!isRaw32Literal(b), "Raw32 is a bit type and cannot be used as RHS in >>: ", ShowValue(b)); return std::make_shared(ArithmeticShiftR{a, b}); } inline ExpressionPtr arithmeticShiftR(ExpressionPtr a, ExpressionPtr b) { AssertFatal(!isRaw32Literal(b), "Raw32 is a bit type and cannot be used as RHS in ArithmeticShiftR: ", ShowValue(b)); return std::make_shared(ArithmeticShiftR{a, b}); } inline ExpressionPtr logicalShiftR(ExpressionPtr a, ExpressionPtr b) { AssertFatal(!isRaw32Literal(b), "Raw32 is a bit type and cannot be used as RHS in logicalShiftR: ", ShowValue(b)); return std::make_shared(LogicalShiftR{a, b}); } inline ExpressionPtr operator&(ExpressionPtr a, ExpressionPtr b) { return std::make_shared(BitwiseAnd{a, b}); } inline ExpressionPtr operator|(ExpressionPtr a, ExpressionPtr b) { return std::make_shared(BitwiseOr{a, b}); } inline ExpressionPtr operator^(ExpressionPtr a, ExpressionPtr b) { return std::make_shared(BitwiseXor{a, b}); } inline ExpressionPtr operator>(ExpressionPtr a, ExpressionPtr b) { AssertFatal(!isRaw32Literal(a) and !isRaw32Literal(b), "Raw32 is a bit type and cannot be used for >: ", ShowValue(a), ShowValue(b)); return std::make_shared(GreaterThan{a, b}); } inline ExpressionPtr operator>=(ExpressionPtr a, ExpressionPtr b) { AssertFatal(!isRaw32Literal(a) and !isRaw32Literal(b), "Raw32 is a bit type and cannot be used for >=: ", ShowValue(a), ShowValue(b)); return std::make_shared(GreaterThanEqual{a, b}); } inline ExpressionPtr operator<(ExpressionPtr a, ExpressionPtr b) { AssertFatal(!isRaw32Literal(a) and !isRaw32Literal(b), "Raw32 is a bit type and cannot be used for <: ", ShowValue(a), ShowValue(b)); return std::make_shared(LessThan{a, b}); } inline ExpressionPtr operator<=(ExpressionPtr a, ExpressionPtr b) { AssertFatal(!isRaw32Literal(a) and !isRaw32Literal(b), "Raw32 is a bit type and cannot be used for <=: ", ShowValue(a), ShowValue(b)); return std::make_shared(LessThanEqual{a, b}); } inline ExpressionPtr operator==(ExpressionPtr a, ExpressionPtr b) { // Either both are Raw32 or both are not AssertFatal((isRaw32Literal(a) and isRaw32Literal(b)) || (!isRaw32Literal(a) and !isRaw32Literal(b)), "Cannot compare Raw32 with other types: ", ShowValue(a), ShowValue(b)); return std::make_shared(Equal{a, b}); } inline ExpressionPtr operator!=(ExpressionPtr a, ExpressionPtr b) { // Either both are Raw32 or both are not AssertFatal((isRaw32Literal(a) and isRaw32Literal(b)) || (!isRaw32Literal(a) and !isRaw32Literal(b)), "Cannot compare Raw32 with other types: ", ShowValue(a), ShowValue(b)); return std::make_shared(NotEqual{a, b}); } inline ExpressionPtr operator&&(ExpressionPtr a, ExpressionPtr b) { AssertFatal(!isRaw32Literal(a) and !isRaw32Literal(b), "Raw32 is a bit type and cannot be used for LogicalAnd: ", ShowValue(a), ShowValue(b)); return std::make_shared(LogicalAnd{a, b}); } inline ExpressionPtr operator||(ExpressionPtr a, ExpressionPtr b) { AssertFatal(!isRaw32Literal(a) and !isRaw32Literal(b), "Raw32 is a bit type and cannot be used for LogicalOr: ", ShowValue(a), ShowValue(b)); return std::make_shared(LogicalOr{a, b}); } inline ExpressionPtr logicalNot(ExpressionPtr a) { AssertFatal(!isRaw32Literal(a), "Raw32 is a bit type and cannot be used for LogicalNot: ", ShowValue(a)); return std::make_shared(LogicalNot{a}); } inline ExpressionPtr operator-(ExpressionPtr a) { AssertFatal( !isRaw32Literal(a), "Raw32 is a bit type and cannot be used for -: ", ShowValue(a)); return std::make_shared(Negate{a}); } inline ExpressionPtr operator~(ExpressionPtr a) { return std::make_shared(BitwiseNegate{a}); } inline ExpressionPtr exp2(ExpressionPtr a) { AssertFatal(!isRaw32Literal(a), "Raw32 is a bit type and cannot be used for exp2: ", ShowValue(a)); return std::make_shared(Exponential2{a}); } inline ExpressionPtr exp(ExpressionPtr a) { AssertFatal(!isRaw32Literal(a), "Raw32 is a bit type and cannot be used for exp: ", ShowValue(a)); return std::make_shared(Exponential{a}); } inline static bool convertibleTo(DataType dt) { return dt == DataType::Half || dt == DataType::Halfx2 || dt == DataType::BFloat16 || dt == DataType::BFloat16x2 || dt == DataType::FP8 || dt == DataType::BF8 || dt == DataType::FP8x4 || dt == DataType::BF8x4 || dt == DataType::Float || dt == DataType::FP6x16 || dt == DataType::BF6x16 || dt == DataType::FP4x8 || dt == DataType::Double || dt == DataType::Int32 || dt == DataType::Int64 || dt == DataType::UInt32 || dt == DataType::UInt64 || dt == DataType::Bool || dt == DataType::Bool32 || dt == DataType::Bool64; } inline ExpressionPtr convert(DataType dt, ExpressionPtr a) { // Couldn't use previous impl. of CExpression // (aka. std::constructible_from) because Convert // is not a type anymore. Convert is no longer templated. // Either a runtime-check or other templatized function call is needed // to make sure valid destination types for a convert. // Currently, explicit runtime check is used. (Notice that checking the range // of dt doesn't work because target type of convert is not consecutively // laid in DataType enum.) if(!convertibleTo(dt)) Throw("Expression - Unsupported datatype conversion: ", ShowValue(dt)); return std::make_shared(Convert{{.arg{a}}, dt}); } inline ExpressionPtr convert(VariableType vt, ExpressionPtr a) { AssertFatal(!vt.isPointer(), "Convert to pointer type not supported.", ShowValue(vt)); return convert(vt.dataType, a); } template inline ExpressionPtr convert(ExpressionPtr a) { return convert(DATATYPE, a); } template inline ExpressionPtr literal(T value) { return std::make_shared(value); } inline ExpressionPtr literal(Buffer value) { std::vector operands{literal(value.desc0), literal(value.desc1), literal(value.desc2), literal(value.desc3)}; return std::make_shared( Concatenate{{operands}, {DataType::None, PointerType::Buffer}}); } template ExpressionPtr literal(T value, VariableType v) { AssertFatal(v.pointerType == PointerType::Value); switch(v.dataType) { case DataType::Int32: return literal(value); case DataType::UInt32: return literal(value); case DataType::Int64: return literal(value); case DataType::UInt64: return literal(value); case DataType::Bool: return literal(value); case DataType::Half: return literal(static_cast(value)); case DataType::Float: return literal(value); case DataType::Double: return literal(value); case DataType::Raw32: return literal(Raw32(static_cast(value))); default: Throw( "Unsupported datatype ", v.dataType, " provided to Expression::literal"); } } static_assert(CExpression); static_assert(!CExpression, "ValuePtr can be an Expression but Value cannot."); template struct ExpressionInfo { }; #define EXPRESSION_INFO_CUSTOM(cls, cls_name) \ template <> \ struct ExpressionInfo \ { \ constexpr static auto name() \ { \ return cls_name; \ } \ } #define EXPRESSION_INFO(cls) EXPRESSION_INFO_CUSTOM(cls, #cls) EXPRESSION_INFO(Add); EXPRESSION_INFO(Subtract); EXPRESSION_INFO(MatrixMultiply); EXPRESSION_INFO(ScaledMatrixMultiply); EXPRESSION_INFO(Multiply); EXPRESSION_INFO(MultiplyAdd); EXPRESSION_INFO(MultiplyHigh); EXPRESSION_INFO(Divide); EXPRESSION_INFO(Modulo); EXPRESSION_INFO(ShiftL); EXPRESSION_INFO(LogicalShiftR); EXPRESSION_INFO(ArithmeticShiftR); EXPRESSION_INFO(BitfieldCombine); EXPRESSION_INFO(BitwiseNegate); EXPRESSION_INFO(BitwiseAnd); EXPRESSION_INFO(BitwiseOr); EXPRESSION_INFO(BitwiseXor); EXPRESSION_INFO(Exponential2); EXPRESSION_INFO(Exponential); EXPRESSION_INFO(ShiftLAdd); EXPRESSION_INFO(AddShiftL); EXPRESSION_INFO(Conditional); EXPRESSION_INFO(GreaterThan); EXPRESSION_INFO(GreaterThanEqual); EXPRESSION_INFO(LessThan); EXPRESSION_INFO(LessThanEqual); EXPRESSION_INFO(Equal); EXPRESSION_INFO(NotEqual); EXPRESSION_INFO(LogicalAnd); EXPRESSION_INFO(LogicalOr); EXPRESSION_INFO(LogicalNot); EXPRESSION_INFO(MagicMultiple); EXPRESSION_INFO(MagicShifts); EXPRESSION_INFO(MagicShiftAndSign); EXPRESSION_INFO(Negate); EXPRESSION_INFO(RandomNumber); EXPRESSION_INFO(ToScalar); EXPRESSION_INFO(BitFieldExtract); EXPRESSION_INFO(Convert); EXPRESSION_INFO(Concatenate); EXPRESSION_INFO_CUSTOM(SRConvert, "SRConvert_FP8"); EXPRESSION_INFO_CUSTOM(SRConvert, "SRConvert_BF8"); EXPRESSION_INFO_CUSTOM(Register::ValuePtr, "RegisterValue"); EXPRESSION_INFO_CUSTOM(CommandArgumentPtr, "CommandArgument"); EXPRESSION_INFO_CUSTOM(CommandArgumentValue, "LiteralValue"); EXPRESSION_INFO_CUSTOM(AssemblyKernelArgumentPtr, "Kernel Argument"); EXPRESSION_INFO_CUSTOM(WaveTilePtr, "WaveTile"); EXPRESSION_INFO(DataFlowTag); EXPRESSION_INFO(PositionalArgument); #undef EXPRESSION_INFO #undef EXPRESSION_INFO_CUSTOM struct ExpressionNameVisitor { template std::string operator()(Expr const& expr) const { return ExpressionInfo::name(); } std::string call(Expression const& expr) const { return std::visit(*this, expr); } std::string call(ExpressionPtr const& expr) const { return call(*expr); } }; inline std::string name(ExpressionPtr const& expr) { return ExpressionNameVisitor().call(expr); } inline std::string name(Expression const& expr) { return ExpressionNameVisitor().call(expr); } struct ExpressionArgumentNameVisitor { template std::string operator()(Expr const& expr) const { return ExpressionInfo::name(); } std::string operator()(CommandArgumentPtr const& expr) const { if(expr) return expr->name(); return ExpressionInfo::name(); } std::string call(Expression const& expr) const { return std::visit(*this, expr); } std::string call(ExpressionPtr const& expr) const { return call(*expr); } }; inline std::string argumentName(ExpressionPtr const& expr) { return ExpressionArgumentNameVisitor().call(expr); } inline std::string argumentName(Expression const& expr) { return ExpressionArgumentNameVisitor().call(expr); } struct ExpressionEvaluationTimesVisitor { EvaluationTimes operator()(WaveTilePtr const& expr) const { return {EvaluationTime::KernelExecute}; } EvaluationTimes operator()(ScaledMatrixMultiply const& expr) const { auto matA = call(expr.matA); auto matB = call(expr.matB); auto matC = call(expr.matC); auto scaleA = call(expr.scaleA); auto scaleB = call(expr.scaleB); return matA & matB & matC & scaleA & scaleB & ScaledMatrixMultiply::EvalTimes; } template EvaluationTimes operator()(Expr const& expr) const { EvaluationTimes result = Expr::EvalTimes; for(auto const& operand : expr.operands) { result = result & call(operand); } return result; } template EvaluationTimes operator()(Expr const& expr) const { auto lhs = call(expr.lhs); auto r1hs = call(expr.r1hs); auto r2hs = call(expr.r2hs); return lhs & r1hs & r2hs & Expr::EvalTimes; } template EvaluationTimes operator()(Expr const& expr) const { auto lhs = call(expr.lhs); auto rhs = call(expr.rhs); return lhs & rhs & Expr::EvalTimes; } template EvaluationTimes operator()(Expr const& expr) const { return call(expr.arg) & Expr::EvalTimes; } EvaluationTimes operator()(Register::ValuePtr const& expr) const { if(expr->regType() == Register::Type::Literal) return EvaluationTimes::All(); return {EvaluationTime::KernelExecute}; } constexpr EvaluationTimes operator()(AssemblyKernelArgumentPtr const& expr) const { return {EvaluationTime::KernelLaunch, EvaluationTime::KernelExecute}; } constexpr EvaluationTimes operator()(DataFlowTag const& expr) const { return {EvaluationTime::KernelExecute}; } constexpr EvaluationTimes operator()(PositionalArgument const& expr) const { return {EvaluationTime::Translate}; } constexpr EvaluationTimes operator()(CommandArgumentPtr const& expr) const { return {EvaluationTime::KernelLaunch}; } constexpr EvaluationTimes operator()(CommandArgumentValue const& expr) const { return EvaluationTimes::All(); } EvaluationTimes call(Expression const& expr) const { return std::visit(*this, expr); } EvaluationTimes call(ExpressionPtr const& expr) const { if(expr == nullptr) { return EvaluationTimes::All(); } return call(*expr); } }; inline EvaluationTimes evaluationTimes(Expression const& expr) { return ExpressionEvaluationTimesVisitor().call(expr); } inline EvaluationTimes evaluationTimes(ExpressionPtr const& expr) { return ExpressionEvaluationTimesVisitor().call(expr); } template requires(CUnary || CBinary || CTernary) auto split(ExpressionPtr expr) { AssertFatal(expr && std::holds_alternative(*expr), "Expression does not hold the correct type"); auto exp = std::get(*expr); if constexpr(CUnary) { return std::make_tuple(exp.arg); } else if constexpr(CBinary) { return std::make_tuple(exp.lhs, exp.rhs); } else if constexpr(CTernary) { return std::make_tuple(exp.lhs, exp.r1hs, exp.r2hs); } } } }