Advertisement
aurko96

CNN-RNN

Apr 24th, 2024 (edited)
661
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.03 KB | None | 0 0
  1. # Import necessary libraries
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from tensorflow.keras.datasets import mnist
  5. from tensorflow.keras.models import Sequential
  6. from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, SimpleRNN
  7. from tensorflow.keras.utils import to_categorical
  8. from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
  9. import seaborn as sns
  10. from brian2 import *
  11.  
  12. # Load MNIST dataset
  13. (x_train, y_train), (x_test, y_test) = mnist.load_data()
  14.  
  15. # Preprocess the data
  16. x_train = x_train.reshape(x_train.shape[0], 28, 28, 1).astype('float32') / 255
  17. x_test = x_test.reshape(x_test.shape[0], 28, 28, 1).astype('float32') / 255
  18. y_train = to_categorical(y_train, 10)
  19. y_test = to_categorical(y_test, 10)
  20.  
  21. num_classes = 10
  22.  
  23. # Define the CNN model
  24. cnn_model = Sequential([
  25.     Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
  26.     MaxPooling2D(pool_size=(2, 2)),
  27.     Flatten(),
  28.     Dense(128, activation='relu'),
  29.     Dense(10, activation='softmax')
  30. ])
  31.  
  32. cnn_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
  33.  
  34. # Train the CNN model
  35. cnn_history = cnn_model.fit(x_train, y_train, batch_size=128, epochs=20, validation_split=0.1)
  36.  
  37. # Evaluate the CNN model
  38. cnn_loss, cnn_accuracy = cnn_model.evaluate(x_test, y_test)
  39. cnn_predictions = cnn_model.predict(x_test)
  40. cnn_predictions = np.argmax(cnn_predictions, axis=1)
  41. cnn_true_labels = np.argmax(y_test, axis=1)
  42.  
  43. cnn_precision = precision_score(cnn_true_labels, cnn_predictions, average='weighted')
  44. cnn_recall = recall_score(cnn_true_labels, cnn_predictions, average='weighted')
  45. cnn_f1 = f1_score(cnn_true_labels, cnn_predictions, average='weighted')
  46. cnn_conf_matrix = confusion_matrix(cnn_true_labels, cnn_predictions)
  47.  
  48. print(f'CNN Test Accuracy: {cnn_accuracy}')
  49. print(f'CNN Precision: {cnn_precision}')
  50. print(f'CNN Recall: {cnn_recall}')
  51. print(f'CNN F1 Score: {cnn_f1}')
  52.  
  53. # Plot confusion matrix for CNN
  54. plt.figure(figsize=(8, 6))
  55. sns.heatmap(cnn_conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=range(num_classes), yticklabels=range(num_classes))
  56. plt.title('Confusion Matrix - CNN')
  57. plt.xlabel('Predicted')
  58. plt.ylabel('True')
  59. plt.show()
  60.  
  61. # Define the RNN model
  62. rnn_model = Sequential([
  63.     SimpleRNN(128, input_shape=(28, 28), activation='relu', return_sequences=False),
  64.     Dense(10, activation='softmax')
  65. ])
  66.  
  67. rnn_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
  68.  
  69. # Train the RNN model
  70. rnn_history = rnn_model.fit(x_train, y_train, batch_size=128, epochs=20, validation_split=0.1)
  71.  
  72. # Evaluate the RNN model
  73. rnn_loss, rnn_accuracy = rnn_model.evaluate(x_test, y_test)
  74. rnn_predictions = rnn_model.predict(x_test)
  75. rnn_predictions = np.argmax(rnn_predictions, axis=1)
  76. rnn_true_labels = np.argmax(y_test, axis=1)
  77.  
  78. rnn_precision = precision_score(rnn_true_labels, rnn_predictions, average='weighted')
  79. rnn_recall = recall_score(rnn_true_labels, rnn_predictions, average='weighted')
  80. rnn_f1 = f1_score(rnn_true_labels, rnn_predictions, average='weighted')
  81. rnn_conf_matrix = confusion_matrix(rnn_true_labels, rnn_predictions)
  82.  
  83. print(f'RNN Test Accuracy: {rnn_accuracy}')
  84. print(f'RNN Precision: {rnn_precision}')
  85. print(f'RNN Recall: {rnn_recall}')
  86. print(f'RNN F1 Score: {rnn_f1}')
  87.  
  88. # Plot confusion matrix for RNN
  89. plt.figure(figsize=(8, 6))
  90. sns.heatmap(rnn_conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=range(num_classes), yticklabels=range(num_classes))
  91. plt.title('Confusion Matrix - RNN')
  92. plt.xlabel('Predicted')
  93. plt.ylabel('True')
  94. plt.show()
  95.  
  96. # Plot training history for CNN and RNN
  97. plt.plot(cnn_history.history['accuracy'], label='CNN Train Accuracy')
  98. plt.plot(cnn_history.history['val_accuracy'], label='CNN Validation Accuracy')
  99. plt.plot(rnn_history.history['accuracy'], label='RNN Train Accuracy')
  100. plt.plot(rnn_history.history['val_accuracy'], label='RNN Validation Accuracy')
  101. plt.title('CNN and RNN Model Training History')
  102. plt.xlabel('Epoch')
  103. plt.ylabel('Accuracy')
  104. plt.legend()
  105. plt.show()
  106.  
  107.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement