MNIST
Feedforward Fully Connected SNN
This tutorial goes over how to train a simple feedforward SNN and deploy on HiAER Spike using our conversion pipline.
Define a Feedforward SNN
To build a simple feedforward spiking neural network with PyTorch, we can use snnTorch, SpikingJelly or other deep learning frameworks that are based on PyTorch. Currently, our conversion pipline supports snnTorch and SpikingJelly. In this tutorial, we will be using SpikingJelly.
Install the PyPi distribution of SpikingJelly
$ pip install spikingjelly
Import necessary libraries from SpikingJelly and PyTorch
from spikingjelly.activation_based import neuron, functional, surrogate, layer
import torch
import torch.nn as nn
Model Architecture
Using SpikingJelly, we can define a simple 2-layer feedforward SNN model with 1000 hidden neurons. The PyTorch layer will act as synapses between the spiking neuron layers. #### Surrogate Function SpikingJelly and snnTorch both use backpropagation through time to train the spiking neural networks. However, because of the non-differentiability of spikes, surrogate gradients are used in place of the Heaviside function in the backward pass.
class model(nn.Module):
def __init__(self, features = 1000):
super().__init__()
self.flat = nn.Flatten()
self.linear1 = nn.Linear(28 * 28, features, bias=False)
self.lif1 = neuron.LIFNode(surrogate_function=surrogate.ATan())
self.linear2 = nn.Linear(features, 10, bias=False)
self.lif2 = neuron.LIFNode(surrogate_function=surrogate.ATan())
def forward(self, x):
= self.flat(x)
x = self.linear1(x)
x = self.lif1(x)
x = self.linear2(x)
x = self.lif2(x)
x return x
#Initiate the Network
= model() net
Setting up the MNIST Dataset
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
#Download MNIST data from torch
= datasets.MNIST('data/mnist', train=True, download=True, transform=transforms.Compose(
mnist_train
[transforms.ToTensor()]))= datasets.MNIST('data/mnist', train=False, download=True, transform=transforms.Compose(
mnist_test
[transforms.ToTensor()]))
# Create DataLoaders
= DataLoader(mnist_train, batch_size=128, shuffle=True, drop_last=True)
train_loader = DataLoader(mnist_test, batch_size=128, shuffle=True, drop_last=True) test_loader
Training the CSNN
Since we are using a static image dataset, we will first encode the image into spikes using the rate encoding function from spikingjelly. With rate encoding, the input feature determines the firing frequency and the neuron that fries the most is selected as the predicted class.
from spikingjelly.activation_based import encoding
import time
from tqdm import tqdm
#Setting up the encoder and the time steps
= encoding.PoissonEncoder()
encoder = 20
num_steps
#Define training parameters
= 20
epochs = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device
#Copy netowrk to device
net.to(device)
#Define optimizer, scheduler and the loss function
= torch.optim.Adam(net.parameters(), lr=1e-3)
optimizer = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
lr_scheduler = torch.nn.MSELoss() loss_fun
for epoch in range(epochs):
= time.time()
start_time
net.train()= 0
train_loss = 0
train_acc = 0
train_samples for img, label in train_loader:
optimizer.zero_grad()= img.to(device)
img = label.to(device)
label = torch.nn.functional.one_hot(label, 10).float()
label_onehot = 0.
out_fr for t in range(num_steps):
= encoder(img)
encoded_img += net(encoded_img)
out_fr = out_fr/num_steps
out_fr = loss_fun(out_fr, label_onehot)
loss
loss.backward()
optimizer.step()
+= label.numel()
train_samples += loss.item() * label.numel()
train_loss += (out_fr.argmax(1) == label).float().sum().item()
train_acc
#reset the membrane protential after each input image
functional.reset_net(net)
= time.time()
train_time = train_samples / (train_time - start_time)
train_speed /= train_samples
train_loss /= train_samples
train_acc
lr_scheduler.step()
eval()
net.= 0
test_loss = 0
test_acc = 0
test_samples
with torch.no_grad():
for img, label in test_loader:
= img.to(device)
img = label.to(device)
label = torch.nn.functional.one_hot(label, 10).float()
label_onehot = 0.
out_fr for t in range(num_steps):
= encoder(img)
encoded_img += net(encoded_img)
out_fr = out_fr/num_steps
out_fr
= loss_fun(out_fr, label_onehot)
loss
+= label.numel()
test_samples += loss.item() * label.numel()
test_loss += (out_fr.argmax(1) == label).float().sum().item()
test_acc
functional.reset_net(net)
= time.time()
test_time = test_samples / (test_time - train_time)
test_speed /= test_samples
test_loss /= test_samples
test_acc
print(f'epoch = {epoch}, train_loss ={train_loss: .4f}, train_acc ={train_acc: .4f}, test_loss ={test_loss: .4f}, test_acc ={test_acc: .4f}')
print(f'train speed ={train_speed: .4f} images/s, test speed ={test_speed: .4f} images/s')
Converting the trained SNN to HiAER Spike Format
from hs_api.converter import CRI_Converter, Quantize_Network, BN_Folder
from hs_api.api import CRI_network
# import hs_bridge #Uncomment when running on FPGA
#Fold the BN layer
= BN_Folder()
bn = bn.fold(net)
net_bn
#Weight, Bias Quantization
= Quantize_Network()
qn = qn.quantize(net_bn)
net_quan
#Set the parameters for conversion
= 1 #first pytorch layer that acts as synapses
input_layer = 3 #last pytorch layer that acts as synapses
output_layer = (1, 28, 28)
input_shape = 'spikingjelly'
backend = qn.v_threshold v_threshold
Initiate the HiAER Spike SNN
= {}
config 'neuron_type'] = "I&F"
config['global_neuron_params'] = {}
config['global_neuron_params']['v_thr'] = int(quan_fun.v_threshold)
config[
#Create a network running on the FPGA
= CRI_network(dict(cri_convert.axon_dict),
hardwareNetwork =dict(cri_convert.neuron_dict),
connections=config,
config='CRI',
target= cri_convert.output_neurons,
outputs =1)
coreID
#Create a network running on the software simulation
= CRI_network(dict(cri_convert.axon_dict),
softwareNetwork =dict(cri_convert.neuron_dict),
connections=config,
config='simpleSim',
target= cri_convert.output_neurons,
outputs =1) coreID
Deploying the SNN on HiAER Spike
Using the run_CRI_hw and run_CRI_sw method from the CRI_Converter class, we can deploy the converted SNN on the HiAER Spike platform.
= int(cri_convert.output_neurons[0])
cri_convert.bias_start_idx = nn.MSELoss()
loss_fun = time.time()
start_time = 0
test_loss = 0
test_acc = 0
test_samples = 0
num_batches
= False #Set to True if running on FPGA
RUN_HARDWARE
for img, label in tqdm(test_loader):
= cri_convert.input_converter(img)
cri_input = None
output if RUN_HARDWARE:
= torch.tensor(cri_convert.run_CRI_hw(cri_input,hardwareNetwork), dtype=float)
output else:
= torch.tensor(cri_convert.run_CRI_sw(cri_input,softwareNetwork), dtype=float)
output = loss_fun(output, label)
loss += label.numel()
test_samples += loss.item() * label.numel()
test_loss += (output == label).float().sum().item()
test_acc += 1
num_batches = time.time()
test_time = test_samples / (test_time - start_time)
test_speed /= test_samples
test_loss /= test_samples
test_acc
print(f'test_loss ={test_loss: .4f}, test_acc ={test_acc: .4f}')
print(f'test speed ={test_speed: .4f} images/s')