Skip to content
This repository has been archived by the owner on Nov 3, 2022. It is now read-only.

Commit

Permalink
Update kernel_initializer for ResNet50
Browse files Browse the repository at this point in the history
  • Loading branch information
taehoonlee committed Jul 16, 2018
1 parent fbf035b commit 235ca35
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions keras_applications/resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,21 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
bn_name_base = 'bn' + str(stage) + block + '_branch'

x = layers.Conv2D(filters1, (1, 1),
kernel_initializer='he_normal',
name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
x = layers.Activation('relu')(x)

x = layers.Conv2D(filters2, kernel_size,
padding='same', name=conv_name_base + '2b')(x)
padding='same',
kernel_initializer='he_normal',
name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
x = layers.Activation('relu')(x)

x = layers.Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x)
x = layers.Conv2D(filters3, (1, 1),
kernel_initializer='he_normal',
name=conv_name_base + '2c')(x)
x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)

x = layers.add([x, input_tensor])
Expand Down Expand Up @@ -109,19 +114,24 @@ def conv_block(input_tensor,
bn_name_base = 'bn' + str(stage) + block + '_branch'

x = layers.Conv2D(filters1, (1, 1), strides=strides,
kernel_initializer='he_normal',
name=conv_name_base + '2a')(input_tensor)
x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
x = layers.Activation('relu')(x)

x = layers.Conv2D(filters2, kernel_size, padding='same',
kernel_initializer='he_normal',
name=conv_name_base + '2b')(x)
x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
x = layers.Activation('relu')(x)

x = layers.Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x)
x = layers.Conv2D(filters3, (1, 1),
kernel_initializer='he_normal',
name=conv_name_base + '2c')(x)
x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)

shortcut = layers.Conv2D(filters3, (1, 1), strides=strides,
kernel_initializer='he_normal',
name=conv_name_base + '1')(input_tensor)
shortcut = layers.BatchNormalization(
axis=bn_axis, name=bn_name_base + '1')(shortcut)
Expand Down Expand Up @@ -214,6 +224,7 @@ def ResNet50(include_top=True,
x = layers.Conv2D(64, (7, 7),
strides=(2, 2),
padding='valid',
kernel_initializer='he_normal',
name='conv1')(x)
x = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
x = layers.Activation('relu')(x)
Expand Down

0 comments on commit 235ca35

Please sign in to comment.