DistNeighborSampler

class DistNeighborSampler(fanouts: List[int], prob: str | None = None, replace: bool | None = False, copy_edata: bool | None = True, input_node_features: Dict[str, Tensor] | None = None, output_node_features: Dict[str, Tensor] | None = None, output_device: device | None = None)

A neighbor sampler that does multi-layer sampling on a distributed graph

Parameters:
  • fanouts (List[int]) – A list of node fanouts where the ith entry is the sampling fanout for nodes at layer i. The length of the list should match the number of layers in the GNN model

  • prob (Optional[str]) – Feature name used as the (unnormalized) probabilities associated with each neighboring edge of a node. The feature must have only one element for each edge

  • replace (bool) – If True, sample with replacement.

  • copy_edata (bool) – If True, the edge features of the new graph are copied from the original graph. If False, the new graph will not have any edge features.

  • input_node_features (Optional[Dict[str, Tensor]]) – An optional dictionary of node features that should be added to the srcdata of the sampled block closest to the input. Each feature tensor’s first dimension must be the number of nodes in the local partition. If not specified, the sampled blocks will not have any input features

  • output_node_features (Optional[Dict[str, Tensor]]) – An optional dictionary of node features that should be added to the dstdata of the top sampled block. Each feature tensor’s first dimension must be the number of nodes in the local partition. In a node classification setting, this is typically the node labels. If not specified, the sampled blocks will not have any output features

  • output_device – The output device

sample(full_graph_manager: GraphShardManager, seeds: Tensor) List[DGLBlock]

Distributed sampling

Parameters:
  • full_graph_manager (GraphShardManager) – The distributed graph from which to sample

  • seeds (Tensor) – The seed nodes for sampling

Returns:

A list of DGLBlock objects with the same length as fanouts

Return type:

List[DGLBlock]