0

If I build the decoder as a mirror of encoder the output size of the last layer does not match.

This is the model summary:

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
conv_1_j (Conv2D)            (None, 28, 28, 64)        640       
_________________________________________________________________
batch_normalization_v2 (Batc (None, 28, 28, 64)        256       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 64)        0         
_________________________________________________________________
conv_2_j (Conv2D)            (None, 14, 14, 64)        36928     
_________________________________________________________________
batch_normalization_v2_1 (Ba (None, 14, 14, 64)        256       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 64)          0         
_________________________________________________________________
conv_3_j (Conv2D)            (None, 7, 7, 64)          36928     
_________________________________________________________________
batch_normalization_v2_2 (Ba (None, 7, 7, 64)          256       
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 3, 3, 64)          0         
_________________________________________________________________
conv_4_j (Conv2D)            (None, 3, 3, 64)          36928     
_________________________________________________________________
batch_normalization_v2_3 (Ba (None, 3, 3, 64)          256       
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 1, 1, 64)          0         
_________________________________________________________________
flatten (Flatten)            (None, 64)                0         
_________________________________________________________________
dense_1_j (Dense)            (None, 64)                4160      
_________________________________________________________________
reshape_out (Lambda)         (None, 1, 1, 64)          0         
_________________________________________________________________
conv2d (Conv2D)              (None, 1, 1, 64)          36928     
_________________________________________________________________
batch_normalization_v2_4 (Ba (None, 1, 1, 64)          256       
_________________________________________________________________
up_sampling2d (UpSampling2D) (None, 2, 2, 64)          0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 2, 2, 64)          36928     
_________________________________________________________________
batch_normalization_v2_5 (Ba (None, 2, 2, 64)          256       
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 4, 4, 64)          0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 4, 4, 64)          36928     
_________________________________________________________________
batch_normalization_v2_6 (Ba (None, 4, 4, 64)          256       
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 8, 8, 64)          0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 8, 8, 64)          36928     
_________________________________________________________________
batch_normalization_v2_7 (Ba (None, 8, 8, 64)          256       
_________________________________________________________________
up_sampling2d_3 (UpSampling2 (None, 16, 16, 64)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 16, 16, 1)         577       
=================================================================
Total params: 265,921
Trainable params: 264,897
Non-trainable params: 1,024
_________________________________________________________________

Code to reproduce:

import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
from tensorflow.python.keras.layers import Lambda

from tensorflow.python.keras.callbacks import TensorBoard, ModelCheckpoint, EarlyStopping


def resize(example):
    image = example['image']
    image = tf.image.resize(image, [28, 28])
    image = tf.image.rgb_to_grayscale(image, )
    image = image / 255

    example['image'] = image
    return example


def get_tupel(example):
    return example['image'], example['image']


def gen_dataset(dataset, batch_size):
    dataset = dataset.map(resize, num_parallel_calls=4)
    dataset = dataset.map(get_tupel, num_parallel_calls=4)
    dataset = dataset.shuffle(batch_size*50).repeat()  # infinite stream
    dataset = dataset.prefetch(10000)
    dataset = dataset.batch(batch_size)
    return dataset


def main():
    builder = tfds.builder("cifar10")
    builder.download_and_prepare()
    datasets = builder.as_dataset()
    train_dataset, test_dataset = datasets['train'], datasets['test']

    batch_size = 48

    train_dataset = gen_dataset(train_dataset, batch_size)
    test_dataset = gen_dataset(test_dataset, batch_size)

    device = '/cpu:0' if not tf.test.is_gpu_available() else tf.test.gpu_device_name()
    print(tf.test.gpu_device_name())
    with tf.device(device):
        filters = 64
        kernel = 3
        pooling = 2
        image_size = 28

        inp_layer = tf.keras.layers.Input(shape=(image_size, image_size, 1))
        cnn_embedding_out = cnn_encoder(inp_layer, filters, kernel, pooling)
        cnn_decoder_out = cnn_decoder(cnn_embedding_out, filters, kernel, pooling)

        model = tf.keras.Model(inputs=inp_layer, outputs=cnn_decoder_out)

        model.compile(optimizer=tf.optimizers.Adam(0.0001), loss='binary_crossentropy',
                      metrics=['accuracy'])

        print(model.summary())

        model.fit(train_dataset, validation_data=test_dataset,
                  steps_per_epoch=100,  # 1000
                  validation_steps=100,
                  epochs=150,)


def cnn_encoder(inp_layer, filters, kernel, pooling):
    cnn1 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu', name='conv_1_j')(inp_layer)
    bn1 = tf.keras.layers.BatchNormalization()(cnn1)
    max1 = tf.keras.layers.MaxPooling2D(pooling, pooling, padding="valid")(bn1)
    cnn2 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu', name='conv_2_j')(max1)
    bn2 = tf.keras.layers.BatchNormalization()(cnn2)
    max2 = tf.keras.layers.MaxPooling2D(pooling, pooling, padding="valid")(bn2)
    cnn3 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu', name='conv_3_j')(max2)
    bn3 = tf.keras.layers.BatchNormalization()(cnn3)
    max3 = tf.keras.layers.MaxPooling2D(pooling, pooling, padding="valid")(bn3)
    cnn4 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu', name='conv_4_j')(max3)
    bn4 = tf.keras.layers.BatchNormalization()(cnn4)
    max4 = tf.keras.layers.MaxPooling2D(pooling, pooling, padding="valid")(bn4)
    flat = tf.keras.layers.Flatten()(max4)
    fc = tf.keras.layers.Dense(64, name='dense_1_j')(flat)  # this is the encoder layer!

    return fc


def cnn_decoder(inp_layer, filters, kernel, pooling):
    res1 = reshape([1, 1, filters], name="reshape_out")(inp_layer)
    cnn1 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(res1)
    bn1 = tf.keras.layers.BatchNormalization()(cnn1)
    up1 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn1)
    cnn2 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up1)
    bn2 = tf.keras.layers.BatchNormalization()(cnn2)
    up2 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn2)
    cnn3 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up2)
    bn3 = tf.keras.layers.BatchNormalization()(cnn3)
    up3 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn3)
    cnn4 = tf.keras.layers.Conv2D(filters, kernel,  padding="same", activation='relu',)(up3)
    bn4 = tf.keras.layers.BatchNormalization()(cnn4)
    up4 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn4)
    decoded = tf.keras.layers.Conv2D(1, kernel,  padding="same", activation='sigmoid')(up4)
    return decoded


