/******************************************************************************* * * MIT License * * Copyright 2024-2025 AMD ROCm(TM) Software * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * *******************************************************************************/ #pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include namespace rocRoller { namespace Graph { enum class ElementType : int { Node = 0, Edge, Count }; std::string toString(ElementType e); /** * @brief Returns the complementary type to `t`. */ ElementType getConnectingType(ElementType t); enum class Direction : int { Upstream = 0, Downstream, Count }; std::string toString(Direction d); constexpr Direction opposite(Direction d); enum class GraphModification : int { DeleteElement = 0, AddElement, SetElement, Count }; std::string toString(GraphModification m); /** * @brief A directed Hypergraph. * * @tparam Node What objects represent the nodes in the graph? Typically a std::variant * @tparam Edge What objects represent the edges in the graph? Typically a std::variant * @tparam Hyper A Hypergraph (true) is a graph where an edge can connect to any number of * incoming and outgoing nodes. A "calm" graph (false) is an ordinary graph where an edge * connects to at most one incoming node and at most one outgoing node. */ template class Hypergraph { public: using Element = std::variant; inline static constexpr bool IsHyper = Hyper; static std::string ElementName(Element const& el); struct Location { int tag; std::vector incoming; std::vector outgoing; Element element; constexpr inline bool operator==(Location const& rhs) const; }; bool exists(int tag) const; /** * @brief Returns whether `tag` points to a node or an edge. */ ElementType getElementType(int tag) const; template T getNode(int tag) const; template T getEdge(int tag) const; /** * @brief Returns whether `e` is a node or an edge. */ ElementType getElementType(Element const& e) const; virtual void clearCache(GraphModification modification); template int addElement(T&& element); /** * Adds each element to the graph and returns a tuple of int. */ template auto addElements(T&& element, Ts&&... rest); /** * @brief Set (overwrite) existing element. * * Asserts that the tag exists already. */ template void setElement(int tag, T&& element); template int addElement(T&& element, std::initializer_list inputs, std::initializer_list outputs); template T_Inputs, CForwardRangeOf T_Outputs> int addElement(T&& element, T_Inputs const& inputs, T_Outputs const& outputs); template T_Inputs, CForwardRangeOf T_Outputs> void addElement(int tag, T&& element, T_Inputs const& inputs, T_Outputs const& outputs); void deleteElement(int tag); template T_Inputs, CForwardRangeOf T_Outputs, std::predicate T_Predicate> void deleteElement(T_Inputs const& inputs, T_Outputs const& outputs, T_Predicate edgePredicate); template T_Inputs, CForwardRangeOf T_Outputs> requires(std::constructible_from) void deleteElement(T_Inputs const& inputs, T_Outputs const& outputs); size_t getElementCount() const; Element const& getElement(int tag) const; /** * @brief Returns a Location info object detailing connections to the element `tag`. */ Location getLocation(int tag) const; /** * @brief Yields element tag without any incoming connections. */ Generator roots() const; /** * @brief Yields element tag without any outgoing connections. */ Generator leaves() const; /** * @brief Yields element tag that are the child nodes of a given element */ Generator childNodes(int parent) const; /** * @brief Yields element tag that are the parent nodes of a given element */ Generator parentNodes(int child) const; Generator allElements() const; /** * @brief Yields element tags connected in the specified direction to start, in * depth-first order */ Generator depthFirstVisit(int start, Direction dir = Direction::Downstream) const; /** * @brief Yields node tag connected in the specified direction to start, that satisfy * the node selector. */ template Predicate> Generator findNodes(int start, Predicate nodeSelector, Direction dir = Direction::Downstream) const; /** * @brief Yields node tag connected in the specified direction to start, that satisfy * the node selector. */ template Range, std::predicate Predicate> Generator findNodes(Range const& starts, Predicate nodeSelector, Direction dir = Direction::Downstream) const; /** * @brief Yields node tag that satisfy the node selector. */ template Predicate> Generator findElements(Predicate nodeSelector) const; /** * @brief Yields element tags connected in the specified direction to starts, in * depth-first order */ template Range> Generator depthFirstVisit(Range const& starts, Direction dir = Direction::Downstream) const; /** * @brief Yields element tags connected in the specified direction to starts, in * depth-first order. * * Will only visit through edges if the edgePredicate returns true. */ template Range, std::predicate Predicate> Generator depthFirstVisit(Range const& starts, Predicate edgePredicate, Direction dir = Direction::Downstream) const; /** * @brief Yields element tags connected in the specified direction to the start, in * depth-first order. * * Will only visit through edges if the edgePredicate returns true. */ template Predicate> Generator depthFirstVisit(int start, Predicate edgePredicate, Direction dir) const; /** * @brief Yields element tags connected in the specified direction to start, in * depth-first order. * * Will not yield any nodes in `visitedNodes`, and will insert nodes `visitedNodes` to * track already visited nodes. */ template Generator depthFirstVisit(int start, std::unordered_set& visitedNodes) const; template Predicate> Generator depthFirstVisit(int start, Predicate edgePredicate, std::unordered_set& visitedNodes) const; /** * @brief Yields element tags connected in the specified direction of start, in * breadth-first order. */ Generator breadthFirstVisit(int start, Direction dir = Direction::Downstream) const; /** * @brief Yields element tag (both nodes and edges) that form the paths * from the starts to the ends */ template RangeStart, CForwardRangeOf RangeEnd, std::predicate Predicate> Generator path(RangeStart const& starts, RangeEnd const& ends, Predicate edgeSelector, std::map& visitedElements) const; /** * @brief Yields element tag (both nodes and edges) that form the paths * from the starts to the ends */ template RangeStart, CForwardRangeOf RangeEnd, std::predicate Predicate> Generator path(RangeStart const& starts, RangeEnd const& ends, Predicate edgeSelector) const; /** * @brief Yields element tag (both nodes and edges) that form the paths * from the starts to the ends */ template RangeStart, CForwardRangeOf RangeEnd> Generator path(RangeStart const& starts, RangeEnd const& ends) const; /** * @brief Yields element tags of immediately connecting elements in the direction * specified */ template std::vector getNeighbours(int const tag) const; /** * @brief Yields element tags of immediately connecting elements in the direction * specified */ std::vector getNeighbours(int const tag, Direction Dir) const; /** * @brief Return edges in topological order. * * Traversing edges in topological order preserves edge * dependencies. */ Generator topologicalSort() const; /** * @brief Return edges in reverse topological order. * * Traversing edges in reverse topological order can preserves edge * dependencies. */ Generator reverseTopologicalSort() const; std::string toDOT(std::string const& prefix = "", bool standalone = true) const; static bool identity(Edge const&) { return true; } template Predicate> std::string toDOT(Predicate edgePredicate = identity) const; /** * @brief Yields all element tags that match the supplied element type * * @tparam T Type of element to filter by. Can be Node, Edge, or a subvariant of those. */ template requires(std::constructible_from || std::constructible_from) Generator getElements() const; /** * @brief Yields tags of all Nodes * * @tparam T Subvariant of Node to filter by. Returns all Node tags by default. */ template requires(std::constructible_from) Generator getNodes() const; /** * @brief Yields tags of all Edges * * @tparam T Subvariant of Edge to filter by. Returns all Edge tags by default. */ template requires(std::constructible_from) Generator getEdges() const; /** * @brief Yields tags of nodes immediately connected to `dst` through Edges of type T, * in direction Dir. */ template requires(std::constructible_from) Generator getConnectedNodeIndices( int const dst) const; /** * @brief Yields tags of nodes immediately connected to `dst` through Edges that satisfy * the edgePredicate, in direction Dir. */ template Predicate> Generator getConnectedNodeIndices(int const dst, Predicate edgePredicate) const; /** * @brief Yields tags of nodes that immediately preceed `dst` where the Edges are of * type T. */ template requires(std::constructible_from) Generator getInputNodeIndices( int const dst) const; /** * @brief Yields tags of nodes that immediately preceed `dst` where the Edges satisfy * the edgePredicate. */ template Predicate> Generator getInputNodeIndices(int const dst, Predicate edgePredicate) const; Generator> getInputNodesAndEdges(int dst); /** * @brief Yields tags of nodes that immediately follow `src` where the Edges are of type * T. */ template requires(std::constructible_from) Generator getOutputNodeIndices( int const src) const; /** * @brief Yields tags of nodes that immediately follow `src` where the Edges satisfy the * edgePredicate. */ template Predicate> Generator getOutputNodeIndices(int const src, Predicate edgePredicate) const; /** * @brief Finds all downstream node tags that are connected to `candidates` via the * specified Edge type. The set of original candidates is included in the returned set. * * Note that this function recursively follows edges. * * @param candidates Set of Node tags * @return std::set Set of Node tags expanded */ template requires(std::constructible_from) std::set followEdges( std::set const& candidates) const; /** * @brief Find an edge that connects two nodes, if one exists. * * @param tail Tag of Node that is downstream of Edge * @param head Tag of Node that is upstream of Edge * @return std::optional Tag of connecting Edge, if one exists. */ std::optional findEdge(int tail, int head) const; /** * This function is called to disallow certain operations * (add/delete certain elements). */ virtual bool isModificationAllowed(int) const { return true; } // clang-format off private: // clang-format on template friend struct rocRoller::Serialization::MappingTraits; /** * @brief Map of Element tags and Elements * */ std::map m_elements; /** * @brief Container holding the incident connections between Hypergraph Elements * */ HypergraphIncidenceContainer m_incidence; template bool edgeSatisfied(int const edge, std::map const& visitedElements) const; /** * @brief Gets a tag that hasn't been used in the Hypergraph for a new Element * * @return int Next available tag that can be used to store an Element */ int nextAvailableTag() const; }; template std::ostream& operator<<(std::ostream& stream, Hypergraph const& graph); template std::string variantToString(Cls const& el) { return std::visit([](auto const& v) { return toString(v); }, el); } /** * Yields nodes connected to start: * * - In direction `Dir` * - Connected to `start` by edges satisfying `edgePredicate` and nodes satisfying * `nodePredicate` * - The nodes that are yielded must satisfy destNodePredicate. * * @param graph A graph * @param start The starting node * @param nodePredicate Only traverse nodes that satisfy this predicate. * @param edgePredicate Only traverse edges that satisfy this predicate. * @param destNodePredicate Only yield nodes that satisfy this predicate. */ template Generator reachableNodes(Graph::Hypergraph const& graph, int start, auto nodePredicate, auto edgePredicate, auto destNodePredicate); /** * @brief A Hypergraph is a graph where an edge can connect to any number of incoming and * outgoing nodes. A "calm" graph is an ordinary graph where an edge connects to exactly one * incoming node and exactly one outgoing node. * * @tparam T The graph to test */ template concept CCalmGraph = !T::IsHyper; } } #include