/******************************************************************************* * * MIT License * * Copyright 2021-2025 AMD ROCm(TM) Software * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * *******************************************************************************/ #pragma once #include #include #include #include #include #include #include #include namespace rocRoller { namespace Expression { std::string toString(EvaluationTime t); std::ostream& operator<<(std::ostream&, EvaluationTime const&); using EvaluationTimes = EnumBitset; enum class AlgebraicProperty : int { Commutative = 0, Associative, Count }; std::string toString(AlgebraicProperty t); std::ostream& operator<<(std::ostream&, AlgebraicProperty const&); using AlgebraicProperties = EnumBitset; enum class Category : int { Arithmetic = 0, Comparison, Logical, Conversion, Value, Count }; std::string toString(Category c); std::ostream& operator<<(std::ostream&, Category const&); // Expression: type alias for std::variant of all expression subtypes. // Defined in Expression_fwd.hpp. struct Binary { ExpressionPtr lhs, rhs; std::string comment = ""; template requires std::derived_from inline T& copyParams(const T& other) { return static_cast(*this); } }; template concept CBinary = requires { requires std::derived_from; }; // Complexity is a heuristic that estimates the relative cost of computing different // expressions. See the KernelOption minLaunchTimeExpressionComplexity for a more // in-depth description. struct Add : Binary { constexpr static inline auto Type = Category::Arithmetic; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline AlgebraicProperties Properties{AlgebraicProperty::Associative, AlgebraicProperty::Commutative}; constexpr static inline int Complexity = 2; }; struct Subtract : Binary { constexpr static inline auto Type = Category::Arithmetic; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline AlgebraicProperties Properties{}; constexpr static inline int Complexity = 2; }; struct Multiply : Binary { constexpr static inline auto Type = Category::Arithmetic; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline AlgebraicProperties Properties{AlgebraicProperty::Associative, AlgebraicProperty::Commutative}; constexpr static inline int Complexity = 4; }; struct MultiplyHigh : Binary { constexpr static inline auto Type = Category::Arithmetic; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline AlgebraicProperties Properties{AlgebraicProperty::Commutative}; constexpr static inline int Complexity = 4; }; struct Divide : Binary { constexpr static inline auto Type = Category::Arithmetic; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline AlgebraicProperties Properties{}; constexpr static inline int Complexity = 50; }; struct Modulo : Binary { constexpr static inline auto Type = Category::Arithmetic; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline AlgebraicProperties Properties{}; constexpr static inline int Complexity = 50; }; struct ShiftL : Binary { constexpr static inline auto Type = Category::Arithmetic; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline AlgebraicProperties Properties{}; constexpr static inline int Complexity = 1; }; struct LogicalShiftR : Binary { constexpr static inline auto Type = Category::Arithmetic; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline AlgebraicProperties Properties{}; constexpr static inline int Complexity = 1; }; struct ArithmeticShiftR : Binary { constexpr static inline auto Type = Category::Arithmetic; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline AlgebraicProperties Properties{}; constexpr static inline int Complexity = 1; }; struct BitwiseAnd : Binary { constexpr static inline auto Type = Category::Arithmetic; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline AlgebraicProperties Properties{AlgebraicProperty::Associative, AlgebraicProperty::Commutative}; constexpr static inline int Complexity = 1; }; struct BitwiseOr : Binary { constexpr static inline auto Type = Category::Arithmetic; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline AlgebraicProperties Properties{AlgebraicProperty::Associative, AlgebraicProperty::Commutative}; constexpr static inline int Complexity = 1; }; struct BitwiseXor : Binary { constexpr static inline auto Type = Category::Arithmetic; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline AlgebraicProperties Properties{AlgebraicProperty::Associative, AlgebraicProperty::Commutative}; constexpr static inline int Complexity = 1; }; struct GreaterThan : Binary { constexpr static inline auto Type = Category::Comparison; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline AlgebraicProperties Properties{}; constexpr static inline int Complexity = 2; }; struct GreaterThanEqual : Binary { constexpr static inline auto Type = Category::Comparison; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline AlgebraicProperties Properties{}; constexpr static inline int Complexity = 2; }; struct LessThan : Binary { constexpr static inline auto Type = Category::Comparison; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline AlgebraicProperties Properties{}; constexpr static inline int Complexity = 2; }; struct LessThanEqual : Binary { constexpr static inline auto Type = Category::Comparison; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline AlgebraicProperties Properties{}; constexpr static inline int Complexity = 2; }; struct Equal : Binary { constexpr static inline auto Type = Category::Comparison; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline AlgebraicProperties Properties{AlgebraicProperty::Commutative}; constexpr static inline int Complexity = 2; }; struct NotEqual : Binary { constexpr static inline auto Type = Category::Comparison; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline AlgebraicProperties Properties{AlgebraicProperty::Commutative}; constexpr static inline int Complexity = 1; }; struct LogicalAnd : Binary { constexpr static inline auto Type = Category::Logical; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline AlgebraicProperties Properties{AlgebraicProperty::Associative, AlgebraicProperty::Commutative}; constexpr static inline int Complexity = 1; }; struct LogicalOr : Binary { constexpr static inline auto Type = Category::Logical; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline AlgebraicProperties Properties{AlgebraicProperty::Associative, AlgebraicProperty::Commutative}; constexpr static inline int Complexity = 1; }; struct BitfieldCombine : Binary { unsigned srcOffset = 0u; unsigned dstOffset = 0u; unsigned width = 0u; // if srcIsZero sets to true, that means bits outside [srcOffset:srcOffset+width-1] are 0 std::optional srcIsZero = std::nullopt; // if dstIsZero sets to true, that means bits [dstOffset:dstOffset+width-1] are 0 std::optional dstIsZero = std::nullopt; constexpr static inline auto Type = Category::Arithmetic; constexpr static inline EvaluationTimes EvalTimes{EvaluationTime::Translate}; constexpr static inline AlgebraicProperties Properties{}; constexpr static inline int Complexity = 4; }; /* * SRConversion performs a stochastic rounding conversion. * The lhs is the value to be converted, the rhs is the seed * for stochastic rounding. */ template struct SRConvert : Binary { constexpr static inline auto DestinationType = DATATYPE; constexpr static inline auto Type = Category::Conversion; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline int Complexity = 2; }; struct Ternary { ExpressionPtr lhs, r1hs, r2hs; std::string comment = ""; template requires std::derived_from inline T& copyParams(const T& other) { return static_cast(*this); } }; struct TernaryMixed : Ternary { }; template concept CTernaryMixed = requires { requires std::derived_from; }; template concept CTernary = requires { requires std::derived_from || CTernaryMixed; }; /** * `result = (lhs + r1hs) << r2hs` * * AddShiftL performs a fusion of Add expression followed by * ShiftL expression, lowering to the fused instruction if possible. */ struct AddShiftL : Ternary { constexpr static inline auto Type = Category::Arithmetic; constexpr static inline EvaluationTimes EvalTimes{EvaluationTime::KernelExecute}; constexpr static inline int Complexity = 2; }; /** * `result = (lhs << r1hs) + r2hs` * * ShiftLAdd performs a fusion of ShiftL expression followed by * Add expression, lowering to the fused instruction if possible. */ struct ShiftLAdd : Ternary { constexpr static inline auto Type = Category::Arithmetic; constexpr static inline EvaluationTimes EvalTimes{EvaluationTime::KernelExecute}; constexpr static inline int Complexity = 2; }; /** * result = (lhs x r1hs) + r2hs. * * MatA is M x K, with B batches. MatB is K x N, with B batches. MatC is M x N, with B batches. */ struct MatrixMultiply : Ternary { MatrixMultiply() = default; /** * @brief Construct a new Matrix Multiply object * * @param matA WaveTile. M x K, B batches * @param matB WaveTile. K x N, B batches * @param matC WaveTile. M x N, B batches */ MatrixMultiply(ExpressionPtr matA, ExpressionPtr matB, ExpressionPtr matC) : Ternary{matA, matB, matC} { } DataType accumulationPrecision = DataType::Float; constexpr static inline auto Type = Category::Arithmetic; constexpr static inline EvaluationTimes EvalTimes{EvaluationTime::KernelExecute}; constexpr static inline int Complexity = 20; }; /** * result = ((matA * scaleA) x (matB * scaleB)) + matC */ struct ScaledMatrixMultiply { ExpressionPtr matA, matB, matC, scaleA, scaleB; DataType accumulationPrecision = DataType::Float; std::string comment = ""; ScaledMatrixMultiply() = default; ScaledMatrixMultiply(ExpressionPtr a, ExpressionPtr b, ExpressionPtr c, ExpressionPtr sA, ExpressionPtr sB) : matA(a) , matB(b) , matC(c) , scaleA(sA) , scaleB(sB) { } constexpr static inline auto Type = Category::Arithmetic; constexpr static inline EvaluationTimes EvalTimes{EvaluationTime::KernelExecute}; constexpr static inline int Complexity = 20; }; /** * dest = lhs ? r1hs : r2hs. * * Utilizes cselect */ struct Conditional : Ternary { constexpr static inline auto Type = Category::Arithmetic; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline int Complexity = 4; }; /** * dest = lhs * r1hs + r2hs. * * Utilizes TernaryMixed instead of Ternary * allows for mixed precision arithmetic */ struct MultiplyAdd : TernaryMixed { constexpr static inline auto Type = Category::Arithmetic; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline bool Associative = false; constexpr static inline bool Commutative = false; constexpr static inline int Complexity = 4; }; struct Unary { ExpressionPtr arg; std::string comment = ""; template requires std::derived_from inline T& copyParams(const T& other) { return static_cast(*this); } }; template concept CUnary = requires { requires std::derived_from; }; struct MagicMultiple : Unary { constexpr static inline auto Type = Category::Arithmetic; constexpr static inline EvaluationTimes EvalTimes{EvaluationTime::Translate, EvaluationTime::KernelLaunch}; constexpr static inline int Complexity = 0; }; struct MagicShifts : Unary { constexpr static inline auto Type = Category::Arithmetic; constexpr static inline EvaluationTimes EvalTimes{EvaluationTime::Translate, EvaluationTime::KernelLaunch}; constexpr static inline int Complexity = 0; }; struct MagicShiftAndSign : Unary { constexpr static inline auto Type = Category::Arithmetic; constexpr static inline EvaluationTimes EvalTimes{EvaluationTime::Translate, EvaluationTime::KernelLaunch}; constexpr static inline int Complexity = 0; }; struct Negate : Unary { constexpr static inline auto Type = Category::Arithmetic; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline int Complexity = 1; }; struct BitwiseNegate : Unary { constexpr static inline auto Type = Category::Arithmetic; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline int Complexity = 1; }; struct Convert : Unary { inline Convert& copyParams(const Convert& other) { destinationType = other.destinationType; return *this; } constexpr static inline auto Type = Category::Conversion; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline int Complexity = 2; DataType destinationType = DataType::None; }; struct LogicalNot : Unary { constexpr static inline auto Type = Category::Logical; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline int Complexity = 1; }; struct Exponential2 : Unary { constexpr static inline auto Type = Category::Arithmetic; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline int Complexity = 1; }; struct Exponential : Unary { constexpr static inline auto Type = Category::Arithmetic; constexpr static inline EvaluationTimes EvalTimes{EvaluationTime::Translate, EvaluationTime::KernelLaunch}; constexpr static inline int Complexity = 2; }; struct RandomNumber : Unary { constexpr static inline auto Type = Category::Arithmetic; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline int Complexity = 1; }; struct ToScalar : Unary { constexpr static inline auto Type = Category::Arithmetic; constexpr static inline auto EvalTimes = EvaluationTimes::All(); constexpr static inline int Complexity = 1; }; struct BitFieldExtract : Unary { inline BitFieldExtract& copyParams(const BitFieldExtract& other) { outputDataType = other.outputDataType; offset = other.offset; width = other.width; return *this; } constexpr static inline auto Type = Category::Arithmetic; constexpr static inline EvaluationTimes EvalTimes{EvaluationTime::Translate}; constexpr static inline int Complexity = 1; DataType outputDataType = DataType::None; int offset = 0; int width = 0; }; struct Nary { std::vector operands; std::string comment = ""; template requires std::derived_from inline T& copyParams(const T& other) { return static_cast(*this); } }; template concept CNary = requires { requires std::derived_from; }; /** * @brief Perform bitwise concatenation among all operands. * * Each operand must be dword aligned and the total number of operands' * registers must be equal to the number of registers for * 'destinationType'. * * All operands should have register type of literal, scalar or * vector. */ struct Concatenate : Nary { constexpr static inline auto Type = Category::Value; constexpr static inline EvaluationTimes EvalTimes = EvaluationTimes{}; constexpr static inline int Complexity = 1; VariableType destinationType; inline Concatenate& copyParams(const Concatenate& other) { destinationType = other.destinationType; return Nary::copyParams(other); } }; /** * @brief Register value from the coordinate graph. * * If the register associated with the `tag` hasn't been * allocated yet, a new register is created based on `regType` * and `varType`. * * If `varType` is `DataType::None`, the data type is * "deferred". */ struct DataFlowTag { int tag; Register::Type regType; VariableType varType; auto operator<=>(DataFlowTag const&) const = default; }; /** * @brief Positional argument */ struct PositionalArgument { int slot; Register::Type regType; VariableType varType; auto operator<=>(PositionalArgument const&) const = default; }; ExpressionPtr operator+(ExpressionPtr a, ExpressionPtr b); ExpressionPtr operator-(ExpressionPtr a, ExpressionPtr b); ExpressionPtr operator*(ExpressionPtr a, ExpressionPtr b); ExpressionPtr operator/(ExpressionPtr a, ExpressionPtr b); ExpressionPtr operator%(ExpressionPtr a, ExpressionPtr b); ExpressionPtr operator<<(ExpressionPtr a, ExpressionPtr b); ExpressionPtr operator>>(ExpressionPtr a, ExpressionPtr b); ExpressionPtr operator&(ExpressionPtr a, ExpressionPtr b); ExpressionPtr operator>(ExpressionPtr a, ExpressionPtr b); ExpressionPtr operator>=(ExpressionPtr a, ExpressionPtr b); ExpressionPtr operator<(ExpressionPtr a, ExpressionPtr b); ExpressionPtr operator<=(ExpressionPtr a, ExpressionPtr b); ExpressionPtr operator==(ExpressionPtr a, ExpressionPtr b); ExpressionPtr operator&&(ExpressionPtr a, ExpressionPtr b); ExpressionPtr operator||(ExpressionPtr a, ExpressionPtr b); ExpressionPtr operator-(ExpressionPtr a); ExpressionPtr logicalNot(ExpressionPtr a); ExpressionPtr multiplyHigh(ExpressionPtr a, ExpressionPtr b); ExpressionPtr multiplyAdd(ExpressionPtr a, ExpressionPtr b, ExpressionPtr c); ExpressionPtr addShiftL(ExpressionPtr a, ExpressionPtr b, ExpressionPtr c); ExpressionPtr shiftLAdd(ExpressionPtr a, ExpressionPtr b, ExpressionPtr c); ExpressionPtr conditional(ExpressionPtr a, ExpressionPtr b, ExpressionPtr c); // arithmeticShiftR is the same as >> ExpressionPtr arithmeticShiftR(ExpressionPtr a, ExpressionPtr b); ExpressionPtr logicalShiftR(ExpressionPtr a, ExpressionPtr b); ExpressionPtr magicMultiple(ExpressionPtr a); ExpressionPtr magicShifts(ExpressionPtr a); ExpressionPtr magicShiftAndSign(ExpressionPtr a); ExpressionPtr convert(VariableType vt, ExpressionPtr a); ExpressionPtr convert(DataType dt, ExpressionPtr a); template ExpressionPtr convert(ExpressionPtr a); ExpressionPtr bfe(DataType dt, ExpressionPtr a, uint8_t offset, uint8_t width); ExpressionPtr bfe(ExpressionPtr a, uint8_t offset, uint8_t width); ExpressionPtr bfc(ExpressionPtr src, ExpressionPtr dst, unsigned srcOffset, unsigned dstOffset, unsigned width); ExpressionPtr concat(const std::vector& ops, VariableType v); template ExpressionPtr literal(T value); ExpressionPtr fromKernelArgument(AssemblyKernelArgument const& arg); /** * @brief Create an Expression representing a literal value with a * specific datatype. Does not accept pointer variable types. * * @tparam T * @param value The value to represent. * @param v The datatype of value. * @return ExpressionPtr */ template ExpressionPtr literal(T value, VariableType v); ExpressionPtr dataFlowTag(int tag, Register::Type t, VariableType v); ExpressionPtr positionalArgument(int slot, Register::Type t, VariableType v); template concept CValue = CIsAnyOf; template concept COpCategory = requires { requires static_cast(T::Type) == cat; }; template concept CArithmetic = requires { requires static_cast(T::Type) == Category::Arithmetic; }; template concept CComparison = requires { requires static_cast(T::Type) == Category::Comparison; }; template concept CLogical = requires { requires static_cast(T::Type) == Category::Logical; }; template concept CConversion = requires { requires static_cast(T::Type) == Category::Conversion; }; template concept CShift = CIsAnyOf; template concept CBitwise = CIsAnyOf; template concept CAssociativeBinary = requires { requires CBinary && T::Properties[AlgebraicProperty::Associative] == true; }; template concept CCommutativeBinary = requires { requires CBinary && T::Properties[AlgebraicProperty::Commutative] == true; }; static_assert(CBinary); static_assert(CArithmetic); static_assert(!CComparison); static_assert(!CBinary); static_assert(CAssociativeBinary); static_assert(!CAssociativeBinary); template concept CTranslateTimeValue = std::same_as; template concept CTranslateTimeOperation = requires { requires T::EvalTimes[EvaluationTime::Translate] == true; }; template concept CTranslateTime = requires { requires CTranslateTimeValue || CTranslateTimeOperation; }; template concept CKernelLaunchTimeValue = CIsAnyOf; template concept CKernelLaunchTimeOperation = requires { requires T::EvalTimes[EvaluationTime::KernelLaunch] == true; }; template concept CKernelLaunchTime = requires { requires CKernelLaunchTimeValue || CKernelLaunchTimeOperation; }; template concept CKernelExecuteTimeValue = CIsAnyOf; template concept CKernelExecuteTimeOperation = requires { requires(T::EvalTimes[EvaluationTime::KernelExecute] == true); }; template concept CKernelExecuteTime = requires { requires CKernelExecuteTimeValue || CKernelExecuteTimeOperation; }; static_assert(CTranslateTime); static_assert(CTranslateTime); static_assert(CKernelLaunchTime); static_assert(CKernelLaunchTime); static_assert(CKernelExecuteTime); static_assert(CKernelExecuteTime); static_assert(!CKernelExecuteTime); // // Other visitors // std::string toString(ExpressionPtr const& expr); std::string toString(Expression const& expr); std::ostream& operator<<(std::ostream&, ExpressionPtr const&); std::ostream& operator<<(std::ostream&, Expression const&); std::ostream& operator<<(std::ostream&, std::vector const&); std::string name(ExpressionPtr const& expr); std::string name(Expression const& expr); std::string argumentName(ExpressionPtr const& expr); std::string argumentName(Expression const& expr); // EvaluationTime max(EvaluationTime lhs, EvaluationTime rhs); EvaluationTimes evaluationTimes(ExpressionPtr const& expr); EvaluationTimes evaluationTimes(Expression const& expr); VariableType resultVariableType(Expression const& expr); VariableType resultVariableType(ExpressionPtr const& expr); Register::Type resultRegisterType(Expression const& expr); Register::Type resultRegisterType(ExpressionPtr const& expr); struct ResultType { Register::Type regType; VariableType varType; bool operator==(ResultType const&) const = default; }; ResultType resultType(ExpressionPtr const& expr); ResultType resultType(Expression const& expr); std::string toString(ResultType const& obj); std::ostream& operator<<(std::ostream&, ResultType const&); /** * True when two expressions are identical. * * NOTE: Never considers commutativity or associativity. */ bool identical(ExpressionPtr const&, ExpressionPtr const&); bool identical(Expression const&, Expression const&); /** * True when two expressions are equivalent. * Optionally considers algebraic properties like commutativity. */ bool equivalent(ExpressionPtr const&, ExpressionPtr const&, AlgebraicProperties = AlgebraicProperties::All()); /** * Comment accessors. */ void setComment(ExpressionPtr& expr, std::string comment); void setComment(Expression& expr, std::string comment); std::string getComment(Expression const& expr, bool includeRegisterComments); std::string getComment(ExpressionPtr const& expr, bool includeRegisterComments); std::string getComment(ExpressionPtr const& expr); std::string getComment(Expression const& expr); std::string getComment(ExpressionPtr const& expr, bool includeRegisterComments); std::string getComment(Expression const& expr, bool includeRegisterComments); /** * Copies any comments from src into dst. If dst is not of a type that allows * comments, does nothing. */ void copyComment(ExpressionPtr const& dst, ExpressionPtr const& src); void copyComment(Expression& dst, ExpressionPtr const& src); void copyComment(ExpressionPtr const& dst, Expression const& src); void copyComment(Expression& dst, Expression const& src); void appendComment(ExpressionPtr& expr, std::string comment); void appendComment(Expression& expr, std::string comment); /** * Evaluate an expression whose evaluationTime is Translate. Will throw an exception * otherwise. */ CommandArgumentValue evaluate(ExpressionPtr const& expr); CommandArgumentValue evaluate(Expression const& expr); /** * Evaluate an expression if its evaluationTime is Translate, returns nullopt * otherwise. */ std::optional tryEvaluate(ExpressionPtr const& expr); std::optional tryEvaluate(Expression const& expr); bool canEvaluateTo(CommandArgumentValue val, ExpressionPtr const& expr); /** * Evaluate an expression whose evaluationTime is Translate or KernelLaunch. Will throw * an exception if it contains any Register values. */ CommandArgumentValue evaluate(ExpressionPtr const& expr, RuntimeArguments const& args); CommandArgumentValue evaluate(Expression const& expr, RuntimeArguments const& args); /** * Splits an expression and returns its operands in a tuple. * * Return type: * std::tuple for unary expressions * std::tuple for binary expressions * std::tuple for ternary expressions * * Throws if expr is not of type Expr. */ template requires(CUnary || CBinary || CTernary) auto split(ExpressionPtr expr); /** * Returns an approximate total complexity for an expression, to be used as a heuristic. * See the KernelOption minLaunchTimeExpressionComplexity for a more in-depth * description. */ int complexity(ExpressionPtr expr); int complexity(Expression const& expr); Generator generate(Register::ValuePtr& dest, ExpressionPtr expr, ContextPtr context); std::string toYAML(ExpressionPtr const& expr); ExpressionPtr fromYAML(std::string const& str); /** * Returns true if expr is of type T or if expr contains a subexpression of type T. */ template bool contains(ExpressionPtr expr); /** * Returns true if expr is of type T or if expr contains a subexpression of type T. */ template bool contains(Expression const& expr); /** * Returns true if expr contains a sub-expression */ bool containsSubExpression(ExpressionPtr const& expr, ExpressionPtr const& subExpr); bool containsSubExpression(Expression const& expr, Expression const& subExpr); std::unordered_set referencedKernelArguments(ExpressionPtr const& expr); std::unordered_set referencedKernelArguments(Expression const& expr); std::unordered_set referencedKernelArguments(ExpressionPtr const& expr, RegisterTagManager const& tagManager); std::unordered_set referencedKernelArguments(Expression const& expr, RegisterTagManager const& tagManager); } // namespace Expression } // namespace rocRoller #include