def reshape(dim, name="", complete=False):
    def func(x):
        if complete:
            ret = tf.reshape(x, dim)
        else:
            ret = tf.reshape(x, [-1, ] + dim)
        return ret
    return Lambda(func, name=name)


if __name__ == "__main__":
    main()

I tried to use Conv2dTranspose and different size of upsampling, but this doesn't feel right.

I would expect the output as the input (48,28,28,1)

What am I doing wrong?

2 Answers 2

2

as you can see in reshape_out (Lambda) you have shape 1, so by doing operations like UpSampling2 you can just get to size 2, 4, 8, 16, 32. To do your way you may have to do UpSampling2 until size 32, then reshape by doing for instance two Conv2D of kernel=3 to get a result of shape (x, 28, 28, 64)

Sign up to request clarification or add additional context in comments.

3 Comments

this works only if I drop the padding="same" flag, is this correct?
Yes because, As you can see in your model the dimension 2/3 are the same after Conv2D, you should use padding="valid" to crop edges after a convolution
now it trains and the loss is also dropping but seems strange to me. Thank you very much for your answer.
1

I changed the cnn_decoder function:

def cnn_decoder(inp_layer, filters, kernel, pooling):
    res1 = reshape([1, 1, filters], name="reshape_out")(inp_layer)
    cnn1 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(res1)
    bn1 = tf.keras.layers.BatchNormalization()(cnn1)
    up1 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn1)
    cnn2 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up1)
    bn2 = tf.keras.layers.BatchNormalization()(cnn2)
    up2 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn2)
    cnn3 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up2)
    bn3 = tf.keras.layers.BatchNormalization()(cnn3)
    up3 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn3)
    cnn4 = tf.keras.layers.Conv2D(filters, kernel,  padding="same", activation='relu',)(up3)
    bn4 = tf.keras.layers.BatchNormalization()(cnn4)
    up4 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn4)
    cnn5 = tf.keras.layers.Conv2D(filters, kernel,  padding="same", activation='relu',)(up4)
    bn5 = tf.keras.layers.BatchNormalization()(cnn5)
    up5 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn5)
    cnn6 = tf.keras.layers.Conv2D(filters, kernel, padding="valid",  activation='relu')(up5)
    bn5 = tf.keras.layers.BatchNormalization()(cnn6)
    decoded = tf.keras.layers.Conv2D(1, kernel, padding="valid",  activation='sigmoid')(bn5)
    return decoded

Does it seem correct to you?

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.