/******************************************************************************* * * MIT License * * Copyright 2024-2025 AMD ROCm(TM) Software * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * *******************************************************************************/ #pragma once #include namespace rocRoller::KernelGraph::ControlGraph { inline std::string toString(NodeOrdering n) { switch(n) { case NodeOrdering::LeftFirst: return "LeftFirst"; case NodeOrdering::LeftInBodyOfRight: return "LeftInBodyOfRight"; case NodeOrdering::Undefined: return "Undefined"; case NodeOrdering::RightInBodyOfLeft: return "RightInBodyOfLeft"; case NodeOrdering::RightFirst: return "RightFirst"; case NodeOrdering::Count: default: break; } Throw("Invalid NodeOrdering"); } inline std::ostream& operator<<(std::ostream& stream, NodeOrdering n) { return stream << toString(n); } inline std::string toString(CacheStatus c) { switch(c) { case CacheStatus::Invalid: return "Invalid"; case CacheStatus::Partial: return "Partial"; case CacheStatus::Valid: return "Valid"; default: break; } Throw("Invalid CacheStatus"); } inline std::ostream& operator<<(std::ostream& stream, CacheStatus c) { return stream << toString(c); } static_assert(CCountedEnum); inline std::string abbrev(NodeOrdering n) { switch(n) { case NodeOrdering::LeftFirst: return " LF"; case NodeOrdering::LeftInBodyOfRight: return "LIB"; case NodeOrdering::Undefined: return "und"; case NodeOrdering::RightInBodyOfLeft: return "RIB"; case NodeOrdering::RightFirst: return " RF"; case NodeOrdering::Count: default: break; } Throw("Invalid NodeOrdering"); } inline NodeOrdering opposite(NodeOrdering order) { switch(order) { case NodeOrdering::LeftFirst: return NodeOrdering::RightFirst; case NodeOrdering::LeftInBodyOfRight: return NodeOrdering::RightInBodyOfLeft; case NodeOrdering::Undefined: return NodeOrdering::Undefined; case NodeOrdering::RightInBodyOfLeft: return NodeOrdering::LeftInBodyOfRight; case NodeOrdering::RightFirst: return NodeOrdering::LeftFirst; case NodeOrdering::Count: default: break; } Throw("Invalid NodeOrdering"); } inline void ControlGraph::checkOrderCache() const { if(m_orderCache.empty()) populateOrderCache(); } inline NodeOrdering ControlGraph::lookupOrder(CacheOnlyPolicy const, int nodeA, int nodeB) const { if(nodeA > nodeB) return opposite(lookupOrder(CacheOnly, nodeB, nodeA)); if(not m_orderCache.contains(nodeA)) return NodeOrdering::Undefined; auto const& nodeOrderPairs = m_orderCache.at(nodeA); return nodeOrderPairs.contains(nodeB) ? nodeOrderPairs.at(nodeB) : NodeOrdering::Undefined; } inline Generator ControlGraph::nodesAfter(int node) const { populateOrderCache(); if(m_orderCache.contains(node)) { for(auto const& [otherNode, order] : m_orderCache.at(node)) { if(order == NodeOrdering::LeftFirst) co_yield otherNode; } } for(auto const& [otherNode, nodeOrderPairs] : m_orderCache) { if(otherNode >= node) continue; if(nodeOrderPairs.contains(node) && nodeOrderPairs.at(node) == NodeOrdering::RightFirst) co_yield otherNode; } } inline Generator ControlGraph::nodesBefore(int node) const { populateOrderCache(); if(m_orderCache.contains(node)) { for(auto const& [otherNode, order] : m_orderCache.at(node)) { if(order == NodeOrdering::RightFirst) co_yield otherNode; } } for(auto const& [otherNode, nodeOrderPairs] : m_orderCache) { if(otherNode >= node) continue; if(nodeOrderPairs.contains(node) && nodeOrderPairs.at(node) == NodeOrdering::LeftFirst) co_yield otherNode; } } inline Generator ControlGraph::nodesInBody(int node) const { populateOrderCache(); if(m_orderCache.contains(node)) { for(auto const& [otherNode, order] : m_orderCache.at(node)) { if(order == NodeOrdering::RightInBodyOfLeft) co_yield otherNode; } } for(auto const& [otherNode, nodeOrderPairs] : m_orderCache) { if(otherNode >= node) continue; if(nodeOrderPairs.contains(node) && nodeOrderPairs.at(node) == NodeOrdering::LeftInBodyOfRight) co_yield otherNode; } } inline Generator ControlGraph::nodesContaining(int node) const { populateOrderCache(); if(m_orderCache.contains(node)) { for(auto const& [otherNode, order] : m_orderCache.at(node)) { if(order == NodeOrdering::LeftInBodyOfRight) co_yield otherNode; } } for(auto const& [otherNode, nodeOrderPairs] : m_orderCache) { if(otherNode >= node) continue; if(nodeOrderPairs.contains(node) && nodeOrderPairs.at(node) == NodeOrdering::RightInBodyOfLeft) co_yield otherNode; } } template requires(std::constructible_from) inline std::set< std::pair> ControlGraph::ambiguousNodes() const { std::set> badNodes; auto memNodes = getNodes().template to(); for(auto iter = memNodes.begin(); iter != memNodes.end(); iter++) { std::set otherNodes(std::next(iter), memNodes.end()); for(auto node : otherNodes) { if(compareNodes(rocRoller::UpdateCache, *iter, node) == NodeOrdering::Undefined) { badNodes.insert(std::make_pair(*iter, node)); } } } return badNodes; } template requires(std::constructible_from) inline std::optional ControlGraph::get(int tag) const { auto x = getElement(tag); if constexpr(CIsAnyOf) { if(std::holds_alternative(x)) { return std::get(x); } } else { if constexpr(std::constructible_from) { if(std::holds_alternative(x)) { if(std::holds_alternative(std::get(x))) { return std::get(std::get(x)); } } } if constexpr(std::constructible_from) { if(std::holds_alternative(x)) { if(std::holds_alternative(std::get(x))) { return std::get(std::get(x)); } } } } return {}; } template ... Nodes> void ControlGraph::chain(int a, int b, Nodes... remaining) { addElement(Edge(), {a}, {b}); if constexpr(sizeof...(remaining) > 0) chain(b, remaining...); } template requires(std::constructible_from) inline bool isOperation(auto const& x) { if(std::holds_alternative(x)) { if(std::holds_alternative(std::get(x))) return true; } return false; } template requires(std::constructible_from) inline bool isEdge(auto const& x) { if(std::holds_alternative(x)) { if(std::holds_alternative(std::get(x))) return true; } return false; } inline std::string name(ControlGraph::Element const& el) { return ControlGraph::ElementName(el); } template Range> inline void ControlGraph::orderMemoryNodes(Range const& aControlStack, Range const& bControlStack, bool ordered) { int src = -1; int dest = -1; for(int i = 0; (i < aControlStack.size()) && (i < bControlStack.size()); ++i) { if(aControlStack.at(i) != bControlStack.at(i)) { auto setCoordA = get(aControlStack.at(i)); auto setCoordB = get(bControlStack.at(i)); if(ordered) { src = aControlStack.at(i); dest = bControlStack.at(i); } else if(setCoordA && setCoordB && evaluationTimes(setCoordA->value)[Expression::EvaluationTime::Translate] && evaluationTimes(setCoordB->value)[Expression::EvaluationTime::Translate]) { src = aControlStack.at(i); dest = bControlStack.at(i); if(getUnsignedInt(evaluate(setCoordB->value)) < getUnsignedInt(evaluate(setCoordA->value))) { src = bControlStack.at(i); dest = aControlStack.at(i); } } else { src = std::min(aControlStack.at(i), bControlStack.at(i)); dest = std::max(aControlStack.at(i), bControlStack.at(i)); } break; } } if(src == -1 || dest == -1) { int aIndex = aControlStack.size() - 1; int bIndex = bControlStack.size() - 1; if(aControlStack.size() > bControlStack.size()) { aIndex = bIndex; } else { bIndex = aIndex; } src = std::min(aControlStack.at(aIndex), bControlStack.at(bIndex)); dest = std::max(aControlStack.at(aIndex), bControlStack.at(bIndex)); } addElement(Sequence(), {src}, {dest}); } }