Advertisement
aurko96

SNN

Apr 24th, 2024
686
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.73 KB | None | 0 0
  1. from keras.datasets import mnist
  2. from brian2 import *
  3. import brian2.numpy_ as np
  4. import matplotlib.pyplot as plt
  5. from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
  6.  
  7. # Load MNIST dataset
  8. (X_train, y_train), (X_test, y_test) = mnist.load_data()
  9.  
  10. # # Simplified classification (0, 1, and 8)
  11. # X_train = X_train[(y_train == 1) | (y_train == 0) | (y_train == 8)]
  12. # y_train = y_train[(y_train == 1) | (y_train == 0) | (y_train == 8)]
  13. # X_test = X_test[(y_test == 1) | (y_test == 0) | (y_test == 8)]
  14. # y_test = y_test[(y_test == 1) | (y_test == 0) | (y_test == 8)]
  15.  
  16. # Pixel intensity to firing rate (255 becomes ~63Hz)
  17. X_train = X_train / 4
  18. X_test = X_test / 4
  19.  
  20. # Flatten the images
  21. X_train = X_train.reshape(X_train.shape[0], -1)
  22. X_test = X_test.reshape(X_test.shape[0], -1)
  23.  
  24. # Define SNN parameters
  25. n_input = 28*28  # Input layer
  26. n_e = 100        # Excitatory neurons
  27. n_i = n_e        # Inhibitory neurons
  28.  
  29. v_rest_e = -60.*mV  # Membrane potential
  30. v_reset_e = -65.*mV
  31. v_thresh_e = -52.*mV
  32.  
  33. v_rest_i = -60.*mV
  34. v_reset_i = -45.*mV
  35. v_thresh_i = -40.*mV
  36.  
  37. taupre = 20*ms
  38. taupost = taupre
  39. gmax = .05
  40. dApre = .01
  41. dApost = -dApre * taupre / taupost * 1.05
  42. dApost *= gmax
  43. dApre *= gmax
  44.  
  45. # Define STDP equations
  46. stdp_eqs = '''
  47.    w : 1
  48.    lr : 1 (shared)
  49.    dApre/dt = -Apre / taupre : 1 (event-driven)
  50.    dApost/dt = -Apost / taupost : 1 (event-driven)'''
  51.  
  52. # Pre-synaptic spike update
  53. stdp_pre = '''
  54.    ge += w
  55.    Apre += dApre
  56.    w = clip(w + lr*Apost, 0, gmax)'''
  57.  
  58. # Post-synaptic spike update
  59. stdp_post = '''
  60.    Apost += dApost
  61.    w = clip(w + lr*Apre, 0, gmax)'''
  62.  
  63. class Model():
  64.    
  65.     def __init__(self):
  66.         # Input Poisson Group
  67.         self.PG = PoissonGroup(n_input, rates=np.zeros(n_input)*Hz, name='PG')
  68.        
  69.         # Excitatory Neuron Group
  70.         self.EG = NeuronGroup(n_e, '''
  71.            dv/dt = (ge*(0*mV-v) + gi*(-100*mV-v) + (v_rest_e-v)) / (100*ms) : volt
  72.            dge/dt = -ge / (5*ms) : 1
  73.            dgi/dt = -gi / (10*ms) : 1
  74.            ''',
  75.             threshold='v>v_thresh_e', refractory=5*ms, reset='v=v_reset_e', method='euler', name='EG')
  76.         self.EG.v = v_rest_e - 20.*mV
  77.        
  78.         # Inhibitory Neuron Group
  79.         self.IG = NeuronGroup(n_i, '''
  80.            dv/dt = (ge*(0*mV-v) + (v_rest_i-v)) / (10*ms) : volt
  81.            dge/dt = -ge / (5*ms) : 1
  82.            ''',
  83.             threshold='v>v_thresh_i', refractory=2*ms, reset='v=v_reset_i', method='euler', name='IG')
  84.         self.IG.v = v_rest_i - 20.*mV
  85.        
  86.         # Synapses between Poisson Group and Excitatory Neurons
  87.         self.S1 = Synapses(self.PG, self.EG, stdp_eqs, on_pre=stdp_pre, on_post=stdp_post, method='euler', name='S1')
  88.         self.S1.connect()
  89.         self.S1.w = 'rand()*gmax'  # Random weights initialization
  90.         self.S1.lr = 1               # Enable STDP
  91.        
  92.         # Synapses between Excitatory and Inhibitory Neurons
  93.         self.S2 = Synapses(self.EG, self.IG, 'w : 1', on_pre='ge += w', name='S2')
  94.         self.S2.connect(j='i')
  95.         self.S2.delay = 'rand()*10*ms'
  96.         self.S2.w = 3                # Very strong fixed weights
  97.        
  98.         # Synapses between Inhibitory and Excitatory Neurons
  99.         self.S3 = Synapses(self.IG, self.EG, 'w : 1', on_pre='gi += w', name='S3')
  100.         self.S3.connect(condition='i!=j')
  101.         self.S3.delay = 'rand()*5*ms'
  102.         self.S3.w = .03              # Balanced weights
  103.        
  104.         # Initialize Brian2 Network
  105.         self.net = Network(self.PG, self.EG, self.IG, self.S1, self.S2, self.S3)
  106.         self.net.run(0*second)
  107.        
  108.     def train(self, X, epoch=1):        
  109.         self.S1.lr = 1  # Enable STDP
  110.        
  111.         for ep in range(epoch):
  112.             for idx in range(len(X)):
  113.                 # Active mode
  114.                 self.PG.rates = X[idx].ravel()*Hz
  115.                 self.net.run(0.35*second)
  116.  
  117.                 # Passive mode
  118.                 self.PG.rates = np.zeros(n_input)*Hz
  119.                 self.net.run(0.15*second)
  120.        
  121.     def evaluate(self, X):      
  122.         self.S1.lr = 0  # Disable STDP
  123.        
  124.         features = []
  125.         for idx in range(len(X)):
  126.             # Rate monitor to count spikes
  127.             mon = SpikeMonitor(self.EG, name='RM')
  128.             self.net.add(mon)
  129.            
  130.             # Active mode
  131.             self.PG.rates = X[idx].ravel()*Hz
  132.             self.net.run(0.35*second)
  133.            
  134.             # Spikes per neuron for each image
  135.             features.append(np.array(mon.count, dtype=int8))
  136.            
  137.             # Passive mode
  138.             self.PG.rates = np.zeros(n_input)*Hz
  139.             self.net.run(0.15*second)
  140.            
  141.             self.net.remove(self.net['RM'])
  142.            
  143.         return features
  144.  
  145.  
  146. import seaborn as sns
  147.  
  148. # Test the SNN model with evaluation metrics and confusion matrix plotting
  149. def test_snn(train_items=500, assign_items=100, eval_items=100):
  150.     seed(0)
  151.    
  152.     model = Model()
  153.     model.train(X_train[:train_items], epoch=1)
  154.    
  155.     train_features = model.evaluate(X_train[:assign_items])
  156.     test_features = model.evaluate(X_test[:eval_items])
  157.    
  158.     # Perform classification using a simple thresholding method
  159.     threshold = 10  # Example threshold value
  160.     train_predictions = [1 if np.sum(f) > threshold else 0 for f in train_features]
  161.     test_predictions = [1 if np.sum(f) > threshold else 0 for f in test_features]
  162.    
  163. #     # Perform classification using argmax to determine the predicted class
  164. #     train_predictions = np.argmax(train_features, axis=1)
  165. #     test_predictions = np.argmax(test_features, axis=1)
  166.    
  167.     # Calculate evaluation metrics
  168.     train_accuracy = accuracy_score(y_train[:assign_items], train_predictions)
  169.     test_accuracy = accuracy_score(y_test[:eval_items], test_predictions)
  170.    
  171.     train_precision = precision_score(y_train[:assign_items], train_predictions, average='weighted')
  172.     test_precision = precision_score(y_test[:eval_items], test_predictions, average='weighted')
  173.    
  174.     train_recall = recall_score(y_train[:assign_items], train_predictions, average='weighted')
  175.     test_recall = recall_score(y_test[:eval_items], test_predictions, average='weighted')
  176.    
  177.     train_f1 = f1_score(y_train[:assign_items], train_predictions, average='weighted')
  178.     test_f1 = f1_score(y_test[:eval_items], test_predictions, average='weighted')
  179.    
  180.     train_confusion_matrix = confusion_matrix(y_train[:assign_items], train_predictions)
  181.     test_confusion_matrix = confusion_matrix(y_test[:eval_items], test_predictions)
  182.    
  183.     print("Train Accuracy:", train_accuracy)
  184.     print("Test Accuracy:", test_accuracy)
  185.    
  186.     print("Train Precision:", train_precision)
  187.     print("Test Precision:", test_precision)
  188.    
  189.     print("Train Recall:", train_recall)
  190.     print("Test Recall:", test_recall)
  191.    
  192.     print("Train F1 Score:", train_f1)
  193.     print("Test F1 Score:", test_f1)
  194.    
  195.     print("Train Confusion Matrix:\n", train_confusion_matrix)
  196.     print("Test Confusion Matrix:\n", test_confusion_matrix)
  197.    
  198.     # Plot confusion matrices
  199.     plot_confusion_matrix(train_confusion_matrix, np.arange(10)) # np.arange(10)
  200.     plot_confusion_matrix(test_confusion_matrix, np.arange(10)) # np.arange(10)
  201.    
  202.     return train_features, test_features
  203.  
  204. # Function to plot confusion matrix
  205. def plot_confusion_matrix(confusion_matrix, labels):
  206.     plt.figure(figsize=(10, 8))
  207.     sns.heatmap(confusion_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
  208.     plt.xlabel('Predicted Labels')
  209.     plt.ylabel('True Labels')
  210.     plt.title('Confusion Matrix')
  211.     plt.show()
  212.  
  213. # Example usage
  214. train_features, test_features = test_snn(train_items=500, assign_items=100, eval_items=100)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement