DistNeighborSampler
- class DistNeighborSampler(fanouts: List[int], prob: str | None = None, replace: bool | None = False, copy_edata: bool | None = True, input_node_features: Dict[str, Tensor] | None = None, output_node_features: Dict[str, Tensor] | None = None, output_device: device | None = None)
A neighbor sampler that does multi-layer sampling on a distributed graph
- Parameters:
fanouts (List[int]) – A list of node fanouts where the ith entry is the sampling fanout for nodes at layer i. The length of the list should match the number of layers in the GNN model
prob (Optional[str]) – Feature name used as the (unnormalized) probabilities associated with each neighboring edge of a node. The feature must have only one element for each edge
replace (bool) – If True, sample with replacement.
copy_edata (bool) – If True, the edge features of the new graph are copied from the original graph. If False, the new graph will not have any edge features.
input_node_features (Optional[Dict[str, Tensor]]) – An optional dictionary of node features that should be added to the
srcdata
of the sampled block closest to the input. Each feature tensor’s first dimension must be the number of nodes in the local partition. If not specified, the sampled blocks will not have any input featuresoutput_node_features (Optional[Dict[str, Tensor]]) – An optional dictionary of node features that should be added to the
dstdata
of the top sampled block. Each feature tensor’s first dimension must be the number of nodes in the local partition. In a node classification setting, this is typically the node labels. If not specified, the sampled blocks will not have any output featuresoutput_device – The output device
- sample(full_graph_manager: GraphShardManager, seeds: Tensor) List[DGLBlock]
Distributed sampling
- Parameters:
full_graph_manager (GraphShardManager) – The distributed graph from which to sample
seeds (Tensor) – The seed nodes for sampling
- Returns:
A list of
DGLBlock
objects with the same length asfanouts
- Return type:
List[DGLBlock]