class message_has_parameters(param_foo: Callable[[Any], Tuple[Tensor, ...]])

A decorator for message functions that use learnable parameters.

You must use this decorator to tell SAR about the parameters that the message function is using to ensure these parameters get the correct gradients. The decorator has one parameter which is a callable returning a tuple containing the parameters of the message function. If the message function is an instance method, the callable will receive the instance as its first argument, otherwise it receives None. Example:

from torch import nn
from sar import message_has_parameters
import dgl.function as fn  # type: ignore

class ParameterizedAggregation(nn.Module):
    def __init__(self, dim):
        self.message_transformation = nn.Linear(dim, dim)

    @message_has_parameters(lambda self: tuple(self.message_transformation.parameters()))
    def message(self, edges):
        m = self.message_transformation(edges.src['h'])
        return {'m': m}

    def forward(grap, features):
        with graph.local_scope():
            graph.srcdata['h'] = features
            graph.update_all(self.message, fn.sum('m', 'result'))
            result = graph.dstdata['result']
        return result

param_foo (Callable[[Any], Tuple[Tensor, ...]]) – A callable returning a Tuple of the parameters used by the message function


The decorated message function