Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add Numpy Backend #87

Closed
wants to merge 54 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
648c4ad
chore: adding numpy backend
ariG23498 May 4, 2023
d9627f3
creview comments
ariG23498 May 4, 2023
df8ec30
review comments
ariG23498 May 4, 2023
2309cb6
Merge branch 'aritra-np-backend' of https://github.com/keras-team/ker…
ariG23498 May 4, 2023
104571b
Merge branch 'main' into aritra-np-backend
ariG23498 May 5, 2023
2b4dbde
chore: adding math
ariG23498 May 5, 2023
99bca4a
Merge branch 'main' into aritra-np-backend
ariG23498 May 8, 2023
155a5b6
chore: adding random module
ariG23498 May 8, 2023
592fd6c
chore: adding ranndom in init
ariG23498 May 10, 2023
e97b21e
Merge branch 'main' into aritra-np-backend
ariG23498 May 10, 2023
8ce4450
review comments
ariG23498 May 10, 2023
65c8076
chore: adding numpy and nn for numpy backend
ariG23498 May 10, 2023
8f5dd4d
chore: adding generic pool, max, and average pool
ariG23498 May 18, 2023
8f22906
chore: adding the conv ops
ariG23498 May 18, 2023
b2c3184
Merge branch 'main' into aritra-np-backend
ariG23498 Jun 5, 2023
97f4e9e
chore: reformat code and using jax for conv and pool
ariG23498 Jun 5, 2023
d650587
chore: added self value
ariG23498 Jun 6, 2023
408f3e8
chore: activation tests pass
ariG23498 Jun 7, 2023
a846b34
chore: adding post build method
ariG23498 Jun 9, 2023
3e0283a
Merge branch 'main' into aritra-np-backend
ariG23498 Jun 9, 2023
5de62ef
Merge branch 'main' into aritra-np-backend
ariG23498 Jun 19, 2023
7143f06
chore: adding necessaity methods to the numpy trainer
ariG23498 Jun 19, 2023
5b4d800
chore: fixing utils test
ariG23498 Jun 19, 2023
4586e33
chore: fixing losses test suite
ariG23498 Jun 21, 2023
8373103
Merge branch 'main' into aritra-np-backend
ariG23498 Jun 21, 2023
82a1e5c
chore: fix backend tests
ariG23498 Jun 21, 2023
295e0e4
chore: fixing initializers test
ariG23498 Jun 21, 2023
dad6c9d
Merge branch 'main' into aritra-np-backend
ariG23498 Jun 22, 2023
55b5e09
chore: fixing accuracy metrics test
ariG23498 Jun 22, 2023
c66abaf
chore: fixing ops test
ariG23498 Jun 22, 2023
ca4869b
chore: review comments
ariG23498 Jun 27, 2023
926e169
chore: init with image and fixing random tests
ariG23498 Jun 27, 2023
23cd5b3
chore: skipping random seed set for numpy backend
ariG23498 Jun 27, 2023
52f8677
Merge branch 'main' into aritra-np-backend
ariG23498 Jun 30, 2023
e013d7d
chore: adding single resize image method
ariG23498 Jun 30, 2023
f6073cd
Merge branch 'main' into aritra-np-backend
ariG23498 Jul 7, 2023
17a5dda
chore: skipping tests for applications and layers
ariG23498 Jul 7, 2023
512c441
chore: skipping tests for models
ariG23498 Jul 7, 2023
f6f6442
chore: skipping testsor saving
ariG23498 Jul 7, 2023
bd38a79
chore: skipping tests for trainers
ariG23498 Jul 7, 2023
e29a54e
chore:ixing one hot
ariG23498 Jul 8, 2023
5694b25
Merge branch 'main' into aritra-np-backend
ariG23498 Jul 8, 2023
9d639cd
chore: fixing vmap in numpy and metrics test
ariG23498 Jul 8, 2023
6c8293b
chore: adding a wrapper to numpy sum, started fixing layer tests
ariG23498 Jul 8, 2023
f007fe0
fix: is_tensor now accepts numpy scalars
ariG23498 Jul 10, 2023
95abe6e
chore: adding draw seed
ariG23498 Jul 11, 2023
3547edc
Merge branch 'main' into aritra-np-backend
ariG23498 Jul 11, 2023
5bedccf
fix: warn message for numpy masking
ariG23498 Jul 11, 2023
f103ae0
fix: checking whether kernel are tensors
ariG23498 Jul 11, 2023
fe6bcf6
chore: adding rnn
ariG23498 Jul 11, 2023
360d913
chore: adding dynamic backend for numpy
ariG23498 Jul 11, 2023
b78500c
fix: axis cannot be None for normalize
ariG23498 Jul 11, 2023
13f256c
chore: adding jax resize for numpy image
ariG23498 Jul 11, 2023
9a88fa7
Merge branch 'main' into aritra-np-backend
ariG23498 Jul 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras_core/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,8 @@
elif backend() == "jax":
print_msg("Using JAX backend.")
from keras_core.backend.jax import * # noqa: F403
elif backend() == "numpy":
print_msg("Using NumPy backend.")
from keras_core.backend.numpy import * # noqa: F403
else:
raise ValueError(f"Unable to import backend : {backend()}")
150 changes: 150 additions & 0 deletions keras_core/backend/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import numpy as np
from tensorflow import nest

