Skip to content

Commit

Permalink
ONNX-TensorRT 8.0 GA release (#706)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Chen <[email protected]>
  • Loading branch information
kevinch-nv authored Jul 2, 2021
1 parent 868e636 commit 8fea430
Show file tree
Hide file tree
Showing 26 changed files with 2,541 additions and 1,504 deletions.
23 changes: 5 additions & 18 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ set(PARSER_LINKER_SCRIPT ${ONNX2TRT_ROOT}/libnvonnxparser.version)
#--------------------------------------------------
# Version information
#--------------------------------------------------
set(ONNX2TRT_MAJOR 7)
set(ONNX2TRT_MINOR 2)
set(ONNX2TRT_PATCH 2)
set(ONNX2TRT_MAJOR 8)
set(ONNX2TRT_MINOR 0)
set(ONNX2TRT_PATCH 1)

#--------------------------------------------------
# Build configurations, global to all projects
Expand All @@ -36,6 +36,7 @@ set(IMPORTER_SOURCES
ModelImporter.cpp
builtin_op_importers.cpp
onnx2trt_utils.cpp
onnxErrorRecorder.cpp
ShapedWeights.cpp
ShapeTensor.cpp
LoopHelpers.cpp
Expand Down Expand Up @@ -72,10 +73,6 @@ if (NOT DEFINED BUILD_LIBRARY_ONLY)
)
endif()

set(HEADERS
NvOnnxParser.h
)

if (NOT TARGET protobuf::libprotobuf)
FIND_PACKAGE(Protobuf REQUIRED)
else()
Expand All @@ -102,16 +99,7 @@ find_library(TENSORRT_LIBRARY_INFER nvinfer
find_library(TENSORRT_LIBRARY_INFER_PLUGIN nvinfer_plugin
HINTS ${TENSORRT_ROOT} ${TENSORRT_BUILD} ${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES lib lib64 lib/x64)
if(WIN32)
find_library(TENSORRT_LIBRARY_MYELIN myelin64_1
HINTS ${TENSORRT_ROOT} ${TENSORRT_BUILD} ${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES lib lib64 lib/x64)
else()
find_library(TENSORRT_LIBRARY_MYELIN myelin
HINTS ${TENSORRT_ROOT} ${TENSORRT_BUILD} ${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES lib lib64 lib/x64)
endif()
set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_INFER_PLUGIN} ${TENSORRT_LIBRARY_MYELIN})
set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_INFER_PLUGIN})
MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}")
find_package_handle_standard_args(
TENSORRT DEFAULT_MSG TENSORRT_INCLUDE_DIR TENSORRT_LIBRARY)
Expand Down Expand Up @@ -175,7 +163,6 @@ install(TARGETS
install(FILES ${HEADERS}
DESTINATION include
)

if (NOT DEFINED BUILD_LIBRARY_ONLY)
install(TARGETS
onnx2trt
Expand Down
177 changes: 115 additions & 62 deletions ImporterContext.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,88 +6,138 @@

#include "onnx2trt.hpp"
#include "onnx2trt_utils.hpp"

#include "onnxErrorRecorder.hpp"
#include "onnx/common/stl_backports.h"
#include <list>
#include <unordered_map>

namespace onnx2trt
{

class ErrorRecorderWrapper
{
public:
ErrorRecorderWrapper(nvinfer1::INetworkDefinition* network, nvinfer1::ILogger* logger)
: mNetwork(network)
, mLogger(logger)
{
if (mNetwork)
{
mUserErrorRecorder = mNetwork->getErrorRecorder();
mOnnxErrorRecorder = ONNXParserErrorRecorder::create(logger, mUserErrorRecorder);
if (mOnnxErrorRecorder)
{
if (mUserErrorRecorder)
{
mUserErrorRecorder->incRefCount();
}
mNetwork->setErrorRecorder(mOnnxErrorRecorder);
}
}
}

~ErrorRecorderWrapper()
{
if (mNetwork && mOnnxErrorRecorder)
{
mNetwork->setErrorRecorder(mUserErrorRecorder);
if (mUserErrorRecorder)
{
mUserErrorRecorder->decRefCount();
}
ONNXParserErrorRecorder::destroy(mOnnxErrorRecorder);
}
}

bool hasError() const
{
return mOnnxErrorRecorder != nullptr && mOnnxErrorRecorder->getNbErrors() != 0;
}

//! Return recorder used by hasError().
nvinfer1::IErrorRecorder* getErrorRecorder() const
{
return mOnnxErrorRecorder ? mOnnxErrorRecorder : nullptr;
}
private:
nvinfer1::INetworkDefinition* mNetwork{nullptr};
nvinfer1::ILogger* mLogger{nullptr};
ONNXParserErrorRecorder* mOnnxErrorRecorder{nullptr};
nvinfer1::IErrorRecorder* mUserErrorRecorder{nullptr};
};

class ImporterContext final : public IImporterContext
{
nvinfer1::INetworkDefinition* _network;
nvinfer1::ILogger* _logger;
std::list<std::vector<uint8_t>> _temp_bufs;
StringMap<nvinfer1::ITensor*> _user_inputs;
StringMap<nvinfer1::ITensor**> _user_outputs;
StringMap<int64_t> _opsets;
nvinfer1::INetworkDefinition* mNetwork;
nvinfer1::ILogger* mLogger;
std::list<std::vector<uint8_t>> mTempBufs;
StringMap<nvinfer1::ITensor*> mUserInputs;
StringMap<nvinfer1::ITensor**> mUserOutputs;
StringMap<int64_t> mOpsets;
StringMap<TensorOrWeights> mTensors; // All tensors in the graph mapped to their names.
StringMap<nvinfer1::TensorLocation> mTensorLocations;
StringMap<float> mTensorRangeMins;
StringMap<float> mTensorRangeMaxes;
StringMap<nvinfer1::DataType> mLayerPrecisions;
std::set<std::string> mTensorNames; // Keep track of how many times a tensor name shows up, to avoid duplicate naming in TRT.
std::set<std::string> mLayerNames; // Keep track of how many times a tensor name shows up, to avoid duplicate naming in TRT.
int64_t mSuffixCounter = 0; // increasing suffix counter used to uniquify layer names.
int64_t mSuffixCounter{0}; // increasing suffix counter used to uniquify layer names.
std::unordered_set<std::string> mUnsupportedShapeTensors; // Container to hold output tensor names of layers that produce shape tensor outputs but do not natively support them.
StringMap<std::string> mLoopTensors; // Container to map subgraph tensors to their original outer graph names.
std::string mOnnxFileLocation; // Keep track of the directory of the parsed ONNX file
std::list<std::string> mInitializerNames; // Keep track of unique names of any initializers
RefitMap_t* mRefitMap; // Keep track of names of ONNX refittable weights with their corresponding TRT layer and role
std::unique_ptr<ErrorRecorderWrapper> mErrorWrapper; // error recorder to control TRT errors

public:
ImporterContext(nvinfer1::INetworkDefinition* network, nvinfer1::ILogger* logger, RefitMap_t* refitMap)
: _network(network)
, _logger(logger)
, mRefitMap(refitMap)
ImporterContext(nvinfer1::INetworkDefinition* network, nvinfer1::ILogger* logger)
: mNetwork(network)
, mLogger(logger)
// Disable ErrorRecorder for now due to incompatibilities with ONNXRT.
// , mErrorWrapper(ONNX_NAMESPACE::make_unique<ErrorRecorderWrapper>(mNetwork, logger))
, mErrorWrapper(nullptr)
{
}
virtual nvinfer1::INetworkDefinition* network() override
nvinfer1::INetworkDefinition* network() override
{
return _network;
return mNetwork;
}
virtual StringMap<TensorOrWeights>& tensors() override
StringMap<TensorOrWeights>& tensors() override
{
return mTensors;
}
virtual StringMap<nvinfer1::TensorLocation>& tensorLocations() override
StringMap<nvinfer1::TensorLocation>& tensorLocations() override
{
return mTensorLocations;
}
virtual StringMap<float>& tensorRangeMins() override
StringMap<float>& tensorRangeMins() override
{
return mTensorRangeMins;
}
virtual StringMap<float>& tensorRangeMaxes() override
StringMap<float>& tensorRangeMaxes() override
{
return mTensorRangeMaxes;
}
virtual StringMap<nvinfer1::DataType>& layerPrecisions() override
StringMap<nvinfer1::DataType>& layerPrecisions() override
{
return mLayerPrecisions;
}
virtual std::unordered_set<std::string>& unsupportedShapeTensors() override
std::unordered_set<std::string>& unsupportedShapeTensors() override
{
return mUnsupportedShapeTensors;
}
virtual StringMap<std::string>& loopTensors() override
StringMap<std::string>& loopTensors() override
{
return mLoopTensors;
}
virtual void setOnnxFileLocation(std::string location) override
void setOnnxFileLocation(std::string location) override
{
mOnnxFileLocation = location;
}
virtual std::string getOnnxFileLocation() override
std::string getOnnxFileLocation() override
{
return mOnnxFileLocation;
}
virtual void insertRefitMap(std::string weightsName, std::string layerName, nvinfer1::WeightsRole role) override
{
mRefitMap->insert({weightsName, WeightsPair_t{layerName, role}});
}
// This actually handles weights as well, but is named this way to be consistent with the tensors()
virtual void registerTensor(TensorOrWeights tensor, const std::string& basename) override
void registerTensor(TensorOrWeights tensor, const std::string& basename) override
{
// TRT requires unique tensor names.
const std::string uniqueName = generateUniqueName(mTensorNames, basename);
Expand All @@ -103,22 +153,22 @@ class ImporterContext final : public IImporterContext
}
else if (tensor.is_weights())
{
mInitializerNames.push_back(uniqueName);
const auto& weights = tensor.weights();
if (tensor.weights().type == ::ONNX_NAMESPACE::TensorProto::INT64)
{
tensor = ShapedWeights{::ONNX_NAMESPACE::TensorProto::INT32,
convertINT64(reinterpret_cast<int64_t*>(weights.values), weights.shape, ctx), weights.shape};
}
tensor.weights().setName(mInitializerNames.back().c_str());
tensor.weights().setName(basename.c_str());
}

}
// Overwrite previous tensors registered with the same name (this only happens when there are subgraphs,
// and in that case, overwriting is the desired behavior).
this->tensors()[basename] = std::move(tensor);
}

virtual void registerLayer(nvinfer1::ILayer* layer, const std::string& basename) override
void registerLayer(nvinfer1::ILayer* layer, const std::string& basename) override
{
// No layer will be added for Constant nodes in ONNX.
if (layer)
Expand All @@ -127,99 +177,102 @@ class ImporterContext final : public IImporterContext
const std::string uniqueName = generateUniqueName(mLayerNames, name);

auto* ctx = this; // To enable logging.
if (layer->getType() == nvinfer1::LayerType::kCONSTANT)
{
LOG_VERBOSE("Registering constant layer: " << uniqueName << " for ONNX initializer: " << basename);
}
else
{
LOG_VERBOSE("Registering layer: " << uniqueName << " for ONNX node: " << basename);
}
LOG_VERBOSE("Registering layer: " << uniqueName << " for ONNX node: " << basename);

layer->setName(uniqueName.c_str());
}
}

virtual nvinfer1::ILogger& logger() override
nvinfer1::ILogger& logger() override
{
return *_logger;
return *mLogger;
}

virtual ShapedWeights createTempWeights(ShapedWeights::DataType type, nvinfer1::Dims shape) override
ShapedWeights createTempWeights(ShapedWeights::DataType type, nvinfer1::Dims shape, uint8_t value = 0) override
{
ShapedWeights weights(type, nullptr, shape);
// Need special logic for handling scalars.
if (shape.nbDims == 0)
{
_temp_bufs.push_back(std::vector<uint8_t>(getDtypeSize(type)));
mTempBufs.push_back(std::vector<uint8_t>(getDtypeSize(type), value));
}
else
{
_temp_bufs.push_back(std::vector<uint8_t>(weights.size_bytes()));
mTempBufs.push_back(std::vector<uint8_t>(weights.size_bytes(), value));
}
weights.values = _temp_bufs.back().data();
weights.values = mTempBufs.back().data();
return weights;
}

bool setUserInput(const char* name, nvinfer1::ITensor* input)
{
_user_inputs[name] = input;
mUserInputs[name] = input;
return true;
}
bool setUserOutput(const char* name, nvinfer1::ITensor** output)
{
_user_outputs[name] = output;
mUserOutputs[name] = output;
return true;
}
nvinfer1::ITensor* getUserInput(const char* name)
{
if (!_user_inputs.count(name))
if (!mUserInputs.count(name))
{
return nullptr;
}
else
{
return _user_inputs.at(name);
return mUserInputs.at(name);
}
}
nvinfer1::ITensor** getUserOutput(const char* name)
{
if (!_user_outputs.count(name))
if (!mUserOutputs.count(name))
{
return nullptr;
}
else
{
return _user_outputs.at(name);
return mUserOutputs.at(name);
}
}
StringMap<nvinfer1::ITensor**> const& getUserOutputs() const
{
return _user_outputs;
return mUserOutputs;
}
void clearOpsets()
{
_opsets.clear();
mOpsets.clear();
}
void addOpset(std::string domain, int64_t version)
{
_opsets.emplace(domain, version);
mOpsets.emplace(domain, version);
}
virtual int64_t getOpsetVersion(const char* domain = "") const override
int64_t getOpsetVersion(const char* domain = "") const override
{
if (_opsets.empty())
if (mOpsets.empty())
{
return 1;
}
else if (_opsets.size() == 1)
else if (mOpsets.size() == 1)
{
return _opsets.begin()->second;
return mOpsets.begin()->second;
}
else
{
assert(_opsets.count(domain));
return _opsets.at(domain);
assert(mOpsets.count(domain));
return mOpsets.at(domain);
}
}
bool hasError() const noexcept override
{
return mErrorWrapper != nullptr && mErrorWrapper->hasError();
}

nvinfer1::IErrorRecorder* getErrorRecorder() const noexcept override
{
return mErrorWrapper ? mErrorWrapper->getErrorRecorder() : nullptr;
}
private:
std::string generateUniqueName(std::set<std::string>& namesSet, const std::string& basename)
{
Expand Down
Loading

0 comments on commit 8fea430

Please sign in to comment.