diff --git a/CMakeLists.txt b/CMakeLists.txt index c3692aee..80c51182 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -29,7 +29,7 @@ add_definitions("-DSOURCE_LENGTH=${SOURCE_LENGTH}") #-------------------------------------------------- set(ONNX2TRT_MAJOR 8) set(ONNX2TRT_MINOR 2) -set(ONNX2TRT_PATCH 0) +set(ONNX2TRT_PATCH 1) set(ONNX2TRT_VERSION "${ONNX2TRT_MAJOR}.${ONNX2TRT_MINOR}.${ONNX2TRT_PATCH}" CACHE STRING "ONNX2TRT version") #-------------------------------------------------- diff --git a/ImporterContext.hpp b/ImporterContext.hpp index af45e1ee..88273607 100644 --- a/ImporterContext.hpp +++ b/ImporterContext.hpp @@ -84,9 +84,8 @@ class ImporterContext final : public IImporterContext int64_t mSuffixCounter{0}; // increasing suffix counter used to uniquify layer names. std::unordered_set mUnsupportedShapeTensors; // Container to hold output tensor names of layers that produce shape tensor outputs but do not natively support them. StringMap mLoopTensors; // Container to map subgraph tensors to their original outer graph names. - std::string mOnnxFileLocation; // Keep track of the directory of the parsed ONNX file + std::string mOnnxFileLocation; // Keep track of the directory of the parsed ONNX file std::unique_ptr mErrorWrapper; // error recorder to control TRT errors - StringMap mConstantLayers; public: ImporterContext(nvinfer1::INetworkDefinition* network, nvinfer1::ILogger* logger) @@ -179,15 +178,6 @@ class ImporterContext final : public IImporterContext LOG_VERBOSE("Registering layer: " << uniqueName << " for ONNX node: " << basename); layer->setName(uniqueName.c_str()); - if (layer->getType() == nvinfer1::LayerType::kCONSTANT) - { - if (basename != uniqueName) - { - LOG_ERROR("Constant layer: " << uniqueName << " can be a duplicate of: " << basename); - assert(!"Internal error: duplicate constant layers for the same weights"); - } - mConstantLayers.insert({uniqueName, static_cast(layer)}); - } } } @@ -281,20 +271,6 @@ class ImporterContext final : public IImporterContext { return mErrorWrapper ? mErrorWrapper->getErrorRecorder() : nullptr; } - nvinfer1::IConstantLayer* getConstantLayer(const char* name) const final - { - if (name == nullptr) - { - return nullptr; - } - auto const iter = mConstantLayers.find(name); - if (iter == mConstantLayers.end()) - { - return nullptr; - } - return iter->second; - } - private: std::string generateUniqueName(std::set& namesSet, const std::string& basename) { diff --git a/ModelImporter.cpp b/ModelImporter.cpp index a2100601..6c76e025 100644 --- a/ModelImporter.cpp +++ b/ModelImporter.cpp @@ -287,12 +287,16 @@ static Status assertDimsWithSameNameAreEqual(ImporterContext* ctx, std::vectortensor->getName() << " must be equal"; + message << "For input: '" << i->tensor->getName() + << "' all named dimensions that share the same name must be equal. Note: Named dimensions were present on the following axes: "; // prev is the current end of the daisy chain. nvinfer1::ITensor* prev = nullptr; for (auto k = i; k < j; ++k) { + message << (prev ? ", " : "") << k->index << " (name: " + << "'" << k->dimParam << "')"; + // Create ITensor "next" with dimension length for record k. auto& shape = shapeMap[k->tensor]; if (shape == nullptr) @@ -489,6 +493,7 @@ bool ModelImporter::supportsModel( } return allSupported; } + // Mark experimental ops as unsupported, mark plugin ops as supported bool ModelImporter::supportsOperator(const char* op_name) const { @@ -496,13 +501,12 @@ bool ModelImporter::supportsOperator(const char* op_name) const { return false; } - if (std::string(op_name) == "EfficientNMS_TRT" || std::string(op_name) == "PyramidROIAlign_TRT") + if (std::string(op_name) == "EfficientNMS_TRT" || std::string(op_name) == "PyramidROIAlign_TRT" || std::string(op_name) == "MultilevelCropAndResize_TRT") { return true; } return _op_importers.count(op_name); } - bool ModelImporter::parseWithWeightDescriptors(void const* serialized_onnx_model, size_t serialized_onnx_model_size) { _current_node = -1; diff --git a/README.md b/README.md index 0317069f..ccf2edcd 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ For press and other inquiries, please contact Hector Marinez at hmarinez@nvidia. ## Supported TensorRT Versions -Development on the Master branch is for the latest version of [TensorRT 8.2.0.6](https://developer.nvidia.com/nvidia-tensorrt-download) with full-dimensions and dynamic shape support. +Development on the Master branch is for the latest version of [TensorRT 8.2.1.8](https://developer.nvidia.com/nvidia-tensorrt-download) with full-dimensions and dynamic shape support. For previous versions of TensorRT, refer to their respective branches. @@ -48,8 +48,8 @@ Current supported ONNX operators are found in the [operator support matrix](docs ### Dependencies - [Protobuf >= 3.0.x](https://github.com/google/protobuf/releases) - - [TensorRT 8.2.0.6](https://developer.nvidia.com/tensorrt) - - [TensorRT 8.2.0.6 open source libaries (master branch)](https://github.com/NVIDIA/TensorRT/) + - [TensorRT 8.2.1.8](https://developer.nvidia.com/tensorrt) + - [TensorRT 8.2.1.8 open source libaries (master branch)](https://github.com/NVIDIA/TensorRT/) ### Building @@ -101,9 +101,9 @@ Python bindings for the ONNX-TensorRT parser are packaged in the shipped `.whl` python3 -m pip install /python/tensorrt-8.x.x.x-cp-none-linux_x86_64.whl -TensorRT 8.2.0.6 supports ONNX release 1.6.0. Install it with: +TensorRT 8.2.1.8 supports ONNX release 1.8.0. Install it with: - python3 -m pip install onnx==1.6.0 + python3 -m pip install onnx==1.8.0 The ONNX-TensorRT backend can be installed by running: diff --git a/builtin_op_importers.cpp b/builtin_op_importers.cpp index 6e049c3f..27edc37c 100644 --- a/builtin_op_importers.cpp +++ b/builtin_op_importers.cpp @@ -1381,12 +1381,9 @@ DEFINE_BUILTIN_OP_IMPORTER(Expand) const ShapeTensor starts = similar(ctx, newDims, 0); // Do the broadcast rule. const ShapeTensor sizes = broadcast(ctx, newDims, newShape); - - const ShapeTensor delta = sub(ctx, sizes, newDims); + // Compute (x > 1 ? 1 : 0) for x in newDims, assuming positive x, using only TensorRT operations. const ShapeTensor one = shapeVector(1); - // stride 1 for dims where sizes same as Slice input, 0 for not the same. - // delta is non-negative for Expand here - const ShapeTensor strides = sub(ctx, one, min(ctx, one, delta)); + const ShapeTensor strides = min(ctx, one, sub(ctx, newDims, one)); nvinfer1::ISliceLayer* sliceLayer = addSlice(ctx, newInputTensor, starts, sizes, strides); ctx->registerLayer(sliceLayer, getNodeName(node)); @@ -3470,7 +3467,7 @@ DEFINE_BUILTIN_OP_IMPORTER(ReduceSum) } DEFINE_BUILTIN_OP_IMPORTER(ReduceSumSquare) { - nvinfer1::ITensor& tensor = inputs.at(0).tensor(); + nvinfer1::ITensor& tensor = convertToTensor(inputs.at(0), ctx); auto* sqr_layer = ctx->network()->addElementWise(tensor, tensor, nvinfer1::ElementWiseOperation::kPROD); ASSERT(sqr_layer && "Failed to add an ElementWise layer.", ErrorCode::kUNSUPPORTED_NODE); nvinfer1::ITensor* sqr_tensorPtr = sqr_layer->getOutput(0); @@ -4070,21 +4067,16 @@ DEFINE_BUILTIN_OP_IMPORTER(Scan) DEFINE_BUILTIN_OP_IMPORTER(ScatterND) { - auto* layer = addScatterLayer(ctx, inputs, nvinfer1::ScatterMode::kND); - ctx->registerLayer(layer, getNodeName(node)); - RETURN_FIRST_OUTPUT(layer); + return addScatterLayer(ctx, node, inputs, nvinfer1::ScatterMode::kND); } DEFINE_BUILTIN_OP_IMPORTER(ScatterElements) { - auto* layer = addScatterLayer(ctx, inputs, nvinfer1::ScatterMode::kELEMENT); OnnxAttrs attrs(node, ctx); int32_t axis = attrs.get("axis", 0); int32_t nbDims = inputs.at(0).shape().nbDims; CHECK(convertAxis(axis, nbDims)); - layer->setAxis(axis); - ctx->registerLayer(layer, getNodeName(node)); - RETURN_FIRST_OUTPUT(layer); + return addScatterLayer(ctx, node, inputs, nvinfer1::ScatterMode::kELEMENT, axis); } DEFINE_BUILTIN_OP_IMPORTER(Scatter) diff --git a/docs/Changelog.md b/docs/Changelog.md index 1fab021d..b4825272 100644 --- a/docs/Changelog.md +++ b/docs/Changelog.md @@ -2,6 +2,17 @@ # ONNX-TensorRT Changelog +## TensorRT 8.2 GA Release - 2021-11-23 + +### Added + +See the 8.2 EA release notes for new features added in TensorRT 8.2. + +### Fixes +- Removed duplicate constant layer checks that caused some performance regressions +- Fixed expand dynamic shape calculations +- Added parser-side checks for Scatter layer support + ## TensorRT 8.2 EA Release - 2021-10-04 ### Added - Added support for the following ONNX operators: diff --git a/docs/operators.md b/docs/operators.md index 5972ca37..5e75e002 100644 --- a/docs/operators.md +++ b/docs/operators.md @@ -31,13 +31,13 @@ See below for the support matrix of ONNX operators in ONNX-TensorRT. | Cast | Y | FP32, FP16, INT32, INT8, BOOL | | | Ceil | Y | FP32, FP16 | | Celu | Y | FP32, FP16 | -| Clip | Y | FP32, FP16, INT8 | `min` and `max` clip values must be initializers | +| Clip | Y | FP32, FP16, INT8 | | | Compress | N | | Concat | Y | FP32, FP16, INT32, INT8, BOOL | | ConcatFromSequence | N | | Constant | Y | FP32, FP16, INT32, INT8, BOOL | | ConstantOfShape | Y | FP32 | -| Conv | Y | FP32, FP16, INT8 | 2D or 3D convolutions only | +| Conv | Y | FP32, FP16, INT8 | 2D or 3D convolutions only\. Weights `W` must be an initailizer | | ConvInteger | N | | ConvTranspose | Y | FP32, FP16, INT8 | 2D or 3D deconvolutions only\. Weights `W` must be an initializer | | Cos | Y | FP32, FP16 | @@ -49,7 +49,7 @@ See below for the support matrix of ONNX operators in ONNX-TensorRT. | Div | Y | FP32, FP16, INT32 | | Dropout | Y | FP32, FP16 | | DynamicQuantizeLinear | N | -| Einsum | Y | FP32, FP16 | Ellipsis and diagonal operations are not supported. +| Einsum | Y | FP32, FP16 | Ellipsis and diagonal operations are not supported. Broadcasting between inputs is not supported | Elu | Y | FP32, FP16, INT8 | | Equal | Y | FP32, FP16, INT32 | | Erf | Y | FP32, FP16 | @@ -89,7 +89,7 @@ See below for the support matrix of ONNX operators in ONNX-TensorRT. | MatMul | Y | FP32, FP16 | | MatMulInteger | N | | Max | Y | FP32, FP16, INT32 | -| MaxPool | Y | FP32, FP16, INT8 | +| MaxPool | Y | FP32, FP16, INT8 | 2D or 3D pooling only. `Indices` output tensor unsupported | MaxRoiPool | N | | MaxUnpool | N | | Mean | Y | FP32, FP16, INT32 | @@ -114,8 +114,8 @@ See below for the support matrix of ONNX operators in ONNX-TensorRT. | QuantizeLinear | Y | FP32, FP16 | `y_zero_point` must be 0 | | RandomNormal | N | | RandomNormalLike | N | -| RandomUniform | Y | FP32, FP16 | -| RandomUniformLike | Y | FP32, FP16 | +| RandomUniform | Y | FP32, FP16 | `seed` value is ignored by TensorRT +| RandomUniformLike | Y | FP32, FP16 | `seed` value is ignored by TensorRT | Range | Y | FP32, FP16, INT32 | Floating point inputs are only supported if `start`, `limit`, and `delta` inputs are initializers | | Reciprocal | N | | ReduceL1 | Y | FP32, FP16 | @@ -160,7 +160,7 @@ See below for the support matrix of ONNX operators in ONNX-TensorRT. | Softplus | Y | FP32, FP16, INT8 | | Softsign | Y | FP32, FP16, INT8 | | SpaceToDepth | Y | FP32, FP16, INT32 | -| Split | Y | FP32, FP16, INT32, BOOL | `split` must be an initializer | +| Split | Y | FP32, FP16, INT32, BOOL | | | SplitToSequence | N | | Sqrt | Y | FP32, FP16 | | Squeeze | Y | FP32, FP16, INT32, INT8, BOOL | `axes` must be an initializer | @@ -172,7 +172,7 @@ See below for the support matrix of ONNX operators in ONNX-TensorRT. | TfIdfVectorizer | N | | ThresholdedRelu | Y | FP32, FP16, INT8 | | Tile | Y | FP32, FP16, INT32, BOOL | -| TopK | Y | FP32, FP16 | +| TopK | Y | FP32, FP16 | `K` input must be an initializer | Transpose | Y | FP32, FP16, INT32, INT8, BOOL | | Unique | N | | Unsqueeze | Y | FP32, FP16, INT32, INT8, BOOL | `axes` must be a constant tensor | diff --git a/onnx2trt.hpp b/onnx2trt.hpp index 4ee38e04..680ef900 100644 --- a/onnx2trt.hpp +++ b/onnx2trt.hpp @@ -54,10 +54,11 @@ class IImporterContext virtual nvinfer1::ILogger& logger() = 0; virtual bool hasError() const = 0; virtual nvinfer1::IErrorRecorder* getErrorRecorder() const = 0; - virtual nvinfer1::IConstantLayer* getConstantLayer(const char* name) const = 0; protected: - virtual ~IImporterContext() {} + virtual ~IImporterContext() + { + } }; } // namespace onnx2trt diff --git a/onnx2trt_utils.cpp b/onnx2trt_utils.cpp index cf50bb9f..9138a2c4 100644 --- a/onnx2trt_utils.cpp +++ b/onnx2trt_utils.cpp @@ -776,12 +776,6 @@ nvinfer1::ITensor& convertToTensor(TensorOrWeights& input, IImporterContext* ctx } // Handle non-tensor indices input by adding a new constant layer to the network. ShapedWeights& weights = input.weights(); - - auto const existingConstantLayer = ctx->getConstantLayer(weights.getName()); - if (existingConstantLayer != nullptr) - { - return *(existingConstantLayer->getOutput(0)); - } // Note the TRT doesn't natively handle boolean weights. First create an INT32 weights copy of the boolean weights, // then cast it back to bool within TRT. if (weights.type == ::ONNX_NAMESPACE::TensorProto::BOOL) @@ -2314,13 +2308,42 @@ nvinfer1::ITensor* addSoftmax(IImporterContext* ctx, const ::ONNX_NAMESPACE::Nod return softMax->getOutput(0); } -nvinfer1::IScatterLayer* addScatterLayer( - IImporterContext* ctx, std::vector& inputs, nvinfer1::ScatterMode mode) +NodeImportResult addScatterLayer( + IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, std::vector& inputs, nvinfer1::ScatterMode mode, int32_t axis) { nvinfer1::ITensor& data = convertToTensor(inputs.at(0), ctx); nvinfer1::ITensor& indices = convertToTensor(inputs.at(1), ctx); nvinfer1::ITensor& updates = convertToTensor(inputs.at(2), ctx); - return ctx->network()->addScatter(data, indices, updates, mode); + + // Validate input dimensions + if (mode == nvinfer1::ScatterMode::kELEMENT) + { + const auto dataDims = data.getDimensions(); + const auto indicesDims = indices.getDimensions(); + const auto updatesDims = updates.getDimensions(); + + // Ranks must all be the same + ASSERT(dataDims.nbDims == indicesDims.nbDims && dataDims.nbDims == updatesDims.nbDims && "Input dimensions to ScatterElements must have the same rank!", + ErrorCode::kUNSUPPORTED_NODE); + + // Corresponding dimensions of indices and updates must be <= data + for (int32_t i = 0; i < dataDims.nbDims; ++i) + { + if (indicesDims.d[i] != -1 && dataDims.d[i] != -1) + { + ASSERT(indicesDims.d[i] <= dataDims.d[i] && "Indices dimensions must be less than data dimensions!", ErrorCode::kUNSUPPORTED_NODE); + } + if (updatesDims.d[i] != -1 && dataDims.d[i] != -1) + { + ASSERT(updatesDims.d[i] <= dataDims.d[i] && "Updates dimensions must be less than data dimensions!", ErrorCode::kUNSUPPORTED_NODE); + } + } + } + + auto* layer = ctx->network()->addScatter(data, indices, updates, mode); + layer->setAxis(axis); + ctx->registerLayer(layer, getNodeName(node)); + return {{layer->getOutput(0)}}; } } // namespace onnx2trt diff --git a/onnx2trt_utils.hpp b/onnx2trt_utils.hpp index db2c0cf6..0a9e12c0 100644 --- a/onnx2trt_utils.hpp +++ b/onnx2trt_utils.hpp @@ -373,7 +373,7 @@ ShapeTensor axesToInterlaceSubscripts(const ShapeTensor& axes, int nbDims); nvinfer1::ITensor* addSoftmax(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ITensor& input); // Helper function to import ONNX scatter nodes into TRT -nvinfer1::IScatterLayer* addScatterLayer( - IImporterContext* ctx, std::vector& inputs, nvinfer1::ScatterMode mode); +NodeImportResult addScatterLayer( + IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, std::vector& inputs, nvinfer1::ScatterMode mode, int32_t axis = 0); } // namespace onnx2trt diff --git a/onnx_tensorrt/__init__.py b/onnx_tensorrt/__init__.py index 2a3f701a..e241103d 100644 --- a/onnx_tensorrt/__init__.py +++ b/onnx_tensorrt/__init__.py @@ -4,4 +4,4 @@ from . import backend -__version__ = "8.2.0" +__version__ = "8.2.1"