Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[FEATURE] Add feature of attach_grad to nonleaf variables in HybridizedBlock. #21091

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1276,6 +1276,13 @@ MXNET_DLL int MXAutogradMarkVariables(uint32_t num_var,
NDArrayHandle* var_handles,
uint32_t* reqs_array,
NDArrayHandle* grad_handles);
/*!
* \brief mark nonleaf NDArrays as variables during deferredcomputation
* \param num_nleafs number of nonleaf NDArrays
* \param cnt_var count of existing marked nonleaf variables
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayMarkDCVariables(NDArrayHandle* nleaf_handles, int num_nleafs, int cnt_var);
/*!
* \brief unmark nonleaf NDArrays to free the memory
* \param num_var number of variable NDArrays
Expand Down
2 changes: 2 additions & 0 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,8 @@ class Imperative {
void MarkVariables(const std::vector<NDArray*>& variables,
const std::vector<uint32_t>& grad_reqs,
const std::vector<NDArray*>& gradients);
/*! \brief mark nonleaf variables during DC for computing gradients. */
void MarkDCVariables(const std::vector<NDArray*>& nleafs, int cnt_vars);
/*! \brief unmark nonleaf variables to free the memory. */
void DropGrads(const std::vector<NDArray*>& variables);
/*! \brief compute the gradient of outputs w.r.t variables. */
Expand Down
2 changes: 2 additions & 0 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ class NDArray {
bool fresh_out_grad() const;
/*! \return updated grad state in autograd_entry_ */
void set_fresh_out_grad(bool state) const;
/*! \brief copy the autograd_entry_ from src NDArray */
void copy_autograd_entry_(const NDArray* src);
/*! \brief Returns true if a sparse ndarray's aux_data and storage are initialized
* Throws an exception if the indices array shape is inconsistent
* Returns false if the indices array is empty(nnz = 0) for csr/row_sparse
Expand Down
6 changes: 5 additions & 1 deletion python/mxnet/_ctypes/cached_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __call__(self, *args, **kwargs):
if not default_device:
default_device = kwargs.pop('default_ctx', None)
out = kwargs.pop('out', None)
nleaf_vars = [container.data() for container in kwargs.pop('_nleaf_vars', [])]
if kwargs:
raise TypeError(
"CachedOp.__call__ got unexpected keyword argument(s): " + \
Expand All @@ -93,7 +94,10 @@ def __call__(self, *args, **kwargs):
*args,
type_id,
device_id,
*out_arg
len(out_arg),
*out_arg,
len(nleaf_vars),
*nleaf_vars
)
if out is not None:
return out
Expand Down
94 changes: 91 additions & 3 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@
import json
import numpy as np

from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, check_call, _LIB
from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, check_call, _LIB, \
_as_list
from .. import symbol, ndarray, initializer, autograd, _deferred_compute as dc, name as _name, \
profiler as _profiler, device as _device
from ..symbol.numpy import _symbol as np_symbol
from ..symbol import Symbol, fromjson
from ..ndarray import NDArray, get_dtype_name
from .parameter import Parameter, DeferredInitializationError
from .parameter import Parameter, DeferredInitializationError, Intermediate
from .utils import _indent, _brief_print_list, HookHandle, shape_is_known
from .utils import _check_same_symbol_type, _check_all_np_ndarrays, _check_block_input_np_ndarrays
from .. import numpy_extension as _mx_npx
Expand Down Expand Up @@ -1091,6 +1092,7 @@ def __init__(self):
self._backend_opts = {}
self._partition_if_dynamic = True
self._first_forward = True
self._nleaf_vars = OrderedDict()

def __setattr__(self, name, value):
"""Registers parameters."""
Expand Down Expand Up @@ -1302,7 +1304,7 @@ def _call_cached_op(self, *args):
args_without_none = [ele for ele in args if ele is not None]
cargs = [args_without_none[i] if is_arg else i.data()
for is_arg, name, i in self._cached_op_args]
out = self._cached_op(*cargs)
out = self._cached_op(*cargs, _nleaf_vars=self._nleaf_vars.values())
if isinstance(out, NDArray):
out = [out]
return _regroup(out, self._out_format)
Expand Down Expand Up @@ -1678,6 +1680,92 @@ def reset_ctx(self, ctx):
self.reset_device(ctx)


def intermediate(self, names, var_arrays_inp, grad_req='write'):
"""Mark the intermediate variables.

Parameters
----------
name : str or tuple[str], name of the registered intermediate variable
var_arrays_inp : ndarray or tuple[ndarray], the output of the expression
grad_req : str, gradient request
"""
if not self._active:
var_arrays = _as_list(var_arrays_inp)
names = _as_list(names)
self._nleaf_vars.update(
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
else:
prev_val = dc.set_deferred_compute(False)
var_arrays = _as_list(var_arrays_inp)
names = _as_list(names)
# Prepare ctypes array types
import ctypes
var_handles_type = ctypes.c_void_p * len(var_arrays)
# Convert handles
var_handles = var_handles_type(*[arr.handle for arr in var_arrays])
check_call(_LIB.MXNDArrayMarkDCVariables(var_handles, len(var_arrays), len(self._nleaf_vars)))
self._nleaf_vars.update(
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
dc.set_deferred_compute(prev_val)
return var_arrays_inp

def attach_grad_intermediate(self):
"""Attach gradient to all the intermediate variables.
"""
for val in self._nleaf_vars.values():
val.data().attach_grad(grad_req=val.grad_req)

def get_intermediate(self, names):
"""Get the intermediate variables by names
"""
if isinstance(names, list):
return [self._nleaf_vars[n] for n in names]
else:
return self._nleaf_vars[names]

def intermediate(self, names, var_arrays_inp, grad_req='write'):
"""Mark the intermediate variables.

Parameters
----------
name : str or tuple[str], name of the registered intermediate variable
var_arrays_inp : ndarray or tuple[ndarray], the output of the expression
grad_req : str, gradient request
"""
if not self._active:
var_arrays = _as_list(var_arrays_inp)
names = _as_list(names)
self._nleaf_vars.update(
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
else:
prev_val = dc.set_deferred_compute(False)
var_arrays = _as_list(var_arrays_inp)
names = _as_list(names)
# Prepare ctypes array types
import ctypes
var_handles_type = ctypes.c_void_p * len(var_arrays)
# Convert handles
var_handles = var_handles_type(*[arr.handle for arr in var_arrays])
check_call(_LIB.MXNDArrayMarkDCVariables(var_handles, len(var_arrays), len(self._nleaf_vars)))
self._nleaf_vars.update(
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
dc.set_deferred_compute(prev_val)
return var_arrays_inp

def attach_grad_intermediate(self):
"""Attach gradient to all the intermediate variables.
"""
for val in self._nleaf_vars.values():
val.data().attach_grad(grad_req=val.grad_req)

def get_intermediate(self, names):
"""Get the intermediate variables by names
"""
if isinstance(names, list):
return [self._nleaf_vars[n] for n in names]
else:
return self._nleaf_vars[names]

class SymbolBlock(HybridBlock):
"""Construct block from symbol. This is useful for using pre-trained models
as feature extractors. For example, you may want to extract the output
Expand Down
37 changes: 37 additions & 0 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,3 +773,40 @@ def grad_req(self, req):
warnings.warn('Constant parameter "{}" does not support '
'grad_req other than "null", and new value "{}" '
'is ignored.'.format(self.name, req))

class Intermediate:
"""A Container holding marked intermediate variables of Blocks.

Parameters
----------
name : str.
Name of this parameter. It be used to retrieve the marked variables.
grad_req : {'write', 'add', 'null'}, default 'write'
Specifies how to update gradient to grad arrays.

- ``'write'`` means everytime gradient is written to grad :py:class:`NDArray`.
- ``'add'`` means everytime gradient is added to the grad :py:class:`NDArray`. You need
to manually call ``zero_grad()`` to clear the gradient buffer before each
iteration when using this option.
- 'null' means gradient is not requested for this parameter. gradient arrays
will not be allocated.
"""
def __init__(self, name, data=None, grad_req='write'):
self._name = name
self._data = data
self._grad_req = grad_req

def __repr__(self):
s = 'Intermediate name={name}'
return s.format(name=self._name)

def data(self):
return self._data

@property
def name(self):
return self._name

@property
def grad_req(self):
return self._grad_req
15 changes: 12 additions & 3 deletions src/api/cached_op_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,18 @@ MXNET_REGISTER_GLOBAL("cached_op.invoke")
ndinputs.push_back(static_cast<mxnet::NDArray*>(args[i]));
}

int num_outputs = args[num_inputs + 4];
int num_nleafs = args[num_inputs + num_outputs + 5];
std::vector<NDArray*> ndoutputs;
ndoutputs.reserve(op->num_outputs());
if (args[num_inputs + 4].type_code() == kNull) {
if (args[num_inputs + 5].type_code() == kNull) {
for (int i = 0; i < op->num_outputs(); ++i)
ndoutputs.push_back(new NDArray());
} else {
int array_size = args_size - num_inputs - 4;
int array_size = args_size - num_inputs - num_nleafs - 6;
CHECK_EQ(array_size, op->num_outputs()) << "CachedOp expects " << op->num_outputs()
<< " outputs, but " << array_size << " was given.";
for (int i = num_inputs + 4; i < array_size; ++i) {
for (int i = num_inputs + 5; i < num_inputs + num_outputs + 5; ++i) {
ndoutputs.push_back(args[i].operator mxnet::NDArray*());
}
}
Expand All @@ -69,6 +71,13 @@ MXNET_REGISTER_GLOBAL("cached_op.invoke")
default_dev_id = ctx.dev_id;
}

std::vector<NDArray*> nleafs;
nleafs.reserve(num_nleafs);
for (int i = 0; i < num_nleafs; ++i) {
nleafs.push_back(static_cast<mxnet::NDArray*>(args[i + num_inputs + num_outputs + 6]));
}
op->set_nleafs(nleafs);

// construct default context
Context ctx =
Context::Create(static_cast<Context::DeviceType>(default_dev_type), default_dev_id);
Expand Down
12 changes: 12 additions & 0 deletions src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -495,3 +495,15 @@ int MXNDArrayGetDeferredComputeSymbol(NDArrayHandle* output_handles,
*out = s;
API_END_HANDLE_ERROR(delete s;);
}

int MXNDArrayMarkDCVariables(NDArrayHandle* nleaf_handles, int num_nleafs, int cnt_var) {
API_BEGIN();
std::vector<NDArray*> nleafs;
nleafs.reserve(num_nleafs);
for (int i = 0; i < num_nleafs; ++i) {
NDArray* array = reinterpret_cast<NDArray*>(nleaf_handles[i]);
nleafs.emplace_back(array);
}
Imperative::Get()->MarkDCVariables(nleafs, cnt_var);
API_END();
}
5 changes: 4 additions & 1 deletion src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,8 @@ OpStatePtr CachedOp::DynamicForward(const Context& default_ctx,
recording && inlining_,
nullptr,
monitor_callback_,
monitor_all_);
monitor_all_,
nleafs_);
} else {
mxnet::ShapeVector shapes = g.GetAttr<mxnet::ShapeVector>("shape");
NaiveRunGraph(false,
Expand Down Expand Up @@ -1063,6 +1064,7 @@ void CachedOp::StaticBackward(const bool retain_graph,
if (!idx.exist(entry.node.get()))
continue;
auto eid = idx.entry_id(entry);
state.array_reqs[eid] = reqs[iter->second];
// An input and an output may share the same array.
INIT_DETACHED(outputs[iter->second], arrays[eid]);
arrays[eid] = outputs[iter->second];
Expand All @@ -1073,6 +1075,7 @@ void CachedOp::StaticBackward(const bool retain_graph,
if (!idx.exist(entry.node.get()))
continue;
auto eid = idx.entry_id(entry);
state.array_reqs[eid] = reqs[i];
// An input and an output may share the same array.
INIT_DETACHED(outputs[i], arrays[eid]);
arrays[eid] = outputs[i];
Expand Down
4 changes: 4 additions & 0 deletions src/imperative/cached_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,9 @@ class CachedOp {
const std::unordered_set<uint32_t>& mutable_input_nodes() const {
return fwd_graph_.indexed_graph().mutable_input_nodes();
}
void set_nleafs(const std::vector<NDArray*>& nleafs) {
nleafs_ = nleafs;
}
virtual std::vector<nnvm::NodeEntry> Gradient(const nnvm::ObjectPtr& node,
const std::vector<nnvm::NodeEntry>& ograds) const;
virtual OpStatePtr Forward(const std::shared_ptr<CachedOp>& op_ptr,
Expand Down Expand Up @@ -649,6 +652,7 @@ class CachedOp {
std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
std::vector<bool> save_inputs_, save_outputs_;
std::vector<OpReqType> bwd_output_reqs_;
std::vector<NDArray*> nleafs_;

std::function<void(const char*, const char*, NDArrayHandle)> monitor_callback_{nullptr};
bool monitor_all_{false};
Expand Down
12 changes: 12 additions & 0 deletions src/imperative/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,18 @@ void Imperative::MarkVariables(const std::vector<NDArray*>& variables,
}
}

void Imperative::MarkDCVariables(const std::vector<NDArray*>& nleafs, int cnt_vars) {
for (NDArray* nleaf : nleafs) {
if (Imperative::DCInfo::IsNone(*nleaf)) {
LOG(WARNING) << "The marked node doesn't have deferred compute history.";
} else {
nnvm::ObjectPtr node = nleaf->deferredcompute_entry_.node;
node->attrs.dict["mark_id"] = std::to_string(cnt_vars);
}
cnt_vars++;
}
}

// Unmark the variables to free the memory.
void Imperative::DropGrads(const std::vector<NDArray*>& variables) {
for (auto variable : variables) {
Expand Down
12 changes: 11 additions & 1 deletion src/imperative/imperative_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ void RunGraph(const bool retain_graph,
bool recording,
mxnet::ShapeVector* shapes,
const imperative::CachedOpMonCallback& callback,
const bool monitor_all) {
const bool monitor_all,
const std::vector<NDArray*>& nleafs) {
CHECK(shapes == nullptr);
for (size_t i = node_start; i < node_end; ++i) {
const nnvm::IndexedGraph::Node& node = idx[i];
Expand Down Expand Up @@ -166,6 +167,15 @@ void RunGraph(const bool retain_graph,
if (callback) {
mxnet::common::ExecuteMonOutputCallback(idx, arrays, i, callback);
}
// set the autograd_entry_ in marked nleafs
if (nleafs.size()) {
auto it = node.source->attrs.dict.find("mark_id");
if (it != node.source->attrs.dict.end()) {
int mark_id = std::stoi(it->second);
CHECK_LT(mark_id, nleafs.size()) << "Mark_id exceeds the nonleaf list size.";
nleafs[mark_id]->copy_autograd_entry_(ndoutputs[0]);
}
}
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1386,7 +1386,8 @@ void RunGraph(const bool retain_graph,
bool recording,
mxnet::ShapeVector* shapes = nullptr,
const CachedOpMonCallback& callback = nullptr,
const bool monitor_all_ = false);
const bool monitor_all_ = false,
const std::vector<NDArray*>& nleafs = std::vector<NDArray*>());

void NaiveRunGraph(const bool retain_graph,
const Context& default_ctx,
Expand Down
4 changes: 4 additions & 0 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,10 @@ void NDArray::set_fresh_out_grad(bool state) const {
info.fresh_out_grad = state;
}

void NDArray::copy_autograd_entry_(const NDArray* src) {
autograd_entry_ = nnvm::NodeEntry{src->autograd_entry_.node, 0, 0};
}

#if MXNET_USE_ONEDNN == 1

bool NDArray::Chunk::IsDNNL() const {
Expand Down
Loading