from keras_core.backend.common import KerasVariable
from keras_core.backend.common import get_autocast_scope
from keras_core.backend.common import standardize_dtype
from keras_core.backend.common import standardize_shape
from keras_core.backend.common.keras_tensor import KerasTensor
from keras_core.backend.common.stateless_scope import StatelessScope
from keras_core.backend.common.stateless_scope import get_stateless_scope
from keras_core.backend.common.stateless_scope import in_stateless_scope
# from keras_core.backend.numpy import math
# from keras_core.backend.numpy import nn
# from keras_core.backend.numpy import numpy
# from keras_core.backend.numpy import random
from keras_core.utils.naming import auto_name

DYNAMIC_SHAPES_OK = False # Dynamic shapes NG


def convert_to_tensor(x, dtype=None):
if dtype is not None:
dtype = standardize_dtype(dtype)
if isinstance(x, Variable):
if dtype and dtype != x.dtype:
return x.value.astype(dtype)
return x.value
return np.array(x, dtype=dtype)


def is_tensor(x):
if isinstance(x, np.ndarray):
return True
return False


def shape(x):
# This will work as long as we disallow
# dynamic shapes in NumPy.
return x.shape


def cast(x, dtype):
return convert_to_tensor(x, dtype=dtype)


def cond(pred, true_fn, false_fn):
if pred:
return true_fn
return false_fn


class NamedScope:
ariG23498 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, name):
self.name = name

def __enter__(self):
print(f"Starting named scope '{self.name}'")

def __exit__(self):
print(f"Ending named scope '{self.name}'")


def name_scope(name):
return NamedScope(name)
ariG23498 marked this conversation as resolved.
Show resolved Hide resolved


def vectorized_map(function, elements):
return np.vectorize(function)(elements)


class Variable(KerasVariable):
def _initialize(self, value):
self._value = np.array(value, dtype=self._dtype)

def assign(self, value):
value = convert_to_tensor(value, dtype=self.dtype)
if value.shape != self.shape:
raise ValueError(
"The shape of the target variable and "
"the shape of the target value in "
"`variable.assign(value)` must match. "
f"Received: value.shape={value.shape}; "
f"variable.shape={self.value.shape}"
)
if in_stateless_scope():
scope = get_stateless_scope()
scope.add_update((self, value))
else:
if isinstance(value, np.ndarray) and value.dtype == self.dtype:
# Avoid a memory copy
self._value = value
else:
self._value = np.array(value, dtype=self.dtype)

@property
def value(self):
ariG23498 marked this conversation as resolved.
Show resolved Hide resolved
if in_stateless_scope():
scope = get_stateless_scope()
value = scope.get_current_value(self)
if value is not None:
return self._maybe_autocast(value)
if self._value is None:
# Unitialized variable. Return a placeholder.
# This is fine because it's only ever used
# in during shape inference with JAX tracer objects
ariG23498 marked this conversation as resolved.
Show resolved Hide resolved
# (anything else would be a bug, to be fixed.)
return self._maybe_autocast(
np.array(
self._initializer(self._shape, dtype=self._dtype),
dtype=self._dtype,
)
)
return self._maybe_autocast(self._value)

def numpy(self):
return np.array(self.value)

# Overload native accessor.
def __array__(self):
return np.array(self.value)
ariG23498 marked this conversation as resolved.
Show resolved Hide resolved

def _convert_to_tensor(self, value, dtype=None):
return convert_to_tensor(value, dtype=dtype)


# Shape / dtype inference util
def compute_output_spec(fn, *args, **kwargs):
with StatelessScope():
np_out = fn(*args, **kwargs)
ariG23498 marked this conversation as resolved.
Show resolved Hide resolved

def convert_np_to_keras_tensor(x):
if isinstance(x, np.ndarray):
return KerasTensor(x.shape, x.dtype)
return nest.map_structure(convert_np_to_keras_tensor, np_out)


def traceable_tensor(shape, dtype=None):
ariG23498 marked this conversation as resolved.
Show resolved Hide resolved
"""Create a "traceable tensor".

That's a tensor that can be passed as input
to a stateful backend-native function to
create state during the trace.
"""
shape = list(shape)
dtype = dtype or "float32"
for i, x in enumerate(shape):
if x is None:
shape[i] = 1
ariG23498 marked this conversation as resolved.
Show resolved Hide resolved
return np.ones(shape, dtype=dtype)