From 235ca351c5de637061c96390186987412f2c7625 Mon Sep 17 00:00:00 2001 From: Taehoon Lee Date: Mon, 16 Jul 2018 10:56:23 +0900 Subject: [PATCH] Update `kernel_initializer` for `ResNet50` --- keras_applications/resnet50.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/keras_applications/resnet50.py b/keras_applications/resnet50.py index c4ca2ea..ff5f8ca 100644 --- a/keras_applications/resnet50.py +++ b/keras_applications/resnet50.py @@ -59,16 +59,21 @@ def identity_block(input_tensor, kernel_size, filters, stage, block): bn_name_base = 'bn' + str(stage) + block + '_branch' x = layers.Conv2D(filters1, (1, 1), + kernel_initializer='he_normal', name=conv_name_base + '2a')(input_tensor) x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) x = layers.Activation('relu')(x) x = layers.Conv2D(filters2, kernel_size, - padding='same', name=conv_name_base + '2b')(x) + padding='same', + kernel_initializer='he_normal', + name=conv_name_base + '2b')(x) x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) x = layers.Activation('relu')(x) - x = layers.Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x) + x = layers.Conv2D(filters3, (1, 1), + kernel_initializer='he_normal', + name=conv_name_base + '2c')(x) x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) x = layers.add([x, input_tensor]) @@ -109,19 +114,24 @@ def conv_block(input_tensor, bn_name_base = 'bn' + str(stage) + block + '_branch' x = layers.Conv2D(filters1, (1, 1), strides=strides, + kernel_initializer='he_normal', name=conv_name_base + '2a')(input_tensor) x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) x = layers.Activation('relu')(x) x = layers.Conv2D(filters2, kernel_size, padding='same', + kernel_initializer='he_normal', name=conv_name_base + '2b')(x) x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) x = layers.Activation('relu')(x) - x = layers.Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x) + x = layers.Conv2D(filters3, (1, 1), + kernel_initializer='he_normal', + name=conv_name_base + '2c')(x) x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) shortcut = layers.Conv2D(filters3, (1, 1), strides=strides, + kernel_initializer='he_normal', name=conv_name_base + '1')(input_tensor) shortcut = layers.BatchNormalization( axis=bn_axis, name=bn_name_base + '1')(shortcut) @@ -214,6 +224,7 @@ def ResNet50(include_top=True, x = layers.Conv2D(64, (7, 7), strides=(2, 2), padding='valid', + kernel_initializer='he_normal', name='conv1')(x) x = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')(x) x = layers.Activation('relu')(x)