GraphShardManager
- class GraphShardManager(graph_shards: List[GraphShard], local_src_seeds: Tensor, local_tgt_seeds: Tensor)
Manages the local graph partition and exposes a subset of the interface of dgl.heterograph.DGLGraph. Most importantly, it implements a distributed version of the
update_all
andapply_edges
functions which are extensively used by GNN layers to exchange messages. By default, bothupdate_all
andapply_edges
use sequential aggregation and rematerialization (SAR) to minimize data duplication across the workers. In some cases, this might introduce extra communication and computation overhead. SAR can be disabled by settingConfig.disable_sr
to False to avoid this overhead. Memory consumption may jump up siginifcantly, however. You should not construct GraphShardManager directly, but should usesar.construct_mfgs()
andsar.construct_full_graph()
instead.- Parameters:
graph_shards (List[GraphShard]) – List of N graph shards where N is the number of partitions/workers. graph_shards[i] contains information for edges originating from partition i
local_src_seeds (torch.Tensor) – The indices of the input nodes relative to the starting node index of the local partition The input nodes are the nodes needed to produce the output node features assuming one-hop aggregation
local_tgt_seeds (torch.Tensor) – The node indices of the output nodes relative to the starting node index of the local partition
- get_full_partition_graph(delete_shard_data: bool = True) DistributedBlock
Returns a graph representing all the edges incoming to this partition. The
update_all
andapply_edges
functions in this graph will execute one-shot communication and aggregation in the forward and backward passes.- Parameters:
delete_shard_data (bool) – Delete shard information. Remove the graph data in the GraphShardManager object. You almost always want this as you will not be using the GraphShardManager object after obtaining the full partition graph
- Returns:
A graph-like object representing all the incoming edges to nodes in the local partition
- property sampling_graph
Returns a non-compacted graph for sampling. The node ids in the returned graph are the same as their global ids. Message passing on the sampling_graph will be very memory-inefficient as the node feature tensor will have to include a feature vector for each node in the whole graph