Advertisement
aurko96

SNN-RandomForest

Apr 24th, 2024
806
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 9.88 KB | None | 0 0
  1. from keras.datasets import mnist
  2. from brian2 import *
  3. import brian2.numpy_ as np
  4. import matplotlib.pyplot as plt
  5. from sklearn.ensemble import RandomForestClassifier
  6. from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
  7. import seaborn as sns
  8.  
  9. (X_train, y_train), (X_test, y_test) = mnist.load_data()
  10.  
  11. # pixel intensity to Hz (255 becoms ~63Hz)
  12. X_train = X_train / 4
  13. X_test = X_test / 4
  14.  
  15. X_train.shape, X_test.shape
  16.  
  17. plt.figure(figsize=(16,8))
  18. for img in range(32):
  19.     plt.subplot(4,8,1+img)
  20.     plt.title(y_train[img])
  21.     plt.imshow(X_train[img])
  22.     plt.axis('off')
  23.  
  24. n_input = 28*28 # input layer
  25. n_e = 100 # e - excitatory
  26. n_i = n_e # i - inhibitory
  27.  
  28. v_rest_e = -60.*mV # v - membrane potential
  29. v_reset_e = -65.*mV
  30. v_thresh_e = -52.*mV
  31.  
  32. v_rest_i = -60.*mV
  33. v_reset_i = -45.*mV
  34. v_thresh_i = -40.*mV
  35.  
  36. taupre = 20*ms
  37. taupost = taupre
  38. gmax = .05 #.01
  39. dApre = .01
  40. dApost = -dApre * taupre / taupost * 1.05
  41. dApost *= gmax
  42. dApre *= gmax
  43.  
  44. # Apre and Apost - presynaptic and postsynaptic traces, lr - learning rate
  45. stdp='''w : 1
  46.    lr : 1 (shared)
  47.    dApre/dt = -Apre / taupre : 1 (event-driven)
  48.    dApost/dt = -Apost / taupost : 1 (event-driven)'''
  49. pre='''ge += w
  50.    Apre += dApre
  51.    w = clip(w + lr*Apost, 0, gmax)'''
  52. post='''Apost += dApost
  53.    w = clip(w + lr*Apre, 0, gmax)'''
  54.  
  55. class Model():
  56.    
  57.     def __init__(self, debug=False):
  58.         app = {}
  59.                
  60.         # input images as rate encoded Poisson generators
  61.         app['PG'] = PoissonGroup(n_input, rates=np.zeros(n_input)*Hz, name='PG')
  62.        
  63.         # excitatory group
  64.         neuron_e = '''
  65.            dv/dt = (ge*(0*mV-v) + gi*(-100*mV-v) + (v_rest_e-v)) / (100*ms) : volt
  66.            dge/dt = -ge / (5*ms) : 1
  67.            dgi/dt = -gi / (10*ms) : 1
  68.            '''
  69.         app['EG'] = NeuronGroup(n_e, neuron_e, threshold='v>v_thresh_e', refractory=5*ms, reset='v=v_reset_e', method='euler', name='EG')
  70.         app['EG'].v = v_rest_e - 20.*mV
  71.        
  72.         if (debug):
  73.             app['ESP'] = SpikeMonitor(app['EG'], name='ESP')
  74.             app['ESM'] = StateMonitor(app['EG'], ['v'], record=True, name='ESM')
  75.             app['ERM'] = PopulationRateMonitor(app['EG'], name='ERM')
  76.        
  77.         # ibhibitory group
  78.         neuron_i = '''
  79.            dv/dt = (ge*(0*mV-v) + (v_rest_i-v)) / (10*ms) : volt
  80.            dge/dt = -ge / (5*ms) : 1
  81.            '''
  82.         app['IG'] = NeuronGroup(n_i, neuron_i, threshold='v>v_thresh_i', refractory=2*ms, reset='v=v_reset_i', method='euler', name='IG')
  83.         app['IG'].v = v_rest_i - 20.*mV
  84.  
  85.         if (debug):
  86.             app['ISP'] = SpikeMonitor(app['IG'], name='ISP')
  87.             app['ISM'] = StateMonitor(app['IG'], ['v'], record=True, name='ISM')
  88.             app['IRM'] = PopulationRateMonitor(app['IG'], name='IRM')
  89.        
  90.         # poisson generators one-to-all excitatory neurons with plastic connections
  91.         app['S1'] = Synapses(app['PG'], app['EG'], stdp, on_pre=pre, on_post=post, method='euler', name='S1')
  92.         app['S1'].connect()
  93.         app['S1'].w = 'rand()*gmax' # random weights initialisation
  94.         app['S1'].lr = 1 # enable stdp        
  95.        
  96.         if (debug):
  97.             # some synapses
  98.             app['S1M'] = StateMonitor(app['S1'], ['w', 'Apre', 'Apost'], record=app['S1'][380,:4], name='S1M')
  99.        
  100.         # excitatory neurons one-to-one inhibitory neurons
  101.         app['S2'] = Synapses(app['EG'], app['IG'], 'w : 1', on_pre='ge += w', name='S2')
  102.         app['S2'].connect(j='i')
  103.         app['S2'].delay = 'rand()*10*ms'
  104.         app['S2'].w = 3 # very strong fixed weights to ensure corresponding inhibitory neuron will always fire
  105.  
  106.         # inhibitory neurons one-to-all-except-one excitatory neurons
  107.         app['S3'] = Synapses(app['IG'], app['EG'], 'w : 1', on_pre='gi += w', name='S3')
  108.         app['S3'].connect(condition='i!=j')
  109.         app['S3'].delay = 'rand()*5*ms'
  110.         app['S3'].w = .03 # weights are selected in such a way as to maintain a balance between excitation and ibhibition
  111.        
  112.         self.net = Network(app.values())
  113.         self.net.run(0*second)
  114.        
  115.     def __getitem__(self, key):
  116.         return self.net[key]
  117.    
  118.     def train(self, X, epoch=1):        
  119.         self.net['S1'].lr = 1 # stdp on
  120.        
  121.         for ep in range(epoch):
  122.             for idx in range(len(X)):
  123.                 # active mode
  124.                 self.net['PG'].rates = X[idx].ravel()*Hz
  125.                 self.net.run(0.35*second)
  126.  
  127.                 # passive mode
  128.                 self.net['PG'].rates = np.zeros(n_input)*Hz
  129.                 self.net.run(0.15*second)
  130.        
  131.     def evaluate(self, X):      
  132.         self.net['S1'].lr = 0  # stdp off
  133.        
  134.         features = []
  135.         for idx in range(len(X)):
  136.             # rate monitor to count spikes
  137.             mon = SpikeMonitor(self.net['EG'], name='RM')
  138.             self.net.add(mon)
  139.            
  140.             # active mode
  141.             self.net['PG'].rates = X[idx].ravel()*Hz
  142.             self.net.run(0.35*second)
  143.            
  144.             # spikes per neuron foreach image
  145.             features.append(np.array(mon.count, dtype=int8))
  146.            
  147.             # passive mode
  148.             self.net['PG'].rates = np.zeros(n_input)*Hz
  149.             self.net.run(0.15*second)
  150.            
  151.             self.net.remove(self.net['RM'])
  152.            
  153.         return features
  154.  
  155. def plot_w(S1M):
  156.     plt.rcParams["figure.figsize"] = (20,10)
  157.     subplot(311)
  158.     plot(S1M.t/ms, S1M.w.T/gmax)
  159.     ylabel('w / wmax')
  160.     subplot(312)
  161.     plot(S1M.t/ms, S1M.Apre.T)
  162.     ylabel('apre')
  163.     subplot(313)
  164.     plot(S1M.t/ms, S1M.Apost.T)
  165.     ylabel('apost')
  166.     tight_layout()
  167.     show();
  168.    
  169. def plot_v(ESM, ISM, neuron=13):
  170.     plt.rcParams["figure.figsize"] = (20,6)
  171.     cnt = -50000 # tail
  172.     plot(ESM.t[cnt:]/ms, ESM.v[neuron][cnt:]/mV, label='exc', color='r')
  173.     plot(ISM.t[cnt:]/ms, ISM.v[neuron][cnt:]/mV, label='inh', color='b')
  174.     plt.axhline(y=v_thresh_e/mV, color='pink', label='v_thresh_e')
  175.     plt.axhline(y=v_thresh_i/mV, color='silver', label='v_thresh_i')
  176.     legend()
  177.     ylabel('v')
  178.     show();
  179.    
  180. def plot_rates(ERM, IRM):
  181.     plt.rcParams["figure.figsize"] = (20,6)
  182.     plot(ERM.t/ms, ERM.smooth_rate(window='flat', width=0.1*ms)*Hz, color='r')
  183.     plot(IRM.t/ms, IRM.smooth_rate(window='flat', width=0.1*ms)*Hz, color='b')
  184.     ylabel('Rate')
  185.     show();
  186.    
  187. def plot_spikes(ESP, ISP):
  188.     plt.rcParams["figure.figsize"] = (20,6)
  189.     plot(ESP.t/ms, ESP.i, '.r')
  190.     plot(ISP.t/ms, ISP.i, '.b')
  191.     ylabel('Neuron index')
  192.     show();
  193.  
  194. def test0(train_items=30):
  195.     '''
  196.    STDP visualisation
  197.    '''
  198.     seed(0)
  199.    
  200.     model = Model(debug=True)
  201.     model.train(X_train[:train_items], epoch=1)
  202.    
  203.     plot_w(model['S1M'])
  204.     plot_v(model['ESM'], model['ISM'])
  205.     plot_rates(model['ERM'], model['IRM'])
  206.     plot_spikes(model['ESP'], model['ISP'])
  207.    
  208. test0()
  209.  
  210. def test1(train_items=500, assign_items=100, eval_items=100):
  211.     '''
  212.    Feed train set to SNN with STDP
  213.    Freeze STDP
  214.    Feed train set to SNN again and collect generated features
  215.    Train RandomForest on the top of these features and labels provided
  216.    Feed test set to SNN and collect new features
  217.    Predict labels with RandomForest and calculate accuacy score
  218.    '''
  219.     seed(0)
  220.    
  221.     model = Model()
  222.     model.train(X_train[:train_items], epoch=1)
  223.     model.net.store('train', 'train.b2')
  224.     #model.net.restore('train', './train.b2')
  225.    
  226.     f_train = model.evaluate(X_train[:assign_items])
  227.     clf = RandomForestClassifier(max_depth=4, random_state=0)
  228.     clf.fit(f_train, y_train[:assign_items])
  229.     print(clf.score(f_train, y_train[:assign_items]))
  230.  
  231.     f_test = model.evaluate(X_test[:eval_items])
  232.     y_pred = clf.predict(f_test)
  233.     print(accuracy_score(y_pred, y_test[:eval_items]))
  234.  
  235. #     cm = confusion_matrix(y_pred, y_test[:eval_items])
  236. #     print(cm)
  237.    
  238.    
  239.    
  240. #     # Calculate evaluation metrics
  241. #     train_accuracy = accuracy_score(y_train[:assign_items], train_predictions)
  242.     test_accuracy = accuracy_score(y_test[:eval_items], y_pred)
  243.    
  244. #     train_precision = precision_score(y_train[:assign_items], train_predictions, average='weighted')
  245.     test_precision = precision_score(y_test[:eval_items], y_pred, average='weighted')
  246.    
  247. #     train_recall = recall_score(y_train[:assign_items], train_predictions, average='weighted')
  248.     test_recall = recall_score(y_test[:eval_items], y_pred, average='weighted')
  249.    
  250. #     train_f1 = f1_score(y_train[:assign_items], train_predictions, average='weighted')
  251.     test_f1 = f1_score(y_test[:eval_items], y_pred, average='weighted')
  252.    
  253. #     train_confusion_matrix = confusion_matrix(y_train[:assign_items], train_predictions)
  254.     test_confusion_matrix = confusion_matrix(y_test[:eval_items], y_pred)
  255.    
  256. #     print("Train Accuracy:", train_accuracy)
  257.     print("Test Accuracy:", test_accuracy)
  258.    
  259. #     print("Train Precision:", train_precision)
  260.     print("Test Precision:", test_precision)
  261.    
  262. #     print("Train Recall:", train_recall)
  263.     print("Test Recall:", test_recall)
  264.    
  265. #     print("Train F1 Score:", train_f1)
  266.     print("Test F1 Score:", test_f1)
  267.    
  268. #     print("Train Confusion Matrix:\n", train_confusion_matrix)
  269.     print("Test Confusion Matrix:\n", test_confusion_matrix)
  270.    
  271.     # Plot confusion matrices
  272. #     plot_confusion_matrix(train_confusion_matrix, np.arange(10)) # np.arange(10)
  273.     plot_confusion_matrix(test_confusion_matrix, np.arange(10)) # np.arange(10)
  274.    
  275.     # Function to plot confusion matrix
  276. def plot_confusion_matrix(confusion_matrix, labels):
  277.     plt.figure(figsize=(10, 8))
  278.     sns.heatmap(confusion_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
  279.     plt.xlabel('Predicted Labels')
  280.     plt.ylabel('True Labels')
  281.     plt.title('Confusion Matrix')
  282.     plt.show()
  283.    
  284.    
  285. test1()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement