diff --git a/bindsnet/conversion/conversion.py b/bindsnet/conversion/conversion.py index bf9b367d..dba1b37e 100644 --- a/bindsnet/conversion/conversion.py +++ b/bindsnet/conversion/conversion.py @@ -121,10 +121,10 @@ def set_requires_grad(module, value): extractor2 = FeatureExtractor(module) all_activations2 = extractor2.forward(data) - for name2, module2 in module.named_children(): + for name2, module in module.named_children(): activations = all_activations2[name2] - if isinstance(module2, nn.ReLU): + if isinstance(module, nn.ReLU): if prev_module is not None: scale_factor = np.percentile(activations.cpu(), percentile) @@ -136,7 +136,7 @@ def set_requires_grad(module, value): elif isinstance(module2, nn.Linear) or isinstance(module2, nn.Conv2d): prev_module = module2 - if isinstance(module2, nn.Linear): + if isinstance(module, nn.Linear): if prev_module is not None: scale_factor = np.percentile(activations.cpu(), percentile) prev_module.weight *= prev_factor / scale_factor diff --git a/bindsnet/datasets/collate.py b/bindsnet/datasets/collate.py index 17079730..8aea85f0 100644 --- a/bindsnet/datasets/collate.py +++ b/bindsnet/datasets/collate.py @@ -8,7 +8,8 @@ """ import torch -from torch._six import container_abcs, string_classes, int_classes +from torch._six import string_classes +import collections from torch.utils.data._utils import collate as pytorch_collate diff --git a/bindsnet/pipeline/base_pipeline.py b/bindsnet/pipeline/base_pipeline.py index 5cbb3dda..fda2ad0f 100644 --- a/bindsnet/pipeline/base_pipeline.py +++ b/bindsnet/pipeline/base_pipeline.py @@ -2,7 +2,8 @@ from typing import Tuple, Dict, Any import torch -from torch._six import container_abcs, string_classes +from torch._six import string_classes +import collections from ..network import Network from ..network.monitors import Monitor diff --git a/requirements.txt b/requirements.txt index 63296a19..14fe2432 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ foolbox -scipy>=1.1.0 -numpy>=1.14.2 -cython>=0.28.5 -torch==1.8.1 -torchvision==0.9.1 +scipy>=1.5.4 +numpy>=1.19.5 +cython>=0.29.5 +torch==1.9.0 +torchvision==0.10.0 tensorboardX==2.2 -tqdm>=4.19.9 +tqdm>=4.60.0 setuptools>=39.0.1 matplotlib>=2.1.0 gym>=0.10.4 diff --git a/setup.py b/setup.py index 45c753d5..7860a443 100644 --- a/setup.py +++ b/setup.py @@ -16,20 +16,20 @@ packages=find_packages(), zip_safe=False, install_requires=[ - "numpy>=1.14.2", - "torch==1.8.1", - "torchvision==0.9.1", + "numpy>=1.19.5", + "torch==1.9.0", + "torchvision==0.10.0", "tensorboardX==2.2", - "tqdm>=4.19.9", + "tqdm>=4.60.0", "matplotlib>=2.1.0", "gym>=0.10.4", "scikit-build>=0.11.1", "scikit_image>=0.13.1", "scikit_learn>=0.19.1", "opencv-python>=3.4.0.12", - "pytest>=3.4.0", - "scipy>=1.1.0", - "cython>=0.28.5", + "pytest>=6.2.0", + "scipy>=1.5.4", + "cython>=0.29.0", "pandas>=0.23.4", ], ) diff --git a/test/conversion/test_conversion.py b/test/conversion/test_conversion.py index 5f8653b3..8f5deb07 100644 --- a/test/conversion/test_conversion.py +++ b/test/conversion/test_conversion.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn import torch.nn.functional as F @@ -24,14 +25,20 @@ def forward(self, x): return x -def test_conversion(): +def test_conversion_1(): ann = FullyConnectedNetwork() snn = ann_to_snn(ann, input_shape=(784,)) -def main(): +def test_conversion_2(): + data = torch.rand(784, 20) ann = FullyConnectedNetwork() - return ann_to_snn(ann, input_shape=(28, 28)) + snn = ann_to_snn(ann, data=data, input_shape=(784,)) + + +def main(): + test_conversion_1() + test_conversion_2() if __name__ == "__main__":