python – Tensorflow Federated Learning on ResNet failse

I do some some experiments with the tensorflow federated learning API. Actualy I try to train a simple ResNet on 10 Clients. Based on the data and metrics, the training seems to be successful. But the evaluation as well as local and federated fails.

Does anyone have an advice?

The model:

def create_keras_resnet_model(): 

    inputs = tf.keras.layers.Input(shape=(28,28,1))
    bn0 = tf.keras.layers.BatchNormalization(scale=True)(inputs)

    conv1 = tf.keras.layers.Conv2D(filters=32, 
                               kernel_size=(7,7), 
                               padding='same', 
                               activation='relu', 
                               kernel_initializer="uniform")(bn0)
    conv1 = tf.keras.layers.Conv2D(filters=32, 
                               kernel_size=(7,7), 
                               padding='same', 
                               activation='relu', 
                               kernel_initializer="uniform")(conv1)
    bn1 = tf.keras.layers.BatchNormalization(scale=True)(conv1)
    max_pool1 = tf.keras.layers.MaxPooling2D(pool_size=(2,2), padding = 'same')(bn1)


    conv2 = tf.keras.layers.Conv2D(filters=32, 
                               kernel_size=(5,5), 
                               padding='same', 
                               activation='relu', 
                               kernel_initializer="uniform")(max_pool1)
    conv2 = tf.keras.layers.Conv2D(filters=32, 
                               kernel_size=(5,5), 
                               padding='same', 
                               activation='relu', 
                               kernel_initializer="uniform")(conv2)
    conv2 = tf.keras.layers.Conv2D(filters=32, 
                               kernel_size=(5,5), 
                               padding='same', 
                               activation='relu', 
                               kernel_initializer="uniform")(conv2)
    bn2 = tf.keras.layers.BatchNormalization(scale=True)(conv2)

    res1_conv = tf.keras.layers.Conv2D(filters = 32,
                                  kernel_size = (3,3),
                                  padding = 'same',
                                  kernel_initializer="uniform")(max_pool1)
    res1_bn = tf.keras.layers.BatchNormalization(scale=True)(res1_conv)

    add1 = tf.keras.layers.Add()([res1_bn, bn2])


    act1 = tf.keras.layers.Activation('relu')(add1)
    max_pool2 = tf.keras.layers.MaxPooling2D(pool_size=(2,2), padding = 'same')(act1)

    conv3 = tf.keras.layers.Conv2D(filters=32, 
                               kernel_size=(5,5), 
                               padding='same', 
                               activation='relu', 
                               kernel_initializer="uniform")(max_pool2)
    conv3 = tf.keras.layers.Conv2D(filters=32, 
                               kernel_size=(5,5), 
                               padding='same', 
                               activation='relu', 
                               kernel_initializer="uniform")(conv3)
    conv3 = tf.keras.layers.Conv2D(filters=32, 
                               kernel_size=(5,5), 
                               padding='same', 
                               activation='relu', 
                               kernel_initializer="uniform")(conv3)
    bn2 = tf.keras.layers.BatchNormalization(scale=True)(conv3)

    res2_conv = tf.keras.layers.Conv2D(filters = 32,
                                  kernel_size = (3,3),
                                  padding = 'same',
                                  kernel_initializer="uniform")(max_pool2)
    res2_bn = tf.keras.layers.BatchNormalization(scale=True)(res2_conv)

    add2 = tf.keras.layers.Add()([res2_bn, bn2])

    act2 = tf.keras.layers.Activation('relu')(add2)
    max_pool3 = tf.keras.layers.MaxPooling2D(pool_size=(2,2), padding = 'same')(act2)

    flatten = tf.keras.layers.Flatten()(max_pool3)

    dense1 = tf.keras.layers.Dense(128, activation='relu')(flatten)

    do = tf.keras.layers.Dropout(0.20)(dense1)
    dense2 = tf.keras.layers.Dense(10, activation='softmax')(do)

    model = tf.keras.models.Model(inputs=[inputs], outputs=[dense2])

    return model

The model is just a simple ResNet. For the training I use the Tensorflow Federated Simulation Dataset for emnist and here 10 clients for 10 epochs.

Everything looks fine so far…

I have adjusted the provided function for preparing the data. I have already tested the whole process with a simple CNN and all works quiet well.

def preprocess(dataset):

def batch_format_fn(element):
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 28, 28, 1]),
        y=tf.reshape(element['label'], [-1, 1])
    )
    
