Adapted from Chris Tralie's CS 477 at Ursinus College
# Import libraries
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras import layers, losses
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.layers import Dense, Conv2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
import string
import time
import numpy as np
import matplotlib.pyplot as plt
import skimage.io
from skimage.transform import resize
import glob
imgres = 64
# Setup model
class AutoencoderCNN(Model):
def __init__(self, imgres, d, k):
"""
Parameters
----------
imgres: int
Resolution of each grayscale image
d: int
Dimension of the embedding
k: int
Kernel size at each layer
"""
super(AutoencoderCNN, self).__init__()
self.imgres = imgres
self.d = d
self.k = k
print(imgres,d,k)
self.encoder = Sequential([
layers.Input(shape=(imgres, imgres, 1)),
layers.Conv2D(32, (k, k), activation='leaky_relu', padding='same', strides=2),
layers.Dropout(0.2),
layers.Conv2D(64, (k, k), activation='leaky_relu', padding='same', strides=2),
layers.Dropout(0.2),
layers.Conv2D(128, (k, k), activation='leaky_relu', padding='same', strides=2),
layers.Dropout(0.2),
layers.Flatten(),
layers.Dense(d, activation='leaky_relu')
])
self.decoder = Sequential([
layers.Input(shape=(d,)),
layers.Dense(2*imgres**2),
layers.Reshape((imgres//8, imgres//8, 128)),
layers.Conv2DTranspose(128, kernel_size=k, strides=2, activation='leaky_relu', padding='same'),
layers.Dropout(0.2),
layers.Conv2DTranspose(64, kernel_size=k, strides=2, activation='leaky_relu', padding='same'),
layers.Dropout(0.2),
layers.Conv2DTranspose(32, kernel_size=k, strides=2, activation='leaky_relu', padding='same'),
layers.Dropout(0.2),
layers.Conv2DTranspose(1, kernel_size=(k, k), activation='sigmoid', padding='same')
])
def call(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
autoencoder = AutoencoderCNN(imgres=imgres, d=128, k=3)
print(autoencoder.encoder.summary())
print(autoencoder.decoder.summary())
64 128 3 Model: "sequential_69" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d_187 (Conv2D) (None, 32, 32, 32) 320 dropout_202 (Dropout) (None, 32, 32, 32) 0 conv2d_188 (Conv2D) (None, 16, 16, 64) 18496 dropout_203 (Dropout) (None, 16, 16, 64) 0 conv2d_189 (Conv2D) (None, 8, 8, 128) 73856 dropout_204 (Dropout) (None, 8, 8, 128) 0 flatten_36 (Flatten) (None, 8192) 0 dense_72 (Dense) (None, 128) 1048704 ================================================================= Total params: 1,141,376 Trainable params: 1,141,376 Non-trainable params: 0 _________________________________________________________________ None Model: "sequential_70" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_73 (Dense) (None, 8192) 1056768 reshape_36 (Reshape) (None, 8, 8, 128) 0 conv2d_transpose_57 (Conv2D (None, 16, 16, 128) 147584 Transpose) dropout_205 (Dropout) (None, 16, 16, 128) 0 conv2d_transpose_58 (Conv2D (None, 32, 32, 64) 73792 Transpose) dropout_206 (Dropout) (None, 32, 32, 64) 0 conv2d_transpose_59 (Conv2D (None, 64, 64, 32) 18464 Transpose) dropout_207 (Dropout) (None, 64, 64, 32) 0 conv2d_transpose_60 (Conv2D (None, 64, 64, 1) 289 Transpose) ================================================================= Total params: 1,296,897 Trainable params: 1,296,897 Non-trainable params: 0 _________________________________________________________________ None
class PlotterCallback(tf.keras.callbacks.Callback):
def __init__(self, n=10):
self.epoch = 0
self.n = n
self.plot_imgs = datagen.flow_from_directory('../bonus/autoencoder/Valid'.format(imgres, imgres),
class_mode="input", batch_size=n, shuffle=True,
target_size=(imgres, imgres), color_mode="grayscale")
self.X = self.plot_imgs.next()[0]
def on_epoch_end(self, epoch, logs=None):
keys = list(logs.keys())
print("End epoch {} of training; got log keys: {}".format(epoch, keys))
X = self.X
encoded_imgs = autoencoder.encoder(X).numpy()
print(encoded_imgs.shape)
decoded_imgs = autoencoder.decoder(encoded_imgs).numpy()
plt.figure(figsize=(20, 5))
n = self.n
for i in range(n):
# display original
plt.subplot(3, n, i + 1)
plt.imshow(X[i, :, :, 0], cmap='gray')
plt.title("original")
plt.axis("off")
# display reconstruction
plt.subplot(3, n, n+i+1)
plt.imshow(decoded_imgs[i], cmap='gray')
plt.title("reconstructed")
plt.axis("off")
# display difference
plt.subplot(3, n, 2*n+i+1)
plt.imshow(X[i, :, :, :] - decoded_imgs[i], cmap='gray')
plt.title("Difference")
plt.axis("off")
self.epoch += 1
plt.savefig("Epoch{}.png".format(self.epoch), bbox_inches='tight')
plt.show()
img = load_img('../bonus/autoencoder/Train/0.jpg')
plt.imshow(img, cmap='gray')
datagen = ImageDataGenerator(rescale=1/255)
train_it = datagen.flow_from_directory('../bonus/autoencoder/Train'.format(imgres, imgres),
class_mode="input", batch_size=64, shuffle=True,
target_size=(imgres, imgres), color_mode="grayscale")
test_it = datagen.flow_from_directory('../bonus/autoencoder/Valid'.format(imgres, imgres),
class_mode="input", batch_size=64, shuffle=True,
target_size=(imgres, imgres), color_mode="grayscale")
Found 9381 images belonging to 1 classes. Found 94 images belonging to 1 classes.
autoencoder.compile(optimizer=Adam(3e-4), loss=losses.MeanSquaredError())
history = autoencoder.fit(train_it,
steps_per_epoch=len(train_it),
validation_data=test_it,
validation_steps=len(test_it),
epochs=25,
verbose=1,
callbacks=[PlotterCallback()])
Found 94 images belonging to 1 classes. Epoch 1/25 147/147 [==============================] - ETA: 0s - loss: 0.0454End epoch 0 of training; got log keys: ['loss', 'val_loss'] (10, 128)
147/147 [==============================] - 116s 777ms/step - loss: 0.0454 - val_loss: 0.0221 Epoch 2/25 147/147 [==============================] - ETA: 0s - loss: 0.0212End epoch 1 of training; got log keys: ['loss', 'val_loss'] (10, 128)
147/147 [==============================] - 125s 853ms/step - loss: 0.0212 - val_loss: 0.0163 Epoch 3/25 147/147 [==============================] - ETA: 0s - loss: 0.0174End epoch 2 of training; got log keys: ['loss', 'val_loss'] (10, 128)
147/147 [==============================] - 116s 784ms/step - loss: 0.0174 - val_loss: 0.0142 Epoch 4/25 147/147 [==============================] - ETA: 0s - loss: 0.0157End epoch 3 of training; got log keys: ['loss', 'val_loss'] (10, 128)
147/147 [==============================] - 121s 819ms/step - loss: 0.0157 - val_loss: 0.0129 Epoch 5/25 147/147 [==============================] - ETA: 0s - loss: 0.0145End epoch 4 of training; got log keys: ['loss', 'val_loss'] (10, 128)
147/147 [==============================] - 130s 882ms/step - loss: 0.0145 - val_loss: 0.0121 Epoch 6/25 147/147 [==============================] - ETA: 0s - loss: 0.0137End epoch 5 of training; got log keys: ['loss', 'val_loss'] (10, 128)
147/147 [==============================] - 121s 817ms/step - loss: 0.0137 - val_loss: 0.0115 Epoch 7/25 147/147 [==============================] - ETA: 0s - loss: 0.0131End epoch 6 of training; got log keys: ['loss', 'val_loss'] (10, 128)
147/147 [==============================] - 123s 836ms/step - loss: 0.0131 - val_loss: 0.0109 Epoch 8/25 147/147 [==============================] - ETA: 0s - loss: 0.0126End epoch 7 of training; got log keys: ['loss', 'val_loss'] (10, 128)
147/147 [==============================] - 126s 857ms/step - loss: 0.0126 - val_loss: 0.0105 Epoch 9/25 147/147 [==============================] - ETA: 0s - loss: 0.0122End epoch 8 of training; got log keys: ['loss', 'val_loss'] (10, 128)
147/147 [==============================] - 128s 866ms/step - loss: 0.0122 - val_loss: 0.0102 Epoch 10/25 147/147 [==============================] - ETA: 0s - loss: 0.0119End epoch 9 of training; got log keys: ['loss', 'val_loss'] (10, 128)
147/147 [==============================] - 127s 862ms/step - loss: 0.0119 - val_loss: 0.0099 Epoch 11/25 147/147 [==============================] - ETA: 0s - loss: 0.0116End epoch 10 of training; got log keys: ['loss', 'val_loss'] (10, 128)
147/147 [==============================] - 141s 959ms/step - loss: 0.0116 - val_loss: 0.0097 Epoch 12/25 7/147 [>.............................] - ETA: 1:55 - loss: 0.0117
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) <ipython-input-82-353d6d2b9e92> in <module> 1 autoencoder.compile(optimizer=Adam(3e-4), loss=losses.MeanSquaredError()) ----> 2 history = autoencoder.fit(train_it, 3 steps_per_epoch=len(train_it), 4 validation_data=test_it, 5 validation_steps=len(test_it), /opt/miniconda3/lib/python3.8/site-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs) 62 filtered_tb = None 63 try: ---> 64 return fn(*args, **kwargs) 65 except Exception as e: # pylint: disable=broad-except 66 filtered_tb = _process_traceback_frames(e.__traceback__) /opt/miniconda3/lib/python3.8/site-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing) 1214 _r=1): 1215 callbacks.on_train_batch_begin(step) -> 1216 tmp_logs = self.train_function(iterator) 1217 if data_handler.should_sync: 1218 context.async_wait() /opt/miniconda3/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py in error_handler(*args, **kwargs) 148 filtered_tb = None 149 try: --> 150 return fn(*args, **kwargs) 151 except Exception as e: 152 filtered_tb = _process_traceback_frames(e.__traceback__) /opt/miniconda3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds) 908 909 with OptionalXlaContext(self._jit_compile): --> 910 result = self._call(*args, **kwds) 911 912 new_tracing_count = self.experimental_get_tracing_count() /opt/miniconda3/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds) 940 # In this case we have created variables on the first call, so we run the 941 # defunned version which is guaranteed to never create variables. --> 942 return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable 943 elif self._stateful_fn is not None: 944 # Release the lock early so that multiple threads can perform the call /opt/miniconda3/lib/python3.8/site-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs) 3128 (graph_function, 3129 filtered_flat_args) = self._maybe_define_function(args, kwargs) -> 3130 return graph_function._call_flat( 3131 filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access 3132 /opt/miniconda3/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager) 1957 and executing_eagerly): 1958 # No tape is watching; skip to running the function. -> 1959 return self._build_call_outputs(self._inference_function.call( 1960 ctx, args, cancellation_manager=cancellation_manager)) 1961 forward_backward = self._select_forward_and_backward_functions( /opt/miniconda3/lib/python3.8/site-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager) 596 with _InterpolateFunctionError(self): 597 if cancellation_manager is None: --> 598 outputs = execute.execute( 599 str(self.signature.name), 600 num_outputs=self._num_outputs, /opt/miniconda3/lib/python3.8/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name) 56 try: 57 ctx.ensure_initialized() ---> 58 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name, 59 inputs, attrs, num_outputs) 60 except core._NotOkStatusException as e: KeyboardInterrupt:
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Training Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Training Data', 'Test Data'], loc='upper left')
<matplotlib.legend.Legend at 0x7f9bdb412bb0>
noise = np.random.randn(1, 128)
print(noise)
plt.imshow(noise, cmap='gray')
[[ 0.36365099 -0.57124146 -2.24982094 -0.04572127 -0.7623621 0.27381251 -0.46977491 -0.28089856 -1.84885116 0.89819736 -0.05504846 -0.08784695 -2.45880068 0.19547292 1.11674366 -0.73704871 -0.37318174 1.28677856 -0.5078776 1.86896335 0.38249207 -1.20543122 -0.55466725 -2.23541235 -0.13419347 -1.52287624 -0.69169099 -1.49436769 -1.29809662 -0.64439226 0.70374901 -0.70314049 -0.45782775 -0.36123666 -0.23116846 -0.18226894 0.21925094 -0.65435838 -0.09766458 1.46278022 0.52346102 -0.98701142 0.59982174 -1.29426715 0.41508662 0.06726141 -0.38352217 0.76913507 0.16934828 -1.29085469 -0.96013425 0.4173405 -2.00407608 -0.1325844 -0.1890487 0.26696987 -0.9545931 -0.94540622 0.56162225 -0.04396481 -0.14732489 -0.42755104 -0.03470711 -1.70312179 0.10209181 0.30309277 -0.55009911 0.38287504 -0.99443193 0.63477 -0.43802978 -1.83313165 1.16701813 0.44465108 0.12858884 0.23753583 -0.46203722 -1.32656806 -1.09334076 -0.94791051 1.42706351 -0.26378179 1.28744467 1.04426902 0.87610991 -1.26928409 1.07233458 1.5931745 -1.07138529 -0.78709491 -1.95043178 1.23528422 1.00247891 0.90650039 0.05646435 -1.25317373 0.23526039 -0.60300706 1.02105325 -0.0844902 -0.74507911 0.01892909 1.6358378 -0.15490854 -1.29048747 0.08010853 1.69385517 1.59656401 1.54696693 -0.34664569 0.20329954 1.22005823 -0.53798122 -0.28738952 0.76955664 -0.6904962 -1.20409866 -1.02080907 0.84095057 -0.27847734 0.79985154 -0.4083079 0.17286629 -0.63939139 0.13572822 0.56072906 1.90363916 -0.1986177 ]]
<matplotlib.image.AxesImage at 0x7f9bdb232730>
x = autoencoder.decoder(noise).numpy()
plt.imshow(x[0, :, :, 0], cmap='gray')
<matplotlib.image.AxesImage at 0x7f9cf9506700>
autoencoder.save_weights("models/catautoencoder")