MNIST

Feedforward Fully Connected SNN

Open In Colab

This tutorial goes over how to train a simple feedforward SNN and deploy on HiAER Spike using our conversion pipline.

Define a Feedforward SNN

To build a simple feedforward spiking neural network with PyTorch, we can use snnTorch, SpikingJelly or other deep learning frameworks that are based on PyTorch. Currently, our conversion pipline supports snnTorch and SpikingJelly. In this tutorial, we will be using SpikingJelly.

Install the PyPi distribution of SpikingJelly

$ pip install spikingjelly

Import necessary libraries from SpikingJelly and PyTorch

from spikingjelly.activation_based import neuron, functional, surrogate, layer
import torch 
import torch.nn as nn

Model Architecture

Using SpikingJelly, we can define a simple 2-layer feedforward SNN model with 1000 hidden neurons. The PyTorch layer will act as synapses between the spiking neuron layers. #### Surrogate Function SpikingJelly and snnTorch both use backpropagation through time to train the spiking neural networks. However, because of the non-differentiability of spikes, surrogate gradients are used in place of the Heaviside function in the backward pass.

class model(nn.Module): 
    def __init__(self, features = 1000): 
        super().__init__() 
        self.flat = nn.Flatten()
        self.linear1 = nn.Linear(28 * 28, features, bias=False) 
        self.lif1 = neuron.LIFNode(surrogate_function=surrogate.ATan()) 
        self.linear2 = nn.Linear(features, 10, bias=False) 
        self.lif2 = neuron.LIFNode(surrogate_function=surrogate.ATan()) 
    def forward(self, x): 
        x = self.flat(x)
        x = self.linear1(x) 
        x = self.lif1(x) 
        x = self.linear2(x) 
        x = self.lif2(x) 
        return x
#Initiate the Network
net = model()

Setting up the MNIST Dataset

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

#Download MNIST data from torch 
mnist_train = datasets.MNIST('data/mnist', train=True, download=True, transform=transforms.Compose(
    [transforms.ToTensor()]))
mnist_test = datasets.MNIST('data/mnist', train=False, download=True, transform=transforms.Compose(
    [transforms.ToTensor()]))

# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=128, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=128, shuffle=True, drop_last=True)

Training the CSNN

Since we are using a static image dataset, we will first encode the image into spikes using the rate encoding function from spikingjelly. With rate encoding, the input feature determines the firing frequency and the neuron that fries the most is selected as the predicted class.

from spikingjelly.activation_based import encoding
import time
from tqdm import tqdm
#Setting up the encoder and the time steps
encoder = encoding.PoissonEncoder()
num_steps = 20

#Define training parameters
epochs = 20
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

#Copy netowrk to device 
net.to(device)

#Define optimizer, scheduler and the loss function
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
loss_fun = torch.nn.MSELoss()
for epoch in range(epochs):
    start_time = time.time()
    net.train()
    train_loss = 0
    train_acc = 0
    train_samples = 0
    for img, label in train_loader:
        optimizer.zero_grad()
        img = img.to(device)
        label = label.to(device)
        label_onehot = torch.nn.functional.one_hot(label, 10).float()
        out_fr = 0.
        for t in range(num_steps):
            encoded_img = encoder(img)
            out_fr += net(encoded_img)
        out_fr = out_fr/num_steps  
        loss = loss_fun(out_fr, label_onehot)
        loss.backward()
        optimizer.step()

        train_samples += label.numel()
        train_loss += loss.item() * label.numel()
        train_acc += (out_fr.argmax(1) == label).float().sum().item()

        #reset the membrane protential after each input image
        functional.reset_net(net)

    train_time = time.time()
    train_speed = train_samples / (train_time - start_time)
    train_loss /= train_samples
    train_acc /= train_samples
    
    lr_scheduler.step()
        
    net.eval()
    test_loss = 0
    test_acc = 0
    test_samples = 0

    with torch.no_grad():
        for img, label in test_loader:
            img = img.to(device)
            label = label.to(device)
            label_onehot = torch.nn.functional.one_hot(label, 10).float()
            out_fr = 0.   
            for t in range(num_steps):
                encoded_img = encoder(img)
                out_fr += net(encoded_img)
            out_fr = out_fr/num_steps  

            loss = loss_fun(out_fr, label_onehot)

            test_samples += label.numel()
            test_loss += loss.item() * label.numel()
            test_acc += (out_fr.argmax(1) == label).float().sum().item()
            functional.reset_net(net)

    test_time = time.time()
    test_speed = test_samples / (test_time - train_time)
    test_loss /= test_samples
    test_acc /= test_samples

    print(f'epoch = {epoch}, train_loss ={train_loss: .4f}, train_acc ={train_acc: .4f}, test_loss ={test_loss: .4f}, test_acc ={test_acc: .4f}')
    print(f'train speed ={train_speed: .4f} images/s, test speed ={test_speed: .4f} images/s')

Converting the trained SNN to HiAER Spike Format

from hs_api.converter import CRI_Converter, Quantize_Network, BN_Folder
from hs_api.api import CRI_network
# import hs_bridge #Uncomment when running on FPGA

#Fold the BN layer 
bn = BN_Folder() 
net_bn = bn.fold(net)

#Weight, Bias Quantization 
qn = Quantize_Network() 
net_quan = qn.quantize(net_bn)

#Set the parameters for conversion
input_layer = 1 #first pytorch layer that acts as synapses
output_layer = 3 #last pytorch layer that acts as synapses
input_shape = (1, 28, 28)
backend = 'spikingjelly'
v_threshold = qn.v_threshold

Initiate the HiAER Spike SNN

config = {}
config['neuron_type'] = "I&F"
config['global_neuron_params'] = {}
config['global_neuron_params']['v_thr'] = int(quan_fun.v_threshold)
    
#Create a network running on the FPGA
hardwareNetwork = CRI_network(dict(cri_convert.axon_dict),
                              connections=dict(cri_convert.neuron_dict),
                              config=config,
                              target='CRI', 
                              outputs = cri_convert.output_neurons,
                              coreID=1)

#Create a network running on the software simulation
softwareNetwork = CRI_network(dict(cri_convert.axon_dict),
                              connections=dict(cri_convert.neuron_dict),
                              config=config,
                              target='simpleSim', 
                              outputs = cri_convert.output_neurons,
                              coreID=1)

Deploying the SNN on HiAER Spike

Using the run_CRI_hw and run_CRI_sw method from the CRI_Converter class, we can deploy the converted SNN on the HiAER Spike platform.

cri_convert.bias_start_idx = int(cri_convert.output_neurons[0])
loss_fun = nn.MSELoss()
start_time = time.time()
test_loss = 0
test_acc = 0
test_samples = 0
num_batches = 0

RUN_HARDWARE = False #Set to True if running on FPGA

for img, label in tqdm(test_loader):
    cri_input = cri_convert.input_converter(img)
    output = None
    if RUN_HARDWARE:
        output = torch.tensor(cri_convert.run_CRI_hw(cri_input,hardwareNetwork), dtype=float)
    else:
        output = torch.tensor(cri_convert.run_CRI_sw(cri_input,softwareNetwork), dtype=float)
    loss = loss_fun(output, label)
    test_samples += label.numel()
    test_loss += loss.item() * label.numel()
    test_acc += (output == label).float().sum().item()
    num_batches += 1
test_time = time.time()
test_speed = test_samples / (test_time - start_time)
test_loss /= test_samples
test_acc /= test_samples

print(f'test_loss ={test_loss: .4f}, test_acc ={test_acc: .4f}')
print(f'test speed ={test_speed: .4f} images/s')