Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using the Forward-Forward Algorithm for Image Classification to keras 3.0 (Tensorflow backend only) #1932

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 50 additions & 38 deletions examples/vision/forwardforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Title: Using the Forward-Forward Algorithm for Image Classification
Author: [Suvaditya Mukherjee](https://twitter.com/halcyonrayes)
Date created: 2023/01/08
Last modified: 2023/01/08
Last modified: 2024/09/17
Description: Training a Dense-layer model using the Forward-Forward algorithm.
Accelerator: GPU
"""
Expand Down Expand Up @@ -59,9 +59,13 @@
"""
## Setup imports
"""
import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import tensorflow as tf
from tensorflow import keras
import keras
from keras import ops
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
Expand Down Expand Up @@ -143,7 +147,7 @@ class FFDense(keras.layers.Layer):
def __init__(
self,
units,
optimizer,
init_optimizer,
loss_metric,
num_epochs=50,
use_bias=True,
Expand All @@ -163,7 +167,7 @@ def __init__(
bias_regularizer=bias_regularizer,
)
self.relu = keras.layers.ReLU()
self.optimizer = optimizer
self.optimizer = init_optimizer()
self.loss_metric = loss_metric
self.threshold = 1.5
self.num_epochs = num_epochs
Expand All @@ -172,7 +176,7 @@ def __init__(
# layer.

def call(self, x):
x_norm = tf.norm(x, ord=2, axis=1, keepdims=True)
x_norm = ops.norm(x, ord=2, axis=1, keepdims=True)
x_norm = x_norm + 1e-4
x_dir = x / x_norm
res = self.dense(x_dir)
Expand All @@ -192,22 +196,24 @@ def call(self, x):
def forward_forward(self, x_pos, x_neg):
for i in range(self.num_epochs):
with tf.GradientTape() as tape:
g_pos = tf.math.reduce_mean(tf.math.pow(self.call(x_pos), 2), 1)
g_neg = tf.math.reduce_mean(tf.math.pow(self.call(x_neg), 2), 1)
g_pos = ops.mean(ops.power(self.call(x_pos), 2), 1)
g_neg = ops.mean(ops.power(self.call(x_neg), 2), 1)

loss = tf.math.log(
loss = ops.log(
1
+ tf.math.exp(
tf.concat([-g_pos + self.threshold, g_neg - self.threshold], 0)
+ ops.exp(
ops.concatenate(
[-g_pos + self.threshold, g_neg - self.threshold], 0
)
)
)
mean_loss = tf.cast(tf.math.reduce_mean(loss), tf.float32)
mean_loss = ops.cast(ops.mean(loss), dtype="float32")
self.loss_metric.update_state([mean_loss])
gradients = tape.gradient(mean_loss, self.dense.trainable_weights)
self.optimizer.apply_gradients(zip(gradients, self.dense.trainable_weights))
return (
tf.stop_gradient(self.call(x_pos)),
tf.stop_gradient(self.call(x_neg)),
ops.stop_gradient(self.call(x_pos)),
ops.stop_gradient(self.call(x_neg)),
self.loss_metric.result(),
)

Expand Down Expand Up @@ -248,25 +254,24 @@ class FFNetwork(keras.Model):
# the `Adam` optimizer with a default learning rate of 0.03 as that was
# found to be the best rate after experimentation.
# Loss is tracked using `loss_var` and `loss_count` variables.
# Use legacy optimizer for Layer Optimizer to fix issue
# https://github.com/keras-team/keras-io/issues/1241

def __init__(
self,
dims,
layer_optimizer=keras.optimizers.legacy.Adam(learning_rate=0.03),
init_layer_optimizer=lambda: keras.optimizers.Adam(learning_rate=0.03),
**kwargs,
):
super().__init__(**kwargs)
self.layer_optimizer = layer_optimizer
self.loss_var = tf.Variable(0.0, trainable=False, dtype=tf.float32)
self.loss_count = tf.Variable(0.0, trainable=False, dtype=tf.float32)
self.init_layer_optimizer = init_layer_optimizer
self.loss_var = keras.Variable(0.0, trainable=False, dtype="float32")
self.loss_count = keras.Variable(0.0, trainable=False, dtype="float32")
self.layer_list = [keras.Input(shape=(dims[0],))]
self.metrics_built = False
for d in range(len(dims) - 1):
self.layer_list += [
FFDense(
dims[d + 1],
optimizer=self.layer_optimizer,
init_optimizer=self.init_layer_optimizer,
loss_metric=keras.metrics.Mean(),
)
]
Expand All @@ -280,9 +285,9 @@ def __init__(
@tf.function(reduce_retracing=True)
def overlay_y_on_x(self, data):
X_sample, y_sample = data
max_sample = tf.reduce_max(X_sample, axis=0, keepdims=True)
max_sample = tf.cast(max_sample, dtype=tf.float64)
X_zeros = tf.zeros([10], dtype=tf.float64)
max_sample = ops.amax(X_sample, axis=0, keepdims=True)
max_sample = ops.cast(max_sample, dtype="float64")
X_zeros = ops.zeros([10], dtype="float64")
X_update = xla.dynamic_update_slice(X_zeros, max_sample, [y_sample])
X_sample = xla.dynamic_update_slice(X_sample, X_update, [0])
return X_sample, y_sample
Expand All @@ -297,25 +302,23 @@ def overlay_y_on_x(self, data):
@tf.function(reduce_retracing=True)
def predict_one_sample(self, x):
goodness_per_label = []
x = tf.reshape(x, [tf.shape(x)[0] * tf.shape(x)[1]])
x = ops.reshape(x, [ops.shape(x)[0] * ops.shape(x)[1]])
for label in range(10):
h, label = self.overlay_y_on_x(data=(x, label))
h = tf.reshape(h, [-1, tf.shape(h)[0]])
h = ops.reshape(h, [-1, ops.shape(h)[0]])
goodness = []
for layer_idx in range(1, len(self.layer_list)):
layer = self.layer_list[layer_idx]
h = layer(h)
goodness += [tf.math.reduce_mean(tf.math.pow(h, 2), 1)]
goodness_per_label += [
tf.expand_dims(tf.reduce_sum(goodness, keepdims=True), 1)
]
goodness += [ops.mean(ops.power(h, 2), 1)]
goodness_per_label += [ops.expand_dims(ops.sum(goodness, keepdims=True), 1)]
goodness_per_label = tf.concat(goodness_per_label, 1)
return tf.cast(tf.argmax(goodness_per_label, 1), tf.float64)
return ops.cast(ops.argmax(goodness_per_label, 1), dtype="float64")

def predict(self, data):
x = data
preds = list()
preds = tf.map_fn(fn=self.predict_one_sample, elems=x)
preds = ops.vectorized_map(fn=self.predict_one_sample, elems=x)
return np.asarray(preds, dtype=int)

# This custom `train_step` function overrides the internal `train_step`
Expand All @@ -328,14 +331,23 @@ def predict(self, data):
# the Forward-Forward computation on it. The returned loss is the final
# loss value over all the layers.

@tf.function(jit_compile=True)
@tf.function(jit_compile=False)
def train_step(self, data):
x, y = data

if not self.metrics_built:
# build metrics to ensure they can be queried without erroring out.
# We can't update the metrics' state, as we would usually do, since
# we do not perform predictions within the train step
for metric in self.metrics:
if hasattr(metric, "build"):
metric.build(y, y)
self.metrics_built = True

# Flatten op
x = tf.reshape(x, [-1, tf.shape(x)[1] * tf.shape(x)[2]])
x = ops.reshape(x, [-1, ops.shape(x)[1] * ops.shape(x)[2]])

x_pos, y = tf.map_fn(fn=self.overlay_y_on_x, elems=(x, y))
x_pos, y = ops.vectorized_map(fn=self.overlay_y_on_x, elems=(x, y))

random_y = tf.random.shuffle(y)
x_neg, y = tf.map_fn(fn=self.overlay_y_on_x, elems=(x, random_y))
Expand All @@ -351,7 +363,7 @@ def train_step(self, data):
else:
print(f"Passing layer {idx+1} now : ")
x = layer(x)
mean_res = tf.math.divide(self.loss_var, self.loss_count)
mean_res = ops.divide(self.loss_var, self.loss_count)
return {"FinalLoss": mean_res}


Expand Down Expand Up @@ -386,8 +398,8 @@ def train_step(self, data):
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=0.03),
loss="mse",
jit_compile=True,
metrics=[keras.metrics.Mean()],
jit_compile=False,
metrics=[],
)

epochs = 250
Expand All @@ -400,7 +412,7 @@ def train_step(self, data):
test set. We calculate the Accuracy Score to understand the results closely.
"""

preds = model.predict(tf.convert_to_tensor(x_test))
preds = model.predict(ops.convert_to_tensor(x_test))

preds = preds.reshape((preds.shape[0], preds.shape[1]))

Expand Down
Loading