construct_mfgs

class construct_mfgs(partition_data: PartitionData, seed_nodes: Tensor, n_layers: int, keep_seed_nodes: bool = True)

Constructs a list of GraphShardManager objects (one for each GNN layer) to compute only the node features needed for producing the output features for the seed_nodes at the top layer. This is analoguous to the Message Flow Graphs (MFG) created by DGL’s sampling DataLoaders. MFGs are particularly useful in node classification tasks where they can avoid a large amount of redundant computations compared to using the full graph.

Parameters:
  • partition_data (PartitionData) – The local partition data

  • seed_nodes (Tensor) – The global indices of the graph nodes whose features need to be computed at the top layer. Typically, these are the labeled nodes in a node classification task.

  • n_layers (int) – The number of layers in the GNN

  • keep_seed_nodes (bool) – Keep the seed nodes as part of the source nodes. Default: True

Returns:

A list of GraphShardManager objects, one for each layer