class DistributedBN1D(n_feats: int, eps: float = 1e-05, affine: bool = True, distributed: bool | None = None)

Distributed Batch normalization layer

Normalizes a 2D feature tensor using the global mean and standard deviation calculated across all workers.

  • n_feats (int) – The second dimension (feature dimension) in the 2D input tensor

  • eps (float) – a value added to the variance for numerical stability

  • affine (bool) – When True, the module will use learnable affine parameter

  • distributed (Optional[bool]) – Boolean speficying whether to run in distributed mode where normalizing statistics are calculated across all workers, or local mode where the normalizing statistics are calculated using only the local input feature tensor. If not specified, it will be set to True if the user has called sar.initialize_comms(), and False otherwise


forward implementation of DistributedBN1D