return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=42).batch(BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

Doing the evaluation process with tensorflow shows a strange result. The accuracy will be at around 11 percent and the loss has something between 7 and 8.

If I copy the weights to a local model and do the evaluation local, the same result. If I try to predict a single image from the test data an exception is thrown:

ValueError: Input 0 of layer dense_10 is incompatible with the layer: expected axis -1 of input shape to have value 512 but received input with shape (None, 128)

Here the model summaray:

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_13 (InputLayer)           [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
batch_normalization_72 (BatchNo (None, 28, 28, 1)    4           input_13[0][0]                   
__________________________________________________________________________________________________
conv2d_120 (Conv2D)             (None, 28, 28, 32)   1600        batch_normalization_72[0][0]     
__________________________________________________________________________________________________
conv2d_121 (Conv2D)             (None, 28, 28, 32)   50208       conv2d_120[0][0]                 
__________________________________________________________________________________________________
batch_normalization_73 (BatchNo (None, 28, 28, 32)   128         conv2d_121[0][0]                 
__________________________________________________________________________________________________
max_pooling2d_36 (MaxPooling2D) (None, 14, 14, 32)   0           batch_normalization_73[0][0]     
__________________________________________________________________________________________________
conv2d_122 (Conv2D)             (None, 14, 14, 32)   25632       max_pooling2d_36[0][0]           
__________________________________________________________________________________________________
conv2d_123 (Conv2D)             (None, 14, 14, 32)   25632       conv2d_122[0][0]                 
__________________________________________________________________________________________________
conv2d_125 (Conv2D)             (None, 14, 14, 32)   9248        max_pooling2d_36[0][0]           
__________________________________________________________________________________________________
conv2d_124 (Conv2D)             (None, 14, 14, 32)   25632       conv2d_123[0][0]                 
__________________________________________________________________________________________________
batch_normalization_75 (BatchNo (None, 14, 14, 32)   128         conv2d_125[0][0]                 
__________________________________________________________________________________________________
batch_normalization_74 (BatchNo (None, 14, 14, 32)   128         conv2d_124[0][0]                 
__________________________________________________________________________________________________
add_24 (Add)                    (None, 14, 14, 32)   0           batch_normalization_75[0][0]     
                                                                 batch_normalization_74[0][0]     
__________________________________________________________________________________________________
activation_24 (Activation)      (None, 14, 14, 32)   0           add_24[0][0]                     
__________________________________________________________________________________________________
max_pooling2d_37 (MaxPooling2D) (None, 7, 7, 32)     0           activation_24[0][0]              
__________________________________________________________________________________________________
conv2d_126 (Conv2D)             (None, 7, 7, 32)     25632       max_pooling2d_37[0][0]           
__________________________________________________________________________________________________
conv2d_127 (Conv2D)             (None, 7, 7, 32)     25632       conv2d_126[0][0]                 
__________________________________________________________________________________________________
conv2d_129 (Conv2D)             (None, 7, 7, 32)     9248        max_pooling2d_37[0][0]           
__________________________________________________________________________________________________
conv2d_128 (Conv2D)             (None, 7, 7, 32)     25632       conv2d_127[0][0]                 
__________________________________________________________________________________________________
batch_normalization_77 (BatchNo (None, 7, 7, 32)     128         conv2d_129[0][0]                 
__________________________________________________________________________________________________
batch_normalization_76 (BatchNo (None, 7, 7, 32)     128         conv2d_128[0][0]                 
__________________________________________________________________________________________________
add_25 (Add)                    (None, 7, 7, 32)     0           batch_normalization_77[0][0]     
                                                                 batch_normalization_76[0][0]     
__________________________________________________________________________________________________
activation_25 (Activation)      (None, 7, 7, 32)     0           add_25[0][0]                     
__________________________________________________________________________________________________
max_pooling2d_38 (MaxPooling2D) (None, 4, 4, 32)     0           activation_25[0][0]              
__________________________________________________________________________________________________
flatten_12 (Flatten)            (None, 512)          0           max_pooling2d_38[0][0]           
__________________________________________________________________________________________________
dense_24 (Dense)                (None, 128)          65664       flatten_12[0][0]                 
__________________________________________________________________________________________________
dropout_12 (Dropout)            (None, 128)          0           dense_24[0][0]                   
__________________________________________________________________________________________________
dense_25 (Dense)                (None, 10)           1290        dropout_12[0][0]                 
==================================================================================================
Total params: 291,694
Trainable params: 291,372
Non-trainable params: 322
__________________________________________________________________________________________________

I did not convert the labels with with to_categorical function from the karas util package. But why is the exception, the input of the dense layer is wrong? And why does the training work?

Leave a Comment