Source code for pymonntorch.NetworkBehavior.Basics.Normalization
import torch
from pymonntorch.NetworkCore.Behavior import Behavior
from pymonntorch.utils import check_is_torch_tensor
[docs]class SynapticNormalization(Behavior):
def __init__(self, *args, synapse_type="glutamate", norm_factor=1.0, **kwargs):
super().__init__(
*args, synapse_type=synapse_type, norm_factor=norm_factor, **kwargs
)
[docs] def initialize(self, neurons):
super().initialize(neurons)
self.synapse_type = self.parameter("synapse_type", "glutamate", neurons)
neurons.require_synapses(self.synapse_type)
self.norm_factor = check_is_torch_tensor(
self.parameter("norm_factor", 1.0, neurons),
device=neurons.device,
dtype=neurons.def_dtype,
)
neurons.sum_w = neurons.vector(kwargs={"dtype": neurons.def_dtype})
[docs] def forward(self, neurons):
neurons.sum_w.zero_()
for syn in neurons.afferent_synapses[self.synapse_type]:
syn.dst.sum_w.add_(syn.w.sum(dim=1))
neurons.sum_w.div_(self.norm_factor)
for syn in neurons.afferent_synapses[self.synapse_type]:
syn.w.T.div_(syn.dst.sum_w)