/******************************************************************************* * * MIT License * * Copyright 2024-2025 AMD ROCm(TM) Software * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * *******************************************************************************/ #pragma once #include #include #include #include #include #include #include #include #include #include #include #include namespace rocRoller { namespace Graph { inline std::string toString(ElementType e) { switch(e) { case ElementType::Node: return "Node"; case ElementType::Edge: return "Edge"; case ElementType::Count: break; } throw std::runtime_error("Invalid ElementType"); } inline std::ostream& operator<<(std::ostream& stream, ElementType const& e) { return stream << toString(e); } inline std::string toString(Direction d) { switch(d) { case Direction::Upstream: return "Upstream"; case Direction::Downstream: return "Downstream"; case Direction::Count: break; } throw std::runtime_error("Invalid Direction"); } constexpr inline Direction opposite(Direction d) { return d == Direction::Downstream ? Direction::Upstream : Direction::Downstream; } inline std::string toString(GraphModification g) { switch(g) { case GraphModification::DeleteElement: return "DeleteElement"; case GraphModification::AddElement: return "AddElement"; case GraphModification::SetElement: return "SetElement"; case GraphModification::Count: return "Count"; } Throw("Invalid GraphModification ", static_cast(g)); } template std::string Hypergraph::ElementName( typename Hypergraph::Element const& el) { return std::visit(rocRoller::overloaded{[](Node const&) { return "Node"; }, [](Edge const&) { return "Edge"; }}, el); } template constexpr inline bool Hypergraph::Location::operator==(Location const& rhs) const { // clang-format off return LexicographicCompare( tag, rhs.tag, incoming, rhs.incoming, outgoing, rhs.outgoing) == 0; // clang-format on } template bool Hypergraph::exists(int tag) const { return m_elements.contains(tag); } template ElementType Hypergraph::getElementType(int tag) const { return getElementType(getElement(tag)); } template ElementType Hypergraph::getElementType(Element const& e) const { if(holds_alternative(e)) return ElementType::Node; return ElementType::Edge; } inline ElementType getConnectingType(ElementType t) { return t == ElementType::Node ? ElementType::Edge : ElementType::Node; } template void Hypergraph::clearCache(GraphModification) { } template template int Hypergraph::addElement(T&& element) { auto tag = nextAvailableTag(); m_elements.emplace(tag, std::forward(element)); AssertFatal(isModificationAllowed(tag), "addElement is disallowed on this graph"); clearCache(GraphModification::AddElement); return tag; } template template auto Hypergraph::addElements(T&& element, Ts&&... rest) { auto myValue = addElement(std::forward(element)); return std::tuple{myValue, addElement(std::forward(rest))...}; } template template void Hypergraph::setElement(int tag, T&& element) { AssertFatal(m_elements.find(tag) != m_elements.end()); m_elements[tag] = std::forward(element); clearCache(GraphModification::SetElement); } template template int Hypergraph::addElement(T&& element, std::initializer_list inputs, std::initializer_list outputs) { return addElement, std::initializer_list>( std::forward(element), inputs, outputs); } template template T_Inputs, CForwardRangeOf T_Outputs> int Hypergraph::addElement(T&& element, T_Inputs const& inputs, T_Outputs const& outputs) { auto tag = nextAvailableTag(); addElement(tag, std::forward(element), inputs, outputs); return tag; } template template T_Inputs, CForwardRangeOf T_Outputs> void Hypergraph::addElement(int tag, T&& element, T_Inputs const& inputs, T_Outputs const& outputs) { AssertFatal(m_elements.find(tag) == m_elements.end()); auto elementType = getElementType(element); auto connectingType = getConnectingType(elementType); for(auto input : inputs) { AssertFatal(getElementType(input) == connectingType); } for(auto output : outputs) { AssertFatal(getElementType(output) == connectingType); } clearCache(GraphModification::AddElement); m_elements.emplace(tag, std::forward(element)); AssertFatal(isModificationAllowed(tag), "addElement is disallowed on this graph"); m_incidence.addIncidentConnections(tag, inputs, outputs); if constexpr(!Hyper) { // Enforce "calm" graph restriction of Edges requiring exactly one incoming Node and // one outgoing Node std::string errorMsg = "Graph is not a Hypergraph and Edge requires exactly one " "incoming Node and one outgoing Node"; if(elementType == ElementType::Edge) { AssertFatal(m_incidence.getSrcCount(tag) == 1 && m_incidence.getDstCount(tag) == 1, errorMsg, ShowValue(tag)); } else { for(auto input : inputs) { AssertFatal(m_incidence.getDstCount(input) == 1, errorMsg, ShowValue(tag), ShowValue(input)); } for(auto output : outputs) { AssertFatal(m_incidence.getSrcCount(output) == 1, errorMsg, ShowValue(tag), ShowValue(output)); } } } else { // Check if we accidentally added a dangling edge if(elementType == ElementType::Edge) { AssertFatal(m_incidence.getSrcCount(tag) >= 1 && m_incidence.getDstCount(tag) >= 1, "Hypergraph has dangling edge", ShowValue(tag)); } } } template void Hypergraph::deleteElement(int tag) { AssertFatal(isModificationAllowed(tag), "deleteElement is disallowed on this graph"); clearCache(GraphModification::DeleteElement); m_incidence.deleteTag(tag); m_elements.erase(tag); } // delete edge between the inputs and outputs with exact match // deletes the first match found (duplicates not deleted) template template T_Inputs, CForwardRangeOf T_Outputs, std::predicate T_Predicate> void Hypergraph::deleteElement(T_Inputs const& inputs, T_Outputs const& outputs, T_Predicate edgePredicate) { AssertFatal(!inputs.empty() && !outputs.empty()); clearCache(GraphModification::DeleteElement); for(auto input : inputs) { AssertFatal(getElementType(input) == ElementType::Node, "Requires node handles"); } for(auto output : outputs) { AssertFatal(getElementType(output) == ElementType::Node, "Requires node handles"); } auto match = false; for(auto e : getNeighbours(inputs[0])) { auto elem = getElement(e); if(!edgePredicate(std::get(elem))) continue; match = true; auto srcs = getNeighbours(e); if(srcs.size() != inputs.size()) { match = false; continue; } for(auto src : inputs) { if(std::find(srcs.begin(), srcs.end(), src) == srcs.end()) { match = false; break; } } if(match) { auto dsts = getNeighbours(e); if(dsts.size() != outputs.size()) { match = false; continue; } for(auto dst : outputs) { if(std::find(dsts.begin(), dsts.end(), dst) == dsts.end()) { match = false; break; } } } if(match) { deleteElement(e); return; } } AssertFatal(match, "edge to delete : match not found"); } template template T_Inputs, CForwardRangeOf T_Outputs> requires(std::constructible_from) void Hypergraph:: deleteElement(T_Inputs const& inputs, T_Outputs const& outputs) { return deleteElement( inputs, outputs, [](Edge const& edge) { return std::holds_alternative(edge); }); } template size_t Hypergraph::getElementCount() const { return m_elements.size(); } template auto Hypergraph::getElement(int tag) const -> Element const& { AssertFatal(m_elements.contains(tag), "Element not found", ShowValue(tag)); return m_elements.at(tag); } template template T Hypergraph::getNode(int tag) const { static_assert(std::constructible_from); auto const& node = std::get(getElement(tag)); if constexpr(std::same_as) { return node; } else { return std::get(node); } } template template T Hypergraph::getEdge(int tag) const { static_assert(std::constructible_from); auto const& edge = std::get(getElement(tag)); if constexpr(std::same_as) { return edge; } else { return std::get(edge); } } template auto Hypergraph::getLocation(int tag) const -> Location { return {.tag = tag, .incoming = m_incidence.getSrcs(tag), .outgoing = m_incidence.getDsts(tag), .element = getElement(tag)}; } template Generator Hypergraph::roots() const { for(auto const& pair : m_elements) { auto tag = pair.first; if(m_incidence.getSrcCount(tag) == 0) co_yield tag; } } template Generator Hypergraph::leaves() const { for(auto const& pair : m_elements) { auto tag = pair.first; if(m_incidence.getDstCount(tag) == 0) co_yield tag; } } template Generator Hypergraph::childNodes(int parent) const { if(getElementType(parent) == ElementType::Node) { std::set visited; for(auto edgeTag : getNeighbours(parent)) { for(auto neighbour : getNeighbours(edgeTag)) { if(!visited.contains(neighbour)) { visited.insert(neighbour); co_yield neighbour; } } } } else { for(auto child : getNeighbours(parent)) co_yield child; } } template Generator Hypergraph::parentNodes(int child) const { if(getElementType(child) == ElementType::Node) { std::set visited; for(auto edgeTag : getNeighbours(child)) { for(auto neighbour : getNeighbours(edgeTag)) { if(!visited.contains(neighbour)) { visited.insert(neighbour); co_yield neighbour; } } } } else { for(auto parent : getNeighbours(child)) co_yield parent; } } template template Range> Generator Hypergraph::depthFirstVisit(Range const& starts, Direction dir) const { std::unordered_set visitedNodes; if(dir == Direction::Downstream) { for(auto tag : starts) co_yield depthFirstVisit(tag, visitedNodes); } else { for(auto tag : starts) co_yield depthFirstVisit(tag, visitedNodes); } } template template Range, std::predicate Predicate> Generator Hypergraph::depthFirstVisit(Range const& starts, Predicate edgePredicate, Direction dir) const { std::unordered_set visitedNodes; if(dir == Direction::Downstream) { for(auto tag : starts) co_yield depthFirstVisit( tag, edgePredicate, visitedNodes); } else { for(auto tag : starts) co_yield depthFirstVisit(tag, edgePredicate, visitedNodes); } } template template Predicate> Generator Hypergraph::depthFirstVisit(int start, Predicate edgePredicate, Direction dir) const { std::unordered_set visitedNodes; if(dir == Direction::Downstream) { co_yield depthFirstVisit(start, edgePredicate, visitedNodes); } else { co_yield depthFirstVisit(start, edgePredicate, visitedNodes); } } template Generator Hypergraph::depthFirstVisit(int start, Direction dir) const { std::initializer_list starts{start}; co_yield depthFirstVisit(starts, dir); } template template Predicate> Generator Hypergraph::findNodes(int start, Predicate nodeSelector, Direction dir) const { co_yield filter(nodeSelector, depthFirstVisit(start, dir)); } template template Range, std::predicate Predicate> Generator Hypergraph::findNodes(Range const& starts, Predicate nodeSelector, Direction dir) const { co_yield filter(nodeSelector, depthFirstVisit(starts, dir)); } template Generator Hypergraph::allElements() const { for(auto const& pair : m_elements) co_yield pair.first; } template template Predicate> Generator Hypergraph::findElements(Predicate nodeSelector) const { co_yield filter(nodeSelector, allElements()); } template Generator Hypergraph::breadthFirstVisit(int start, Direction dir) const { std::unordered_set visitedNodes; visitedNodes.insert(start); co_yield start; // This is a pair for Downstream, or pair for Upstream std::deque> toExplore; std::set> noted; for(auto connected : dir == Direction::Downstream ? m_incidence.getDsts(start) : m_incidence.getSrcs(start)) { std::pair candidate = {start, connected}; toExplore.push_back(candidate); noted.insert(candidate); } while(!toExplore.empty()) { auto i = toExplore.front(); auto node = i.second; toExplore.pop_front(); if(visitedNodes.contains(node)) continue; visitedNodes.insert(node); co_yield node; for(auto connected : dir == Direction::Downstream ? m_incidence.getDsts(node) : m_incidence.getSrcs(node)) { std::pair candidate = {node, connected}; if(!noted.contains(candidate)) { toExplore.push_back(candidate); noted.insert(candidate); } } } co_return; } template template Generator Hypergraph::depthFirstVisit( int start, std::unordered_set& visitedNodes) const { if(visitedNodes.contains(start)) co_return; visitedNodes.insert(start); co_yield start; for(auto element : getNeighbours(start)) { co_yield depthFirstVisit(element, visitedNodes); } } template template Predicate> Generator Hypergraph::depthFirstVisit( int start, Predicate edgePredicate, std::unordered_set& visitedElements) const { if(visitedElements.contains(start)) co_return; visitedElements.insert(start); co_yield start; for(auto tag : getNeighbours(start)) { visitedElements.insert(tag); if(edgePredicate(tag)) { for(auto child : getNeighbours(tag)) { co_yield depthFirstVisit(child, edgePredicate, visitedElements); } } } } template template RangeStart, CForwardRangeOf RangeEnd> Generator Hypergraph::path(RangeStart const& starts, RangeEnd const& ends) const { auto truePred = [](int) { return true; }; co_yield path(starts, ends, truePred); } template template RangeStart, CForwardRangeOf RangeEnd, std::predicate Predicate> Generator Hypergraph::path(RangeStart const& starts, RangeEnd const& ends, Predicate edgeSelector) const { std::map visitedElements; co_yield path(starts, ends, edgeSelector, visitedElements); } template template RangeStart, CForwardRangeOf RangeEnd, std::predicate Predicate> Generator Hypergraph::path(RangeStart const& starts, RangeEnd const& ends, Predicate edgeSelector, std::map& visitedElements) const { constexpr Direction reverseDir = opposite(Dir); for(auto end : ends) { if(visitedElements.contains(end)) { continue; } if(std::count(starts.begin(), starts.end(), end) > 0) { visitedElements[end] = true; co_yield end; continue; } visitedElements[end] = false; std::vector results; for(auto nextElement : getNeighbours(end)) { if(getElementType(nextElement) == ElementType::Edge && !edgeSelector(nextElement)) { continue; } std::vector branchResults = path( starts, std::vector{nextElement}, edgeSelector, visitedElements) .template to(); results.insert(results.end(), branchResults.begin(), branchResults.end()); bool satisfied = (getElementType(end) != ElementType::Edge || edgeSatisfied(end, visitedElements)); visitedElements[end] = visitedElements[end] || visitedElements[nextElement]; visitedElements[end] = visitedElements[end] && satisfied; } if(visitedElements.at(end)) { for(auto const result : results) { co_yield result; } co_yield end; } } } template template bool Hypergraph::edgeSatisfied( int const edge, std::map const& visitedElements) const { for(auto element : getNeighbours(edge)) { auto iter = visitedElements.find(element); if(iter == visitedElements.end() || !iter->second) return false; } return true; } template template std::vector Hypergraph::getNeighbours(int const tag) const { AssertFatal(m_elements.contains(tag), "Graph tag not registered, element not in graph", ShowValue(tag)); if constexpr(Dir == Direction::Downstream) { return m_incidence.getDsts(tag); } else { return m_incidence.getSrcs(tag); } } template std::vector Hypergraph::getNeighbours(int const tag, Direction Dir) const { if(Dir == Direction::Downstream) { return getNeighbours(tag); } else { return getNeighbours(tag); } } template inline Generator Hypergraph::topologicalSort() const { auto start = roots().template to(); auto end = leaves().template to(); co_yield path(start, end); } template inline Generator Hypergraph::reverseTopologicalSort() const { auto start = roots().template to(); auto end = leaves().template to(); co_yield path(end, start); } template std::string Hypergraph::toDOT(std::string const& prefix, bool standalone) const { std::ostringstream msg; if(standalone) msg << "digraph {" << std::endl; for(auto const& pair : m_elements) { msg << '"' << prefix << pair.first << '"' << "[label=\""; if(getElementType(pair.second) == ElementType::Node) { auto x = std::get(pair.second); msg << toString(x) << "(" << pair.first << ")\""; } else { auto x = std::get(pair.second); msg << toString(x) << "(" << pair.first << ")\",shape=box"; } msg << "];" << std::endl; } msg << m_incidence.toDOTSection(prefix); // Enforce left-to-right ordering for elements connected to an edge. for(auto const& pair : m_elements) { if(getElementType(pair.second) != ElementType::Edge) continue; auto const& loc = getLocation(pair.first); if(loc.incoming.size() > 1) { msg << "{\nrank=same\n"; bool first = true; for(auto idx : loc.incoming) { if(!first) msg << "->"; msg << '"' << prefix << idx << '"'; first = false; } msg << "[style=invis]\nrankdir=LR\n}\n"; } if(loc.outgoing.size() > 1) { msg << "{\nrank=same\n"; bool first = true; for(auto idx : loc.outgoing) { if(!first) msg << "->"; msg << '"' << prefix << idx << '"'; first = false; } msg << "[style=invis]\nrankdir=LR\n}\n"; } } if(standalone) msg << "}" << std::endl; return msg.str(); } template template Predicate> std::string Hypergraph::toDOT(Predicate edgePredicate) const { std::ostringstream msg; std::string const prefix = ""; msg << "digraph {" << std::endl; for(auto const& pair : m_elements) { if(getElementType(pair.second) == ElementType::Node) { auto x = std::get(pair.second); msg << '"' << prefix << pair.first << '"' << "[label=\""; msg << toString(x) << "(" << pair.first << ")\""; msg << "];" << std::endl; } else { auto x = std::get(pair.second); if(edgePredicate(x)) { msg << '"' << prefix << pair.first << '"' << "[label=\""; msg << toString(x) << "(" << pair.first << ")\",shape=box"; msg << "];" << std::endl; } } } for(auto const& pair : m_elements) { if(getElementType(pair.second) == ElementType::Edge) { auto x = std::get(pair.second); if(edgePredicate(x)) { for(auto y : getNeighbours(pair.first)) { msg << '"' << prefix << y << "\" -> \"" << prefix << pair.first << '"' << std::endl; } for(auto y : getNeighbours(pair.first)) { msg << '"' << prefix << pair.first << "\" -> \"" << prefix << y << '"' << std::endl; } } } } msg << "}" << std::endl; return msg.str(); } template template requires(std::constructible_from || std::constructible_from) Generator Hypergraph::getElements() const { for(auto const& elem : m_elements) { if constexpr(std::same_as || std::same_as) { if(std::holds_alternative(elem.second)) co_yield elem.first; } else if constexpr(std::constructible_from) { if(std::holds_alternative(elem.second)) { auto const& node = std::get(elem.second); if(std::holds_alternative(node)) co_yield elem.first; } } else if constexpr(std::constructible_from) { if(std::holds_alternative(elem.second)) { auto const& edge = std::get(elem.second); if(std::holds_alternative(edge)) co_yield elem.first; } } } } template template requires(std::constructible_from) Generator Hypergraph::getNodes() const { co_yield getElements(); } template template requires(std::constructible_from) Generator Hypergraph::getEdges() const { co_yield getElements(); } template template requires(std::constructible_from) Generator Hypergraph::getConnectedNodeIndices(int const dst) const { if constexpr(std::same_as) { auto truePredicate = [](auto const&) { return true; }; co_yield getConnectedNodeIndices(dst, truePredicate); } else { auto edgePredicate = [](Edge const& edge) { return std::holds_alternative(edge); }; co_yield getConnectedNodeIndices(dst, edgePredicate); } } template template Predicate> Generator Hypergraph::getConnectedNodeIndices(int const dst, Predicate edgePredicate) const { AssertFatal(getElementType(dst) == ElementType::Node, "Require a node handle"); for(auto elem : getNeighbours(dst)) { if(edgePredicate(std::get(getElement(elem)))) { for(auto tag : getNeighbours(elem)) co_yield tag; } } } template template requires(std::constructible_from) Generator Hypergraph::getInputNodeIndices(int const dst) const { co_yield getConnectedNodeIndices(dst); } template template Predicate> Generator Hypergraph::getInputNodeIndices(int const dst, Predicate edgePredicate) const { co_yield getConnectedNodeIndices(dst, edgePredicate); } template template requires(std::constructible_from) Generator Hypergraph::getOutputNodeIndices(int const src) const { co_yield getConnectedNodeIndices(src); } template template Predicate> Generator Hypergraph::getOutputNodeIndices(int const src, Predicate edgePredicate) const { co_yield getConnectedNodeIndices(src, edgePredicate); } template template requires(std::constructible_from) std::set Hypergraph::followEdges( std::set const& candidates) const { // Nodes to be analyzed std::set currentNodes = candidates; // Full set of connected nodes to be returned std::set connectedNodes = candidates; auto numCandidates = connectedNodes.size(); do { // Nodes which are found by this sweep std::set foundNodes; numCandidates = connectedNodes.size(); for(auto tag : currentNodes) { auto outTags = getOutputNodeIndices(tag); foundNodes.insert(outTags.begin(), outTags.end()); } connectedNodes.insert(foundNodes.begin(), foundNodes.end()); currentNodes = std::move(foundNodes); } while(numCandidates != connectedNodes.size()); return connectedNodes; } template inline std::ostream& operator<<(std::ostream& stream, Hypergraph const& graph) { return stream << graph.toDOT(); } template Generator reachableNodes(Graph::Hypergraph const& graph, int start, auto nodePredicate, auto edgePredicate, auto destNodePredicate) { for(auto nextNode : graph.template getConnectedNodeIndices(start, edgePredicate)) { auto const& node = graph.getNode(nextNode); if(destNodePredicate(node)) co_yield nextNode; if(nodePredicate(node)) co_yield reachableNodes( graph, nextNode, nodePredicate, edgePredicate, destNodePredicate); } } template std::optional Hypergraph::findEdge(int tail, int head) const { static_assert(!Hyper, "findEdge not supported for hypergraphs."); AssertFatal(m_elements.contains(tail) && m_elements.contains(head), "Graph tags not registered, elements not in graph", ShowValue(tail), ShowValue(head)); auto dsts = m_incidence.getDsts(tail); for(auto src : m_incidence.getSrcs(head)) { auto rv = std::find(dsts.begin(), dsts.end(), src); if(rv != dsts.end()) return *rv; } return std::nullopt; } template int Hypergraph::nextAvailableTag() const { if(m_elements.empty()) return 1; return m_elements.rbegin()->first + 1; } } }