/******************************************************************************* * * MIT License * * Copyright 2024-2025 AMD ROCm(TM) Software * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * *******************************************************************************/ #pragma once #include #include namespace rocRoller { namespace KernelGraph::CoordinateGraph { template inline std::vector CoordinateGraph::traverse(std::vector sdims, std::vector const& srcs, std::vector const& dsts, Visitor& visitor) const { bool constexpr forward = Dir == Graph::Direction::Downstream; auto constexpr OppositeDir = opposite(Dir); auto const& starts = forward ? srcs : dsts; auto const& ends = forward ? dsts : srcs; std::map exprMap; std::vector visitedExprs; auto edgeSelector = [this](int element) { return getEdgeType(element) == EdgeType::CoordinateTransform; }; for(size_t i = 0; i < sdims.size(); i++) { int key = starts[i]; exprMap.emplace(key, sdims[i]); } // traverse through the edges in the path from `starts` to `ends` // and generate expressions (populates exprMap) for the target // dimension(s) in `ends` via the given visitor. for(auto const elemId : path(starts, ends, edgeSelector)) { Element const& element = getElement(elemId); if(std::holds_alternative(element)) { Edge const& edge = std::get(std::get(element)); std::vector einds; std::vector keys, localSrcTags, localDstTags; std::vector localSrcs, localDsts; for(auto const& tag : getNeighbours(elemId)) { if(forward) { einds.push_back(exprMap[tag]); } else { keys.push_back(tag); } localSrcs.emplace_back(getNode(tag)); localSrcTags.emplace_back(tag); } for(auto const& tag : getNeighbours(elemId)) { if(!forward) { einds.push_back(exprMap[tag]); } else { keys.push_back(tag); } localDsts.emplace_back(getNode(tag)); localDstTags.emplace_back(tag); } visitor.setLocation(einds, localSrcs, localDsts, localSrcTags, localDstTags); visitedExprs = visitor.call(edge); AssertFatal(visitedExprs.size() == keys.size(), ShowValue(visitedExprs)); for(size_t i = 0; i < visitedExprs.size(); i++) { exprMap[keys[i]] = std::move(visitedExprs[i]); } } } std::vector results; for(int const key : ends) { if(!exprMap.contains(key)) { auto keys = [&exprMap]() -> Generator { for(auto const& pair : exprMap) co_yield pair.first; }() .template to(); std::ostringstream msg; streamJoin(msg, keys, ", "); AssertFatal(exprMap.contains(key), "Path not found for ", Graph::variantToString(getElement(key)), ShowValue(key), ShowValue(Dir), msg.str()); } results.push_back(exprMap.at(key)); } return results; } template inline bool CoordinateGraph::hasPath(std::vector const& srcs, std::vector const& dsts) const { bool constexpr forward = Dir == Graph::Direction::Downstream; auto const& starts = forward ? srcs : dsts; auto const& ends = forward ? dsts : srcs; auto edgeSelector = [this](int element) { return getEdgeType(element) == EdgeType::CoordinateTransform; }; auto partial = path(starts, ends, edgeSelector).template to(); for(auto end : ends) { if(!partial.contains(end)) return false; } return true; } inline EdgeType CoordinateGraph::getEdgeType(int index) const { Element const& elem = getElement(index); if(std::holds_alternative(elem)) { Edge const& edge = std::get(elem); if(std::holds_alternative(edge)) { return EdgeType::DataFlow; } else if(std::holds_alternative(edge)) { return EdgeType::CoordinateTransform; } } return EdgeType::None; } template requires(std::constructible_from) inline std::optional< T> CoordinateGraph::get(int tag) const { auto x = getElement(tag); if constexpr(std::constructible_from) { if(std::holds_alternative(x)) { auto y = std::get(x); if constexpr(std::constructible_from) { if(std::holds_alternative(y)) { if(std::holds_alternative(std::get(y))) { return std::get(std::get(y)); } } } else if constexpr(std::constructible_from) { if(std::holds_alternative(y)) { if(std::holds_alternative(std::get(y))) { return std::get(std::get(y)); } } } } } if constexpr(std::constructible_from) { if(std::holds_alternative(x)) { if(std::holds_alternative(std::get(x))) { return std::get(std::get(x)); } } } return {}; } inline std::string name(CoordinateGraph::Element const& el) { return CoordinateGraph::ElementName(el); } } }