message_has_parameters

class message_has_parameters(param_foo: Callable[[Any], Tuple[Tensor, ...]])

A decorator for message functions that use learnable parameters.

You must use this decorator to tell SAR about the parameters that the message function is using to ensure these parameters get the correct gradients. The decorator has one parameter which is a callable returning a tuple containing the parameters of the message function. If the message function is an instance method, the callable will receive the instance as its first argument, otherwise it receives None. Example:

from torch import nn
from sar import message_has_parameters
import dgl.function as fn  # type: ignore


class ParameterizedAggregation(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.message_transformation = nn.Linear(dim, dim)

    @message_has_parameters(lambda self: tuple(self.message_transformation.parameters()))
    def message(self, edges):
        m = self.message_transformation(edges.src['h'])
        return {'m': m}

    def forward(grap, features):
        with graph.local_scope():
            graph.srcdata['h'] = features
            graph.update_all(self.message, fn.sum('m', 'result'))
            result = graph.dstdata['result']
        return result
Parameters:

param_foo (Callable[[Any], Tuple[Tensor, ...]]) – A callable returning a Tuple of the parameters used by the message function

Returns:

The decorated message function