GraphShardManager

class GraphShardManager(graph_shards: List[GraphShard], local_src_seeds: Tensor, local_tgt_seeds: Tensor)

Manages the local graph partition and exposes a subset of the interface of dgl.heterograph.DGLGraph. Most importantly, it implements a distributed version of the update_all and apply_edges functions which are extensively used by GNN layers to exchange messages. By default, both update_all and apply_edges use sequential aggregation and rematerialization (SAR) to minimize data duplication across the workers. In some cases, this might introduce extra communication and computation overhead. SAR can be disabled by setting Config.disable_sr to False to avoid this overhead. Memory consumption may jump up siginifcantly, however. You should not construct GraphShardManager directly, but should use sar.construct_mfgs() and sar.construct_full_graph() instead.

Parameters:
  • graph_shards (List[GraphShard]) – List of N graph shards where N is the number of partitions/workers. graph_shards[i] contains information for edges originating from partition i

  • local_src_seeds (torch.Tensor) – The indices of the input nodes relative to the starting node index of the local partition The input nodes are the nodes needed to produce the output node features assuming one-hop aggregation

  • local_tgt_seeds (torch.Tensor) – The node indices of the output nodes relative to the starting node index of the local partition

get_full_partition_graph(delete_shard_data: bool = True) DistributedBlock

Returns a graph representing all the edges incoming to this partition. The update_all and apply_edges functions in this graph will execute one-shot communication and aggregation in the forward and backward passes.

Parameters:

delete_shard_data (bool) – Delete shard information. Remove the graph data in the GraphShardManager object. You almost always want this as you will not be using the GraphShardManager object after obtaining the full partition graph

Returns:

A graph-like object representing all the incoming edges to nodes in the local partition

property sampling_graph

Returns a non-compacted graph for sampling. The node ids in the returned graph are the same as their global ids. Message passing on the sampling_graph will be very memory-inefficient as the node feature tensor will have to include a feature vector for each node in the whole graph