0

I have trained an autoencoder and saved it using keras built in save() method. Now I want to split it into two parts: Encoder and decoder. I can successfully load the model and get the encoder part by creating a new model using the old model:

encoder_model = keras.models.Model(inputs=self.model.input, 
 outputs=self.model.get_layer(layer_of_activations).get_output_at(0))

However, if I try to do the alternative thing with decoder, I cannot. I tried it using various methods, none of which were correct. Then I found a similar issue here (Keras replacing input layer) and tried using this method using code below:

    for i, l in enumerate(self.model.layers[0:19]):
        self.model.layers.pop(0)
    newInput = Input(batch_shape=(None, None, None, 64))
    newOutputs = self.model(newInput)
    newModel = keras.models.Model(newInput, newOutputs)

The output shape of the last layer I remove is (None, None, None, 64), but this code produces the following error:

ValueError: number of input channels does not match corresponding dimension of filter, 64 != 3

I assume this is because the input dimensions of the model are not updated after popping original layers, which is noted in this question's first answer, second comment: Keras replacing input layer

Simply looping through the layers and recreating them in a new model does not work as my model is not sequential.

1 Answer 1

1

I resolved this by building a new model with the exact same architecture as the decoder part of the original autoencoder network and then just copied the weights.

Here's the code:

    # Looping through the old model and popping the encoder part + encoded layer
    for i, l in enumerate(self.model.layers[0:19]): 
        self.model.layers.pop(0)

    # Building a clean model that is the exact same architecture as the decoder part of the autoencoder
    new_model = nb.build_decoder()

    # Looping through both models and setting the weights on the new decoder
    for i, l in enumerate(self.model.layers):
        new_model.layers[i+1].set_weights(l.get_weights())
Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.