/******************************************************************************* * * MIT License * * Copyright 2024-2025 AMD ROCm(TM) Software * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * *******************************************************************************/ #pragma once #include #include #include #include namespace rocRoller { /// Base Arithmetic Generator class. All Arithmetic generators should be derived /// from this class. class ArithmeticGenerator { public: ArithmeticGenerator(ContextPtr context) : m_context(context) { } protected: ContextPtr m_context; /// Move a value into a single VGPR Generator moveToVGPR(Register::ValuePtr& val); /// Copy the sign bit from `src` into every bit of `dst`. Generator signExtendDWord(Register::ValuePtr dst, Register::ValuePtr src); /// Split a single register value into two registers each containing one word Generator get2DwordsScalar(Register::ValuePtr& lsd, Register::ValuePtr& msd, Register::ValuePtr input); Generator get2DwordsVector(Register::ValuePtr& lsd, Register::ValuePtr& msd, Register::ValuePtr input); /// Generate comments describing an operation that is being generated. Generator describeOpArgs(std::string const& argName0, Register::ValuePtr arg0, std::string const& argName1, Register::ValuePtr arg1, std::string const& argName2, Register::ValuePtr arg2); virtual std::string name() const = 0; /** * @brief Use VALU to perform a scalar comparison. * * Some vectors comparison instructions (like v_cmp_le_i64) * don't have scalar equivalents. In this case, the RHS is * copied to VGPRs and the comparison is done with the VALU. * * @param lhs LHS of comparison (stored in SGPR) * @param rhs RHS of comparison (stored in SGPR) * @param dst Destination, can be null (in which case result is in SCC). */ Generator scalarCompareThroughVALU(std::string const instruction, Register::ValuePtr dst, Register::ValuePtr lhs, Register::ValuePtr rhs); }; // Unary Arithmetic Generator. Most unary generators should be derived from // this class. template class UnaryArithmeticGenerator : public ArithmeticGenerator { public: UnaryArithmeticGenerator(ContextPtr context) : ArithmeticGenerator(context) { } virtual Generator generate(Register::ValuePtr dst, Register::ValuePtr arg, Operation const& expr) = 0; using Argument = std::tuple; using Base = UnaryArithmeticGenerator; static const std::string Basename; std::string name() const override { return Expression::ExpressionInfo::name(); } }; template const std::string UnaryArithmeticGenerator::Basename = concatenate(Expression::ExpressionInfo::name(), "Generator"); // Binary Arithmetic Generator. Most binary generators should be derived from // this class. template class BinaryArithmeticGenerator : public ArithmeticGenerator { public: BinaryArithmeticGenerator(ContextPtr context) : ArithmeticGenerator(context) { } virtual Generator generate(Register::ValuePtr dst, Register::ValuePtr lhs, Register::ValuePtr rhs, Operation const& expr) = 0; using Argument = std::tuple; using Base = BinaryArithmeticGenerator; static const std::string Basename; std::string name() const override { return Expression::ExpressionInfo::name(); } }; template const std::string BinaryArithmeticGenerator::Basename = concatenate(Expression::ExpressionInfo::name(), "Generator"); // Ternary Arithmetic Generator. Most ternary generators should be derived from // this class. template class TernaryArithmeticGenerator : public ArithmeticGenerator { public: TernaryArithmeticGenerator(ContextPtr context) : ArithmeticGenerator(context) { } virtual Generator generate(Register::ValuePtr dst, Register::ValuePtr arg1, Register::ValuePtr arg2, Register::ValuePtr arg3, Operation const& expr) = 0; using Argument = std::tuple; using Base = TernaryArithmeticGenerator; static const std::string Basename; std::string name() const override { return Expression::ExpressionInfo::name(); } }; template const std::string TernaryArithmeticGenerator::Basename = concatenate(Expression::ExpressionInfo::name(), "Generator"); // TernaryMixed Arithmetic Generator. Only Ternary generators that can support mixed // airthmetic should be derived from this class. template class TernaryMixedArithmeticGenerator : public ArithmeticGenerator { public: TernaryMixedArithmeticGenerator(ContextPtr context) : ArithmeticGenerator(context) { } virtual Generator generate(Register::ValuePtr dst, Register::ValuePtr arg1, Register::ValuePtr arg2, Register::ValuePtr arg3, Operation const& expr) = 0; using Argument = std::tuple; using Base = TernaryMixedArithmeticGenerator; static const std::string Basename; std::string name() const override { return Expression::ExpressionInfo::name(); } }; template const std::string TernaryMixedArithmeticGenerator::Basename = concatenate(Expression::ExpressionInfo::name(), "Generator"); // -------------------------------------------------- // Get Functions // These functions are used to pick the proper Generator class for the provided // Expression and arguments. template std::shared_ptr> GetGenerator(Register::ValuePtr dst, Register::ValuePtr arg, Operation const& expr) { return nullptr; } template Generator generateOp(Register::ValuePtr dst, Register::ValuePtr arg, Operation const& expr = Operation{}); template Generator generateOp(Register::ValuePtr dst, Register::ValuePtr arg, Operation const& expr) { static_assert(!std::same_as); auto gen = GetGenerator(dst, arg, expr); AssertFatal(gen != nullptr, "No generator"); co_yield gen->generate(dst, arg, expr); } template <> Generator generateOp(Register::ValuePtr dst, Register::ValuePtr arg, Expression::ToScalar const& expr); template std::shared_ptr> GetGenerator(Register::ValuePtr dst, Register::ValuePtr lhs, Register::ValuePtr rhs, Operation const& expr) { return nullptr; } template Generator generateOp(Register::ValuePtr const dst, Register::ValuePtr lhs, Register::ValuePtr rhs, Operation const& expr = Operation{}) { auto gen = GetGenerator(dst, lhs, rhs, expr); AssertFatal(gen != nullptr, "No generator"); co_yield gen->generate(dst, lhs, rhs, expr); } template std::shared_ptr> GetGenerator(Register::ValuePtr dst, Register::ValuePtr arg1, Register::ValuePtr arg2, Register::ValuePtr arg3, Operation const& expr) { return nullptr; } template Generator generateOp(Register::ValuePtr dst, Register::ValuePtr arg1, Register::ValuePtr arg2, Register::ValuePtr arg3, Operation const& expr = Operation{}) { auto gen = GetGenerator(dst, arg1, arg2, arg3, expr); AssertFatal(gen != nullptr, "No generator"); co_yield gen->generate(dst, arg1, arg2, arg3, expr); } // -------------------------------------------------- // Helper functions // Return the expected datatype from the arguments to an operation. DataType promoteDataType(Register::ValuePtr dst, Register::ValuePtr lhs, Register::ValuePtr rhs); // Return the expected register type from the arguments to an operation. Register::Type promoteRegisterType(Register::ValuePtr dst, Register::ValuePtr lhs, Register::ValuePtr rhs); // Return the data of a register that will be used for arithmetic calculations. // If the register contains a pointer, the DataType that is returned is UInt64. /** * @brief Return the data of a register that will be used for arithmetic calculations. * * If the register contains a pointer, the DataType that is returned is UInt64. * * @param reg * @return DataType */ DataType getArithDataType(Register::ValuePtr const reg); // Return the context from a list of register values. inline ContextPtr getContextFromValues(Register::ValuePtr const r) { AssertFatal(r != nullptr, "No context"); return r->context(); } template inline ContextPtr getContextFromValues(Register::ValuePtr const arg, Args... args) { if(arg && arg->context()) { return arg->context(); } else { return getContextFromValues(args...); } } } #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include