/******************************************************************************* * * MIT License * * Copyright 2023-2025 AMD ROCm(TM) Software * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * *******************************************************************************/ #pragma once #include #include #include #include #include namespace rocRoller { namespace KernelGraph { /** * @brief Class for generating instructions related to loading and storing tiles * to and from memory. * */ class LoadStoreTileGenerator { public: LoadStoreTileGenerator(KernelGraphPtr, ContextPtr, unsigned int); /** * @brief Generate instructions needed to load a tile from global memory * * @param tag The tag of the node in the control graph * @param load The node in the control graph * @param coords Known coordinates * @return Generator */ Generator genLoadTile(int tag, ControlGraph::LoadTiled const& load, CoordinateGraph::Transformer coords); /** * @brief Generate instructions needed to load a tile from LDS * * @param tag The tag of the node in the control graph * @param load The node in the control graph * @param coords Known coordinates * @return Generator */ Generator genLoadLDSTile(int tag, ControlGraph::LoadLDSTile const& load, CoordinateGraph::Transformer coords); /** * @brief Generate instructions needed to load a tile from global memory direct to lds * * @param tag The tag of the node in the control graph * @param load The node in the control graph * @param coords Known coordinates * @return Generator */ Generator genLoadTileDirect2LDS(int tag, ControlGraph::LoadTileDirect2LDS const& load, CoordinateGraph::Transformer coords); /** * @brief Generate instructions needed to store a tile to global memory * * @param tag The tag of the node in the control graph * @param load The node in the control graph * @param coords Known coordinates * @return Generator */ Generator genStoreTile(int tag, ControlGraph::StoreTiled const& store, CoordinateGraph::Transformer coords); /** * @brief Generate instructions needed to store a tile to LDS * * @param tag The tag of the node in the control graph * @param load The node in the control graph * @param coords Known coordinates * @return Generator */ Generator genStoreLDSTile(int tag, ControlGraph::StoreLDSTile const& store, CoordinateGraph::Transformer coords); /** * @brief Generate instructions needed to calculate offset and stride information * * @param tag The tag of the node in the control graph * @param load The node in the control graph * @param coords Known coordinates * @return Generator */ Generator genComputeIndex(int tag, ControlGraph::ComputeIndex const& ci, CoordinateGraph::Transformer coords); /** * @brief Information needed in order to load or store a tile. * * @field tag The tag of the control graph node generating the load or store * @field kind The kind of memory instruction to use * @field m Number of rows in the tile * @field n Number of columns in the tile * @field dataType The type of the data being loaded * @field isTransposedTile if tile needs to be transposed * @field vgpr The registers to store the data in (null is loading) * @field offset Offset from the starting index */ struct LoadStoreTileInfo { int tag = -1; MemoryInstructions::MemoryKind kind = MemoryInstructions::MemoryKind::Count; uint64_t m = 0; uint64_t n = 0; uint32_t elementBits = 0; uint32_t packedAmount = 0; uint32_t ldsWriteStride = 0; Register::ValuePtr data = nullptr; VariableType varType = VariableType{DataType::Count}; Register::ValuePtr rowOffsetReg = nullptr; Register::ValuePtr rowStrideReg = nullptr; RegisterExpressionAttributes rowStrideAttributes; Register::ValuePtr colStrideReg = nullptr; RegisterExpressionAttributes colStrideAttributes; Register::ValuePtr offset = nullptr; std::shared_ptr bufDesc = nullptr; BufferInstructionOptions bufOpts = {}; bool isTransposedTile = false; bool isPadded = false; }; private: ContextPtr m_context; KernelGraphPtr m_graph; Expression::ExpressionTransducer m_fastArith; unsigned int m_workgroupSizeTotal; inline Generator generate(auto& dest, Expression::ExpressionPtr expr) const; // Index calculation Helpers std::shared_ptr getBufferDesc(int tag); Expression::ExpressionPtr getOffsetExpr(int opTag, bool isStorePartOfGlobalToLDS, CoordinateGraph::Transformer const& coords); Generator getOffset(LoadStoreTileInfo& info, CoordinateGraph::Transformer coords, bool preserveOffset, bool isStorePartOfGlobalToLDS = false); /** * @brief Generate stride (in bytes). * * The `unitStride` flag is set if the generated * byte-stride corresponds to a unit element-stride. A * unit element-stride is a unitary (=1) stride with * respect to the element of the underlying data type. * * The generated stride is in bytes. This facilitates, * eg, advancing offset registers to the next macro tile * by simply adding the stride in the increment of a for * loop. * * However, determining whether a byte-stride in a * stride-expression is a unit-stride is tricky for * sub-byte datatypes. To make this more robust, * stride-expressions have meta-data attached to the * expression to make this explicit. * * For example, if we only knew the byte-stride: * * | data type | byte-stride | unit element-stride | * |-----------|-------------|---------------------| * | FP64 | 8 | true | * | FP32 | 4 | true | * | FP32 | 8 | false | * | FP16 | 2 | true | * | FP8 | 1 | true | * | Sub-byte | 1 | maybe! | */ Generator generateStride(Register::ValuePtr& stride, RegisterExpressionAttributes& attrs, int tag, int dimension); // Move Tile Helpers template Generator moveTile(LoadStoreTileInfo& info, CoordinateGraph::Transformer& coords); template Generator moveTileLiteralStrides(LoadStoreTileInfo& info); template Generator moveTileColStrideOne(LoadStoreTileInfo& info); template Generator moveTileRuntimeStrides(LoadStoreTileInfo& info); template Generator moveTileDirect2LDS(LoadStoreTileInfo& info, int numBytes, bool setM0, Register::ValuePtr readAddr); Generator loadTileLiteralStridesPack(LoadStoreTileInfo& info); Generator loadTileRuntimeStridesPack(LoadStoreTileInfo& info); // Load Tile Helpers Generator loadMacroTileVGPR(int tag, ControlGraph::LoadTiled const& load, CoordinateGraph::Transformer coords); Generator loadMacroTileLDS(int tag, ControlGraph::LoadLDSTile const& load, CoordinateGraph::Transformer coords); Generator loadMacroTileWAVELDS(int tag, ControlGraph::LoadLDSTile const& load, CoordinateGraph::Transformer coords); Generator loadMacroTileWAVE(int tag, ControlGraph::LoadTiled const& load, CoordinateGraph::Transformer coords); Generator loadMacroTileWAVECIACCUM(int tag, ControlGraph::LoadTiled const& load, CoordinateGraph::Transformer coords); Generator loadMacroTileDirect2LDS(int tag, ControlGraph::LoadTileDirect2LDS const& load, CoordinateGraph::Transformer coords); // Store Tile Helpers Generator storeMacroTileLDS(int tag, ControlGraph::StoreLDSTile const& store, CoordinateGraph::Transformer coords); Generator storeMacroTileVGPR(int tag, ControlGraph::StoreTiled const& store, CoordinateGraph::Transformer coords); Generator storeMacroTileWAVELDS(int tag, ControlGraph::StoreLDSTile const& store, CoordinateGraph::Transformer coords); Generator storeMacroTileWAVE(int tag, ControlGraph::StoreTiled const& store, CoordinateGraph::Transformer coords); }; std::string toString(LoadStoreTileGenerator::LoadStoreTileInfo const& info); } }