import torch
from pymonntorch.NetworkCore.Base import TaggableObject
from pymonntorch.utils import is_number
[docs]class Behavior(TaggableObject):
"""Base class for behaviors. All behaviors all `TaggableObject`s.
Attributes:
tag (str): Tag of the behavior.
device (str): Device of the behavior. This is overwritten by object's device upon calling `initialize`.
behavior_enabled (bool): Whether the behavior is enabled. The default is True.
init_kwargs (dict): Dictionary of the keyword arguments passed to the constructor.
used_attr_keys (list): List of the name of the attributes that have been used in the `initialize` method.
"""
initialize_on_init = False
initialize_last = False
def __init__(self, *args, **kwargs):
"""Constructor of the `Behavior` class.
Args:
**kwargs: Keyword arguments passed to the constructor.
"""
self.init_kwargs = kwargs
for i, arg in enumerate(args):
self.init_kwargs["arg_" + str(i)] = arg
self.used_attr_keys = []
self.behavior_enabled = self.parameter("behavior_enabled", True, None)
super().__init__(
tag=self.parameter("tag", None, None),
device=self.parameter("device", None, None),
)
self.empty_iteration_function = self.is_empty_iteration_function()
self.used_attr_keys = torch.nn.ParameterList(self.used_attr_keys)
[docs] def initialize(self, object):
"""Sets the variables of the object. This method is called by the `Network` class when the object is added to the network.
**Note:** All sub-classes of `Behavior` overriding this method should call the super method to ensure everything is placed on the correct device.
Args:
object (TaggableObject): Object possessing the behavior.
"""
self.device = object.device
return
[docs] def forward(self, object):
"""Forward pass of the behavior. This method is called by the `Network` class per simulation iteration.
Args:
object (TaggableObject): Object possessing the behavior.
"""
pass
def __repr__(self):
result = self.__class__.__name__ + "("
for k in self.init_kwargs:
result += str(k) + "=" + str(self.init_kwargs[k]) + ","
result += ")"
return result
[docs] def evaluate_diversity_string(self, ds, object):
"""Evaluates the diversity string describing tensors of an object.
Args:
ds (str): Diversity string describing the tensors of the object.
object (NetworkObject): The object possessing the behavior.
Returns:
torch.tensor: The resulting tensor.
"""
if "same(" in ds and ds[-1] == ")":
params = ds[5:-1].replace(" ", "").split(",")
if len(params) == 2:
return getattr(object[params[0], 0], params[1])
plot = False
if ";plot" in ds:
ds = ds.replace(";plot", "")
plot = True
result = ds
if "(" in ds and ")" in ds: # is function
if type(object).__name__ == "NeuronGroup":
result = object.vector(ds)
if type(object).__name__ == "SynapseGroup":
result = object.matrix(ds)
if plot:
if type(result) == torch.tensor:
import matplotlib.pyplot as plt
plt.hist(result.to("cpu"), bins=30)
plt.show()
return result
[docs] def set_parameters_as_variables(self, object):
"""Set the variables defined in the init of behavior as the variables of the object.
Args:
object (NetworkObject): The object possessing the behavior.
"""
for key in self.init_kwargs:
setattr(object, key, self.parameter(key, None, object=object))
print("init", key)
[docs] def check_unused_attrs(self):
"""Checks whether all attributes have been used in the `initialize` method."""
for key in self.init_kwargs:
if key not in self.used_attr_keys:
print(
'Warning: "'
+ key
+ '" not used in initialize of '
+ str(self)
+ ' behavior! Make sure that "'
+ key
+ '" is spelled correctly and parameter('
+ key
+ ",...) is called in initialize. Valid attributes are: "
+ ", ".join([f'"{param}"' for param in list(self.used_attr_keys)])
+ "."
)
[docs] def parameter(
self,
key,
default=None,
object=None,
do_not_diversify=False,
search_other_behaviors=False,
tensor=False,
required=False,
):
"""Gets the value of an attribute.
Args:
key (str): Name of the attribute.
default (any): Default value of the attribute.
object (NetworkObject): The object possessing the behavior.
do_not_diversify (bool): Whether to diversify the attribute. The default is False.
search_other_behaviors (bool): Whether to search for the attribute in other behaviors of the object. The default is False.
tensor (bool): Whether to make a tensor out of value. Suitable for list and numbers.
required (bool): Whether the attribute is required. The default is False.
Returns:
any: The value of the attribute.
"""
if required and self.init_kwargs.get(key, None) is None:
print(
"Warning:",
key,
"has to be specified for the behavior with a non None value to run properly.",
self,
)
self.used_attr_keys.append(key)
result = self.init_kwargs.get(key, default)
result = default if result is None else result
if (
key not in self.init_kwargs
and object is not None
and search_other_behaviors
):
for b in object.behaviors:
if key in b.init_kwargs:
result = b.init_kwargs.get(key, result)
if not do_not_diversify and type(result) is str and object is not None:
result = self.evaluate_diversity_string(result, object)
if type(result) is str and default is not None:
if "%" in result and is_number(result.replace("%", "")):
result = str(float(result.replace("%", "")) / 100.0)
result = type(default)(result)
if tensor and result is not None:
if object is None:
raise RuntimeError(
f'To turn parameter value of key "{key}" to a tensor, object should not be None.'
)
result = torch.tensor(result, device=object.device)
if result.is_floating_point():
result = result.to(dtype=object.def_dtype)
return result
[docs] def is_empty_iteration_function(self):
"""Checks whether a function does anything or not.
used to stop calling behaviors with empty forward method.
"""
f = self.forward
# Returns true if f is an empty function.
def empty_func():
pass
def empty_func_with_docstring():
"""Empty function with docstring."""
pass
def constants(f):
# Return a tuple containing all the constants of a function without: * docstring
return tuple(x for x in f.__code__.co_consts if x != f.__doc__)
return (
f.__code__.co_code == empty_func.__code__.co_code
and constants(f) == constants(empty_func)
) or (
f.__code__.co_code == empty_func_with_docstring.__code__.co_code
and constants(f) == constants(empty_func_with_docstring)
)