aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sami Kama <skama@nvidia.com>2018-01-19 22:58:50 +0000
committerGravatar Sami Kama <skama@nvidia.com>2018-01-19 22:58:50 +0000
commit825e7a32e9f4dbad21a9ddb9d8a34bd3e32b1d0e (patch)
treeab8f8406065cac5ada579b571a1ebcc6338f5365
parente810b107d81a0016417b100bd89fd53e065e8d14 (diff)
Introducing TensortRT Operator to TF which can run (sub)graphs in
highly optimized TensorRT engines. This commit is a merged version of many commits by benbarsdell <bbarsdell at nvidia.com> deadeyegoodwin <davidg at nvidia.com jjsjann123 <jiej at nvidia.com> samikama <skama at nvidia.com>
-rw-r--r--configure.py126
-rw-r--r--tensorflow/BUILD8
-rw-r--r--tensorflow/contrib/BUILD5
-rw-r--r--tensorflow/contrib/tensorrt/BUILD266
-rw-r--r--tensorflow/contrib/tensorrt/README.md42
-rw-r--r--tensorflow/contrib/tensorrt/__init__.py19
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc253
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.h34
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc1737
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.h42
-rw-r--r--tensorflow/contrib/tensorrt/convert/inferShapes.cc125
-rw-r--r--tensorflow/contrib/tensorrt/convert/inferShapes.h39
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc183
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.h55
-rw-r--r--tensorflow/contrib/tensorrt/log/trt_logger.cc56
-rw-r--r--tensorflow/contrib/tensorrt/log/trt_logger.h41
-rw-r--r--tensorflow/contrib/tensorrt/ops/trt_engine_op.cc37
-rw-r--r--tensorflow/contrib/tensorrt/python/__init__.py8
-rw-r--r--tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py35
-rw-r--r--tensorflow/contrib/tensorrt/python/trt_convert.py91
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.cc259
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.h53
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment_test.cc363
-rw-r--r--tensorflow/contrib/tensorrt/segment/union_find.h77
-rw-r--r--tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc123
-rw-r--r--tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h28
-rw-r--r--tensorflow/contrib/tensorrt/trt_conversion.i84
-rw-r--r--tensorflow/tensorflow.bzl62
-rw-r--r--tensorflow/tools/pip_package/BUILD4
-rw-r--r--tensorflow/workspace.bzl2
-rw-r--r--third_party/tensorrt/BUILD0
-rw-r--r--third_party/tensorrt/BUILD.tpl42
-rw-r--r--third_party/tensorrt/LICENSE203
-rw-r--r--third_party/tensorrt/build_defs.bzl85
-rw-r--r--third_party/tensorrt/build_defs.bzl.tpl18
35 files changed, 4589 insertions, 16 deletions
diff --git a/configure.py b/configure.py
index cf16ef4837..580bbc0ebe 100644
--- a/configure.py
+++ b/configure.py
@@ -37,12 +37,14 @@ _TF_BAZELRC = os.path.join(os.path.dirname(os.path.abspath(__file__)),
_TF_WORKSPACE = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'WORKSPACE')
_DEFAULT_CUDA_VERSION = '9.0'
+_DEFAULT_TENSORRT_VERSION = '4'
_DEFAULT_CUDNN_VERSION = '7'
_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2'
_DEFAULT_CUDA_PATH = '/usr/local/cuda'
_DEFAULT_CUDA_PATH_LINUX = '/opt/cuda'
_DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing '
'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION)
+_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/x86_64-linux-gnu'
_TF_OPENCL_VERSION = '1.2'
_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
@@ -382,13 +384,12 @@ def set_build_var(environ_cp, var_name, query_item, option_name,
var = str(int(get_var(environ_cp, var_name, query_item, enabled_by_default)))
environ_cp[var_name] = var
- if var == '1':
- write_to_bazelrc('build --define %s=true' % option_name)
- elif bazel_config_name is not None:
- # TODO(mikecase): Migrate all users of configure.py to use --config Bazel
- # options and not to set build configs through environment variables.
- write_to_bazelrc('build:%s --define %s=true'
- % (bazel_config_name, option_name))
+ # TODO(mikecase): Migrate all users of configure.py to use --config Bazel
+ # options and not to set build configs through environment variables.
+ if var=='1':
+ setting='true'
+ confname=":%s"%(bazel_config_name) if bazel_config_name is not None else ""
+ write_to_bazelrc('build%s --define %s=%s' % (confname,option_name,setting))
def set_action_env_var(environ_cp,
@@ -438,13 +439,12 @@ def convert_version_to_int(version):
for seg in version_segments:
if not seg.isdigit():
return None
-
version_str = ''.join(['%03d' % int(seg) for seg in version_segments])
return int(version_str)
def check_bazel_version(min_version):
- """Check installed bezel version is at least min_version.
+ """Check installed bazel version is at least min_version.
Args:
min_version: string for minimum bazel version.
@@ -1056,6 +1056,108 @@ def set_other_cuda_vars(environ_cp):
write_to_bazelrc('test --config=cuda')
+def set_tf_trt_version(environ_cp):
+ """Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION."""
+ ask_trt_version = (
+ 'Please specify the TensorRT (libnvinfer) version you want to use. '
+ '[Leave empty to default to libnvinfer %s]: ') % _DEFAULT_TENSORRT_VERSION
+
+ while True:
+ tf_trt_version = get_from_env_or_user_or_default(
+ environ_cp, 'TF_TENSORRT_VERSION', ask_trt_version,
+ _DEFAULT_TENSORRT_VERSION)
+ # if library version is passed and known
+ default_trt_path = environ_cp.get('TENSORRT_INSTALL_PATH',_DEFAULT_TENSORRT_PATH_LINUX)
+ ask_trt_path = (r'Please specify the location where libnvinfer %s library is '
+ 'installed. Refer to README.md for more details. [Default'
+ ' is %s]:') % (tf_trt_version, default_trt_path)
+ trt_install_path = get_from_env_or_user_or_default(
+ environ_cp, 'TENSORRT_INSTALL_PATH', ask_trt_path, default_trt_path)
+
+ # Result returned from "read" will be used unexpanded. That make "~"
+ # unusable. Going through one more level of expansion to handle that.
+ trt_install_path = os.path.realpath(
+ os.path.expanduser(trt_install_path))
+ # Simple function to search for libnvinfer in install path
+ # it will find all libnvinfer.so* in user defined install path
+ # and lib64 subdirectory and return absolute paths
+ def find_libs(search_path):
+ fl=set()
+ if os.path.exists(search_path) and os.path.isdir(search_path):
+ fl.update([os.path.realpath(os.path.join(search_path,x)) \
+ for x in os.listdir(search_path) if 'libnvinfer.so' in x])
+ return fl
+ possible_files=find_libs(trt_install_path)
+ possible_files.update(find_libs(os.path.join(trt_install_path,'lib64')))
+ if is_linux():
+ cudnnpatt=re.compile(".*libcudnn.so\.?(.*) =>.*$")
+ cudapatt =re.compile(".*libcudart.so\.?(.*) =>.*$")
+ def is_compatible(lib,cudaver,cudnnver):
+ ldd_bin=which('ldd') or '/usr/bin/ldd'
+ ldd_out=run_shell([ldd_bin,lib]).split(os.linesep)
+ for l in ldd_out:
+ if 'libcudnn.so' in l:
+ cudnn=cudnnpatt.search(l)
+ elif 'libcudart.so' in l:
+ cudart=cudapatt.search(l)
+ if cudnn:
+ cudnn=convert_version_to_int(cudnn.group(1)) if len(cudnn.group(1)) else 0
+ if cudart:
+ cudart=convert_version_to_int(cudart.group(1)) if len(cudart.group(1)) else 0
+ return (cudnn==cudnnver) and (cudart==cudaver)
+ cudaver=convert_version_to_int(environ_cp['TF_CUDA_VERSION'])
+ cudnnver=convert_version_to_int(environ_cp['TF_CUDNN_VERSION'])
+ valid_libs=[]
+ vfinder=re.compile('.*libnvinfer.so.?(.*)$')
+ highest_ver=[0,None,None]
+
+ for l in possible_files:
+ if is_compatible(l,cudaver,cudnnver):
+ valid_libs.append(l)
+ vstr=vfinder.search(l).group(1)
+ currver=convert_version_to_int(vstr) if len(vstr) else 0
+ if currver > highest_ver[0]:
+ highest_ver= [currver,vstr,l]
+ if highest_ver[1] is not None:
+ trt_install_path=os.path.dirname(highest_ver[2])
+ tf_trt_version=highest_ver[1]
+ break
+ ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
+ libnvinfer_path_from_ldconfig = run_shell([ldconfig_bin, '-p'])
+ libnvinfer_path_from_ldconfig = re.search('.*libnvinfer.so.* => (.*)',
+ libnvinfer_path_from_ldconfig)
+ if libnvinfer_path_from_ldconfig:
+ libnvinfer_path_from_ldconfig = libnvinfer_path_from_ldconfig.group(1)
+ if os.path.exists('%s.%s' % (libnvinfer_path_from_ldconfig,
+ tf_trt_version)):
+ trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig)
+ break
+
+ # Reset and Retry
+ if len(possible_files):
+ print(
+ 'Invalid path to TensorRT %s. libnvinfer.so* files found are for incompatible cuda versions '
+ % tf_trt_version)
+ print(trt_install_path)
+ print(os.path.join(trt_install_path,'lib64'))
+ else:
+ print(
+ 'Invalid path to TensorRT %s. No libnvinfer.so* files found in '
+ 'found:' % tf_trt_version)
+ print(trt_install_path)
+ print(os.path.join(trt_install_path,'lib64'))
+ if is_linux():
+ print('%s.%s' % (libnvinfer_path_from_ldconfig, tf_trt_version))
+
+ environ_cp['TF_TENSORRT_VERSION'] = ''
+
+ # Set TENSORRT_INSTALL_PATH and TENSORRT_CUDNN_VERSION
+ environ_cp['TENSORRT_INSTALL_PATH'] = trt_install_path
+ write_action_env_to_bazelrc('TENSORRT_INSTALL_PATH', trt_install_path)
+ environ_cp['TF_TENSORRT_VERSION'] = tf_trt_version
+ write_action_env_to_bazelrc('TF_TENSORRT_VERSION', tf_trt_version)
+ write_to_bazelrc('build:tensorrt --define using_tensorrt=true')
+
def set_host_cxx_compiler(environ_cp):
"""Set HOST_CXX_COMPILER."""
default_cxx_host_compiler = which('g++') or ''
@@ -1244,9 +1346,11 @@ def main():
environ_cp['TF_NEED_COMPUTECPP'] = '0'
environ_cp['TF_NEED_OPENCL'] = '0'
environ_cp['TF_CUDA_CLANG'] = '0'
+ environ_cp['TF_NEED_TENSORRT'] = '0'
if is_macos():
environ_cp['TF_NEED_JEMALLOC'] = '0'
+ environ_cp['TF_NEED_TENSORRT'] = '0'
set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc',
'with_jemalloc', True)
@@ -1301,6 +1405,10 @@ def main():
if not is_windows():
set_gcc_host_compiler_path(environ_cp)
set_other_cuda_vars(environ_cp)
+ # enable tensorrt if desired. Disabled on non-linux
+ set_action_env_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False)
+ if environ_cp.get('TF_NEED_TENSORRT') == '1':
+ set_tf_trt_version(environ_cp)
set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False)
if environ_cp.get('TF_NEED_MPI') == '1':
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index da37564697..b374462d32 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -359,6 +359,14 @@ config_setting(
)
config_setting(
+ name = "using_tensorrt",
+ define_values = {
+ "using_tensorrt":"true",
+ },
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
name = "with_mpi_support",
values = {"define": "with_mpi_support=true"},
visibility = ["//visibility:public"],
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 8bed0fabd7..e5c3017426 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -7,6 +7,7 @@ package(default_visibility = ["//tensorflow:__subpackages__"])
load("//third_party/mpi:mpi.bzl", "if_mpi")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
+load("@local_config_tensorrt//:build_defs.bzl", "if_trt")
py_library(
name = "contrib_py",
@@ -104,7 +105,9 @@ py_library(
"//tensorflow/contrib/training:training_py",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:util",
- ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]),
+ ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_ops_py"])
+ + if_trt(["//tensorflow/contrib/tensorrt:init_py"]),
+
)
cc_library(
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
new file mode 100644
index 0000000000..723c9f5434
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -0,0 +1,266 @@
+# -*- python -*-
+# Description:
+# provide tensorrt operators and converter package
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+licenses(["notice"]) # Apache 2.0
+
+load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_custom_op_library",
+ "tf_gen_op_libs",
+ "tf_gen_op_wrapper_py",
+ "tf_py_wrap_cc",
+ "tf_cc_test",
+ "tf_kernel_library",
+ "tf_custom_op_py_library",
+ "tf_copts",
+)
+
+
+
+tf_custom_op_library(
+ name = "python/ops/_trt_engine_op.so",
+ srcs = [
+ "kernels/trt_engine_op.cc",
+ "ops/trt_engine_op.cc",
+ "kernels/trt_engine_op.h",
+ ],
+ gpu_srcs = [],
+ deps = [
+ "@local_config_tensorrt//:tensorrt",
+ ":trt_shape_function",
+ "//tensorflow/core:lib_proto_parsing",
+ "//tensorflow/core/kernels:bounds_check_lib",
+ "//tensorflow/core/kernels:ops_util_hdrs",
+ ],
+)
+
+cc_library(
+ name = "trt_shape_function",
+ srcs=[
+ "shape_fn/trt_shfn.cc",
+ ],
+ hdrs=["shape_fn/trt_shfn.h"],
+ copts=tf_copts(),
+ deps=[
+ ":trt_logging",
+ "//third_party/eigen3",
+ "@local_config_tensorrt//:tensorrt",
+ "@protobuf_archive//:protobuf",
+ "@nsync//:nsync_headers",
+ "//tensorflow/core:framework_headers_lib",
+ ]
+)
+
+
+tf_kernel_library(
+ name = "trt_engine_op_kernel",
+ srcs = [
+ "kernels/trt_engine_op.cc",
+ ],
+ hdrs=[
+ "kernels/trt_engine_op.h",
+ ],
+ gpu_srcs = [
+ ],
+ deps = [
+ ":trt_logging",
+ ":trt_shape_function",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ "//tensorflow/core:gpu_headers_lib",
+ "@local_config_tensorrt//:tensorrt",
+ "//tensorflow/core:lib_proto_parsing",
+ ],
+ alwayslink=1,
+)
+
+tf_gen_op_libs(
+ op_lib_names = [
+ "trt_engine_op",
+ ],
+ deps=[
+ "@local_config_tensorrt//:tensorrt",
+ ]
+)
+
+
+cc_library(
+ name="trt_logging",
+ srcs = [
+ "log/trt_logger.cc",
+ ],
+ hdrs=[
+ "log/trt_logger.h",
+ ],
+ deps=[
+ "@local_config_tensorrt//:tensorrt",
+ "//tensorflow/core:lib_proto_parsing",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+tf_gen_op_wrapper_py(
+ name = "trt_engine_op",
+ deps = [
+ ":trt_engine_op_op_lib",
+ ":trt_shape_function",
+ ],
+)
+
+
+tf_custom_op_py_library(
+ name = "trt_engine_op_loader",
+ srcs = ["python/ops/trt_engine_op.py"],
+ dso = [":python/ops/_trt_engine_op.so",
+ "@local_config_tensorrt//:tensorrt",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:resources",
+ ],
+)
+
+py_library(
+ name = "init_py",
+ srcs = [
+ "__init__.py",
+ "python/__init__.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":trt_ops_py",
+ ":trt_convert_py",
+
+ ],
+)
+
+py_library(
+ name="trt_ops_py",
+ srcs_version = "PY2AND3",
+ deps=[":trt_engine_op",
+ ":trt_engine_op_loader",
+ ],
+
+)
+
+py_library(
+ name="trt_convert_py",
+ srcs=["python/trt_convert.py"],
+ srcs_version = "PY2AND3",
+ deps=[
+ ":wrap_conversion"
+ ],
+)
+
+tf_py_wrap_cc(
+ name="wrap_conversion",
+ srcs=["trt_conversion.i"],
+ deps=[
+ ":trt_conversion",
+ "//tensorflow/core:framework_lite",
+ "//util/python:python_headers",
+ ],
+)
+
+cc_library(
+ name= "trt_conversion",
+ srcs=[
+ "convert/convert_nodes.cc",
+ "convert/convert_graph.cc",
+ "segment/segment.cc",
+ "convert/inferShapes.cc",
+ ],
+ hdrs=[
+ "convert/convert_nodes.h",
+ "convert/convert_graph.h",
+ "convert/inferShapes.h",
+ "segment/segment.h",
+ "segment/union_find.h",
+ ],
+ deps=[
+ "@local_config_tensorrt//:tensorrt",
+ "@protobuf_archive//:protobuf_headers",
+ "@nsync//:nsync_headers",
+ ":trt_logging",
+ "//tensorflow/core:framework_lite",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:core_cpu_base",
+ #"//third_party/eigen3",
+ ],
+)
+
+tf_custom_op_library(
+ name = "tensorrt_ops.so",
+ srcs = [
+ "ops/tensorrt_ops.cc",
+ ],
+ deps = [
+ "@local_config_tensorrt//:tensorrt",
+ ],
+)
+
+
+# Library for the segmenting portion of TensorRT operation creation
+cc_library(
+ name = "segment",
+ srcs = [
+ "segment/segment.cc",
+ ],
+ hdrs = [
+ "segment/union_find.h",
+ "segment/segment.h",
+ ],
+ deps = [
+ "@protobuf_archive//:protobuf_headers",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:lib_proto_parsing",
+ "//third_party/eigen3",
+ ],
+ linkstatic = 1,
+)
+
+tf_cc_test(
+ name = "segment_test",
+ size = "small",
+ srcs = ["segment/segment_test.cc"],
+ deps = [
+ ":segment",
+ "//tensorflow/c:c_api",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+
+# Library for the node-level conversion portion of TensorRT operation creation
+
+filegroup(
+ name = "cppfiles",
+ srcs = glob(["**/*.cc"]),
+ visibility=["//visibility:private"],
+)
+
+filegroup(
+ name = "headers",
+ srcs = glob(["**/*.h"]),
+ visibility=["//visibility:private"],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/tensorrt/README.md b/tensorflow/contrib/tensorrt/README.md
new file mode 100644
index 0000000000..61b348fc60
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/README.md
@@ -0,0 +1,42 @@
+Using TensorRT in TensorFlow
+============================
+
+This module provides necessary bindings and introduces TRT_engine_op
+operator that wraps a subgraph in TensorRT.
+
+Compilation
+-----------
+
+In order to compile the module, you need to have a local TensorRT
+installation (libnvinfer.so and respective include files). During the
+configuration step, TensorRT should be enabled and installation path
+should be set. If installed through package managers (deb,rpm),
+configure script should find the necessary components from the system
+automatically. If installed from tar packages, user has to set path to
+location where the library is installed during configuration.
+
+In order to enable TensorRT support, user has to add `--config=tensorrt` to
+the build flags during the compilation such as
+
+```
+bazel build --config=cuda --config=opt --config=tensorrt //tensorflow/tools/pip_package:build_pip_package
+bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/
+```
+
+After the installation of tensorflow package, TensorRT transformation
+will be available. An example use is shown below.
+
+```python
+import tensorflow as tf
+import tensorflow.contrib.tensorrt as trt
+#... create and train or load model
+gdef=sess.graph.as_graph_def()
+trt_gdef=trt.CreateInferenceGraph(gdef, #original graph_def
+ ["output"], #name of output node(s)
+ max_batch_size, #maximum batch size to run the inference
+ max_workspace_size # max memory for TensorRT to use
+ )
+tf.reset_default_graph()
+tf.import_graph_def(graph_def=trt_gdef)
+#...... run inference
+```
diff --git a/tensorflow/contrib/tensorrt/__init__.py b/tensorflow/contrib/tensorrt/__init__.py
new file mode 100644
index 0000000000..0d69ffe466
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.tensorrt.python import *
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
new file mode 100644
index 0000000000..29aa555467
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -0,0 +1,253 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
+
+#include <list>
+#include <set>
+#include <sstream>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+#include <map>
+#include <utility>
+
+#include "NvInfer.h"
+
+#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
+#include "tensorflow/contrib/tensorrt/convert/inferShapes.h"
+#include "tensorflow/contrib/tensorrt/segment/segment.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+
+#define _TF_LOG_DEBUG ::tensorflow::internal::LogMessage(__FILE__, __LINE__, -1)
+//------------------------------------------------------------------------------
+namespace tensorrt {
+namespace convert {
+
+namespace {
+
+static std::unordered_set<std::string> output_nodes;
+bool IsTensorRTCandidate(const tensorflow::NodeDef& node_def) {
+ static const std::set<std::string> candidate_ops = {
+ "Identity", "Const", "Conv2D", "MaxPool", "BiasAdd", "Relu",
+ "Add", "Mul", "Sub", "Rsqrt", "Pad" // "Placeholder" ,"Mean"
+ // TODO(ben,jie): ...
+ };
+ if (output_nodes.count(node_def.name())) return false;
+ return candidate_ops.count(node_def.op());
+}
+
+void GetSubGraphIncomingEdges(tensorflow::Graph const& graph,
+ std::set<int> const& subgraph_node_ids,
+ tensorflow::EdgeSet* incoming_edges) {
+ for (int node_id : subgraph_node_ids) {
+ tensorflow::Node const* node = graph.FindNodeId(node_id);
+ LOG(DEBUG) << node->name() << " has incoming edges: ";
+ for (tensorflow::Edge const* edge : node->in_edges()) {
+ if (!subgraph_node_ids.count(edge->src()->id()) &&
+ !edge->src()->IsSource()) {
+ LOG(DEBUG) << edge->src()->name() << ", ";
+ incoming_edges->insert(edge);
+ }
+ }
+ }
+}
+
+void GetSubGraphOutgoingEdges(tensorflow::Graph const& graph,
+ std::set<int> const& subgraph_node_ids,
+ tensorflow::EdgeSet* outgoing_edges) {
+ for (int node_id : subgraph_node_ids) {
+ tensorflow::Node const* node = graph.FindNodeId(node_id);
+ LOG(DEBUG) << node->name() << " has outgoing edges: ";
+ for (tensorflow::Edge const* edge : node->out_edges()) {
+ if (!subgraph_node_ids.count(edge->dst()->id()) &&
+ !edge->dst()->IsSink()) {
+ outgoing_edges->insert(edge);
+ }
+ }
+ }
+}
+
+std::pair<std::string, int> ParseTensorName(std::string name,
+ int default_idx = 0) {
+ int idx = default_idx;
+ size_t sep = name.find_last_of(':');
+ if (sep != std::string::npos) {
+ name = name.substr(0, sep);
+ idx = std::stoi(name.substr(sep + 1));
+ }
+ return std::make_pair(name, idx);
+}
+
+std::unordered_map<std::string, std::vector<int>> BuildTensorNameMap(
+ const std::vector<std::string>& tensor_names) {
+ std::unordered_map<std::string, std::vector<int>> result;
+ for (std::string const& tensor_name : tensor_names) {
+ std::string node_name;
+ int index;
+ std::tie(node_name, index) = ParseTensorName(tensor_name);
+ result[node_name].push_back(index);
+ }
+ return result;
+}
+
+tensorflow::Status ConvertSubGraphToTensorRT(
+ tensorflow::Graph& graph, const std::vector<std::string>& output_names,
+ const std::set<int>& subgraph_node_ids, size_t max_batch_size,
+ size_t max_workspace_size, const ShapeMap& shape_map) {
+ tensorflow::EdgeSet subgraph_incoming_edges;
+ GetSubGraphIncomingEdges(graph, subgraph_node_ids, &subgraph_incoming_edges);
+
+ std::vector<std::pair<int, int>> subgraph_inputs;
+
+
+ // Collect inputs by looking for incoming edges
+ for (tensorflow::Edge const* edge : subgraph_incoming_edges) {
+ subgraph_inputs.push_back({edge->src()->id(), edge->src_output()});
+ }
+ std::set<std::pair<int, int>> subgraph_outputs_set;
+ // Collect outputs referenced from output_names
+ auto output_name_to_index_map = BuildTensorNameMap(output_names);
+ // for (int node_id : subgraph_node_ids_no_placeholder) {
+ for (int node_id : subgraph_node_ids) {
+ tensorflow::Node* node = graph.FindNodeId(node_id);
+ if (output_name_to_index_map.count(node->name())) {
+ for (int index : output_name_to_index_map.at(node->name())) {
+ subgraph_outputs_set.insert({node_id, index});
+ }
+ }
+ }
+ // Collect outputs referenced from outgoing edges
+ tensorflow::EdgeSet subgraph_outgoing_edges;
+ // GetSubGraphOutgoingEdges(graph, subgraph_node_ids_no_placeholder,
+ // &subgraph_outgoing_edges);
+ GetSubGraphOutgoingEdges(graph, subgraph_node_ids, &subgraph_outgoing_edges);
+ for (tensorflow::Edge const* edge : subgraph_outgoing_edges) {
+ subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()});
+ }
+ // Impose an ordering on the outputs
+ std::vector<std::pair<int, int>> subgraph_outputs(
+ subgraph_outputs_set.begin(), subgraph_outputs_set.end());
+ // Build TensorRT node and add it to the graph
+ tensorflow::NodeDef trt_node_def;
+ TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef(
+ graph, subgraph_node_ids, subgraph_inputs, subgraph_outputs,
+ max_batch_size, max_workspace_size, shape_map, &trt_node_def));
+ tensorflow::Status status;
+ tensorflow::Node* trt_node = graph.AddNode(trt_node_def, &status);
+
+ TF_RETURN_IF_ERROR(status);
+
+ // Re-map outgoing edges to use the new TRT node instead of the orig subgraph
+ std::map<std::pair<int, int>, int> subgraph_edge_to_output_map;
+ for (size_t i = 0; i < subgraph_outputs.size(); ++i) {
+ subgraph_edge_to_output_map.insert({subgraph_outputs.at(i), i});
+ }
+ TF_RETURN_IF_ERROR(status);
+ for (tensorflow::Edge const* edge : subgraph_outgoing_edges) {
+ std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()};
+ int new_src_output = subgraph_edge_to_output_map.at(old_src);
+ graph.UpdateEdge(trt_node, new_src_output, edge->dst(), edge->dst_input());
+ }
+ // Remove the original subgraph
+ for (int node_id : subgraph_node_ids) {
+ tensorflow::Node* node = graph.FindNodeId(node_id);
+ // Don't remove the input placeholders
+ if (node->type_string() == "Placeholder") {
+ continue;
+ }
+ graph.RemoveNode(node);
+ }
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status BuildNodeMap(
+ const tensorflow::Graph& graph,
+ std::unordered_map<std::string, tensorflow::Node*>* node_map) {
+ for (auto* node : graph.op_nodes()) {
+ if (!node_map->insert({node->name(), node}).second) {
+ return tensorflow::errors::AlreadyExists(
+ "Node name is not unique in graph: " + node->name());
+ }
+ }
+ return tensorflow::Status::OK();
+}
+
+} // namespace
+
+tensorflow::Status ConvertGraphDefToTensorRT(
+ const tensorflow::GraphDef& graph_def,
+ const std::vector<std::string>& output_names, size_t max_batch_size,
+ size_t max_workspace_size, tensorflow::GraphDef* new_graph_def) {
+ ShapeMap shape_map;
+ TF_RETURN_IF_ERROR(
+ tensorflow::trt::inferShapes(graph_def, output_names, shape_map));
+ std::stringstream oss;
+ for (auto& n : shape_map) { // nodes
+ oss << " Node= " << n.first << ", ";
+ for (auto o : n.second) { // outputs
+ oss << o.first.DebugString() << " T= " << o.second << ", ";
+ }
+ LOG(DEBUG) << oss.str();
+ oss.str("");
+ }
+ // Build full graph
+ tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
+ graph_def.library());
+ tensorflow::Graph graph(flib);
+ TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
+ tensorflow::GraphConstructorOptions(), graph_def, &graph));
+
+ // Segment the graph into subgraphs that can be converted to TensorRT
+ tensorrt::segment::SegmentOptions segment_options;
+ // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT)
+ for (auto node : output_names) output_nodes.insert(node);
+
+ // TODO(sami): this should be passed as a knob!!!!
+ segment_options.minimum_segment_size = 2;
+ tensorrt::segment::SegmentNodesVector segments;
+ TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph(
+ graph_def, IsTensorRTCandidate, segment_options, &segments));
+ if (segments.size() > 1) {
+ // LOG(WARNING) << "Multiple TensorRT candidate subgraphs were found, "
+ //<< "but only the first can be converted.";
+ // segments.erase(++segments.begin(), segments.end());
+ LOG(INFO) << "MULTIPLE tensorrt candidate conversion: " << segments.size();
+ }
+ std::unordered_map<std::string, tensorflow::Node*> node_map;
+ TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
+ for (std::set<std::string> const& subgraph_node_names : segments) {
+ std::set<int> subgraph_node_ids;
+ for (std::string const& node_name : subgraph_node_names) {
+ subgraph_node_ids.insert(node_map.at(node_name)->id());
+ }
+ TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRT(
+ graph, output_names, subgraph_node_ids, max_batch_size,
+ max_workspace_size, shape_map));
+ }
+ graph.ToGraphDef(new_graph_def);
+ return tensorflow::Status::OK();
+}
+
+} // namespace convert
+} // namespace tensorrt
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h
new file mode 100644
index 0000000000..cd713de888
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h
@@ -0,0 +1,34 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorrt {
+namespace convert {
+
+tensorflow::Status ConvertGraphDefToTensorRT(
+ const tensorflow::GraphDef& graph_def,
+ const std::vector<std::string>& output_names, size_t max_batch_size,
+ size_t max_workspace_size, tensorflow::GraphDef* new_graph_def);
+}
+} // namespace tensorrt
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
new file mode 100644
index 0000000000..03146b1b54
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -0,0 +1,1737 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
+
+#include <algorithm>
+#include <fstream>
+#include <list>
+#include <map>
+#include <memory>
+#include <set>
+#include <sstream>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+#include "NvInfer.h"
+
+#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+
+#define _TF_LOG_DEBUG ::tensorflow::internal::LogMessage(__FILE__, __LINE__, -1)
+// Check if the types are equal. Cast to int first so that failure log message
+// would work!
+#define CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2)
+//------------------------------------------------------------------------------
+namespace tensorrt {
+namespace convert {
+
+namespace {
+
+inline int get_dtype_size(nvinfer1::DataType trt_dtype) {
+ switch (trt_dtype) {
+ case nvinfer1::DataType::kFLOAT:
+ return 4;
+ case nvinfer1::DataType::kINT8:
+ return 1;
+ case nvinfer1::DataType::kHALF:
+ return 2;
+ default:
+ return -1;
+ }
+}
+
+inline int get_dtype_size(tensorflow::DataType trt_dtype) {
+ switch (trt_dtype) {
+ case tensorflow::DataType::DT_FLOAT:
+ return 4;
+ case tensorflow::DataType::DT_INT8:
+ return 1;
+ case tensorflow::DataType::DT_HALF:
+ return 2;
+ case tensorflow::DataType::DT_INT32:
+ return 4;
+ default:
+ return -1;
+ }
+}
+
+inline tensorflow::Status convert_dtype(tensorflow::DataType tf_dtype,
+ nvinfer1::DataType* trt_dtype) {
+ switch (tf_dtype) {
+ case tensorflow::DataType::DT_FLOAT:
+ *trt_dtype = nvinfer1::DataType::kFLOAT;
+ break;
+ case tensorflow::DataType::DT_INT8:
+ *trt_dtype = nvinfer1::DataType::kINT8;
+ break;
+ case tensorflow::DataType::DT_HALF:
+ *trt_dtype = nvinfer1::DataType::kHALF;
+ break;
+ default:
+ return tensorflow::errors::InvalidArgument("Unsupported data type");
+ }
+ return tensorflow::Status::OK();
+}
+
+inline nvinfer1::Dims get_tensor_shape(const tensorflow::Tensor& tensor) {
+ nvinfer1::Dims dims;
+ dims.nbDims = tensor.dims();
+ for (int i = 0; i < dims.nbDims; i++) {
+ dims.d[i] = tensor.dim_size(i);
+ }
+ return dims;
+}
+
+inline int64_t get_shape_size(nvinfer1::Dims shape) {
+ // Returns total number of elements in shape
+ int64_t count = 1;
+ for (int d = 0; d < shape.nbDims; ++d) {
+ count *= shape.d[d];
+ }
+ return count;
+}
+
+static std::vector<std::pair<int, int>> createSamePadding(
+ nvinfer1::DimsHW& stride, nvinfer1::DimsHW& kernel,
+ std::vector<int64_t> inputDims) {
+ std::vector<std::pair<int, int>> padding(inputDims.size());
+ CHECK_EQ((size_t)stride.nbDims, inputDims.size()); // TODO(jie): N+C? NC+?
+
+ for (size_t i = 0; i < inputDims.size(); ++i) {
+ /* formula to calculate the padding */
+ int p = ((inputDims[i] - 1) / stride.d[i]) * stride.d[i] + kernel.d[i] -
+ inputDims[i];
+ p = (p > 0) ? p : 0;
+
+ /* right precedence padding, like in TensorFlow */
+ int left = p / 2;
+ int right = p - left;
+
+ padding[i] = {left, right};
+ }
+ return padding;
+}
+
+// class TRT_ShapedWeights : public nvinfer1::Weights {
+class TRT_ShapedWeights {
+ public:
+ nvinfer1::Dims shape_;
+ tensorflow::DataType type_;
+ const void* values_;
+ bool dummy_flag_;
+ int64_t count() const {
+ int64_t c = 1;
+ for (int i = 0; i < shape_.nbDims; i++) c *= shape_.d[i];
+ return c;
+ }
+ TRT_ShapedWeights(tensorflow::DataType type, const void* values,
+ nvinfer1::Dims shape)
+ : shape_(shape), type_(type), values_(values), dummy_flag_(false) {
+ // Note: this->shape.type[] is not used
+ }
+ explicit TRT_ShapedWeights(tensorflow::DataType type)
+ : type_(type), values_(nullptr), dummy_flag_(true) {}
+ nvinfer1::Weights getWeightsForTRT() const {
+ nvinfer1::DataType trt_type(nvinfer1::DataType::kFLOAT);
+ TF_CHECK_OK(convert_dtype(type_, &trt_type));
+ if (dummy_flag_) return nvinfer1::Weights{trt_type, nullptr, 0};
+
+ // Note: this->shape.type[] is not used
+ return nvinfer1::Weights{trt_type, values_, get_shape_size(shape_)};
+ }
+ size_t size_bytes() const {
+ return this->count() * get_dtype_size(this->type_);
+ }
+ // default converter
+ operator nvinfer1::Weights() const { return getWeightsForTRT(); }
+};
+
+class TRT_TensorOrWeights {
+ union {
+ nvinfer1::ITensor* _tensor_;
+ TRT_ShapedWeights _weights_;
+ };
+ enum { TRT_NODE_TENSOR, TRT_NODE_WEIGHTS } _variant_;
+
+ public:
+ explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor)
+ : _tensor_(tensor), _variant_(TRT_NODE_TENSOR) {}
+ explicit TRT_TensorOrWeights(TRT_ShapedWeights const& weights)
+ : _weights_(weights), _variant_(TRT_NODE_WEIGHTS) {}
+ TRT_TensorOrWeights() = delete;
+ bool is_tensor() const { return _variant_ == TRT_NODE_TENSOR; }
+ bool is_weights() const { return _variant_ == TRT_NODE_WEIGHTS; }
+ nvinfer1::ITensor* tensor() {
+ CHECK_EQ(this->is_tensor(), true);
+ return _tensor_;
+ }
+ nvinfer1::ITensor const* tensor() const {
+ CHECK_EQ(this->is_tensor(), true);
+ return _tensor_;
+ }
+ TRT_ShapedWeights& weights() {
+ CHECK_EQ(this->is_weights(), true);
+ return _weights_;
+ }
+ TRT_ShapedWeights const& weights() const {
+ CHECK_EQ(this->is_weights(), true);
+ return _weights_;
+ }
+ nvinfer1::Dims shape() const {
+ if (this->is_tensor()) {
+ return this->tensor()->getDimensions();
+ } else {
+ return this->weights().shape_;
+ }
+ }
+};
+
+class TRT_LayerOrWeights {
+ union {
+ nvinfer1::ILayer* _layer_;
+ TRT_ShapedWeights _weights_;
+ };
+ enum { TRT_NODE_LAYER, TRT_NODE_WEIGHTS } _variant_;
+
+ public:
+ explicit TRT_LayerOrWeights(nvinfer1::ILayer* layer)
+ : _layer_(layer), _variant_(TRT_NODE_LAYER) {}
+ explicit TRT_LayerOrWeights(TRT_ShapedWeights const& weights)
+ : _weights_(weights), _variant_(TRT_NODE_WEIGHTS) {}
+ bool is_layer() const { return _variant_ == TRT_NODE_LAYER; }
+ bool is_weights() const { return _variant_ == TRT_NODE_WEIGHTS; }
+ nvinfer1::ILayer* layer() {
+ CHECK_EQ(this->is_layer(), true);
+ return _layer_;
+ }
+ TRT_ShapedWeights& weights() {
+ CHECK_EQ(this->is_weights(), true);
+ return _weights_;
+ }
+ TRT_TensorOrWeights output(int index = 0) const {
+ if (this->is_layer()) {
+ nvinfer1::ITensor* tensor = _layer_->getOutput(index);
+ return TRT_TensorOrWeights(tensor);
+ } else {
+ CHECK_EQ(index, 0);
+ return TRT_TensorOrWeights(_weights_);
+ }
+ }
+};
+
+class TFAttrs {
+ typedef std::map<std::string, tensorflow::AttrValue const*> AttrMap;
+ AttrMap _attrs;
+
+ public:
+ explicit TFAttrs(tensorflow::NodeDef const& tf_node) {
+ for (auto const& attr : tf_node.attr()) {
+ _attrs.insert({attr.first, &attr.second});
+ }
+ }
+ bool count(std::string key) const { return _attrs.count(key); }
+ tensorflow::AttrValue const* at(std::string key) const {
+ if (!_attrs.count(key)) {
+ throw std::out_of_range("Attribute not found: " + key);
+ }
+ return _attrs.at(key);
+ }
+ template <typename T>
+ T get(std::string key) const;
+ template <typename T>
+ T getShape(std::string key) const;
+ template <typename T>
+ T get(std::string key, T const& default_value) const {
+ return _attrs.count(key) ? this->get<T>(key) : default_value;
+ }
+};
+// template <>
+// float TFAttrs::get<float>(std::string key) const {
+// return this->at(key)->f();
+//}
+
+// template <>
+// int TFAttrs::get<int>(std::string key) const {
+// return (int)this->at(key)->i();
+//}
+
+// template <>
+// bool TFAttrs::get<bool>(std::string key) const {
+// auto value = this->at(key)->i();
+// return bool(value);
+//}
+
+template <>
+std::string TFAttrs::get<std::string>(std::string key) const {
+ return this->at(key)->s();
+}
+template <>
+std::vector<int> TFAttrs::get<std::vector<int>>(std::string key) const {
+ auto attr = this->at(key)->list().i();
+ return std::vector<int>(attr.begin(), attr.end());
+}
+template <>
+nvinfer1::Dims TFAttrs::get<nvinfer1::Dims>(std::string key) const {
+ auto values = this->get<std::vector<int>>(key);
+ nvinfer1::Dims dims;
+ dims.nbDims = values.size();
+ std::copy(values.begin(), values.end(), dims.d);
+ // Note: No dimension type information is included
+ return dims;
+}
+// template <>
+// nvinfer1::DimsHW TFAttrs::get<nvinfer1::DimsHW>(std::string key) const {
+// nvinfer1::Dims dims = this->get<nvinfer1::Dims>(key);
+// CHECK_EQ(dims.nbDims, 2);
+// return nvinfer1::DimsHW(dims.d[0], dims.d[1]);
+//}
+// template <>
+// nvinfer1::Permutation TFAttrs::get<nvinfer1::Permutation>(
+// std::string key) const {
+// auto values = this->get<std::vector<int>>(key);
+// nvinfer1::Permutation perm;
+// std::copy(values.begin(), values.end(), perm.order);
+// // Fill unused values with -1 to aid debugging
+// std::fill(perm.order + values.size(), perm.order + nvinfer1::Dims::MAX_DIMS,
+// -1);
+// return perm;
+//}
+// template <>
+// nvinfer1::Dims TFAttrs::getShape<nvinfer1::Dims>(std::string key) const {
+// auto attr = this->at(key)->shape();
+// nvinfer1::Dims dims;
+// dims.nbDims = attr.dim_size();
+// for (int i = 0; i < dims.nbDims; i++) dims.d[i] = attr.dim(i).size();
+// return dims;
+//}
+// template<> TRT_ShapedWeights TFAttrs::get<TRT_ShapedWeights>(std::string key)
+// const {
+// tensorflow::TensorProto const* tf_weights_tensor = &this->at(key)->tensor();
+// TODO(jie): Implement this
+// return convert_tf_weights(tf_weights_tensor);
+//}
+template <>
+nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(std::string key) const {
+ nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT);
+ TF_CHECK_OK(convert_dtype(this->at(key)->type(), &trt_dtype));
+ return trt_dtype;
+}
+template <>
+tensorflow::DataType TFAttrs::get<tensorflow::DataType>(std::string key) const {
+ return this->at(key)->type();
+}
+
+template <typename T>
+void reorder4(nvinfer1::DimsNCHW shape, T const* idata,
+ nvinfer1::DimsNCHW istrides, T* odata,
+ nvinfer1::DimsNCHW ostrides) {
+ for (int n = 0; n < shape.n(); ++n) {
+ for (int c = 0; c < shape.c(); ++c) {
+ for (int h = 0; h < shape.h(); ++h) {
+ for (int w = 0; w < shape.w(); ++w) {
+ odata[n * ostrides.n() + c * ostrides.c() + h * ostrides.h() +
+ w * ostrides.w()] = idata[n * istrides.n() + c * istrides.c() +
+ h * istrides.h() + w * istrides.w()];
+ }
+ }
+ }
+ }
+}
+
+void reorder_rsck_to_kcrs(TRT_ShapedWeights const& iweights,
+ TRT_ShapedWeights* oweights) {
+ CHECK_EQ(iweights.type_, oweights->type_);
+ CHECK_EQ(iweights.size_bytes(), oweights->size_bytes());
+ int r = iweights.shape_.d[0];
+ int s = iweights.shape_.d[1];
+ int c = iweights.shape_.d[2];
+ int k = iweights.shape_.d[3];
+ oweights->shape_.d[0] = k;
+ oweights->shape_.d[1] = c;
+ oweights->shape_.d[2] = r;
+ oweights->shape_.d[3] = s;
+ // nvinfer1::DimsNCHW istrides = {1, s, c*r*s, r*s};
+ nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k};
+ nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1};
+ switch (iweights.type_) {
+ case tensorflow::DataType::DT_FLOAT:
+ reorder4(
+ {k, c, r, s}, static_cast<float const*>(iweights.values_), istrides,
+ static_cast<float*>(const_cast<void*>(oweights->values_)), ostrides);
+ break;
+ default:
+ LOG(FATAL) << "!!!!!!!!!!!!!!!!!!!!!!!!broke!!!!!!!!!!!!";
+ }
+}
+
+/* not used. clean up needed.
+nvinfer1::Weights make_dummy_weights(nvinfer1::DataType
+dtype=nvinfer1::DataType::kFLOAT) { nvinfer1::Weights w; w.count = 0; w.values
+= nullptr; w.type = dtype; return w;
+}
+*/
+
+struct InferDeleter {
+ template <typename T>
+ void operator()(T* obj) const {
+ if (obj) {
+ obj->destroy();
+ }
+ }
+};
+
+template <typename T>
+inline std::shared_ptr<T> infer_object(T* obj) {
+ return std::shared_ptr<T>(obj, InferDeleter());
+}
+
+// Logger for GIE info/warning/errors
+class Converter;
+
+using OpConverter =
+ std::function<tensorflow::Status(Converter&, tensorflow::NodeDef const&,
+ std::vector<TRT_TensorOrWeights> const&,
+ std::vector<TRT_TensorOrWeights>*)>;
+
+class Converter {
+ std::unordered_map<std::string, TRT_TensorOrWeights> _trt_tensors;
+ std::unordered_map<std::string, OpConverter> _op_registry;
+ nvinfer1::INetworkDefinition* _trt_network;
+ std::list<std::vector<uint8_t>> _temp_bufs;
+
+ void register_op_converters();
+
+ std::vector<TRT_TensorOrWeights> get_inputs(
+ tensorflow::NodeDef const& node_def) {
+ std::vector<TRT_TensorOrWeights> inputs;
+ for (auto const& input_name : node_def.input()) {
+ LOG(DEBUG) << "retrieve input: " << input_name;
+ inputs.push_back(_trt_tensors.at(input_name));
+ }
+ return inputs;
+ }
+
+ public:
+ explicit Converter(nvinfer1::INetworkDefinition* trt_network)
+ : _trt_network(trt_network) {
+ this->register_op_converters();
+ }
+
+ TRT_ShapedWeights get_temp_weights(tensorflow::DataType type,
+ nvinfer1::Dims shape) {
+ TRT_ShapedWeights weights(type, nullptr, shape);
+ _temp_bufs.push_back(std::vector<uint8_t>(weights.size_bytes()));
+ weights.values_ = _temp_bufs.back().data();
+ return weights;
+ }
+
+ TRT_ShapedWeights get_temp_weights_like(TRT_ShapedWeights const& weights) {
+ return this->get_temp_weights(weights.type_, weights.shape_);
+ }
+
+ tensorflow::Status convert_node(tensorflow::NodeDef const& node_def) {
+ std::vector<TRT_TensorOrWeights> inputs = this->get_inputs(node_def);
+ std::string op = node_def.op();
+ if (!_op_registry.count(op)) {
+ return tensorflow::errors::Unimplemented(
+ "no converter registered for op: " + op);
+ }
+ OpConverter op_converter = _op_registry.at(op);
+ std::vector<TRT_TensorOrWeights> outputs;
+ TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs));
+ for (size_t i = 0; i < outputs.size(); ++i) {
+ TRT_TensorOrWeights output = outputs.at(i);
+ // TODO(jie): tf protobuf seems to be omitting the :0 suffix
+ std::string output_name = node_def.name();
+ if (i != 0) output_name = output_name + ":" + std::to_string(i);
+ if (output.is_tensor()) {
+ output.tensor()->setName(output_name.c_str());
+ }
+ LOG(DEBUG) << "write out tensor: " << output_name;
+ if (!_trt_tensors.insert({output_name, output}).second) {
+ return tensorflow::errors::AlreadyExists(
+ "output tensor already exists for op: " + op);
+ }
+ }
+ return tensorflow::Status::OK();
+ }
+
+ nvinfer1::INetworkDefinition* network() { return _trt_network; }
+
+ TRT_TensorOrWeights get_tensor(std::string name) {
+ if (!_trt_tensors.count(name)) {
+ return TRT_TensorOrWeights(nullptr);
+ }
+ return _trt_tensors.at(name);
+ }
+
+ bool insert_input_tensor(std::string name, nvinfer1::ITensor* tensor) {
+ return _trt_tensors.insert({name, TRT_TensorOrWeights(tensor)}).second;
+ }
+
+ nvinfer1::ITensor* transposeTensor(nvinfer1::ITensor* input_tensor,
+ std::vector<int> order) {
+ auto dims = input_tensor->getDimensions();
+
+ // TODO(jie): change the return to status and properly exit
+ if (order.size() - 1 != size_t(dims.nbDims))
+ LOG(ERROR) << "dimension does not match, fail gracefully";
+
+ nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input_tensor);
+ nvinfer1::Permutation permutation;
+ for (int32_t i = 0; i < dims.nbDims; ++i) {
+ permutation.order[i] = order[i + 1] - 1;
+ }
+ layer->setFirstTranspose(permutation);
+
+ nvinfer1::Dims reshapeDims;
+ reshapeDims.nbDims = dims.nbDims;
+ for (int32_t i = 0; i < reshapeDims.nbDims; ++i) {
+ reshapeDims.d[i] = 0;
+ reshapeDims.type[i] = dims.type[i];
+ }
+ layer->setReshapeDimensions(reshapeDims);
+ return layer->getOutput(0);
+ }
+};
+
+/*******************************************************************************
+ Constant folding functions
+ TODO(jie): once optimizer kicks in, we should have done constant folding
+there.
+*******************************************************************************/
+struct LambdaFactory {
+ enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB };
+ OP_CATEGORY op;
+
+ template <typename T>
+ std::function<T(T)> unary() {
+ switch (op) {
+ case OP_CATEGORY::RSQRT: {
+ LOG(DEBUG) << "RSQRT GETS DONE";
+ return [](T t) -> T { return 1.0 / std::sqrt(t); };
+ }
+ case OP_CATEGORY::NEG:
+ return [](T t) -> T { return -t; };
+ default:
+ LOG(DEBUG) << "not supported op for unary: " << static_cast<int>(op);
+ return nullptr;
+ }
+ }
+
+ template <typename T>
+ std::function<T(T, T)> binary() {
+ switch (op) {
+ case OP_CATEGORY::ADD:
+ return [](T l, T r) -> T { return l + r; };
+ case OP_CATEGORY::SUB:
+ return [](T l, T r) -> T { return l - r; };
+ case OP_CATEGORY::MUL:
+ return [](T l, T r) -> T { return l * r; };
+ default:
+ LOG(WARNING) << "not supported op for binary: " << static_cast<int>(op);
+ }
+ return [](T l, T r) -> T {
+ LOG(FATAL) << "Unsupported op type ";
+ return l;
+ };
+ }
+
+ template <typename T>
+ std::function<T(T)> broadcast_r(T val) {
+ LOG(DEBUG) << "LAMBDA VAL : " << val;
+ switch (op) {
+ case OP_CATEGORY::ADD:
+ return [val](T l) -> T {
+ LOG(DEBUG) << "LAMBDA VAL : " << val;
+ return l + val;
+ };
+ // return [val](T l)-> T {return l+val;};
+ case OP_CATEGORY::SUB:
+ return [val](T l) -> T {
+ LOG(DEBUG) << "LAMBDA VAL : " << val;
+ return l - val;
+ };
+ case OP_CATEGORY::MUL:
+ return [val](T l) -> T {
+ LOG(DEBUG) << "LAMBDA VAL : " << val;
+ return l * val;
+ };
+ default:
+ LOG(WARNING) << "not supported op for binary: " << static_cast<int>(op);
+ }
+ return [val](T l) -> T {
+ LOG(FATAL) << "Unsupported op type ";
+ return l;
+ };
+ }
+
+ template <typename T>
+ std::function<T(T)> broadcast_l(T val) {
+ LOG(DEBUG) << "LAMBDA VAL : " << val;
+ switch (op) {
+ case OP_CATEGORY::ADD:
+ return [val](T l) -> T {
+ LOG(DEBUG) << "LAMBDA VAL : " << val;
+ return val + l;
+ };
+ case OP_CATEGORY::SUB:
+ return [val](T l) -> T {
+ LOG(DEBUG) << "LAMBDA VAL : " << val;
+ return val - l;
+ };
+ case OP_CATEGORY::MUL:
+ return [val](T l) -> T {
+ LOG(DEBUG) << "LAMBDA VAL : " << val;
+ return val * l;
+ };
+ default:
+ LOG(ERROR) << "not supported op for binary: " << static_cast<int>(op);
+ }
+ return [val](T l) -> T {
+ LOG(FATAL) << "Unsupported op type ";
+ return l;
+ };
+ }
+};
+
+tensorflow::Status UnaryCompute(TRT_ShapedWeights const& iweights,
+ TRT_ShapedWeights* oweights,
+ LambdaFactory unary_op) {
+ // assume iweights.type == oweights.type
+ CHECK_EQ(iweights.type_, oweights->type_);
+
+ switch (iweights.type_) {
+ case tensorflow::DataType::DT_FLOAT: {
+ auto inp = static_cast<float const*>(iweights.values_);
+ auto oup = static_cast<float*>(const_cast<void*>(oweights->values_));
+ std::transform(inp, inp + iweights.count(), oup, unary_op.unary<float>());
+ break;
+ }
+ default:
+ return tensorflow::errors::Unimplemented("data type not supported: " +
+ iweights.type_);
+ }
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status BinaryCompute(TRT_ShapedWeights const& iweights_l,
+ TRT_ShapedWeights const& iweights_r,
+ TRT_ShapedWeights* oweights,
+ LambdaFactory binary_op) {
+ // assume iweights_l.type == iweight_r.type
+ CHECK_EQ(iweights_l.type_, oweights->type_);
+ CHECK_EQ(iweights_r.type_, oweights->type_);
+ LOG(DEBUG) << "SANITY CHECK!";
+
+ switch (iweights_l.type_) {
+ case tensorflow::DataType::DT_FLOAT: {
+ auto inp_l = static_cast<float const*>(iweights_l.values_);
+ auto inp_r = static_cast<float const*>(iweights_r.values_);
+ auto oup = static_cast<float*>(const_cast<void*>(oweights->values_));
+
+ if (iweights_l.count() != iweights_r.count()) {
+ // we only supports broadcast of RankZero
+ if (iweights_l.count() == 1) {
+ LOG(DEBUG) << "I bet it is not working!" << (*inp_l);
+ std::transform(inp_r, inp_r + iweights_r.count(), oup,
+ binary_op.broadcast_l<float>(*inp_l));
+ } else if (iweights_r.count() == 1) {
+ LOG(DEBUG) << "I bet it is not working!" << (*inp_r);
+ std::transform(inp_l, inp_l + iweights_l.count(), oup,
+ binary_op.broadcast_r<float>(*inp_r));
+ } else {
+ return tensorflow::errors::Unimplemented(
+ "Binary op with non-rankZero broadcast not supported");
+ }
+ } else {
+ std::transform(inp_l, inp_l + iweights_l.count(), inp_r, oup,
+ binary_op.binary<float>());
+ }
+ break;
+ }
+ default:
+ return tensorflow::errors::Unimplemented("data type not supported: " +
+ iweights_l.type_);
+ }
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConstantFoldUnary(
+ Converter& ctx, tensorflow::NodeDef const& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ TRT_ShapedWeights weights_input = inputs.at(0).weights();
+
+ // allocate output weights
+ TRT_ShapedWeights weights_output = ctx.get_temp_weights_like(weights_input);
+
+ // FIXME assume type matches input weights
+ // get trt type & shape
+ // maybe this part has to be moved into the block of rsqrt later
+ // check type consistency
+ CHECK_EQ(weights_input.type_,
+ TFAttrs(node_def).get<tensorflow::DataType>("T"));
+
+ // Maybe I should do a switch
+ LambdaFactory unary_op;
+ if (node_def.op() == "Rsqrt") {
+ // compute rsqrt
+ unary_op.op = LambdaFactory::OP_CATEGORY::RSQRT;
+ auto ret = UnaryCompute(weights_input, &weights_output, unary_op);
+ // pass the output
+ if (ret == tensorflow::Status::OK()) {
+ outputs->push_back(TRT_TensorOrWeights(weights_output));
+ }
+ return ret;
+ } else {
+ return tensorflow::errors::Unimplemented("Binary op not supported: " +
+ node_def.op());
+ }
+}
+
+// TODO(jie,ben) broadcast is needed yet not implemented
+// Let's get the simple stuff working first. Maybe we should fall bakc to TF
+// approach for constant folding
+tensorflow::Status ConstantFoldBinary(
+ Converter& ctx, tensorflow::NodeDef const& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ TRT_ShapedWeights weights_input_l = inputs.at(0).weights();
+ TRT_ShapedWeights weights_input_r = inputs.at(1).weights();
+
+ // check type consistency
+ CHECK_EQ(weights_input_l.type_, weights_input_r.type_);
+
+ if (weights_input_l.shape_.nbDims != weights_input_r.shape_.nbDims)
+ return tensorflow::errors::Unimplemented(
+ "Binary op implicit broadcast not supported: " + node_def.op());
+
+ // TODO(jie): constant fold should really fall back to TF.
+ int nbDims = weights_input_l.shape_.nbDims;
+ nvinfer1::Dims output_shape;
+ output_shape.nbDims = nbDims;
+ LOG(DEBUG) << "nbDims: " << nbDims
+ << "the other: " << weights_input_r.shape_.nbDims;
+ for (int i = 0; i < nbDims; i++) {
+ if (weights_input_l.shape_.d[i] == weights_input_r.shape_.d[i]) {
+ output_shape.d[i] = weights_input_l.shape_.d[i];
+ } else if (weights_input_l.shape_.d[i] == 1 ||
+ weights_input_r.shape_.d[i] == 1) {
+ output_shape.d[i] =
+ std::max(weights_input_l.shape_.d[i], weights_input_r.shape_.d[i]);
+ } else {
+ return tensorflow::errors::Unimplemented(
+ "Binary op with incompatible shape at, " + node_def.op());
+ }
+ LOG(DEBUG) << "left: " << weights_input_l.shape_.d[i]
+ << "right: " << weights_input_r.shape_.d[i]
+ << "output: " << output_shape.d[i];
+ }
+
+ // FIXME assume type matches input weights
+ // get trt type & shape
+ TFAttrs attrs(node_def);
+ // maybe this part has to be moved into the block of rsqrt later
+ tensorflow::DataType dtype = attrs.get<tensorflow::DataType>("T");
+
+ // allocate output weights
+ TRT_ShapedWeights weights_output = ctx.get_temp_weights(dtype, output_shape);
+
+ // Maybe I should do a switch
+ LambdaFactory binary_op;
+ if (node_def.op() == "Sub") {
+ binary_op.op = LambdaFactory::OP_CATEGORY::SUB;
+ } else if (node_def.op() == "Mul") {
+ binary_op.op = LambdaFactory::OP_CATEGORY::MUL;
+ } else if (node_def.op() == "Add") {
+ binary_op.op = LambdaFactory::OP_CATEGORY::ADD;
+ } else {
+ return tensorflow::errors::Unimplemented("Binary op not supported: " +
+ node_def.op());
+ }
+ auto ret = BinaryCompute(weights_input_l, weights_input_r, &weights_output,
+ binary_op);
+
+ // pass the output
+ if (ret == tensorflow::Status::OK()) {
+ outputs->push_back(TRT_TensorOrWeights(weights_output));
+ }
+
+ return ret;
+}
+
+// TODO(jie): broadcast is needed yet not implemented
+// only implemented channel wise for the time being
+tensorflow::Status BinaryTensorOpWeight(
+ Converter& ctx, tensorflow::NodeDef const& node_def,
+ const nvinfer1::ITensor* tensor, TRT_ShapedWeights weights,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ // FIXME assume type matches input weights
+ // get trt type & shape
+ // maybe this part has to be moved into the block of rsqrt later
+
+ // check type consistency
+ auto dtype = TFAttrs(node_def).get<nvinfer1::DataType>("T");
+ CHECK_EQ_TYPE(tensor->getType(), dtype); // cast to int for error messages
+ nvinfer1::DataType ttype;
+ TF_CHECK_OK(convert_dtype(weights.type_, &ttype));
+ CHECK_EQ_TYPE(ttype, dtype); // cast to int for error message
+
+ // check scale mode
+ auto dims_w = weights.shape_;
+ auto dims_t = tensor->getDimensions();
+
+ // default to channel-wise
+ auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
+
+ /*
+ if (weights.count() == 1) {
+ LOG(DEBUG) << "UNIFORM";
+ scale_mode = nvinfer1::ScaleMode::kUNIFORM;
+ } else if (dims_w.nbDims == 1) {
+ // TODO(jie): should we check for implicit chennel wise binary op
+ // where weights has shape 1x1xC?
+ LOG(DEBUG) << "CHANNEL";
+ scale_mode = nvinfer1::ScaleMode::kCHANNEL;
+ } else {
+ // TODO(jie): check weight shape.
+ // broadcast is not fully supported
+ LOG(DEBUG) << "ELEMENTWISE";
+ scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
+ } */
+
+ if (weights.count() == 1) {
+ LOG(DEBUG) << "UNIFORM";
+ scale_mode = nvinfer1::ScaleMode::kUNIFORM;
+ } else {
+ // no broadcasting on Batch dimension;
+ assert(dims_w.d[0]==1);
+
+ // broadcasting on Channel dimension only allowed in kUNIFORM
+ assert(dims_w.d[1]==dims_t.d[0]);
+ assert(dims_w.nbDims==dims_t.nbDims);
+
+ // default is element;
+ for (int i=2; i<dims_w.nbDims; i++) {
+ if (dims_w.d[i]!=dims_t.d[i-1]) {
+ scale_mode = nvinfer1::ScaleMode::kCHANNEL;
+ break;
+ }
+ }
+ if (scale_mode == nvinfer1::ScaleMode::kELEMENTWISE) {
+ scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
+ for (int i=2; i<dims_w.nbDims; i++) {
+ if (dims_w.d[i]!=1)
+ return tensorflow::errors::InvalidArgument(
+ "Weight shape not compatible at, " + node_def.name());
+ }
+ }
+ }
+
+ // transpose last dimension
+ /*
+ std::vector<int> permutation(dims_t.nbDims + 1);
+ if (scale_mode == nvinfer1::ScaleMode::kCHANNEL && dims_t.nbDims > 1) {
+ // we swap the last dimension into channel for trt.
+ // because of tensorflow default broadcasting rules.
+ for (int i = 0; i < static_cast<int>(permutation.size()); i++) {
+ permutation[i] = i;
+ }
+ permutation[1] = dims_t.nbDims;
+ permutation[dims_t.nbDims] = 1;
+ tensor = ctx.transposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
+ permutation);
+ }
+ */
+
+ // prepare weights
+ TRT_ShapedWeights shiftWeights(weights.type_);
+ TRT_ShapedWeights scaleWeights(weights.type_);
+ TRT_ShapedWeights powerWeights(weights.type_);
+
+ // Maybe I should do a switch
+ if (node_def.op() == "Sub") {
+ TRT_ShapedWeights neg_weights = ctx.get_temp_weights_like(weights);
+ LambdaFactory unary_op;
+ unary_op.op = LambdaFactory::OP_CATEGORY::NEG;
+ UnaryCompute(weights, &neg_weights, unary_op);
+ shiftWeights = neg_weights;
+ } else if (node_def.op() == "Mul") {
+ scaleWeights = weights;
+ } else if (node_def.op() == "Add") {
+ shiftWeights = weights;
+ } else {
+ return tensorflow::errors::Unimplemented("Binary op not supported: " +
+ node_def.op());
+ }
+
+ nvinfer1::IScaleLayer* layer = ctx.network()->addScale(
+ *const_cast<nvinfer1::ITensor*>(tensor), scale_mode, shiftWeights,
+ scaleWeights, powerWeights);
+
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+ // transpose back dimension
+ /*
+ if (scale_mode == nvinfer1::ScaleMode::kCHANNEL && dims_t.nbDims > 1) {
+ output_tensor = ctx.transposeTensor(output_tensor, permutation);
+ }
+ */
+
+ // pass the output
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status BinaryTensorOpTensor(
+ Converter& ctx, tensorflow::NodeDef const& node_def,
+ const nvinfer1::ITensor* tensor_l, const nvinfer1::ITensor* tensor_r,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ static const std::unordered_map<std::string, nvinfer1::ElementWiseOperation>
+ ops{
+ {"Add", nvinfer1::ElementWiseOperation::kSUM},
+ {"Mul", nvinfer1::ElementWiseOperation::kPROD},
+ // {"max", nvinfer1::ElementWiseOperation::kMAX},
+ // {"min", nvinfer1::ElementWiseOperation::kMIN},
+ {"Sub", nvinfer1::ElementWiseOperation::kSUB},
+ {"Div", nvinfer1::ElementWiseOperation::kDIV},
+ };
+
+ // FIXME assume type matches input weights
+ // get trt type & shape
+ TFAttrs attrs(node_def);
+ // maybe this part has to be moved into the block of rsqrt later
+ nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("T");
+
+ // check type consistency
+ CHECK_EQ_TYPE(tensor_l->getType(), dtype);
+ CHECK_EQ_TYPE(tensor_r->getType(), dtype);
+ auto op_pair = ops.find(node_def.op());
+ if (op_pair == ops.end())
+ return tensorflow::errors::Unimplemented(
+ "binary op: " + node_def.op() +
+ " not supported at: " + node_def.name());
+
+ nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise(
+ *const_cast<nvinfer1::ITensor*>(tensor_l),
+ *const_cast<nvinfer1::ITensor*>(tensor_r), op_pair->second);
+
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+
+ // pass the output
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertPlaceholder(
+ Converter& ctx, tensorflow::NodeDef const& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ LOG(DEBUG) << "Placeholder should have been replace already";
+ return tensorflow::errors::Unimplemented("cannot convert Placeholder op");
+ // OK this make sense since we are supposed to replace it with input
+ TFAttrs attrs(node_def);
+ nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("dtype");
+ nvinfer1::Dims dims = attrs.get<nvinfer1::Dims>("shape");
+
+ dims.nbDims--;
+ for (int i = 0; i < dims.nbDims; i++) dims.d[i] = dims.d[i + 1];
+
+ nvinfer1::ITensor* output =
+ ctx.network()->addInput(node_def.name().c_str(), dtype, dims);
+ if (!output) {
+ return tensorflow::errors::InvalidArgument("Failed to create Input layer");
+ }
+ outputs->push_back(TRT_TensorOrWeights(output));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertConv2D(Converter& ctx,
+ tensorflow::NodeDef const& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+ // nvinfer1::ITensor* tensor = inputs.at(0).tensor();
+ // TODO(jie): handle NHWC/NCHW transpose;
+ TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
+ TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_rsck);
+ reorder_rsck_to_kcrs(weights_rsck, &weights);
+ TRT_ShapedWeights biases(weights.type_);
+ int noutput = weights.shape_.d[0];
+ nvinfer1::DimsHW kernel_size;
+ kernel_size.h() = weights.shape_.d[2];
+ kernel_size.w() = weights.shape_.d[3];
+ TFAttrs attrs(node_def);
+
+ int h_index = 2;
+ int w_index = 3;
+ auto data_format = attrs.get<std::string>("data_format");
+ if (data_format == "NHWC") {
+ tensor = ctx.transposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
+ {0, 3, 1, 2});
+ h_index = 1;
+ w_index = 2;
+ // TODO(jie): transpose it
+ } else {
+ LOG(DEBUG) << "NCHW !!!!";
+ }
+ // TODO(jie): stride. (NHWC/NCHW)
+ auto tf_stride = attrs.get<std::vector<int>>("strides");
+ nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
+
+ auto tensor_dim = tensor->getDimensions();
+ std::vector<std::pair<int, int>> padding;
+ // TODO(jie): padding.
+ if (attrs.get<std::string>("padding") == "SAME") {
+ // This is NCHW tensor with no batch dimension.
+ // 1 -> h
+ // 2 -> w
+ padding = createSamePadding(stride, kernel_size,
+ {static_cast<int>(tensor_dim.d[h_index]),
+ static_cast<int>(tensor_dim.d[w_index])});
+ } else {
+ // return tensorflow::errors::Unimplemented(
+ // "Current Conv2D cannot support padding other than SAME");
+ padding = {{0, 0}, {0, 0}};
+ }
+
+ if (padding[0].first != padding[0].second ||
+ padding[1].first != padding[1].second) {
+ // TODO(jie): handle asymmetric padding
+ // return tensorflow::errors::Unimplemented(
+ // "Asymmetric padding not implemented yet");
+ auto padLayer = ctx.network()->addPadding(
+ *const_cast<nvinfer1::ITensor*>(tensor),
+ nvinfer1::DimsHW(padding[1].first, padding[0].first),
+ nvinfer1::DimsHW(padding[1].second, padding[0].second));
+ tensor = padLayer->getOutput(0);
+ }
+
+ nvinfer1::IConvolutionLayer* layer =
+ ctx.network()->addConvolution(*const_cast<nvinfer1::ITensor*>(tensor),
+ noutput, kernel_size, weights, biases);
+
+ layer->setStride(stride);
+ layer->setPadding({padding[0].first, padding[1].first});
+ layer->setName(node_def.name().c_str());
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+
+ if (data_format == "NHWC") {
+ // TODO(jie): transpose it back!
+ output_tensor = ctx.transposeTensor(output_tensor, {0, 2, 3, 1});
+ } else {
+ LOG(DEBUG) << "NCHW !!!!";
+ }
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertPool(Converter& ctx,
+ tensorflow::NodeDef const& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+ TFAttrs attrs(node_def);
+
+ int h_index = 2;
+ int w_index = 3;
+ auto data_format = attrs.get<std::string>("data_format");
+ if (data_format == "NHWC") {
+ h_index = 1;
+ w_index = 2;
+ tensor = ctx.transposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
+ {0, 3, 1, 2});
+ } else {
+ LOG(DEBUG) << "NCHW !!!!";
+ }
+ nvinfer1::PoolingType type;
+ // TODO(jie): support other pooling type
+ if (node_def.op() == "MaxPool")
+ type = nvinfer1::PoolingType::kMAX;
+ else
+ return tensorflow::errors::Unimplemented("only supports Max pool");
+
+ // TODO(jie): NCHW
+ auto tf_stride = attrs.get<std::vector<int>>("strides");
+ nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
+
+ auto tf_kernel = attrs.get<std::vector<int>>("ksize");
+ nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]);
+
+ auto tensor_dim = tensor->getDimensions();
+ std::vector<std::pair<int, int>> padding;
+ // TODO(jie): padding.
+ if (attrs.get<std::string>("padding") == "SAME") {
+ // This is NCHW tensor with no batch dimension.
+ // 1 -> h
+ // 2 -> w
+ padding = createSamePadding(
+ stride, ksize,
+ {static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])});
+ } else if (attrs.get<std::string>("padding") == "VALID") {
+ // No padding for valid padding here
+ LOG(DEBUG) << "no padding added for VALID padding in pool"
+ << node_def.name();
+ padding = {{0, 0}, {0, 0}};
+ } else {
+ return tensorflow::errors::Unimplemented(
+ "Current MaxPool cannot support padding other than SAME");
+ }
+
+ if (padding[0].first != padding[0].second ||
+ padding[1].first != padding[1].second) {
+ // TODO(jie): handle asymmetric padding
+ // return tensorflow::errors::Unimplemented(
+ // "Asymmetric padding not implemented yet");
+ auto padLayer = ctx.network()->addPadding(
+ *const_cast<nvinfer1::ITensor*>(tensor),
+ nvinfer1::DimsHW(padding[1].first, padding[0].first),
+ nvinfer1::DimsHW(padding[1].second, padding[0].second));
+ tensor = padLayer->getOutput(0);
+ }
+
+ nvinfer1::IPoolingLayer* layer = ctx.network()->addPooling(
+ *const_cast<nvinfer1::ITensor*>(tensor), type, ksize);
+
+ layer->setStride(stride);
+ layer->setPadding({padding[0].first, padding[1].first});
+ layer->setName(node_def.name().c_str());
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+
+ if (data_format == "NHWC") {
+ // TODO(jie): transpose it back!
+ output_tensor = ctx.transposeTensor(output_tensor, {0, 2, 3, 1});
+ } else {
+ LOG(DEBUG) << "NCHW !!!!";
+ }
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertActivation(
+ Converter& ctx, tensorflow::NodeDef const& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+ nvinfer1::IActivationLayer* layer = ctx.network()->addActivation(
+ *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ActivationType::kRELU);
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertScale(Converter& ctx,
+ tensorflow::NodeDef const& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
+ !inputs.at(1).is_weights())
+ return tensorflow::errors::Unimplemented(
+ "only supports tensor op weight for now, at " + node_def.name());
+ // implement tensor binaryOp weight [channel wise] for now;
+ nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+ // nvinfer1::ITensor* tensor = inputs.at(0).tensor();
+
+ // TODO(jie): handle NHWC/NCHW transpose;
+ TRT_ShapedWeights weights = inputs.at(1).weights();
+ // nvinfer1::Weights empty_weights{weights.type, nullptr, 0};
+ TRT_ShapedWeights empty_weights(weights.type_);
+
+ TFAttrs attrs(node_def);
+
+ // transpose NHWC
+ auto data_format = attrs.get<std::string>("data_format");
+ if (data_format == "NHWC") {
+ tensor = ctx.transposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
+ {0, 3, 1, 2});
+ // TODO(jie): transpose it
+ } else {
+ LOG(DEBUG) << "NCHW !!!!";
+ }
+ nvinfer1::IScaleLayer* layer = ctx.network()->addScale(
+ *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ScaleMode::kCHANNEL,
+ weights, empty_weights, empty_weights);
+
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+ if (data_format == "NHWC") {
+ // TODO(jie): transpose it back!
+ output_tensor = ctx.transposeTensor(output_tensor, {0, 2, 3, 1});
+ } else {
+ LOG(DEBUG) << "NCHW !!!!";
+ }
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertConst(Converter& ctx,
+ tensorflow::NodeDef const& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ auto const& weights_tensor = node_def.attr().at("value").tensor();
+
+ // get trt type & shape
+ TFAttrs attrs(node_def);
+ // nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("dtype");
+ tensorflow::DataType dtype = attrs.get<tensorflow::DataType>("dtype");
+
+ // create shaped weights as output
+ tensorflow::Tensor tensor;
+ if (!tensor.FromProto(weights_tensor))
+ return tensorflow::errors::Internal("cannot parse weight tensor proto: " +
+ node_def.name());
+
+ TRT_ShapedWeights weights(dtype);
+ if (!weights_tensor.float_val().empty()) {
+ LOG(DEBUG) << "SCALAR!!!" << node_def.name();
+ nvinfer1::Dims scalar_shape;
+ if (tensor.dims() > 0) {
+ LOG(DEBUG) << "dimensions: " << tensor.dims();
+ weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(),
+ get_tensor_shape(tensor));
+ } else {
+ LOG(DEBUG) << "dimensions: " << tensor.dims();
+ scalar_shape.nbDims = 1;
+ scalar_shape.d[0] = 1;
+ scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL;
+ for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; i++) {
+ scalar_shape.d[i] = 0;
+ scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL;
+ }
+ weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(),
+ scalar_shape);
+ }
+ // LOG(INFO) << " add: " << weights_tensor.float_val().data();
+ // LOG(INFO) << " value: " << (*weights_tensor.float_val().data());
+
+ // weights = ctx.get_temp_weights(dtype, scalar_shape);
+ // std::memcpy(const_cast<void*>(weights.values),
+ // weights_tensor.float_val().data(), weights.size_bytes());
+ } else if (!weights_tensor.tensor_content().empty()) {
+ LOG(DEBUG) << "TENSOR!!!" << node_def.name();
+ weights = TRT_ShapedWeights(dtype, weights_tensor.tensor_content().data(),
+ get_tensor_shape(tensor));
+ } else {
+ return tensorflow::errors::Unimplemented(
+ "not supported constant type, at " + node_def.name());
+ }
+ // pass the output
+ outputs->push_back(TRT_TensorOrWeights(weights));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertIdentity(
+ Converter& ctx, tensorflow::NodeDef const& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ outputs->push_back(inputs.at(0));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertBinary(Converter& ctx,
+ tensorflow::NodeDef const& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ if (inputs.size() != 2)
+ return tensorflow::errors::FailedPrecondition(
+ "Binary ops require two tensor input, at " + node_def.name());
+
+ if (inputs.at(0).is_weights() && inputs.at(1).is_weights())
+ return ConstantFoldBinary(ctx, node_def, inputs, outputs);
+
+ if (inputs.at(0).is_tensor() && inputs.at(1).is_weights())
+ return BinaryTensorOpWeight(ctx, node_def, inputs.at(0).tensor(),
+ inputs.at(1).weights(), outputs);
+
+ if (inputs.at(0).is_weights() && inputs.at(1).is_tensor())
+ return BinaryTensorOpWeight(ctx, node_def, inputs.at(1).tensor(),
+ inputs.at(0).weights(), outputs);
+
+ if (inputs.at(0).is_tensor() && inputs.at(1).is_tensor())
+ return BinaryTensorOpTensor(ctx, node_def, inputs.at(0).tensor(),
+ inputs.at(1).tensor(), outputs);
+
+ return tensorflow::errors::Unknown("Binary op input error, at " +
+ node_def.name());
+}
+
+tensorflow::Status ConvertUnary(Converter& ctx,
+ tensorflow::NodeDef const& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ if (inputs.size() != 1)
+ return tensorflow::errors::FailedPrecondition(
+ "Unary ops require single tensor input, at " + node_def.name());
+
+ if (inputs.at(0).is_weights())
+ return ConstantFoldUnary(ctx, node_def, inputs, outputs);
+ else if (inputs.at(0).is_tensor())
+ return tensorflow::errors::Unimplemented(
+ "Unary op for tensor not supported, at " + node_def.name());
+
+ return tensorflow::errors::Unknown("Binary op input error, at " +
+ node_def.name());
+}
+
+tensorflow::Status ConvertReduce(Converter& ctx,
+ tensorflow::NodeDef const& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
+ !inputs.at(1).is_weights())
+ return tensorflow::errors::InvalidArgument(
+ "Input expects tensor and weights, at" + node_def.name());
+
+ // implement tensor binaryOp weight [channel wise] for now;
+ nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+ auto dims = tensor->getDimensions();
+ // restore implicit batch dimension
+ int nbDims = dims.nbDims + 1;
+
+ TRT_ShapedWeights index_list = inputs.at(1).weights();
+
+ TFAttrs attrs(node_def);
+ // TODO(jie): handle data type
+ // auto data_type = attrs.get<nvinfer1::DataType>("T");
+ // index type here is done through TF type
+ // so I can leverage their EnumToDataType for my cast
+ auto index_type = attrs.get<tensorflow::DataType>("Tidx");
+ // auto keep_dims_flag = attrs.get<bool>("keep_dims");
+
+ // Only expect to handle INT32 as attributes for now
+ if (index_type != tensorflow::DataType::DT_INT32)
+ return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32");
+ // auto pad_data = const_cast<tensorflow::EnumToDataType<padding_type>::Type*>
+ // (pads.values);
+ auto index_list_data =
+ static_cast<int*>(const_cast<void*>(index_list.values_));
+ // auto index_list_data =
+ // const_cast<tensorflow::EnumToDataType<index_type>::Type*>
+ // (index_list.values);
+
+ // hack warning:
+ // have to fall back to pool layer since reduce is not in public TRT yet.
+ if (nbDims != 4)
+ return tensorflow::errors::InvalidArgument(
+ "TRT only support reduce on 4 dimensional tensors, at" +
+ node_def.name());
+ if (index_list.count() > 2)
+ return tensorflow::errors::InvalidArgument(
+ "TRT cannot support reduce on more than 2 dimensions, at" +
+ node_def.name());
+
+ std::set<int> idx_set;
+ // we cannot operate on Channel. permutation flag used to transpose tensor
+ int permuted_index = -1;
+ for (int i = 0; i < index_list.count(); i++) {
+ if (index_list_data[i] == 0)
+ return tensorflow::errors::InvalidArgument("TRT cannot reduce at 0, at" +
+ node_def.name());
+ if (index_list_data[i] == 1) permuted_index = 1;
+ idx_set.emplace(index_list_data[i]);
+ }
+
+ std::vector<int> permutation_order(nbDims);
+ nvinfer1::DimsHW pool_kernel;
+ if (permuted_index == 1) {
+ for (int i = 2; i < nbDims; i++) {
+ if (idx_set.count(i)) {
+ permuted_index = i;
+ break;
+ }
+ }
+ for (int i = 0; i < nbDims; i++) permutation_order[i] = i;
+
+ permutation_order[permuted_index] = 1;
+ permutation_order[1] = permuted_index;
+
+ // apply permutation before extracting dimension for pool_kernel
+ tensor = ctx.transposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
+ permutation_order);
+ }
+
+ // apply permutation before extracting dimension for pool_kernel
+ pool_kernel.d[0] = (idx_set.count(2) || permuted_index == 2) ? dims.d[1] : 1;
+ pool_kernel.d[1] = (idx_set.count(3) || permuted_index == 3) ? dims.d[2] : 1;
+
+ nvinfer1::ITensor* output_tensor;
+
+ if (node_def.op() == "Mean") {
+ nvinfer1::IPoolingLayer* layer =
+ ctx.network()->addPooling(*const_cast<nvinfer1::ITensor*>(tensor),
+ nvinfer1::PoolingType::kAVERAGE, pool_kernel);
+ output_tensor = layer->getOutput(0);
+ } else {
+ return tensorflow::errors::Unimplemented(
+ "Op not supported " + node_def.op() + " , at " + node_def.name());
+ }
+ if (permuted_index != -1) {
+ // apply permutation before extracting dimension for pool_kernel
+ output_tensor = ctx.transposeTensor(
+ const_cast<nvinfer1::ITensor*>(output_tensor), permutation_order);
+ }
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertPad(Converter& ctx,
+ tensorflow::NodeDef const& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
+ !inputs.at(1).is_weights())
+ return tensorflow::errors::InvalidArgument(
+ "Input expects tensor and weights, at" + node_def.name());
+
+ // implement tensor binaryOp weight [channel wise] for now;
+ nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+ auto dims = tensor->getDimensions();
+ // restore implicit batch dimension
+ int nbDims = dims.nbDims + 1;
+
+ TRT_ShapedWeights pads = inputs.at(1).weights();
+
+ TFAttrs attrs(node_def);
+ // padding type here is done through TF type
+ // so I can leverage their EnumToDataType for my cast
+ auto padding_type = attrs.get<tensorflow::DataType>("Tpaddings");
+ // TODO(jie): handle data type conversion for TRT?
+ // auto data_type = attrs.get<nvinfer1::DataType>("T");
+
+ if (pads.shape_.d[0] != nbDims || pads.shape_.d[1] != 2)
+ return tensorflow::errors::InvalidArgument(
+ "Pad only supports explicit padding on 4 dimensional tensor, at " +
+ node_def.name());
+
+ // Only expect to handle INT32 as attributes for now
+ if (padding_type != tensorflow::DataType::DT_INT32)
+ return tensorflow::errors::Unimplemented(
+ "Tpaddings supports only DT_INT32");
+ // auto pad_data = const_cast<tensorflow::EnumToDataType<padding_type>::Type*>
+ // (pads.values);
+ auto pad_data = static_cast<int*>(const_cast<void*>(pads.values_));
+
+ std::vector<int32_t> pad_index;
+ for (int i = 0; i < nbDims; i++) {
+ if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0)
+ pad_index.push_back(i);
+ }
+
+ // no padding at all, we should exit
+ if (pad_index.size() == 0) {
+ outputs->push_back(inputs.at(0));
+ return tensorflow::Status::OK();
+ }
+
+ // only supports padding on less than 2 axis GIE-2579
+ if (pad_index.size() > 2)
+ return tensorflow::errors::InvalidArgument(
+ "Padding layer does not support padding on > 2");
+
+ // padding on batch dimension is not supported
+ if (pad_index[0] == 0)
+ return tensorflow::errors::InvalidArgument(
+ "Padding layer does not support padding on batch dimension");
+
+ // not doing the legit thing here. ignoring padding on dim 1 and 3;
+ // TODO(jie): implement pad as uff parser
+ if (pad_index.size() == 2 && pad_index[0] == 0 && pad_index[1] == 3)
+ return tensorflow::errors::Unimplemented(
+ "Padding layer does not support padding on dimension 1 and 3 yet");
+
+ bool legit_pad = true;
+ nvinfer1::DimsHW pre_padding(0, 0);
+ nvinfer1::DimsHW post_padding(0, 0);
+
+ std::vector<int32_t> permuted_pad_index(pad_index);
+ if (pad_index[0] == 1) {
+ legit_pad = false;
+ tensor = ctx.transposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
+ {0, 3, 2, 1});
+ permuted_pad_index[0] = 3;
+ }
+
+ for (size_t i = 0; i < pad_index.size(); i++) {
+ int index = pad_index[i];
+ if (permuted_pad_index[i] == 2) {
+ pre_padding.h() = pad_data[index * 2];
+ post_padding.h() = pad_data[index * 2 + 1];
+ } else if (permuted_pad_index[i] == 3) {
+ pre_padding.w() = pad_data[index * 2];
+ post_padding.w() = pad_data[index * 2 + 1];
+ }
+ }
+
+ nvinfer1::IPaddingLayer* layer = ctx.network()->addPadding(
+ *const_cast<nvinfer1::ITensor*>(tensor), pre_padding, post_padding);
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+
+ if (!legit_pad)
+ output_tensor = ctx.transposeTensor(
+ const_cast<nvinfer1::ITensor*>(output_tensor), {0, 3, 2, 1});
+
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
+void Converter::register_op_converters() {
+ // vgg_16 slim implementation
+ _op_registry["Placeholder"] = ConvertPlaceholder;
+ _op_registry["Conv2D"] = ConvertConv2D;
+ _op_registry["Relu"] = ConvertActivation;
+ _op_registry["MaxPool"] = ConvertPool;
+ // This could be really handled as ConvertBinary
+ _op_registry["BiasAdd"] = ConvertScale;
+ _op_registry["Const"] = ConvertConst;
+ // _op_registry["MatMul"] = ConvertFullyConnected; // not used in vgg
+ // TODO(ben,jie): this is a temp hack.
+ _op_registry["Identity"] = ConvertIdentity; // Identity should be removed
+ // _op_registry["AvgPool"] = ConvertPool;
+
+ // resnet_50_v1 slim implementation
+ _op_registry["Add"] = ConvertBinary;
+ _op_registry["Mul"] = ConvertBinary;
+ _op_registry["Sub"] = ConvertBinary;
+ _op_registry["Rsqrt"] = ConvertUnary;
+ _op_registry["Mean"] = ConvertReduce;
+ _op_registry["Pad"] = ConvertPad;
+ // TODO(ben,jie): Add more ops
+}
+
+} // namespace
+
+tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
+ const tensorflow::Graph& graph, const std::set<int>& subgraph_node_ids,
+ const std::vector<std::pair<int, int>>& input_inds,
+ const std::vector<std::pair<int, int>>& output_inds, size_t max_batch_size,
+ size_t max_workspace_size, const ShapeMap& shape_map,
+ tensorflow::NodeDef* trt_node) {
+ // Visit nodes in reverse topological order and construct the TRT network.
+
+ // Toposort
+ std::vector<tensorflow::Node*> order_vec;
+ tensorflow::GetPostOrder(graph, &order_vec);
+ // Select just the subgraph
+ std::list<tensorflow::Node*> order;
+ for (tensorflow::Node* node : order_vec) {
+ if (subgraph_node_ids.count(node->id())) {
+ // order.push_back(node);
+ order.push_front(node); // we want topological order to contstruct the
+ // network layer by layer
+ }
+ }
+ // topological order is needed to build TRT network
+ LOG(DEBUG) << "BUILDING 1";
+
+ // nvinfer1::ILogger::Severity verbosity =
+ // nvinfer1::ILogger::Severity::kWARNING;
+ tensorflow::tensorrt::Logger trt_logger;
+ // TRT_Logger trt_logger(verbosity);
+
+ LOG(DEBUG) << "BUILDING 2";
+
+ auto trt_builder = infer_object(nvinfer1::createInferBuilder(trt_logger));
+ if (!trt_builder) {
+ return tensorflow::errors::Internal(
+ "failed to create TensorRT builder object");
+ }
+
+ LOG(DEBUG) << "BUILDING 3";
+
+ auto trt_network = infer_object(trt_builder->createNetwork());
+ if (!trt_network) {
+ return tensorflow::errors::Internal(
+ "failed to create TensorRT network object");
+ }
+
+ LOG(DEBUG) << "BUILDING 4";
+
+ // Build the network
+ Converter converter(trt_network.get());
+
+ LOG(DEBUG) << "BUILDING 5";
+ std::vector<std::string> input_names;
+ std::vector<tensorflow::DataType> input_dtypes;
+ for (std::pair<int, int> const& input : input_inds) {
+ LOG(DEBUG) << "parsing input!!!!!";
+ int node_id = input.first;
+ int output_idx = input.second;
+ tensorflow::Node* node = graph.FindNodeId(node_id);
+ auto node_name = node->name();
+ input_names.push_back(node_name); // insert original node name without port
+ // TODO(jie): alternative :)
+ // tensorflow::DataType tf_dtype = node->output_type(output_idx);
+ if (shape_map.count(node_name) == 0)
+ return tensorflow::errors::Internal("failed to find input node: " +
+ node_name);
+
+ auto input_entry_vec = shape_map.at(node_name);
+ if (static_cast<int>(input_entry_vec.size()) < output_idx)
+ return tensorflow::errors::Internal(
+ "accessing output index of: " + std::to_string(output_idx) +
+ ", at node: " + node_name + "with output entry from shape_map: " +
+ std::to_string(input_entry_vec.size()));
+
+ auto input_entry = input_entry_vec.at(output_idx);
+
+ tensorflow::DataType tf_dtype = input_entry.second;
+ input_dtypes.push_back(tf_dtype);
+
+ nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT);
+ TF_CHECK_OK(convert_dtype(tf_dtype, &dtype));
+
+ LOG(DEBUG) << "accessing output index of: " << std::to_string(output_idx)
+ << ", at node: " << node_name
+ << "with output entry from shape_map: "
+ << std::to_string(input_entry_vec.size());
+ // TODO(ben,jie): update TRT input format/dimension
+ nvinfer1::DimsCHW input_dim_psuedo_chw;
+ for (int i = 0; i < 3; i++) input_dim_psuedo_chw.d[i] = 1;
+
+ for (int i = 1; i < input_entry.first.dims(); i++) {
+ LOG(DEBUG) << "dimension: " << i
+ << " , size: " << input_entry.first.dim_size(i);
+ input_dim_psuedo_chw.d[i - 1] = input_entry.first.dim_size(i);
+ }
+
+ // TODO(ben,jie): proper way to restore input tensor name?
+ auto input_tensor_name = node_name;
+ if (output_idx != 0)
+ input_tensor_name = node_name + ":" + std::to_string(output_idx);
+
+ nvinfer1::ITensor* input_tensor = converter.network()->addInput(
+ input_tensor_name.c_str(), dtype, input_dim_psuedo_chw);
+
+ if (!input_tensor)
+ return tensorflow::errors::InvalidArgument(
+ "Failed to create Input layer");
+ LOG(DEBUG) << "input tensor name :" << input_tensor_name;
+
+ if (!converter.insert_input_tensor(input_tensor_name, input_tensor))
+ return tensorflow::errors::AlreadyExists(
+ "output tensor already exists for op: " + input_tensor_name);
+ }
+
+ LOG(DEBUG) << "finished sorting";
+
+ for (const tensorflow::Node* node : order) {
+ tensorflow::NodeDef const& node_def = node->def();
+ LOG(DEBUG) << "converting node: " << node_def.name() << " , "
+ << node_def.op();
+ TF_RETURN_IF_ERROR(converter.convert_node(node_def));
+ }
+
+ LOG(DEBUG) << "finished conversion";
+
+ // Gather output metadata
+ std::vector<std::string> output_names;
+ std::vector<tensorflow::DataType> output_dtypes;
+ for (std::pair<int, int> const& output : output_inds) {
+ int node_id = output.first;
+ int output_idx = output.second;
+ tensorflow::Node* node = graph.FindNodeId(node_id);
+ std::string op_name = node->name();
+ std::string tensor_name = op_name;
+ if (output_idx != 0)
+ tensor_name = tensor_name + ":" + std::to_string(output_idx);
+ LOG(DEBUG) << "output tensor name: " << tensor_name;
+ output_names.push_back(tensor_name);
+ auto tensor_or_weights = converter.get_tensor(tensor_name);
+ if (!tensor_or_weights.is_tensor()) {
+ return tensorflow::errors::InvalidArgument(
+ "Output node is weights not tensor");
+ }
+ nvinfer1::ITensor* tensor = tensor_or_weights.tensor();
+ if (!tensor) {
+ return tensorflow::errors::NotFound("Output tensor not found: " +
+ tensor_name);
+ }
+ converter.network()->markOutput(*tensor);
+ tensorflow::DataType tf_dtype = node->output_type(output_idx);
+ output_dtypes.push_back(tf_dtype);
+ nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT;
+ TF_RETURN_IF_ERROR(convert_dtype(tf_dtype, &trt_dtype));
+ tensor->setType(trt_dtype);
+ }
+
+ LOG(DEBUG) << "finished output";
+
+ // Build the engine
+ trt_builder->setMaxBatchSize(max_batch_size);
+ trt_builder->setMaxWorkspaceSize(max_workspace_size);
+ LOG(INFO) << "starting build engine";
+ // TODO(ben,jie): half2 and int8 mode support
+ std::string engine_plan_string;
+ {
+ auto trt_engine =
+ infer_object(trt_builder->buildCudaEngine(*converter.network()));
+ LOG(INFO) << "built network";
+ auto engine_plan = infer_object(trt_engine->serialize());
+ LOG(INFO) << "serialized engine";
+ const char* engine_plan_data =
+ static_cast<const char*>(engine_plan->data());
+ engine_plan_string = std::move(
+ std::string(engine_plan_data, engine_plan_data + engine_plan->size()));
+ }
+ // std::ofstream engine_out("mini.engine");
+ // engine_out << engine_plan_string;
+ // engine_out.close();
+
+ LOG(INFO) << "finished engine";
+
+ // Build the TRT op
+ // TODO(sami,ben,jie): proper naming!
+ static int static_id = 0;
+ tensorflow::NodeDefBuilder op_builder(
+ "my_trt_op" + std::to_string(static_id++), "TRTEngineOp");
+ std::vector<tensorflow::NodeDefBuilder::NodeOut> income_edges;
+ for (size_t i = 0; i < input_names.size(); ++i) {
+ int output_idx = input_inds.at(i).second;
+ // we wired up the input here already, it is redundant to do it again in
+ // ConvertSubGraphToTensorRT(convert_graph.cc)
+ auto incoming_edge = tensorflow::NodeDefBuilder::NodeOut(input_names.at(i),
+ output_idx, input_dtypes.at(i));
+ income_edges.push_back(incoming_edge);
+ }
+ tensorflow::gtl::ArraySlice<tensorflow::NodeDefBuilder::NodeOut>
+ input_list(income_edges);
+ op_builder.Input(input_list);
+
+ LOG(INFO) << "finished op preparation";
+
+ auto status = op_builder.Attr("serialized_engine", engine_plan_string)
+ .Attr("input_nodes", input_names)
+ .Attr("output_nodes", output_names)
+ .Attr("OutT", output_dtypes)
+ .Finalize(trt_node);
+
+ LOG(INFO) << status.ToString();
+ LOG(INFO) << "finished op building";
+
+ return tensorflow::Status::OK();
+}
+
+} // namespace convert
+} // namespace tensorrt
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
new file mode 100644
index 0000000000..a624582dec
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
@@ -0,0 +1,42 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
+
+#include <set>
+#include <vector>
+#include <utility>
+
+#include "tensorflow/contrib/tensorrt/convert/inferShapes.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorrt {
+namespace convert {
+
+tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
+ const tensorflow::Graph& graph, const std::set<int>& subgraph_node_ids,
+ const std::vector<std::pair<int, int>>&
+ input_inds, // {node_id, output_idx}
+ const std::vector<std::pair<int, int>>&
+ output_inds, // {node_id, output_idx}
+ size_t max_batch_size, size_t max_workspace_size, const ShapeMap& shape_map,
+ tensorflow::NodeDef* trt_node);
+} // namespace convert
+} // namespace tensorrt
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
diff --git a/tensorflow/contrib/tensorrt/convert/inferShapes.cc b/tensorflow/contrib/tensorrt/convert/inferShapes.cc
new file mode 100644
index 0000000000..c7f0f0023d
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/convert/inferShapes.cc
@@ -0,0 +1,125 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/tensorrt/convert/inferShapes.h"
+#include <functional>
+#include "tensorflow/core/common_runtime/shape_refiner.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.pb_text.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+
+#define _TF_LOG_DEBUG ::tensorflow::internal::LogMessage(__FILE__, __LINE__, -1)
+
+namespace tensorflow {
+namespace trt {
+std::vector<tensorflow::DataType> getTypes(const tensorflow::OpDef& op,
+ const tensorflow::NodeDef& nd,
+ bool inp = true) {
+ const auto& attrMap = nd.attr();
+ auto getType = [&attrMap](decltype(
+ op.input_arg(0)) a) -> std::vector<tensorflow::DataType> {
+ std::vector<tensorflow::DataType> tvec;
+ if (!a.type_list_attr().empty()) { // get the list types
+ const auto& tl = attrMap.at(a.type_list_attr()).list();
+ int tsize = tl.type_size();
+ tvec.reserve(tsize);
+ for (int t = 0; t < tsize; t++) {
+ tvec.push_back(tl.type(t));
+ }
+ return tvec;
+ }
+ tensorflow::DataType cType = tensorflow::DT_INVALID;
+ if (a.type() != tensorflow::DT_INVALID) { // get defined types
+ cType = a.type();
+ } else if (!a.type_attr().empty()) {
+ cType = attrMap.at(a.type_attr()).type();
+ }
+ if (!a.number_attr().empty()) { // numbertypes
+ int64 nTensors = attrMap.at(a.number_attr()).i();
+ tvec = std::vector<tensorflow::DataType>(nTensors, cType);
+ return tvec;
+ }
+ tvec.push_back(cType);
+ return tvec;
+ };
+ std::vector<tensorflow::DataType> types;
+ if (inp) {
+ int n_inputs = op.input_arg_size();
+ for (int i = 0; i < n_inputs; i++) {
+ auto tout = getType(op.input_arg(i));
+ LOG(DEBUG) << "Node= " << nd.name() << " #inputs" << tout.size();
+ types.insert(types.end(), tout.begin(), tout.end());
+ }
+ } else {
+ int n_outputs = op.output_arg_size();
+ // types.resize(n_outputs);
+ for (int i = 0; i < n_outputs; i++) {
+ auto tout = getType(op.output_arg(i));
+ LOG(DEBUG) << "Node= " << nd.name() << " #outputs" << tout.size();
+ types.insert(types.end(), tout.begin(), tout.end());
+ }
+ }
+ return types;
+}
+
+tensorflow::Status inferShapes(const tensorflow::GraphDef& graph_def,
+ const std::vector<std::string>& output_names,
+ ShapeMap& shapes) {
+ tensorflow::Graph g(OpRegistry::Global());
+ TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
+ tensorflow::GraphConstructorOptions(), graph_def, &g));
+ std::vector<tensorflow::Node*> POnodes;
+ tensorflow::GetPostOrder(g, &POnodes);
+ tensorflow::ShapeRefiner refiner(graph_def.versions().producer(),
+ OpRegistry::Global());
+ for (auto n = POnodes.rbegin(); n != POnodes.rend(); ++n) {
+ TF_CHECK_OK(refiner.AddNode(*n));
+ }
+
+ auto shape2PTS = [](tensorflow::shape_inference::InferenceContext* ic,
+ const tensorflow::shape_inference::ShapeHandle& sh)
+ -> tensorflow::PartialTensorShape {
+ std::vector<int64> dims;
+ int64 rank = ic->Rank(sh);
+ for (int64 i = 0; i < rank; i++) {
+ auto dh = ic->Dim(sh, i);
+ dims.push_back(ic->Value(dh));
+ }
+ return tensorflow::PartialTensorShape(dims);
+ };
+ for (const auto& n : POnodes) {
+ auto ic = refiner.GetContext(n);
+ if (ic) {
+ int nOuts = ic->num_outputs();
+ auto types = getTypes(n->op_def(), n->def(), false);
+ std::vector<
+ std::pair<tensorflow::PartialTensorShape, tensorflow::DataType>>
+ SAT;
+ for (int i = 0; i < nOuts; i++) {
+ auto PTS = shape2PTS(ic, ic->output(i));
+ SAT.push_back({PTS, types.at(i)});
+ }
+ shapes[n->name()] = SAT;
+ } else {
+ LOG(WARNING) << "Node " << n->name() << " doesn't have InferenceContext!";
+ }
+ }
+ return tensorflow::Status::OK();
+}
+} // namespace trt
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/convert/inferShapes.h b/tensorflow/contrib/tensorrt/convert/inferShapes.h
new file mode 100644
index 0000000000..b94f1ee893
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/convert/inferShapes.h
@@ -0,0 +1,39 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_INFERSHAPES_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_INFERSHAPES_H_
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+#include <utility>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/lib/core/status.h"
+
+typedef std::unordered_map<std::string,
+ std::vector<std::pair<tensorflow::PartialTensorShape,
+ tensorflow::DataType>>>
+ ShapeMap;
+namespace tensorflow {
+namespace trt {
+tensorflow::Status inferShapes(const tensorflow::GraphDef& graph_def,
+ const std::vector<std::string>& output_names,
+ ShapeMap& shapes);
+}
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_INFERSHAPES_H_
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
new file mode 100644
index 0000000000..a1524a592a
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -0,0 +1,183 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h"
+#include <cuda_runtime_api.h>
+#include <sstream>
+#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor.h"
+// Use TF logging f
+
+
+namespace tensorflow {
+static ::tensorflow::tensorrt::Logger gLogger;
+
+using namespace nvinfer1;
+
+namespace tensorrt {
+
+TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) {
+ // char *gieModelStream{nullptr};
+ // size_t size{0};
+
+ // read serialized_engine
+ std::string serialized_engine;
+ OP_REQUIRES_OK(context,
+ context->GetAttr("serialized_engine", &serialized_engine));
+
+ // register input output node name in trt_sub_graph
+ OP_REQUIRES_OK(context, context->GetAttr("input_nodes", &input_nodes_));
+ OP_REQUIRES_OK(context, context->GetAttr("output_nodes", &output_nodes_));
+
+ // TODO(samikama) runtime should be taken from a resourcemanager as well.
+ // Only engine should be in the op and context and runtime should be taken
+ // from resourcemanager
+ IRuntime* infer = createInferRuntime(gLogger);
+ trt_engine_ptr_.reset(infer->deserializeCudaEngine(
+ serialized_engine.c_str(), serialized_engine.size(), nullptr));
+
+ trt_context_ptr_.reset(trt_engine_ptr_->createExecutionContext());
+ // runtime is safe to delete after engine creation
+ infer->destroy();
+ std::stringstream oss;
+ // debug iterate through all binding instances
+ for (int i = 0; i < trt_engine_ptr_->getNbBindings(); i++) {
+ LOG(INFO) << "index: " << i
+ << ", binding name: " << trt_engine_ptr_->getBindingName(i);
+
+ if (trt_engine_ptr_->bindingIsInput(i)) {
+ LOG(INFO) << "INPUT";
+ } else {
+ LOG(INFO) << "OUTPUT";
+ }
+ oss << "Dimension: ";
+ auto dims = trt_engine_ptr_->getBindingDimensions(i);
+ oss << " nbDims: " << dims.nbDims << " -> ";
+ for (int j = 0; j < Dims::MAX_DIMS; j++) {
+ oss << dims.d[j] << ", ";
+ }
+ LOG(INFO) << oss.str();
+ oss.str("");
+ switch (trt_engine_ptr_->getBindingDataType(i)) {
+ case nvinfer1::DataType::kFLOAT:
+ LOG(INFO) << "data type float" << std::endl;
+ break;
+ case nvinfer1::DataType::kHALF:
+ LOG(INFO) << "data type half" << std::endl;
+ break;
+ case nvinfer1::DataType::kINT8:
+ LOG(INFO) << "data type int8" << std::endl;
+ break;
+ }
+ }
+
+ // CHECK_NE(cudaStreamCreate(&stream_),0); // logic here is wrong
+ // cudaStreamCreate(&stream_);
+}
+
+void TRTEngineOp::Compute(OpKernelContext* context) {
+ int nbBindings = context->num_inputs() + context->num_outputs();
+ // TODO(jjsjann123) multiple input/output
+ std::vector<void*> buffers(nbBindings);
+
+ size_t bindingIndex;
+ int nbBatch = 0;
+ bool valid = true;
+ for (int i = 0; i < context->num_inputs(); i++) {
+ // Grab the input tensor
+ bindingIndex = trt_engine_ptr_->getBindingIndex(input_nodes_[i].c_str());
+
+ const Tensor& input_tensor = context->input(i);
+ const TensorShape& input_shape = input_tensor.shape();
+ if (i == 0) {
+ nbBatch = input_shape.dim_size(0);
+ } else if (nbBatch != input_shape.dim_size(0)) {
+ valid = false;
+ break;
+ }
+ // int64 input_shape.dim_size(int d)
+ // int input_shape.dims()
+ switch (trt_engine_ptr_->getBindingDataType(bindingIndex)) {
+ case nvinfer1::DataType::kFLOAT:
+ LOG(INFO) << "float";
+ buffers[bindingIndex] = (void*)(input_tensor.flat<float>().data());
+ break;
+ case nvinfer1::DataType::kHALF:
+ LOG(INFO) << "half";
+ // buffers[bindingIndex] = (void*)input_tensor.flat<float16>().data();
+ break;
+ case nvinfer1::DataType::kINT8:
+ LOG(INFO) << "int8";
+ // buffers[bindingIndex] = (void*)input_tensor.flat<int8>().data();
+ break;
+ }
+ }
+
+ if (!valid) LOG(WARNING) << "input data inconsistent batch size";
+
+ for (int i = 0; i < static_cast<int>(output_nodes_.size()); i++) {
+ // This is bad that we have to reallocate output buffer every run.
+ // Create an output tensor
+ bindingIndex = trt_engine_ptr_->getBindingIndex(output_nodes_[i].c_str());
+ Tensor* output_tensor = NULL;
+
+ TensorShape output_shape;
+ if (bindingIndex != -1) {
+ LOG(INFO) << "got binding " << bindingIndex;
+ auto dims = trt_engine_ptr_->getBindingDimensions(bindingIndex);
+ std::vector<int> trt_shape(dims.nbDims + 1);
+ trt_shape[0] = nbBatch;
+ for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j];
+ TensorShapeUtils::MakeShape(trt_shape.data(), trt_shape.size(),
+ &output_shape);
+ } else {
+ LOG(INFO) << "no binding ";
+ break;
+ }
+
+ OP_REQUIRES_OK(context,
+ context->allocate_output(i, output_shape, &output_tensor));
+ // buffers[bindingIndex] = (void*)output_tensor->flat<float>();
+ // buffers[bindingIndex] = output_tensor->flat<float>().data();
+ switch (trt_engine_ptr_->getBindingDataType(bindingIndex)) {
+ case nvinfer1::DataType::kFLOAT:
+ LOG(INFO) << "float";
+ buffers[bindingIndex] =
+ reinterpret_cast<void*>(output_tensor->flat<float>().data());
+ break;
+ case nvinfer1::DataType::kHALF:
+ LOG(INFO) << "half";
+ // buffers[bindingIndex] = (void*)output_tensor->flat<float16>().data();
+ break;
+ case nvinfer1::DataType::kINT8:
+ LOG(INFO) << "int8";
+ // buffers[bindingIndex] = (void*)output_tensor->flat<int8>().data();
+ break;
+ }
+ }
+ // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
+ const cudaStream_t* stream = CHECK_NOTNULL(
+ reinterpret_cast<const cudaStream_t*>(context->op_device_context()
+ ->stream()
+ ->implementation()
+ ->CudaStreamMemberHack()));
+
+ trt_context_ptr_->enqueue(nbBatch, &buffers[0], *stream, nullptr);
+ cudaStreamSynchronize(*stream);
+}
+
+REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);
+} // namespace tensorrt
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
new file mode 100644
index 0000000000..631fc114f2
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
@@ -0,0 +1,55 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_
+
+#include <NvInfer.h>
+#include <cuda_runtime_api.h>
+#include <memory>
+#include <string>
+#include <vector>
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+namespace tensorrt {
+class Logger;
+class TRTEngineOp : public OpKernel {
+ public:
+ explicit TRTEngineOp(OpKernelConstruction* context);
+
+ void Compute(OpKernelContext* context) override;
+
+ private:
+ template <typename T>
+ struct Destroyer {
+ void operator()(T* d) { d->destroy(); }
+ };
+ template <typename T>
+ using destroyed_ptr = std::unique_ptr<T, Destroyer<T>>;
+ destroyed_ptr<nvinfer1::ICudaEngine> trt_engine_ptr_;
+ // TODO(samikama) context should go to a resource manager!
+ destroyed_ptr<nvinfer1::IExecutionContext> trt_context_ptr_;
+ std::vector<string> input_nodes_;
+ std::vector<string> output_nodes_;
+};
+
+} // namespace tensorrt
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_
diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.cc b/tensorflow/contrib/tensorrt/log/trt_logger.cc
new file mode 100644
index 0000000000..545a4aac50
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/log/trt_logger.cc
@@ -0,0 +1,56 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+// Use TF logging for TensorRT informations
+#include "tensorflow/core/platform/logging.h"
+
+#define _TF_LOG_DEBUG ::tensorflow::internal::LogMessage(__FILE__, __LINE__, -1)
+//------------------------------------------------------------------------------
+namespace tensorflow {
+
+//------------------------------------------------------------------------------
+namespace tensorrt {
+
+void Logger::log(Severity severity, const char* msg) {
+ // suppress info-level messages
+ switch (severity) {
+ case Severity::kINFO: { // mark TRT info messages as debug!
+ LOG(DEBUG) << msg;
+ break;
+ }
+ case Severity::kWARNING: {
+ LOG(WARNING) << msg;
+ break;
+ }
+ case Severity::kERROR: {
+ LOG(ERROR) << msg;
+ break;
+ }
+ case Severity::kINTERNAL_ERROR: {
+ LOG(FATAL) << msg;
+ break;
+ }
+ // This is useless for now. But would catch it in future if enum changes. It
+ // is always good to have default case!
+ default: {
+ LOG(FATAL) << name_ << "Got unknown severity level from TRT " << msg;
+ break;
+ }
+ }
+}
+
+} // namespace tensorrt
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.h b/tensorflow/contrib/tensorrt/log/trt_logger.h
new file mode 100644
index 0000000000..10a78b7a1d
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/log/trt_logger.h
@@ -0,0 +1,41 @@
+// -*- c++ -*-
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_
+
+// Use TF logging f
+#include <NvInfer.h>
+#include <string>
+
+//------------------------------------------------------------------------------
+namespace tensorflow {
+
+//------------------------------------------------------------------------------
+namespace tensorrt {
+
+// Logger for GIE info/warning/errors
+class Logger : public nvinfer1::ILogger {
+ void log(nvinfer1::ILogger::Severity severity, const char* msg) override;
+
+ private:
+ std::string name_;
+};
+
+} // namespace tensorrt
+
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_
diff --git a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
new file mode 100644
index 0000000000..38d3707190
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
@@ -0,0 +1,37 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+
+namespace tensorflow {
+
+namespace shape_inference {
+extern Status TRTEngineOpShapeInference(InferenceContext* c);
+}
+
+REGISTER_OP("TRTEngineOp")
+ .Attr("serialized_engine: string")
+ .Attr("input_nodes: list(string)")
+ .Attr("output_nodes: list(string)")
+ .Attr("InT: list({int8, float16, float32})")
+ .Attr("OutT: list({int8, float16, float32})")
+ .Input("in_tensor: InT")
+ .Output("out_tensor: OutT")
+ .SetShapeFn(shape_inference::TRTEngineOpShapeInference);
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/python/__init__.py b/tensorflow/contrib/tensorrt/python/__init__.py
new file mode 100644
index 0000000000..4aeea48515
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/python/__init__.py
@@ -0,0 +1,8 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,wildcard-import
+from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
+from tensorflow.contrib.tensorrt.python.trt_convert import CreateInferenceGraph
+# pylint: enable=unused-import,wildcard-import
diff --git a/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py b/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py
new file mode 100644
index 0000000000..ce78d328de
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py
@@ -0,0 +1,35 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import platform
+
+if platform.system() != "Windows":
+ # pylint: disable=wildcard-import,unused-import,g-import-not-at-top
+ from tensorflow.contrib.tensorrt.ops.gen_trt_engine_op import *
+
+ from tensorflow.contrib.util import loader
+ from tensorflow.python.platform import resource_loader
+ # pylint: enable=wildcard-import,unused-import,g-import-not-at-top
+
+ _trt_engine_op = loader.load_op_library(
+ resource_loader.get_path_to_datafile("_trt_engine_op.so"))
+else:
+ raise RuntimeError("Windows platforms are not supported")
+
+
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py
new file mode 100644
index 0000000000..a66afa8d05
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/python/trt_convert.py
@@ -0,0 +1,91 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Exposes the Python wrapper conversion to trt_graph."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,wildcard-import, line-too-long
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import errors_impl as _impl
+from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert
+from tensorflow.python.util import compat
+import tensorflow as tf
+from tensorflow.python.grappler import tf_optimizer
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.framework import meta_graph
+from tensorflow.python.framework import ops
+
+
+def CreateInferenceGraph(input_graph_def, outputs,max_batch_size=1,max_workspace_size=2<<20):
+ """Python wrapper for the TRT transormation.
+
+
+ Args:
+ input_graph_def: GraphDef object containing a model to be transformed.
+ outputs: List of node names for the model outputs.
+ max_batch_size: max size for the input batch
+ max_workspace_size: parameter to control memory allocation (in Bytes)
+
+ Returns:
+ New GraphDef with TRTEngineOps placed in graph replacing subgraphs.
+ """
+
+ # with errors.raise_exception_on_not_ok_status() as status:
+ # output_graph_def_string = trt_convert(
+ # input_graph_def_string,outputs,
+ # max_batch_size,max_workspace_size, status)
+ g = tf.Graph()
+ with g.as_default():
+ tf.import_graph_def(input_graph_def, name="")
+ rewriter_config = rewriter_config_pb2.RewriterConfig()
+ rewriter_config.optimizers.append('layout')
+ rewriter_config.optimizers.append('constfold')
+
+ # mark output nodes as fetch
+ train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+ for node_name in outputs:
+ out_node = g.get_operation_by_name(node_name)
+ for i in range(0,len(out_node.outputs)):
+ train_op.append(out_node.outputs[0])
+
+ # constant folding
+ mg = meta_graph.create_meta_graph_def(graph=g)
+ meta_graph.add_collection_def(mg, ops.GraphKeys.TRAIN_OP)
+ optimized_graph_def_str = \
+ tf_optimizer.OptimizeGraph(rewriter_config, mg).SerializeToString()
+
+ # TODO(sami): Fix this when we can return status from C++ library
+ # There is a problem with the TF internal library setup that doesn't allow us to return a status object from C++.
+ # Thus we return a pair or strings where first one is encoded status and the second one is the
+ # transformed graphs protobuf string.
+ out = trt_convert(
+ optimized_graph_def_str ,outputs,
+ max_batch_size,max_workspace_size)
+ status = out[0]
+ output_graph_def_string = out[1]
+ del optimized_graph_def_str #save some memory
+ if len(status) < 2:
+ raise _impl.UnknownError(None,None,status)
+ if status[:2] != "OK":
+ msg=status.split(";")
+ if len(msg) == 1:
+ raise RuntimeError("Status message is malformed {}".format(status))
+ raise _impl._make_specific_exception(None,None,";".join(msg[1:]), int(msg[0]))
+ output_graph_def = graph_pb2.GraphDef()
+ output_graph_def.ParseFromString(output_graph_def_string)
+ del output_graph_def_string #save some memory
+ return output_graph_def
diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc
new file mode 100644
index 0000000000..41da528247
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/segment/segment.cc
@@ -0,0 +1,259 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/segment/segment.h"
+
+#include <set>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/tensorrt/segment/union_find.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+//------------------------------------------------------------------------------
+namespace tensorrt {
+namespace segment {
+
+//------------------------------------------------------------------------------
+namespace {
+
+//------------------------------------------------------------------------------
+bool CanContractEdge(const tensorflow::Edge* edge,
+ const tensorflow::Graph& graph) {
+ const tensorflow::Node* src = edge->src();
+ const tensorflow::Node* dst = edge->dst();
+
+ // Can't contract edge if doing so would cause a cycle in the
+ // graph. So, if there is a directed path from 'src' to 'dst', other
+ // than 'edge' (or any other direct edge from 'src' to 'dst'), then
+ // combining 'src' and 'dst' will cause a cycle along that path.
+ //
+ // In practice, to avoid modifying the graph and to take advantage
+ // of existing graph functions, we perform an equivalent.
+ // 1. Get all nodes incoming to 'dst', excluding 'src'
+ // 2. Reverse DFS from those nodes
+ // 3. If reverse DFS reaches 'src' then we have a cycle
+ std::vector<tensorflow::Node*> dfs_start_nodes;
+ for (tensorflow::Node* node : dst->in_nodes()) {
+ if (node != src) {
+ dfs_start_nodes.push_back(node);
+ }
+ }
+
+ bool is_cycle = false;
+ if (!dfs_start_nodes.empty()) {
+ tensorflow::ReverseDFSFrom(graph, dfs_start_nodes, {},
+ [&is_cycle, src](tensorflow::Node* node) {
+ if (node == src) {
+ is_cycle = true;
+ }
+ });
+ }
+
+ return !is_cycle;
+}
+
+//------------------------------------------------------------------------------
+void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph,
+ std::vector<const tensorflow::Edge*>* remove_edges) {
+ // Transfer all inputs and outputs of 'dst' to 'src' except edges
+ // connecting the two.
+ tensorflow::Node* src = edge->src();
+ tensorflow::Node* dst = edge->dst();
+
+ // We can use '0' for input/output index because we don't need them
+ // to be accurate for the way we are using the graph.
+ std::vector<const tensorflow::Edge*> in_edges(dst->in_edges().begin(),
+ dst->in_edges().end());
+ for (const tensorflow::Edge* in_edge : in_edges) {
+ if (in_edge->src() != src) {
+ tensorflow::Edge* e = const_cast<tensorflow::Edge*>(in_edge);
+ if (e->src() == graph->source_node()) {
+ graph->AddEdge(e->src(), e->src_output(), src,
+ tensorflow::Graph::kControlSlot);
+ } else {
+ graph->AddEdge(e->src(), e->src_output(), src, 0 /* input index */);
+ }
+ }
+ }
+
+ std::vector<const tensorflow::Edge*> out_edges(dst->out_edges().begin(),
+ dst->out_edges().end());
+ for (const tensorflow::Edge* out_edge : out_edges) {
+ tensorflow::Edge* e = const_cast<tensorflow::Edge*>(out_edge);
+ if (e->dst() == graph->sink_node()) {
+ graph->AddEdge(src, tensorflow::Graph::kControlSlot, e->dst(),
+ e->dst_input());
+ } else {
+ graph->AddEdge(src, 0 /* output index */, e->dst(), e->dst_input());
+ }
+ }
+
+ // Return the edges that must be removed to disconnect 'dst' from
+ // the graph. We don't actually remove 'dst' since the caller holds
+ // references to all the nodes.
+ for (const auto& in_edge : dst->in_edges()) {
+ remove_edges->push_back(in_edge);
+ }
+ for (const auto& out_edge : dst->out_edges()) {
+ remove_edges->push_back(out_edge);
+ }
+}
+
+} // namespace
+
+//------------------------------------------------------------------------------
+tensorflow::Status SegmentGraph(
+ const tensorflow::GraphDef& gdef,
+ const std::function<bool(const tensorflow::NodeDef&)>& candidate_fn,
+ const SegmentOptions& options, SegmentNodesVector* segments) {
+ // Create a Graph representation of the GraphDef.
+ tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
+ gdef.library());
+ tensorflow::Graph graph(flib);
+ TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
+ tensorflow::GraphConstructorOptions(), gdef, &graph));
+
+ // tensorflow::DumpGraph("Pre-Segment", &graph);
+
+ // Use a union-find to collect the nodes that belong to the same
+ // segment. A node value of nullptr indicates that the node is not a
+ // candidate for TRT.
+ std::vector<UnionFind<tensorflow::Node*>> node_segments;
+ for (int i = 0; i < graph.num_node_ids(); ++i) {
+ tensorflow::Node* node = graph.FindNodeId(i);
+ if (!candidate_fn(node->def())) {
+ node = nullptr;
+ }
+ node_segments.emplace_back(node);
+ }
+
+ // Visit nodes in reverse topological order and use edge
+ // contraction to merge candidate nodes.
+ std::vector<tensorflow::Node*> order;
+ tensorflow::GetPostOrder(graph, &order);
+
+ for (const tensorflow::Node* node : order) {
+ // All output nodes of 'node' have been visited...
+ VLOG(2) << "Trying node " << node->name();
+
+ // 'node' must be a TRT candidate...
+ if (node_segments[node->id()].Value() == nullptr) {
+ VLOG(2) << "... not a TRT candidate";
+ continue;
+ }
+
+ // Contract output edges to combine 'node' with output
+ // nodes. Iterate since combining two nodes may unblock other
+ // combining.
+ while (true) {
+ std::set<const tensorflow::Edge*> contract_edges;
+ for (const tensorflow::Edge* out_edge : node->out_edges()) {
+ VLOG(2) << "... out node " << out_edge->dst()->name();
+
+ // Out node must be TRT candidate...
+ if (node_segments[out_edge->dst()->id()].Value() == nullptr) {
+ VLOG(2) << "... ... not a TRT candidate";
+ continue;
+ }
+
+ if (CanContractEdge(out_edge, graph)) {
+ VLOG(2) << "... ... can contract";
+ contract_edges.insert(out_edge);
+ } else {
+ VLOG(2) << "... ... cannot contract, would form cycle";
+ }
+ }
+
+ if (contract_edges.empty()) {
+ break;
+ }
+
+ // Contract edges and collect the adjacent nodes into the same
+ // segment/subgraph.
+ while (!contract_edges.empty()) {
+ const tensorflow::Edge* contract_edge = *contract_edges.begin();
+ const tensorflow::Node* src = contract_edge->src();
+ const tensorflow::Node* dst = contract_edge->dst();
+
+ VLOG(2) << "Merge " << src->name() << " <- " << dst->name();
+ node_segments[src->id()].Merge(&node_segments[dst->id()]);
+
+ // Contracting the edge leaves disconnected graph edges.
+ // Remove these from the graph and from 'contract_edges' so we
+ // don't visit them again.
+ tensorflow::Edge* e = const_cast<tensorflow::Edge*>(contract_edge);
+ std::vector<const tensorflow::Edge*> remove_edges;
+ ContractEdge(e, &graph, &remove_edges);
+
+ for (const tensorflow::Edge* r : remove_edges) {
+ contract_edges.erase(r);
+ graph.RemoveEdge(r);
+ }
+ }
+ }
+ }
+
+ // Collect the segments/subgraphs. Each subgraph is represented by a
+ // set of the names of the nodes in that subgraph.
+ std::unordered_map<std::string, std::set<std::string>> sg_map;
+ for (auto& u : node_segments) {
+ if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) {
+ sg_map[u.ParentValue()->name()].insert(u.Value()->name());
+ }
+ }
+
+ // Cleanup the graph to remove disconnected nodes before outputting
+ if (VLOG_IS_ON(2)) {
+ for (tensorflow::Node* node : graph.nodes()) {
+ if ((node->in_edges().size() == 0) && (node->out_edges().size() == 0)) {
+ graph.RemoveNode(node);
+ }
+ }
+ // tensorflow::DumpGraph("Post-Segment", &graph);
+ }
+
+ // Convert the segments into the expected return format
+ for (const auto& itr : sg_map) {
+ const auto& segment_node_names = itr.second;
+ if (VLOG_IS_ON(1)) {
+ std::string s;
+ for (const auto& name : segment_node_names) {
+ s += " " + name;
+ }
+ VLOG(1) << "Segment " << segments->size() << ":" << s;
+ }
+
+ // Don't use small segments.
+ if (static_cast<int>(segment_node_names.size()) <
+ options.minimum_segment_size) {
+ VLOG(1) << "Segment " << segments->size() << " has only "
+ << segment_node_names.size() << " nodes, dropping";
+ continue;
+ }
+
+ segments->emplace_back(segment_node_names);
+ }
+
+ return tensorflow::Status::OK();
+}
+
+} // namespace segment
+} // namespace tensorrt
diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h
new file mode 100644
index 0000000000..b5aee5bc34
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/segment/segment.h
@@ -0,0 +1,53 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_
+
+#include <set>
+#include <vector>
+#include <string>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorrt {
+namespace segment {
+
+using SegmentNodesVector = std::vector<std::set<std::string>>;
+
+struct SegmentOptions {
+ // Segment must contain at least this many nodes.
+ int minimum_segment_size = 2;
+};
+
+// Get the subgraphs of a graph that can be handled by TensorRT.
+//
+// @param gdef The GraphDef describing the network
+// @param candidate_fn A function that returns true for a NodeDef if
+// that node can be handled by TensorRT.
+// @param segments Returns the TensorRT segments/subgraphs. Each entry
+// in the vector describes a subgraph by giving a set of the names of
+// all the NodeDefs in that subgraph.
+// @return the status.
+tensorflow::Status SegmentGraph(
+ const tensorflow::GraphDef& gdef,
+ const std::function<bool(const tensorflow::NodeDef&)>& candidate_fn,
+ const SegmentOptions& options, SegmentNodesVector* segments);
+
+} // namespace segment
+} // namespace tensorrt
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_
diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc
new file mode 100644
index 0000000000..dcd0c71ed7
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc
@@ -0,0 +1,363 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/segment/segment.h"
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/test.h"
+
+//------------------------------------------------------------------------------
+using namespace tensorflow;
+
+namespace tensorrt {
+namespace segment {
+namespace test {
+
+class SegmentTest : public ::testing::Test {
+ public:
+ bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def);
+
+ TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name);
+ TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+ TF_Status* s, const char* name);
+
+ std::function<bool(const NodeDef&)> MakeCandidateFn(
+ const std::set<std::string>& node_names);
+
+ protected:
+ void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name,
+ TF_Operation** op);
+ void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+ TF_Status* s, const char* name, TF_Operation** op, bool check);
+
+ SegmentOptions default_options_;
+};
+
+bool SegmentTest::GetGraphDef(TF_Graph* graph,
+ tensorflow::GraphDef* graph_def) {
+ TF_Status* s = TF_NewStatus();
+ TF_Buffer* buffer = TF_NewBuffer();
+ TF_GraphToGraphDef(graph, buffer, s);
+ bool ret = TF_GetCode(s) == TF_OK;
+ EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ if (ret) ret = graph_def->ParseFromArray(buffer->data, buffer->length);
+ TF_DeleteBuffer(buffer);
+ TF_DeleteStatus(s);
+ return ret;
+}
+
+std::function<bool(const NodeDef&)> SegmentTest::MakeCandidateFn(
+ const std::set<std::string>& node_names) {
+ return [node_names](const NodeDef& node) -> bool {
+ return node_names.find(node.name()) != node_names.end();
+ };
+}
+
+void SegmentTest::PlaceholderHelper(TF_Graph* graph, TF_Status* s,
+ const char* name, TF_Operation** op) {
+ TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
+ TF_SetAttrType(desc, "dtype", TF_INT32);
+ *op = TF_FinishOperation(desc, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ ASSERT_NE(*op, nullptr);
+}
+
+TF_Operation* SegmentTest::Placeholder(TF_Graph* graph, TF_Status* s,
+ const char* name) {
+ TF_Operation* op;
+ PlaceholderHelper(graph, s, name, &op);
+ return op;
+}
+
+void SegmentTest::AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+ TF_Status* s, const char* name, TF_Operation** op,
+ bool check) {
+ TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
+ TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
+ TF_AddInputList(desc, add_inputs, 2);
+ *op = TF_FinishOperation(desc, s);
+ if (check) {
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ ASSERT_NE(*op, nullptr);
+ }
+}
+
+TF_Operation* SegmentTest::Add(TF_Operation* l, TF_Operation* r,
+ TF_Graph* graph, TF_Status* s,
+ const char* name) {
+ TF_Operation* op;
+ AddHelper(l, r, graph, s, name, &op, true);
+ return op;
+}
+
+//------------------------------------------------------------------------------
+TEST_F(SegmentTest, Empty) {
+ TF_Graph* graph = TF_NewGraph();
+
+ GraphDef graph_def;
+ ASSERT_TRUE(GetGraphDef(graph, &graph_def));
+
+ SegmentNodesVector segments;
+ ASSERT_EQ(
+ SegmentGraph(graph_def, MakeCandidateFn({}), default_options_, &segments),
+ tensorflow::Status::OK());
+
+ // Expect no segments/subgraphs.
+ EXPECT_TRUE(segments.empty());
+}
+
+//------------------------------------------------------------------------------
+TEST_F(SegmentTest, Simple) {
+ TF_Status* s = TF_NewStatus();
+ TF_Graph* graph = TF_NewGraph();
+
+ // feed
+ // // ||
+ // add0 add1
+ // | | /
+ // | add2
+ // | / ||
+ // add3 add4
+ // | /
+ // <sink>
+ //
+ TF_Operation* feed = Placeholder(graph, s, "feed");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
+
+ TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
+ TF_Operation* add4 = Add(add2, add2, graph, s, "add4");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
+
+ GraphDef graph_def;
+ ASSERT_TRUE(GetGraphDef(graph, &graph_def));
+
+ SegmentNodesVector segments;
+ ASSERT_EQ(
+ SegmentGraph(graph_def,
+ MakeCandidateFn({"add0", "add1", "add2", "add3", "add4"}),
+ default_options_, &segments),
+ tensorflow::Status::OK());
+
+ // Expect all Add operations to be collapsed into a single segment
+ ASSERT_EQ(segments.size(), 1);
+ std::vector<std::string> expected{"add0", "add1", "add2", "add3", "add4"};
+ for (const auto& ex : expected) {
+ EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
+ << "Missing expected node " << ex;
+ }
+}
+
+//------------------------------------------------------------------------------
+TEST_F(SegmentTest, AvoidCycle) {
+ TF_Status* s = TF_NewStatus();
+ TF_Graph* graph = TF_NewGraph();
+
+ // add2 is not a TRT candidate so add0/add3 cannot be formed as a
+ // subgraph
+ //
+ // feed
+ // // ||
+ // add0 add1
+ // | | /
+ // | add2
+ // | / ||
+ // add3 add4
+ // | /
+ // <sink>
+ //
+ TF_Operation* feed = Placeholder(graph, s, "feed");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
+
+ TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
+ TF_Operation* add4 = Add(add2, add2, graph, s, "add4");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
+
+ GraphDef graph_def;
+ ASSERT_TRUE(GetGraphDef(graph, &graph_def));
+
+ SegmentNodesVector segments;
+ ASSERT_EQ(
+ SegmentGraph(graph_def, MakeCandidateFn({"add0", "add1", "add3", "add4"}),
+ default_options_, &segments),
+ tensorflow::Status::OK());
+
+ // Expect no subgraphs
+ EXPECT_EQ(segments.size(), 0);
+}
+
+//------------------------------------------------------------------------------
+TEST_F(SegmentTest, Multiple) {
+ TF_Status* s = TF_NewStatus();
+ TF_Graph* graph = TF_NewGraph();
+
+ // add5 is not a TRT candidate so two subgraphs should be formed
+ //
+ // feed
+ // // || ||
+ // add0 add1 add7
+ // | | / / ||
+ // | add2-----add5 add8
+ // | / | | | |
+ // add3 add4 add6
+ // | | /
+ // <sink>
+ //
+ TF_Operation* feed = Placeholder(graph, s, "feed");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
+
+ TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add7 = Add(feed, feed, graph, s, "add7");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add5 = Add(add2, add7, graph, s, "add5");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add8 = Add(add7, add7, graph, s, "add8");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
+ TF_Operation* add4 = Add(add2, add5, graph, s, "add4");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
+ TF_Operation* add6 = Add(add5, add8, graph, s, "add6");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("add6"), string(TF_OperationName(add6)));
+
+ GraphDef graph_def;
+ ASSERT_TRUE(GetGraphDef(graph, &graph_def));
+
+ SegmentNodesVector segments;
+ ASSERT_EQ(SegmentGraph(graph_def,
+ MakeCandidateFn({"add0", "add1", "add2", "add3",
+ "add4", "add6", "add7", "add8"}),
+ default_options_, &segments),
+ tensorflow::Status::OK());
+
+ // Expect two subgraphs
+ EXPECT_EQ(segments.size(), 2);
+
+ std::vector<std::string> expected0{"add0", "add1", "add2", "add3"};
+ for (const auto& ex : expected0) {
+ EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
+ << "Missing expected node " << ex;
+ }
+
+ std::vector<std::string> expected1{"add6", "add8"};
+ for (const auto& ex : expected1) {
+ EXPECT_TRUE(segments[1].find(ex) != segments[1].end())
+ << "Missing expected node " << ex;
+ }
+}
+
+//------------------------------------------------------------------------------
+TEST_F(SegmentTest, BigIfElse) {
+ TF_Status* s = TF_NewStatus();
+ TF_Graph* graph = TF_NewGraph();
+
+ // add2 is not a TRT candidate
+ //
+ // feed
+ // ||
+ // add0
+ // // ||
+ // add1 add4
+ // || ||
+ // add2 add5
+ // || ||
+ // add3 add6
+ // || //
+ // add7
+ // ||
+ // <sink>
+ //
+ TF_Operation* feed = Placeholder(graph, s, "feed");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
+
+ TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add1 = Add(add0, add0, graph, s, "add1");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add2 = Add(add1, add1, graph, s, "add2");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add3 = Add(add2, add2, graph, s, "add3");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add4 = Add(add0, add0, graph, s, "add4");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add5 = Add(add4, add4, graph, s, "add5");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add6 = Add(add5, add5, graph, s, "add6");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add7 = Add(add3, add6, graph, s, "add7");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("add7"), string(TF_OperationName(add7)));
+
+ GraphDef graph_def;
+ ASSERT_TRUE(GetGraphDef(graph, &graph_def));
+
+ SegmentNodesVector segments;
+ ASSERT_EQ(SegmentGraph(graph_def,
+ MakeCandidateFn({"add0", "add1", "add3", "add4",
+ "add5", "add6", "add7"}),
+ default_options_, &segments),
+ tensorflow::Status::OK());
+
+ // Expect 2 subgraphs
+ EXPECT_EQ(segments.size(), 2);
+
+ std::vector<std::string> expected0{"add3", "add4", "add5", "add6", "add7"};
+ for (const auto& ex : expected0) {
+ EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
+ << "Missing expected node " << ex;
+ }
+
+ std::vector<std::string> expected1{"add0", "add1"};
+ for (const auto& ex : expected1) {
+ EXPECT_TRUE(segments[1].find(ex) != segments[1].end())
+ << "Missing expected node " << ex;
+ }
+}
+
+} // namespace test
+} // namespace segment
+} // namespace tensorrt
diff --git a/tensorflow/contrib/tensorrt/segment/union_find.h b/tensorflow/contrib/tensorrt/segment/union_find.h
new file mode 100644
index 0000000000..8ae877cd05
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/segment/union_find.h
@@ -0,0 +1,77 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_
+
+namespace tensorrt {
+namespace segment {
+
+// Union-Find data structure.
+// Each cluster has an associated value; when merging clusters we can control
+// which value becomes the representative of the merged clusters. Values must be
+// copyable.
+template <typename T>
+class UnionFind {
+ public:
+ UnionFind() : size_(1), parent_(nullptr) {}
+ explicit UnionFind(const T& v) : size_(1), parent_(nullptr), value_(v) {}
+
+ // Returns the number of elements in a cluster.
+ int Size() { return FindRoot()->size_; }
+
+ // Merges this cluster with 'other'. This cluster's value becomes
+ // the value of the merged cluster; the value of 'other' is ignored.
+ void Merge(UnionFind* other);
+
+ // Each cluster has an associated value. Retrieves the value associated
+ // with this cluster.
+ T& ParentValue() { return FindRoot()->value_; }
+
+ // Get the original value of this node.
+ T& Value() { return value_; }
+
+ private:
+ // Finds the root element of the cluster. Performs path compression.
+ UnionFind* FindRoot();
+
+ int size_;
+ UnionFind* parent_;
+ T value_;
+};
+
+template <typename T>
+void UnionFind<T>::Merge(UnionFind* other) {
+ UnionFind<T>* a = FindRoot();
+ UnionFind<T>* b = other->FindRoot();
+ if (a == b) return;
+
+ b->parent_ = a;
+ a->size_ += b->size_;
+}
+
+template <typename T>
+UnionFind<T>* UnionFind<T>::FindRoot() {
+ if (!parent_) return this;
+ // Path compression: update intermediate nodes to point to the root of the
+ // equivalence class.
+ parent_ = parent_->FindRoot();
+ return parent_;
+}
+
+} // namespace segment
+} // namespace tensorrt
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_
diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
new file mode 100644
index 0000000000..72022b99e2
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
@@ -0,0 +1,123 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h"
+#include <string>
+#include <vector>
+#include "NvInfer.h"
+#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+
+namespace tensorflow {
+namespace shape_inference {
+tensorflow::Status TRTEngineOpShapeInference(InferenceContext* c) {
+ tensorflow::tensorrt::Logger gLogger;
+ string serialized_engine;
+ c->GetAttr("serialized_engine", &serialized_engine);
+ nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(gLogger);
+ nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine(
+ serialized_engine.c_str(), serialized_engine.size(), nullptr);
+
+ // debug print out engine binding;
+ std::stringstream oss;
+ for (int i = 0; i < trt_engine->getNbBindings(); i++) {
+ LOG(INFO) << "index: " << i
+ << ", binding name: " << trt_engine->getBindingName(i);
+
+ bool input_flag = trt_engine->bindingIsInput(i);
+ oss << "input?: " << (input_flag ? "Y" : "N");
+
+ oss << "Dimension: ";
+ auto dims = trt_engine->getBindingDimensions(i);
+ oss << " nbDims: " << dims.nbDims << " -> ";
+ for (int j = 0; j < dims.nbDims; j++) oss << dims.d[j] << ", ";
+ LOG(INFO) << oss.str();
+ oss.str("");
+ switch (trt_engine->getBindingDataType(i)) {
+ case nvinfer1::DataType::kFLOAT:
+ LOG(INFO) << "data type: float" << std::endl;
+ break;
+ case nvinfer1::DataType::kHALF:
+ LOG(INFO) << "data type: half" << std::endl;
+ break;
+ case nvinfer1::DataType::kINT8:
+ LOG(INFO) << "data type: int8" << std::endl;
+ break;
+ }
+ }
+
+ int nbBatch = -1;
+ // debug print out input arrays
+ std::vector<::tensorflow::DataType> input_type;
+ c->GetAttr("InT", &input_type);
+ oss.str("");
+ for (size_t i = 0; i < c->num_inputs(); i++) {
+ // check if input shape is legit
+ auto input_shape = c->input(i);
+ int index = i;
+ oss << "input:" << i << " type: " << input_type[index] << " shape: ";
+ for (int j = 0; j < c->Rank(input_shape); j++) {
+ auto dimHandler = c->Dim(input_shape, j);
+ if (c->ValueKnown(dimHandler))
+ oss << c->Value(dimHandler) << ", ";
+ else
+ oss << "?" << c->Value(dimHandler) << ", ";
+ if (j == 0) {
+ if (i == 0)
+ nbBatch = c->Value(dimHandler);
+ else if (nbBatch != c->Value(dimHandler))
+ LOG(WARNING) << "!!!!!!nbBatch does not match!!!!!!";
+ // assert(nbBatch == c->Value(dimHandler);
+ }
+ }
+ LOG(INFO) << oss.str();
+ }
+
+ // arrange input here
+ std::vector<string> input_nodes;
+ c->GetAttr("input_nodes", &input_nodes);
+ for (size_t i = 0; i < input_nodes.size(); i++) {
+ int index = i;
+ LOG(INFO) << "input:" << i << " name: " << input_nodes[index];
+ }
+
+ // arrange output here
+ std::vector<string> output_nodes;
+ c->GetAttr("output_nodes", &output_nodes);
+ oss.str("");
+ for (size_t i = 0; i < output_nodes.size(); i++) {
+ int index = i;
+ int binding_index =
+ trt_engine->getBindingIndex(output_nodes[index].c_str());
+ oss << "string name " << output_nodes[index];
+ ShapeHandle output_shape;
+ std::vector<DimensionHandle> vecDim;
+ vecDim.emplace_back(c->MakeDim(nbBatch));
+ if (binding_index != -1) {
+ oss << "got binding " << binding_index;
+ auto dims = trt_engine->getBindingDimensions(binding_index);
+ for (int j = 0; j < dims.nbDims; j++)
+ vecDim.emplace_back(c->MakeDim(dims.d[j]));
+ } else {
+ oss << "no binding ";
+ }
+ output_shape = c->MakeShape(vecDim);
+ c->set_output(i, output_shape);
+ LOG(INFO) << oss.str();
+ }
+
+ return Status::OK();
+}
+} // namespace shape_inference
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h
new file mode 100644
index 0000000000..90a226d91d
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h
@@ -0,0 +1,28 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_
+
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace shape_inference {
+Status TRTEngineOpShapeInference(InferenceContext* c);
+} // namespace shape_inference
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_
diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i
new file mode 100644
index 0000000000..5f8e73a59f
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/trt_conversion.i
@@ -0,0 +1,84 @@
+/*
+
+ wrap trt_conversion
+
+ */
+%{
+#define SWIG_FILE_WITH_INIT
+%}
+%include "std_string.i"
+%include "std_pair.i"
+%include "tensorflow/python/lib/core/strings.i"
+%include "tensorflow/python/platform/base.i"
+%template(StringPair) std::pair<string,string>;
+%template() std::pair<swig::SwigPtr_PyObject, swig::SwigPtr_PyObject>;
+
+%{
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/stat_summarizer.h"
+#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
+%}
+
+%ignoreall
+%unignore tensorflow;
+%unignore trt_convert;
+
+%{
+ std::pair<string,string> trt_convert(string graph_def_string,//const tensorflow::GraphDef&
+ std::vector<string> output_names,
+ size_t max_batch_size,
+ size_t max_workspace_size
+ // unfortunately we can't use TF_Status here since it
+ // is in c/c_api and brings in a lot of other libraries
+ // which in turn declare ops. These ops are included
+ // statically in our library and cause an abort when
+ // module is loaded due to double registration
+ // until Tensorflow properly exposes these headers
+ // we have to work around this by returning a string
+ // and converting it to exception on python side.
+ //,TF_Status* out_status) {
+ ) {
+ string out_status;
+
+ tensorflow::GraphDef graph_def;
+ if (!graph_def.ParseFromString(graph_def_string)) {
+ out_status="InvalidArgument;Couldn't interpret input as a GraphDef";
+ return std::pair<string,string>{out_status,""};
+ }
+
+ if (!output_names.size()) {
+ out_status="InvalidArgument;Size of the output_names vector is 0";
+ return std::pair<string,string>{out_status,""};
+ //return "";
+ }
+ tensorflow::GraphDef outGraph;
+ tensorflow::Status conversion_status =
+ tensorrt::convert::ConvertGraphDefToTensorRT(graph_def,
+ output_names,
+ max_batch_size,
+ max_workspace_size,
+ &outGraph);
+ if (!conversion_status.ok()) {
+ auto retCode=(int)conversion_status.code();
+ char buff[2000];
+ snprintf(buff,2000,"%d;%s",retCode,conversion_status.error_message().c_str());
+ out_status=buff;
+ return std::pair<string,string>{out_status,""};
+ }
+ string result;
+ if (!outGraph.SerializeToString(&result)) {
+ out_status="InvalidArgument;Couldn't serialize output as a GraphDef";
+ return std::pair<string,string>{out_status,""};
+ }
+ out_status="OK;All good!";
+ return std::pair<string,string>{out_status,result};
+ }
+%}
+
+std::pair<string,string> trt_convert(string graph_def_string,
+ std::vector<string> output_names,
+ size_t max_batch_size,
+ size_t max_workspace_size);
+
+%unignoreall
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 383c97344a..838b1218a4 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -279,7 +279,7 @@ def tf_cc_shared_object(
linkopts=[],
framework_so=tf_binary_additional_srcs(),
**kwargs):
- native.cc_binary(
+ native.cc_binary(
name=name,
srcs=srcs + framework_so,
deps=deps,
@@ -1281,6 +1281,45 @@ def tf_extension_linkopts():
def tf_extension_copts():
return [] # No extension c opts
+# In tf_py_wrap_cc generated libraries
+# module init functions are not exported unless
+# they contain one of the keywords in the version file
+# this prevents custom python modules.
+# This function attempts to append init_module_name to list of
+# exported functions in version script
+def _append_init_to_versionscript_impl(ctx):
+ modName=ctx.attr.module_name
+ isVS=ctx.attr.is_version_script
+ if isVS:
+ ctx.actions.expand_template(
+ template=ctx.file.template_file,
+ output=ctx.outputs.versionscript,
+ substitutions={
+ "global:":"global:\n init_%s;"%modName,
+ },
+ is_executable=False,
+ )
+ else:
+ ctx.actions.expand_template(
+ template=ctx.file.template_file,
+ output=ctx.outputs.versionscript,
+ substitutions={
+ "*tensorflow*":"*tensorflow*\ninit_%s"%modName,
+ },
+ is_executable=False,
+ )
+
+
+_append_init_to_versionscript= rule(
+ implementation=_append_init_to_versionscript_impl,
+ attrs={
+ "module_name":attr.string(mandatory=True),
+ "template_file":attr.label(allow_files=True,single_file=True,mandatory=True),
+ "is_version_script":attr.bool(default=True,doc='whether target is a ld version script or exported symbol list',mandatory=False),
+ },
+ outputs={"versionscript":"%{name}.lds"},
+)
+
def tf_py_wrap_cc(name,
srcs,
swig_includes=[],
@@ -1302,26 +1341,39 @@ def tf_py_wrap_cc(name,
toolchain_deps=["//tools/defaults:crosstool"],
module_name=module_name,
py_module_name=name)
+ vscriptname=name+"_versionscript"
+ _append_init_to_versionscript(
+ name=vscriptname,
+ module_name=module_name,
+ is_version_script=select({
+ "@local_config_cuda//cuda:darwin":False,
+ "//conditions:default":True,
+ }),
+ template_file=select({
+ "@local_config_cuda//cuda:darwin":clean_dep("//tensorflow:tf_exported_symbols.lds"),
+ "//conditions:default":clean_dep("//tensorflow:tf_version_script.lds")
+ })
+ )
extra_linkopts = select({
"@local_config_cuda//cuda:darwin": [
"-Wl,-exported_symbols_list",
- clean_dep("//tensorflow:tf_exported_symbols.lds")
+ "%s.lds"%vscriptname,
],
clean_dep("//tensorflow:windows"): [],
clean_dep("//tensorflow:windows_msvc"): [],
"//conditions:default": [
"-Wl,--version-script",
- clean_dep("//tensorflow:tf_version_script.lds")
+ "%s.lds"%vscriptname,
]
})
extra_deps += select({
"@local_config_cuda//cuda:darwin": [
- clean_dep("//tensorflow:tf_exported_symbols.lds")
+ "%s.lds"%vscriptname,
],
clean_dep("//tensorflow:windows"): [],
clean_dep("//tensorflow:windows_msvc"): [],
"//conditions:default": [
- clean_dep("//tensorflow:tf_version_script.lds")
+ "%s.lds"%vscriptname,
]
})
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index ff5dd6a0b0..f47df0e25d 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -11,6 +11,7 @@ load(
)
load("//third_party/mkl:build_defs.bzl", "if_mkl")
load("//tensorflow:tensorflow.bzl", "if_cuda")
+load("@local_config_tensorrt//:build_defs.bzl", "if_trt")
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps")
# This returns a list of headers of all public header libraries (e.g.,
@@ -201,7 +202,8 @@ sh_binary(
"//tensorflow/python:test_ops",
"//tensorflow/tools/dist_test/server:grpc_tensorflow_server",
],
- }) + if_mkl(["//third_party/mkl:intel_binary_blob"]),
+ }) + if_mkl(["//third_party/mkl:intel_binary_blob"])
+ + if_trt(["//tensorflow/contrib/tensorrt:init_py"]),
)
# A genrule for generating a marker file for the pip package on Windows
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 0ba3cca991..8850610cdb 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -1,6 +1,7 @@
# TensorFlow external dependencies that can be loaded in WORKSPACE files.
load("//third_party/gpus:cuda_configure.bzl", "cuda_configure")
+load("//third_party/tensorrt:build_defs.bzl", "trt_repository")
load("//third_party/mkl:build_defs.bzl", "mkl_repository")
load("//third_party/git:git_configure.bzl", "git_configure")
load("//third_party/py:python_configure.bzl", "python_configure")
@@ -66,6 +67,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
# version we require here.
check_bazel_version_at_least("0.5.4")
cuda_configure(name="local_config_cuda")
+ trt_repository(name="local_config_tensorrt")
git_configure(name="local_config_git")
sycl_configure(name="local_config_sycl")
python_configure(name="local_config_python")
diff --git a/third_party/tensorrt/BUILD b/third_party/tensorrt/BUILD
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/third_party/tensorrt/BUILD
diff --git a/third_party/tensorrt/BUILD.tpl b/third_party/tensorrt/BUILD.tpl
new file mode 100644
index 0000000000..8962751f56
--- /dev/null
+++ b/third_party/tensorrt/BUILD.tpl
@@ -0,0 +1,42 @@
+# -*- python -*-
+# Description:
+# provide tensorrt information
+
+#TODO(Sami) these needs to be defined
+
+licenses(["notice"])
+
+exports_files(["LICENSE"])
+
+load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts", "if_cuda")
+
+config_setting(
+ name = "trt_enabled",
+ define_values = {
+ "using_tensorrt":"true"
+ },
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "tensorrt",
+ srcs =[%{tensorrt_lib}],
+ hdrs = ["include/NvInfer.h",
+ "include/NvUtils.h",
+ ],
+ copts= cuda_default_copts(),
+ deps =["@local_config_cuda//cuda:cuda",
+ "@local_config_cuda//cuda:cudnn",],
+ linkstatic = 1,
+ #include_prefix="include/",
+ includes=["include/"],
+ visibility = ["//visibility:public"],
+)
+
+%{tensorrt_genrules}
+
+# filegroup(
+# name = "%{tensorrt_lib}",
+# srcs = ["%{tensorrt_lib}"],
+# visibility = ["//visibility:public"],
+# )
diff --git a/third_party/tensorrt/LICENSE b/third_party/tensorrt/LICENSE
new file mode 100644
index 0000000000..d3da228420
--- /dev/null
+++ b/third_party/tensorrt/LICENSE
@@ -0,0 +1,203 @@
+Copyright 2015 The TensorFlow Authors. All rights reserved.
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2015, The TensorFlow Authors.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/third_party/tensorrt/build_defs.bzl b/third_party/tensorrt/build_defs.bzl
new file mode 100644
index 0000000000..392c5e0621
--- /dev/null
+++ b/third_party/tensorrt/build_defs.bzl
@@ -0,0 +1,85 @@
+# -*- python -*-
+"""
+ add a repo_generator rule for tensorrt
+
+"""
+
+_TENSORRT_INSTALLATION_PATH="TENSORRT_INSTALL_PATH"
+_TF_TENSORRT_VERSION="TF_TENSORRT_VERSION"
+
+def _is_trt_enabled(repo_ctx):
+ if "TF_NEED_TENSORRT" in repo_ctx.os.environ:
+ enable_trt = repo_ctx.os.environ["TF_NEED_TENSORRT"].strip()
+ return enable_trt == "1"
+ return False
+
+def _dummy_repo(repo_ctx):
+
+ repo_ctx.template("BUILD",Label("//third_party/tensorrt:BUILD.tpl"),
+ {"%{tensorrt_lib}":"","%{tensorrt_genrules}":""},
+ False)
+ repo_ctx.template("build_defs.bzl",Label("//third_party/tensorrt:build_defs.bzl.tpl"),
+ {"%{trt_configured}":"False"},False)
+ repo_ctx.file("include/NvUtils.h","",False)
+ repo_ctx.file("include/NvInfer.h","",False)
+
+def _trt_repo_impl(repo_ctx):
+ """
+ Implements local_config_tensorrt
+ """
+
+ if not _is_trt_enabled(repo_ctx):
+ _dummy_repo(repo_ctx)
+ return
+ trt_libdir=repo_ctx.os.environ[_TENSORRT_INSTALLATION_PATH]
+ trt_ver=repo_ctx.os.environ[_TF_TENSORRT_VERSION]
+# if deb installation
+# once a standardized installation between tar and deb
+# is done, we don't need this
+ if trt_libdir == '/usr/lib/x86_64-linux-gnu':
+ incPath='/usr/include/x86_64-linux-gnu'
+ incname='/usr/include/x86_64-linux-gnu/NvInfer.h'
+ else:
+ incPath=str(repo_ctx.path("%s/../include"%trt_libdir).realpath)
+ incname=incPath+'/NvInfer.h'
+ if len(trt_ver)>0:
+ origLib="%s/libnvinfer.so.%s"%(trt_libdir,trt_ver)
+ else:
+ origLib="%s/libnvinfer.so"%trt_libdir
+ objdump=repo_ctx.which("objdump")
+ if objdump == None:
+ if len(trt_ver)>0:
+ targetlib="lib/libnvinfer.so.%s"%(trt_ver[0])
+ else:
+ targetlib="lib/libnvinfer.so"
+ else:
+ soname=repo_ctx.execute([objdump,"-p",origLib])
+ for l in soname.stdout.splitlines():
+ if "SONAME" in l:
+ lib=l.strip().split(" ")[-1]
+ targetlib="lib/%s"%(lib)
+
+ if len(trt_ver)>0:
+ repo_ctx.symlink(origLib,targetlib)
+ else:
+ repo_ctx.symlink(origLib,targetlib)
+ grule=('genrule(\n name = "trtlinks",\n'+
+ ' outs = [\n "%s",\n "include/NvInfer.h",\n "include/NvUtils.h",\n ],\n'%targetlib +
+ ' cmd="""ln -sf %s $(@D)/%s '%(origLib,targetlib) +
+ '&&\n ln -sf %s $(@D)/include/NvInfer.h '%(incname) +
+ '&&\n ln -sf %s/NvUtils.h $(@D)/include/NvUtils.h""",\n)\n'%(incPath))
+ repo_ctx.template("BUILD",Label("//third_party/tensorrt:BUILD.tpl"),
+ {"%{tensorrt_lib}":'"%s"'%targetlib,"%{tensorrt_genrules}":grule},
+ False)
+ repo_ctx.template("build_defs.bzl",Label("//third_party/tensorrt:build_defs.bzl.tpl"),
+ {"%{trt_configured}":"True"},False)
+
+trt_repository=repository_rule(
+ implementation= _trt_repo_impl,
+ local=True,
+ environ=[
+ "TF_NEED_TENSORRT",
+ _TF_TENSORRT_VERSION,
+ _TENSORRT_INSTALLATION_PATH,
+ ],
+ )
diff --git a/third_party/tensorrt/build_defs.bzl.tpl b/third_party/tensorrt/build_defs.bzl.tpl
new file mode 100644
index 0000000000..18f354ee5a
--- /dev/null
+++ b/third_party/tensorrt/build_defs.bzl.tpl
@@ -0,0 +1,18 @@
+# -*- python -*-
+"""
+template file for trt functions
+
+"""
+
+def is_trt_enabled():
+ return %{trt_configured}
+
+def if_trt(if_true,if_false=[]):
+ # if is_trt_enabled():
+ # return if_true
+ # return if_false
+
+ return select({
+ "@local_config_tensorrt//:trt_enabled":if_true,
+ "//conditions:default":if_false,
+ })