aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/BUILD5
-rw-r--r--tensorflow/contrib/android/README.md5
-rw-r--r--tensorflow/contrib/boosted_trees/python/utils/losses.py4
-rw-r--r--tensorflow/contrib/cmake/CMakeLists.txt22
-rw-r--r--tensorflow/contrib/cmake/external/zlib.cmake108
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt3
-rw-r--r--tensorflow/contrib/crf/python/ops/crf.py12
-rw-r--r--tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py2
-rw-r--r--tensorflow/contrib/eager/python/g3doc/guide.md9
-rw-r--r--tensorflow/contrib/factorization/python/ops/clustering_ops.py4
-rw-r--r--tensorflow/contrib/framework/__init__.py6
-rw-r--r--tensorflow/contrib/framework/python/framework/graph_util.py12
-rw-r--r--tensorflow/contrib/image/kernels/bipartite_match_op.cc2
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py12
-rw-r--r--tensorflow/contrib/lite/README.md17
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_imagenet_slim.txt1001
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_mobilenet_quant_v1_224.txt1001
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java6
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java137
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierFloatInception.java103
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java94
-rw-r--r--tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py48
-rw-r--r--tensorflow/contrib/makefile/README.md99
-rwxr-xr-xtensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh6
-rw-r--r--tensorflow/contrib/rnn/python/ops/lstm_ops.py5
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py5
-rw-r--r--tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc8
-rw-r--r--tensorflow/contrib/signal/python/ops/spectral_ops.py2
-rw-r--r--tensorflow/contrib/slim/python/slim/evaluation_test.py3
-rw-r--r--tensorflow/contrib/tensor_forest/BUILD1
-rw-r--r--tensorflow/contrib/tensorrt/BUILD204
-rw-r--r--tensorflow/contrib/tensorrt/README.md40
-rw-r--r--tensorflow/contrib/tensorrt/__init__.py23
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc273
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.h47
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc1601
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.h52
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc140
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.h62
-rw-r--r--tensorflow/contrib/tensorrt/log/trt_logger.cc57
-rw-r--r--tensorflow/contrib/tensorrt/log/trt_logger.h42
-rw-r--r--tensorflow/contrib/tensorrt/ops/trt_engine_op.cc43
-rw-r--r--tensorflow/contrib/tensorrt/python/__init__.py24
-rw-r--r--tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py34
-rw-r--r--tensorflow/contrib/tensorrt/python/trt_convert.py103
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.cc253
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.h56
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment_test.cc367
-rw-r--r--tensorflow/contrib/tensorrt/segment/union_find.h79
-rw-r--r--tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc89
-rw-r--r--tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h33
-rw-r--r--tensorflow/contrib/tensorrt/test/test_tftrt.py88
-rw-r--r--tensorflow/contrib/tensorrt/trt_conversion.i131
-rw-r--r--tensorflow/contrib/tpu/profiler/pip_package/setup.py2
54 files changed, 6430 insertions, 155 deletions
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 6b3343bb2f..bab37e8906 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -7,6 +7,7 @@ package(default_visibility = ["//tensorflow:__subpackages__"])
load("//third_party/mpi:mpi.bzl", "if_mpi")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
+load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt")
py_library(
name = "contrib_py",
@@ -107,7 +108,9 @@ py_library(
"//tensorflow/contrib/training:training_py",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:util",
- ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]),
+ ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_tensorrt([
+ "//tensorflow/contrib/tensorrt:init_py",
+ ]),
)
cc_library(
diff --git a/tensorflow/contrib/android/README.md b/tensorflow/contrib/android/README.md
index b8d73bf24c..db37bcf73d 100644
--- a/tensorflow/contrib/android/README.md
+++ b/tensorflow/contrib/android/README.md
@@ -81,6 +81,11 @@ For documentation on building a self-contained AAR file with cmake, see
[tensorflow/contrib/android/cmake](cmake).
+### Makefile
+
+For documentation on building native TF libraries with make, including a CUDA-enabled variant for devices like the Nvidia Shield TV, see [tensorflow/contrib/makefile/README.md](../makefile/README.md)
+
+
## AssetManagerFileSystem
This directory also contains a TensorFlow filesystem supporting the Android
diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses.py b/tensorflow/contrib/boosted_trees/python/utils/losses.py
index 1e8b3ac08a..ab7ac2aba6 100644
--- a/tensorflow/contrib/boosted_trees/python/utils/losses.py
+++ b/tensorflow/contrib/boosted_trees/python/utils/losses.py
@@ -78,7 +78,7 @@ def per_example_maxent_loss(labels, weights, logits, num_classes, eps=1e-15):
# Calculate softmax probabilities for each class.
unnormalized_probs = math_ops.exp(logits)
- normalizers = math_ops.reduce_sum(unnormalized_probs, 1, keep_dims=True)
+ normalizers = math_ops.reduce_sum(unnormalized_probs, 1, keepdims=True)
softmax_predictions = math_ops.divide(unnormalized_probs,
math_ops.add(normalizers, eps))
@@ -120,7 +120,7 @@ def per_example_squared_loss(labels, weights, predictions):
update_op: An update operation to update the loss's internal state.
"""
unweighted_loss = math_ops.reduce_sum(
- math_ops.square(predictions - labels), 1, keep_dims=True)
+ math_ops.square(predictions - labels), 1, keepdims=True)
return unweighted_loss * weights, control_flow_ops.no_op()
diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt
index 524946a9a5..23b31ae1dc 100644
--- a/tensorflow/contrib/cmake/CMakeLists.txt
+++ b/tensorflow/contrib/cmake/CMakeLists.txt
@@ -52,6 +52,7 @@ if (NOT WIN32)
# for targets that link ${CMAKE_THREAD_LIBS_INIT}.
find_package (Threads)
+ # Options for linking CUDA/CUDNN libraries
option(tensorflow_PATH_STATIC_LIB "Additional library search path for libcudnn_static.a, libnccl_static.a, libculibos.a" /usr/local/cuda/lib64/)
option(tensorflow_CUDNN_INCLUDE "cudnn.h header install path" /usr/include/)
if (NOT tensorflow_CUDNN_INCLUDE)
@@ -73,6 +74,14 @@ if (NOT WIN32)
# option's default value is OFF. Fill it with real default values
set(tensorflow_CUDA_LIBRARY_PATH /usr/local/cuda/lib64)
endif (NOT tensorflow_CUDA_LIBRARY_PATH)
+
+ # Options for linking other libraries
+ option(systemlib_ZLIB "Use the system installed library as shared objects instead of downloading ZLIB and statically linking to it: ZLIB" OFF)
+
+ option(systemlib_ALL "Turn on every possible systemlib_* options" OFF)
+ if (systemlib_ALL)
+ set (systmelib_ZLIB ON)
+ endif (systemlib_ALL)
endif()
if (WIN32)
@@ -188,8 +197,10 @@ if (tensorflow_BUILD_CC_TESTS)
include(googletest)
endif()
+add_definitions(${ADD_CFLAGS})
+link_directories(${ADD_LINK_DIRECTORY})
+
set(tensorflow_EXTERNAL_LIBRARIES
- ${zlib_STATIC_LIBRARIES}
${gif_STATIC_LIBRARIES}
${png_STATIC_LIBRARIES}
${jpeg_STATIC_LIBRARIES}
@@ -203,6 +214,15 @@ set(tensorflow_EXTERNAL_LIBRARIES
${re2_STATIC_LIBRARIES}
${sqlite_STATIC_LIBRARIES}
)
+
+if (systemlib_ZLIB)
+ set(tensorflow_EXTERNAL_LIBRARIES ${tensorflow_EXTERNAL_LIBRARIES}
+ ${ZLIB_LIBRARIES})
+else (systemlib_ZLIB)
+ set(tensorflow_EXTERNAL_LIBRARIES ${tensorflow_EXTERNAL_LIBRARIES}
+ ${zlib_STATIC_LIBRARIES})
+endif (systemlib_ZLIB)
+
set(tensorflow_EXTERNAL_DEPENDENCIES
zlib_copy_headers_to_destination
gif_copy_headers_to_destination
diff --git a/tensorflow/contrib/cmake/external/zlib.cmake b/tensorflow/contrib/cmake/external/zlib.cmake
index c5eb0cbcc7..116d423093 100644
--- a/tensorflow/contrib/cmake/external/zlib.cmake
+++ b/tensorflow/contrib/cmake/external/zlib.cmake
@@ -12,61 +12,75 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-include (ExternalProject)
+if (systemlib_ZLIB)
+ find_package(PkgConfig)
+ pkg_search_module(ZLIB REQUIRED zlib)
+ set(zlib_INCLUDE_DIR ${ZLIB_INCLUDE_DIRS})
+ set(ADD_LINK_DIRECTORY ${ADD_LINK_DIRECTORY} ${ZLIB_LIBRARY_DIRS})
+ set(ADD_CFLAGS ${ADD_CFLAGS} ${ZLIB_CFLAGS_OTHER})
-set(zlib_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/zlib_archive)
-set(ZLIB_URL https://github.com/madler/zlib)
-set(ZLIB_BUILD ${CMAKE_CURRENT_BINARY_DIR}/zlib/src/zlib)
-set(ZLIB_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/zlib/install)
-set(ZLIB_TAG 50893291621658f355bc5b4d450a8d06a563053d)
+ # To meet DEPENDS zlib from other projects.
+ # If we hit this line, zlib is already built and installed to the system.
+ add_custom_target(zlib)
+ add_custom_target(zlib_copy_headers_to_destination)
-if(WIN32)
- if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*")
- set(zlib_STATIC_LIBRARIES
- debug ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstaticd.lib
- optimized ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstatic.lib)
- else()
- if(CMAKE_BUILD_TYPE EQUAL Debug)
+else (systemlib_ZLIB)
+ include (ExternalProject)
+
+ set(zlib_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/zlib_archive)
+ set(ZLIB_URL https://github.com/madler/zlib)
+ set(ZLIB_BUILD ${CMAKE_CURRENT_BINARY_DIR}/zlib/src/zlib)
+ set(ZLIB_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/zlib/install)
+ set(ZLIB_TAG 50893291621658f355bc5b4d450a8d06a563053d)
+
+ if(WIN32)
+ if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*")
set(zlib_STATIC_LIBRARIES
- ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstaticd.lib)
+ debug ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstaticd.lib
+ optimized ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstatic.lib)
else()
- set(zlib_STATIC_LIBRARIES
- ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstatic.lib)
+ if(CMAKE_BUILD_TYPE EQUAL Debug)
+ set(zlib_STATIC_LIBRARIES
+ ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstaticd.lib)
+ else()
+ set(zlib_STATIC_LIBRARIES
+ ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstatic.lib)
+ endif()
endif()
+ else()
+ set(zlib_STATIC_LIBRARIES
+ ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/libz.a)
endif()
-else()
- set(zlib_STATIC_LIBRARIES
- ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/libz.a)
-endif()
-set(ZLIB_HEADERS
- "${ZLIB_INSTALL}/include/zconf.h"
- "${ZLIB_INSTALL}/include/zlib.h"
-)
+ set(ZLIB_HEADERS
+ "${ZLIB_INSTALL}/include/zconf.h"
+ "${ZLIB_INSTALL}/include/zlib.h"
+ )
-ExternalProject_Add(zlib
- PREFIX zlib
- GIT_REPOSITORY ${ZLIB_URL}
- GIT_TAG ${ZLIB_TAG}
- INSTALL_DIR ${ZLIB_INSTALL}
- BUILD_IN_SOURCE 1
- BUILD_BYPRODUCTS ${zlib_STATIC_LIBRARIES}
- DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
- CMAKE_CACHE_ARGS
- -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE}
- -DCMAKE_BUILD_TYPE:STRING=Release
- -DCMAKE_INSTALL_PREFIX:STRING=${ZLIB_INSTALL}
-)
+ ExternalProject_Add(zlib
+ PREFIX zlib
+ GIT_REPOSITORY ${ZLIB_URL}
+ GIT_TAG ${ZLIB_TAG}
+ INSTALL_DIR ${ZLIB_INSTALL}
+ BUILD_IN_SOURCE 1
+ BUILD_BYPRODUCTS ${zlib_STATIC_LIBRARIES}
+ DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
+ CMAKE_CACHE_ARGS
+ -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE}
+ -DCMAKE_BUILD_TYPE:STRING=Release
+ -DCMAKE_INSTALL_PREFIX:STRING=${ZLIB_INSTALL}
+ )
-# put zlib includes in the directory where they are expected
-add_custom_target(zlib_create_destination_dir
- COMMAND ${CMAKE_COMMAND} -E make_directory ${zlib_INCLUDE_DIR}
- DEPENDS zlib)
+ # put zlib includes in the directory where they are expected
+ add_custom_target(zlib_create_destination_dir
+ COMMAND ${CMAKE_COMMAND} -E make_directory ${zlib_INCLUDE_DIR}
+ DEPENDS zlib)
-add_custom_target(zlib_copy_headers_to_destination
- DEPENDS zlib_create_destination_dir)
+ add_custom_target(zlib_copy_headers_to_destination
+ DEPENDS zlib_create_destination_dir)
-foreach(header_file ${ZLIB_HEADERS})
- add_custom_command(TARGET zlib_copy_headers_to_destination PRE_BUILD
- COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${zlib_INCLUDE_DIR})
-endforeach()
+ foreach(header_file ${ZLIB_HEADERS})
+ add_custom_command(TARGET zlib_copy_headers_to_destination PRE_BUILD
+ COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${zlib_INCLUDE_DIR})
+ endforeach()
+endif (systemlib_ZLIB)
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index f55043c93d..bfe53c01b3 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -413,6 +413,9 @@ tensorflow/contrib/tensorboard
tensorflow/contrib/tensorboard/plugins
tensorflow/contrib/tensorboard/plugins/projector
tensorflow/contrib/tensorboard/plugins/trace
+# TODO(sami): Add cmake implementations.
+# tensorflow/contrib/tensorrt/python
+# tensorflow/contrib/tensorrt/python/ops
tensorflow/contrib/tensor_forest
tensorflow/contrib/tensor_forest/client
tensorflow/contrib/tensor_forest/hybrid
diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py
index faa78769b9..1233c8f251 100644
--- a/tensorflow/contrib/crf/python/ops/crf.py
+++ b/tensorflow/contrib/crf/python/ops/crf.py
@@ -105,8 +105,8 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths,
return utils.smart_cond(
pred=math_ops.equal(inputs.shape[1].value or array_ops.shape(inputs)[1],
1),
- fn1=_single_seq_fn,
- fn2=_multi_seq_fn)
+ true_fn=_single_seq_fn,
+ false_fn=_multi_seq_fn)
def crf_log_norm(inputs, sequence_lengths, transition_params):
@@ -511,7 +511,7 @@ def crf_decode(potentials, transition_params, sequence_length):
return decode_tags, best_score
return utils.smart_cond(
- pred=math_ops.equal(
- potentials.shape[1].value or array_ops.shape(potentials)[1], 1),
- fn1=_single_seq_fn,
- fn2=_multi_seq_fn)
+ pred=math_ops.equal(potentials.shape[1].value or
+ array_ops.shape(potentials)[1], 1),
+ true_fn=_single_seq_fn,
+ false_fn=_multi_seq_fn)
diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
index b6becfa9fc..2aa771a71e 100644
--- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
+++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
@@ -278,7 +278,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution):
* math_ops.log(self.temperature))
# compute the unnormalized density
log_softmax = nn_ops.log_softmax(logits_2d - x_2d * self._temperature_2d)
- log_unnorm_prob = math_ops.reduce_sum(log_softmax, [-1], keep_dims=False)
+ log_unnorm_prob = math_ops.reduce_sum(log_softmax, [-1], keepdims=False)
# combine unnormalized density with normalization constant
log_prob = log_norm_const + log_unnorm_prob
# Reshapes log_prob to be consistent with shape of user-supplied logits
diff --git a/tensorflow/contrib/eager/python/g3doc/guide.md b/tensorflow/contrib/eager/python/g3doc/guide.md
index 4724aa4aee..ebb05051f2 100644
--- a/tensorflow/contrib/eager/python/g3doc/guide.md
+++ b/tensorflow/contrib/eager/python/g3doc/guide.md
@@ -22,11 +22,10 @@ to models defined without using eager execution.
Eager execution is included in TensorFlow versions 1.5 and above.
Installation instructions at https://www.tensorflow.org/install/
-The contents of this guide are compatible with TensorFlow 1.5.
-However, if you run into bugs that are fixed in source but not the
-release, you may want to either either [building from
-source](https://www.tensorflow.org/install/install_sources)
-or the try latest nightly builds. The nightly builds are available as:
+The contents of this guide are compatible with TensorFlow 1.5. However, if you
+run into bugs that are fixed in source but not the release, you may want to
+either [build from source](https://www.tensorflow.org/install/install_sources)
+or try a nightly build. The nightly builds are available as:
- [`pip` packages](https://github.com/tensorflow/tensorflow/blob/master/README.md#installation) and
diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py
index 6d3acb2750..23137e0a97 100644
--- a/tensorflow/contrib/factorization/python/ops/clustering_ops.py
+++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py
@@ -192,11 +192,11 @@ class KMeans(object):
# Computes Euclidean distance. Note the first and third terms are
# broadcast additions.
squared_distance = (
- math_ops.reduce_sum(math_ops.square(inp), 1, keep_dims=True) -
+ math_ops.reduce_sum(math_ops.square(inp), 1, keepdims=True) -
2 * math_ops.matmul(inp, clusters, transpose_b=True) +
array_ops.transpose(
math_ops.reduce_sum(
- math_ops.square(clusters), 1, keep_dims=True)))
+ math_ops.square(clusters), 1, keepdims=True)))
output.append(squared_distance)
return output
diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py
index fb101c3653..deeb5bec79 100644
--- a/tensorflow/contrib/framework/__init__.py
+++ b/tensorflow/contrib/framework/__init__.py
@@ -85,6 +85,8 @@ See the @{$python/contrib.framework} guide.
@@py_func
@@sort
+@@get_placeholders
+
@@CriticalSection
@@BoundedTensorSpec
@@ -102,10 +104,10 @@ from tensorflow.contrib.framework.python.ops import *
from tensorflow.python.framework.ops import prepend_name_scope
from tensorflow.python.framework.ops import strip_name_scope
-
from tensorflow.python.framework.tensor_spec import BoundedTensorSpec
from tensorflow.python.framework.tensor_spec import TensorSpec
-
+from tensorflow.python.ops.control_flow_ops import smart_cond
+from tensorflow.python.ops.control_flow_ops import smart_constant_value
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = ['nest']
diff --git a/tensorflow/contrib/framework/python/framework/graph_util.py b/tensorflow/contrib/framework/python/framework/graph_util.py
index a18ff2320d..49eec3a3f1 100644
--- a/tensorflow/contrib/framework/python/framework/graph_util.py
+++ b/tensorflow/contrib/framework/python/framework/graph_util.py
@@ -133,6 +133,18 @@ def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes,
def get_placeholders(graph):
"""Get placeholders of a graph.
+ For example:
+
+ ```python
+ a = tf.placeholder(dtype=tf.float32, shape=[2, 2], name='a')
+ a = tf.placeholder(dtype=tf.int32, shape=[3, 2], name='b')
+
+ tf.contrib.framework.get_placeholders(tf.get_default_graph())
+ # Returns:
+ # [<tf.Tensor 'a:0' shape=(2, 2) dtype=float32>,
+ # <tf.Tensor 'b:0' shape=(3, 2) dtype=int32>]
+ ```
+
Args:
graph: A tf.Graph.
Returns:
diff --git a/tensorflow/contrib/image/kernels/bipartite_match_op.cc b/tensorflow/contrib/image/kernels/bipartite_match_op.cc
index 7d207c388b..726adb0777 100644
--- a/tensorflow/contrib/image/kernels/bipartite_match_op.cc
+++ b/tensorflow/contrib/image/kernels/bipartite_match_op.cc
@@ -85,7 +85,7 @@ class BipartiteMatchOp : public OpKernel {
context->allocate_output(1, TensorShape({num_input_columns}),
&column_to_row_match_indices));
- typename TTypes<float, 2>::ConstTensor distance_mat =
+ TTypes<float, 2>::ConstTensor distance_mat =
input_distance_mat.shaped<float, 2>(
{num_input_rows, num_input_columns});
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 45ddfbfc9f..b2ea75c7e1 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -517,8 +517,8 @@ def batch_norm(inputs,
then the batch normalization uses weighted mean and
variance. (This can be used to correct for bias in training
example selection.)
- fused: if `True`, use a faster, fused implementation if possible.
- If `None`, use the system recommended implementation.
+ fused: if `None` or `True`, use a faster, fused implementation if possible.
+ If `False`, use the system recommended implementation.
data_format: A string. `NHWC` (default) and `NCHW` are supported.
zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new
pair of variables 'moving_mean/biased' and 'moving_mean/local_step'.
@@ -778,7 +778,7 @@ def batch_norm(inputs,
else:
if data_format == DATA_FORMAT_NCHW:
mean, variance = nn.weighted_moments(
- inputs, moments_axes, batch_weights, keep_dims=True)
+ inputs, moments_axes, batch_weights, keepdims=True)
mean = array_ops.reshape(mean, [-1])
variance = array_ops.reshape(variance, [-1])
else:
@@ -2836,9 +2836,9 @@ def spatial_softmax(features,
softmax_attention = nn.softmax(features / temperature)
expected_x = math_ops.reduce_sum(
- pos_x * softmax_attention, [1], keep_dims=True)
+ pos_x * softmax_attention, [1], keepdims=True)
expected_y = math_ops.reduce_sum(
- pos_y * softmax_attention, [1], keep_dims=True)
+ pos_y * softmax_attention, [1], keepdims=True)
expected_xy = array_ops.concat([expected_x, expected_y], 1)
feature_keypoints = array_ops.reshape(expected_xy,
[-1, num_channels.value * 2])
@@ -3018,7 +3018,7 @@ def poincare_normalize(x, axis=1, epsilon=1e-5, name=None):
"""
with ops.name_scope(name, 'poincare_normalize', [x]) as name:
x = ops.convert_to_tensor(x, name='x')
- square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keep_dims=True)
+ square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims=True)
x_inv_norm = math_ops.rsqrt(square_sum)
x_inv_norm = math_ops.minimum((1. - epsilon) * x_inv_norm, 1.)
return math_ops.multiply(x, x_inv_norm, name=name)
diff --git a/tensorflow/contrib/lite/README.md b/tensorflow/contrib/lite/README.md
index 3e55d2a496..00e93d2c4f 100644
--- a/tensorflow/contrib/lite/README.md
+++ b/tensorflow/contrib/lite/README.md
@@ -6,7 +6,7 @@ TensorFlow Lite uses many techniques for achieving low latency like optimizing t
![image](g3doc/TFLite-Architecture.jpg)
# Getting Started with an Android Demo App
-This section contains an example application using TensorFlow Lite for Android devices. The demo is a sample camera app that classifies images continuously using a quantized Mobilenet model. A device running Android 5.0 ( API 21) or higher is required to run the demo.
+This section contains an example application using TensorFlow Lite for Android devices. The demo is a sample camera app that classifies images continuously using either a quantized Mobilenet model or a floating point Inception-v3 model. A device running Android 5.0 ( API 21) or higher is required to run the demo.
There are 3 ways to get the demo app to your device
- Download the prebuilt binary or
@@ -29,9 +29,16 @@ The simplest way to compile the demo app, and try out changes to the project cod
- Make sure the Android SDK version is greater than 26 and NDK version is greater than 14 (in the Android Studio Settings).
- Import the `tensorflow/contrib/lite/java/demo` directory as a new Android Studio project.
- Click through installing all the Gradle extensions it requests.
- - Download the quantized Mobilenet TensorFlow Lite model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip)
- - unzip and copy mobilenet_quant_v1_224.tflite to the assets directory:
- `tensorflow/contrib/lite/java/demo/app/src/main/assets/`
+ - Either
+ - Download the quantized Mobilenet TensorFlow Lite model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip)
+ - unzip and copy mobilenet_quant_v1_224.tflite to the assets directory:
+ `tensorflow/contrib/lite/java/demo/app/src/main/assets/`
+ - Or download the floating point Inception-v3 model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_slim_2016_android_2017_11_10.zip)
+ - unzip and copy inceptionv3_non_slim_2015.tflite to the assets directory
+ - change the chosen classifier in [Camera2BasicFragment.java](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java) from
+ `classifier = new ImageClassifierQuantizedMobileNet(getActivity());`
+ to
+ `classifier = new ImageClassifierFloatInception(getActivity());`
- Build and run the demo app
## Building TensorFlow Lite and the demo app from source
@@ -84,7 +91,7 @@ Currently, we only support building the Android demo app within a Python 2
environment (due to a Bazel bug).
### More about the demo
-The demo is resizing each camera image frame to (224 width * 224 height) to match the quantized Mobilenet model being used. The resized image is converted into a ByteBuffer row by row of size 1 * 224 * 224 * 3 bytes, where 1 is the number of images in a batch 224 * 224 is the width and height of the image 3 bytes represents three colors of a pixel. This demo uses the TensorFlow Lite Java inference API for models which take a single input and provide a single output. This outputs a two-dimensional array, with the first dimension being the category index and the second dimension being the confidence of classification. The Mobilenet model has 1001 unique categories and the app sorts the probabilities of all the categories and displays the top three. The Mobilenet quantized model is bundled within the assets directory of the app.
+The demo is resizing each camera image frame to (224 width * 224 height) to match the quantized Mobilenet model being used (229 * 229 for Inception-v3). The resized image is converted into a ByteBuffer row by row of size 1 * 224 * 224 * 3 bytes, where 1 is the number of images in a batch. 224 * 224 (299 * 299) is the width and height of the image. 3 bytes represents three colors of a pixel. This demo uses the TensorFlow Lite Java inference API for models which take a single input and provide a single output. This outputs a two-dimensional array, with the first dimension being the category index and the second dimension being the confidence of classification. Both models have 1001 unique categories and the app sorts the probabilities of all the categories and displays the top three. The model file must be downloaded and bundled within the assets directory of the app.
# iOS Demo App
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_imagenet_slim.txt b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_imagenet_slim.txt
new file mode 100644
index 0000000000..572eccf900
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_imagenet_slim.txt
@@ -0,0 +1,1001 @@
+dummy
+tench
+goldfish
+great white shark
+tiger shark
+hammerhead
+electric ray
+stingray
+cock
+hen
+ostrich
+brambling
+goldfinch
+house finch
+junco
+indigo bunting
+robin
+bulbul
+jay
+magpie
+chickadee
+water ouzel
+kite
+bald eagle
+vulture
+great grey owl
+European fire salamander
+common newt
+eft
+spotted salamander
+axolotl
+bullfrog
+tree frog
+tailed frog
+loggerhead
+leatherback turtle
+mud turtle
+terrapin
+box turtle
+banded gecko
+common iguana
+American chameleon
+whiptail
+agama
+frilled lizard
+alligator lizard
+Gila monster
+green lizard
+African chameleon
+Komodo dragon
+African crocodile
+American alligator
+triceratops
+thunder snake
+ringneck snake
+hognose snake
+green snake
+king snake
+garter snake
+water snake
+vine snake
+night snake
+boa constrictor
+rock python
+Indian cobra
+green mamba
+sea snake
+horned viper
+diamondback
+sidewinder
+trilobite
+harvestman
+scorpion
+black and gold garden spider
+barn spider
+garden spider
+black widow
+tarantula
+wolf spider
+tick
+centipede
+black grouse
+ptarmigan
+ruffed grouse
+prairie chicken
+peacock
+quail
+partridge
+African grey
+macaw
+sulphur-crested cockatoo
+lorikeet
+coucal
+bee eater
+hornbill
+hummingbird
+jacamar
+toucan
+drake
+red-breasted merganser
+goose
+black swan
+tusker
+echidna
+platypus
+wallaby
+koala
+wombat
+jellyfish
+sea anemone
+brain coral
+flatworm
+nematode
+conch
+snail
+slug
+sea slug
+chiton
+chambered nautilus
+Dungeness crab
+rock crab
+fiddler crab
+king crab
+American lobster
+spiny lobster
+crayfish
+hermit crab
+isopod
+white stork
+black stork
+spoonbill
+flamingo
+little blue heron
+American egret
+bittern
+crane
+limpkin
+European gallinule
+American coot
+bustard
+ruddy turnstone
+red-backed sandpiper
+redshank
+dowitcher
+oystercatcher
+pelican
+king penguin
+albatross
+grey whale
+killer whale
+dugong
+sea lion
+Chihuahua
+Japanese spaniel
+Maltese dog
+Pekinese
+Shih-Tzu
+Blenheim spaniel
+papillon
+toy terrier
+Rhodesian ridgeback
+Afghan hound
+basset
+beagle
+bloodhound
+bluetick
+black-and-tan coonhound
+Walker hound
+English foxhound
+redbone
+borzoi
+Irish wolfhound
+Italian greyhound
+whippet
+Ibizan hound
+Norwegian elkhound
+otterhound
+Saluki
+Scottish deerhound
+Weimaraner
+Staffordshire bullterrier
+American Staffordshire terrier
+Bedlington terrier
+Border terrier
+Kerry blue terrier
+Irish terrier
+Norfolk terrier
+Norwich terrier
+Yorkshire terrier
+wire-haired fox terrier
+Lakeland terrier
+Sealyham terrier
+Airedale
+cairn
+Australian terrier
+Dandie Dinmont
+Boston bull
+miniature schnauzer
+giant schnauzer
+standard schnauzer
+Scotch terrier
+Tibetan terrier
+silky terrier
+soft-coated wheaten terrier
+West Highland white terrier
+Lhasa
+flat-coated retriever
+curly-coated retriever
+golden retriever
+Labrador retriever
+Chesapeake Bay retriever
+German short-haired pointer
+vizsla
+English setter
+Irish setter
+Gordon setter
+Brittany spaniel
+clumber
+English springer
+Welsh springer spaniel
+cocker spaniel
+Sussex spaniel
+Irish water spaniel
+kuvasz
+schipperke
+groenendael
+malinois
+briard
+kelpie
+komondor
+Old English sheepdog
+Shetland sheepdog
+collie
+Border collie
+Bouvier des Flandres
+Rottweiler
+German shepherd
+Doberman
+miniature pinscher
+Greater Swiss Mountain dog
+Bernese mountain dog
+Appenzeller
+EntleBucher
+boxer
+bull mastiff
+Tibetan mastiff
+French bulldog
+Great Dane
+Saint Bernard
+Eskimo dog
+malamute
+Siberian husky
+dalmatian
+affenpinscher
+basenji
+pug
+Leonberg
+Newfoundland
+Great Pyrenees
+Samoyed
+Pomeranian
+chow
+keeshond
+Brabancon griffon
+Pembroke
+Cardigan
+toy poodle
+miniature poodle
+standard poodle
+Mexican hairless
+timber wolf
+white wolf
+red wolf
+coyote
+dingo
+dhole
+African hunting dog
+hyena
+red fox
+kit fox
+Arctic fox
+grey fox
+tabby
+tiger cat
+Persian cat
+Siamese cat
+Egyptian cat
+cougar
+lynx
+leopard
+snow leopard
+jaguar
+lion
+tiger
+cheetah
+brown bear
+American black bear
+ice bear
+sloth bear
+mongoose
+meerkat
+tiger beetle
+ladybug
+ground beetle
+long-horned beetle
+leaf beetle
+dung beetle
+rhinoceros beetle
+weevil
+fly
+bee
+ant
+grasshopper
+cricket
+walking stick
+cockroach
+mantis
+cicada
+leafhopper
+lacewing
+dragonfly
+damselfly
+admiral
+ringlet
+monarch
+cabbage butterfly
+sulphur butterfly
+lycaenid
+starfish
+sea urchin
+sea cucumber
+wood rabbit
+hare
+Angora
+hamster
+porcupine
+fox squirrel
+marmot
+beaver
+guinea pig
+sorrel
+zebra
+hog
+wild boar
+warthog
+hippopotamus
+ox
+water buffalo
+bison
+ram
+bighorn
+ibex
+hartebeest
+impala
+gazelle
+Arabian camel
+llama
+weasel
+mink
+polecat
+black-footed ferret
+otter
+skunk
+badger
+armadillo
+three-toed sloth
+orangutan
+gorilla
+chimpanzee
+gibbon
+siamang
+guenon
+patas
+baboon
+macaque
+langur
+colobus
+proboscis monkey
+marmoset
+capuchin
+howler monkey
+titi
+spider monkey
+squirrel monkey
+Madagascar cat
+indri
+Indian elephant
+African elephant
+lesser panda
+giant panda
+barracouta
+eel
+coho
+rock beauty
+anemone fish
+sturgeon
+gar
+lionfish
+puffer
+abacus
+abaya
+academic gown
+accordion
+acoustic guitar
+aircraft carrier
+airliner
+airship
+altar
+ambulance
+amphibian
+analog clock
+apiary
+apron
+ashcan
+assault rifle
+backpack
+bakery
+balance beam
+balloon
+ballpoint
+Band Aid
+banjo
+bannister
+barbell
+barber chair
+barbershop
+barn
+barometer
+barrel
+barrow
+baseball
+basketball
+bassinet
+bassoon
+bathing cap
+bath towel
+bathtub
+beach wagon
+beacon
+beaker
+bearskin
+beer bottle
+beer glass
+bell cote
+bib
+bicycle-built-for-two
+bikini
+binder
+binoculars
+birdhouse
+boathouse
+bobsled
+bolo tie
+bonnet
+bookcase
+bookshop
+bottlecap
+bow
+bow tie
+brass
+brassiere
+breakwater
+breastplate
+broom
+bucket
+buckle
+bulletproof vest
+bullet train
+butcher shop
+cab
+caldron
+candle
+cannon
+canoe
+can opener
+cardigan
+car mirror
+carousel
+carpenter's kit
+carton
+car wheel
+cash machine
+cassette
+cassette player
+castle
+catamaran
+CD player
+cello
+cellular telephone
+chain
+chainlink fence
+chain mail
+chain saw
+chest
+chiffonier
+chime
+china cabinet
+Christmas stocking
+church
+cinema
+cleaver
+cliff dwelling
+cloak
+clog
+cocktail shaker
+coffee mug
+coffeepot
+coil
+combination lock
+computer keyboard
+confectionery
+container ship
+convertible
+corkscrew
+cornet
+cowboy boot
+cowboy hat
+cradle
+crane
+crash helmet
+crate
+crib
+Crock Pot
+croquet ball
+crutch
+cuirass
+dam
+desk
+desktop computer
+dial telephone
+diaper
+digital clock
+digital watch
+dining table
+dishrag
+dishwasher
+disk brake
+dock
+dogsled
+dome
+doormat
+drilling platform
+drum
+drumstick
+dumbbell
+Dutch oven
+electric fan
+electric guitar
+electric locomotive
+entertainment center
+envelope
+espresso maker
+face powder
+feather boa
+file
+fireboat
+fire engine
+fire screen
+flagpole
+flute
+folding chair
+football helmet
+forklift
+fountain
+fountain pen
+four-poster
+freight car
+French horn
+frying pan
+fur coat
+garbage truck
+gasmask
+gas pump
+goblet
+go-kart
+golf ball
+golfcart
+gondola
+gong
+gown
+grand piano
+greenhouse
+grille
+grocery store
+guillotine
+hair slide
+hair spray
+half track
+hammer
+hamper
+hand blower
+hand-held computer
+handkerchief
+hard disc
+harmonica
+harp
+harvester
+hatchet
+holster
+home theater
+honeycomb
+hook
+hoopskirt
+horizontal bar
+horse cart
+hourglass
+iPod
+iron
+jack-o'-lantern
+jean
+jeep
+jersey
+jigsaw puzzle
+jinrikisha
+joystick
+kimono
+knee pad
+knot
+lab coat
+ladle
+lampshade
+laptop
+lawn mower
+lens cap
+letter opener
+library
+lifeboat
+lighter
+limousine
+liner
+lipstick
+Loafer
+lotion
+loudspeaker
+loupe
+lumbermill
+magnetic compass
+mailbag
+mailbox
+maillot
+maillot
+manhole cover
+maraca
+marimba
+mask
+matchstick
+maypole
+maze
+measuring cup
+medicine chest
+megalith
+microphone
+microwave
+military uniform
+milk can
+minibus
+miniskirt
+minivan
+missile
+mitten
+mixing bowl
+mobile home
+Model T
+modem
+monastery
+monitor
+moped
+mortar
+mortarboard
+mosque
+mosquito net
+motor scooter
+mountain bike
+mountain tent
+mouse
+mousetrap
+moving van
+muzzle
+nail
+neck brace
+necklace
+nipple
+notebook
+obelisk
+oboe
+ocarina
+odometer
+oil filter
+organ
+oscilloscope
+overskirt
+oxcart
+oxygen mask
+packet
+paddle
+paddlewheel
+padlock
+paintbrush
+pajama
+palace
+panpipe
+paper towel
+parachute
+parallel bars
+park bench
+parking meter
+passenger car
+patio
+pay-phone
+pedestal
+pencil box
+pencil sharpener
+perfume
+Petri dish
+photocopier
+pick
+pickelhaube
+picket fence
+pickup
+pier
+piggy bank
+pill bottle
+pillow
+ping-pong ball
+pinwheel
+pirate
+pitcher
+plane
+planetarium
+plastic bag
+plate rack
+plow
+plunger
+Polaroid camera
+pole
+police van
+poncho
+pool table
+pop bottle
+pot
+potter's wheel
+power drill
+prayer rug
+printer
+prison
+projectile
+projector
+puck
+punching bag
+purse
+quill
+quilt
+racer
+racket
+radiator
+radio
+radio telescope
+rain barrel
+recreational vehicle
+reel
+reflex camera
+refrigerator
+remote control
+restaurant
+revolver
+rifle
+rocking chair
+rotisserie
+rubber eraser
+rugby ball
+rule
+running shoe
+safe
+safety pin
+saltshaker
+sandal
+sarong
+sax
+scabbard
+scale
+school bus
+schooner
+scoreboard
+screen
+screw
+screwdriver
+seat belt
+sewing machine
+shield
+shoe shop
+shoji
+shopping basket
+shopping cart
+shovel
+shower cap
+shower curtain
+ski
+ski mask
+sleeping bag
+slide rule
+sliding door
+slot
+snorkel
+snowmobile
+snowplow
+soap dispenser
+soccer ball
+sock
+solar dish
+sombrero
+soup bowl
+space bar
+space heater
+space shuttle
+spatula
+speedboat
+spider web
+spindle
+sports car
+spotlight
+stage
+steam locomotive
+steel arch bridge
+steel drum
+stethoscope
+stole
+stone wall
+stopwatch
+stove
+strainer
+streetcar
+stretcher
+studio couch
+stupa
+submarine
+suit
+sundial
+sunglass
+sunglasses
+sunscreen
+suspension bridge
+swab
+sweatshirt
+swimming trunks
+swing
+switch
+syringe
+table lamp
+tank
+tape player
+teapot
+teddy
+television
+tennis ball
+thatch
+theater curtain
+thimble
+thresher
+throne
+tile roof
+toaster
+tobacco shop
+toilet seat
+torch
+totem pole
+tow truck
+toyshop
+tractor
+trailer truck
+tray
+trench coat
+tricycle
+trimaran
+tripod
+triumphal arch
+trolleybus
+trombone
+tub
+turnstile
+typewriter keyboard
+umbrella
+unicycle
+upright
+vacuum
+vase
+vault
+velvet
+vending machine
+vestment
+viaduct
+violin
+volleyball
+waffle iron
+wall clock
+wallet
+wardrobe
+warplane
+washbasin
+washer
+water bottle
+water jug
+water tower
+whiskey jug
+whistle
+wig
+window screen
+window shade
+Windsor tie
+wine bottle
+wing
+wok
+wooden spoon
+wool
+worm fence
+wreck
+yawl
+yurt
+web site
+comic book
+crossword puzzle
+street sign
+traffic light
+book jacket
+menu
+plate
+guacamole
+consomme
+hot pot
+trifle
+ice cream
+ice lolly
+French loaf
+bagel
+pretzel
+cheeseburger
+hotdog
+mashed potato
+head cabbage
+broccoli
+cauliflower
+zucchini
+spaghetti squash
+acorn squash
+butternut squash
+cucumber
+artichoke
+bell pepper
+cardoon
+mushroom
+Granny Smith
+strawberry
+orange
+lemon
+fig
+pineapple
+banana
+jackfruit
+custard apple
+pomegranate
+hay
+carbonara
+chocolate sauce
+dough
+meat loaf
+pizza
+potpie
+burrito
+red wine
+espresso
+cup
+eggnog
+alp
+bubble
+cliff
+coral reef
+geyser
+lakeside
+promontory
+sandbar
+seashore
+valley
+volcano
+ballplayer
+groom
+scuba diver
+rapeseed
+daisy
+yellow lady's slipper
+corn
+acorn
+hip
+buckeye
+coral fungus
+agaric
+gyromitra
+stinkhorn
+earthstar
+hen-of-the-woods
+bolete
+ear
+toilet tissue
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_mobilenet_quant_v1_224.txt b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_mobilenet_quant_v1_224.txt
new file mode 100644
index 0000000000..fe811239d8
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_mobilenet_quant_v1_224.txt
@@ -0,0 +1,1001 @@
+background
+tench
+goldfish
+great white shark
+tiger shark
+hammerhead
+electric ray
+stingray
+cock
+hen
+ostrich
+brambling
+goldfinch
+house finch
+junco
+indigo bunting
+robin
+bulbul
+jay
+magpie
+chickadee
+water ouzel
+kite
+bald eagle
+vulture
+great grey owl
+European fire salamander
+common newt
+eft
+spotted salamander
+axolotl
+bullfrog
+tree frog
+tailed frog
+loggerhead
+leatherback turtle
+mud turtle
+terrapin
+box turtle
+banded gecko
+common iguana
+American chameleon
+whiptail
+agama
+frilled lizard
+alligator lizard
+Gila monster
+green lizard
+African chameleon
+Komodo dragon
+African crocodile
+American alligator
+triceratops
+thunder snake
+ringneck snake
+hognose snake
+green snake
+king snake
+garter snake
+water snake
+vine snake
+night snake
+boa constrictor
+rock python
+Indian cobra
+green mamba
+sea snake
+horned viper
+diamondback
+sidewinder
+trilobite
+harvestman
+scorpion
+black and gold garden spider
+barn spider
+garden spider
+black widow
+tarantula
+wolf spider
+tick
+centipede
+black grouse
+ptarmigan
+ruffed grouse
+prairie chicken
+peacock
+quail
+partridge
+African grey
+macaw
+sulphur-crested cockatoo
+lorikeet
+coucal
+bee eater
+hornbill
+hummingbird
+jacamar
+toucan
+drake
+red-breasted merganser
+goose
+black swan
+tusker
+echidna
+platypus
+wallaby
+koala
+wombat
+jellyfish
+sea anemone
+brain coral
+flatworm
+nematode
+conch
+snail
+slug
+sea slug
+chiton
+chambered nautilus
+Dungeness crab
+rock crab
+fiddler crab
+king crab
+American lobster
+spiny lobster
+crayfish
+hermit crab
+isopod
+white stork
+black stork
+spoonbill
+flamingo
+little blue heron
+American egret
+bittern
+crane
+limpkin
+European gallinule
+American coot
+bustard
+ruddy turnstone
+red-backed sandpiper
+redshank
+dowitcher
+oystercatcher
+pelican
+king penguin
+albatross
+grey whale
+killer whale
+dugong
+sea lion
+Chihuahua
+Japanese spaniel
+Maltese dog
+Pekinese
+Shih-Tzu
+Blenheim spaniel
+papillon
+toy terrier
+Rhodesian ridgeback
+Afghan hound
+basset
+beagle
+bloodhound
+bluetick
+black-and-tan coonhound
+Walker hound
+English foxhound
+redbone
+borzoi
+Irish wolfhound
+Italian greyhound
+whippet
+Ibizan hound
+Norwegian elkhound
+otterhound
+Saluki
+Scottish deerhound
+Weimaraner
+Staffordshire bullterrier
+American Staffordshire terrier
+Bedlington terrier
+Border terrier
+Kerry blue terrier
+Irish terrier
+Norfolk terrier
+Norwich terrier
+Yorkshire terrier
+wire-haired fox terrier
+Lakeland terrier
+Sealyham terrier
+Airedale
+cairn
+Australian terrier
+Dandie Dinmont
+Boston bull
+miniature schnauzer
+giant schnauzer
+standard schnauzer
+Scotch terrier
+Tibetan terrier
+silky terrier
+soft-coated wheaten terrier
+West Highland white terrier
+Lhasa
+flat-coated retriever
+curly-coated retriever
+golden retriever
+Labrador retriever
+Chesapeake Bay retriever
+German short-haired pointer
+vizsla
+English setter
+Irish setter
+Gordon setter
+Brittany spaniel
+clumber
+English springer
+Welsh springer spaniel
+cocker spaniel
+Sussex spaniel
+Irish water spaniel
+kuvasz
+schipperke
+groenendael
+malinois
+briard
+kelpie
+komondor
+Old English sheepdog
+Shetland sheepdog
+collie
+Border collie
+Bouvier des Flandres
+Rottweiler
+German shepherd
+Doberman
+miniature pinscher
+Greater Swiss Mountain dog
+Bernese mountain dog
+Appenzeller
+EntleBucher
+boxer
+bull mastiff
+Tibetan mastiff
+French bulldog
+Great Dane
+Saint Bernard
+Eskimo dog
+malamute
+Siberian husky
+dalmatian
+affenpinscher
+basenji
+pug
+Leonberg
+Newfoundland
+Great Pyrenees
+Samoyed
+Pomeranian
+chow
+keeshond
+Brabancon griffon
+Pembroke
+Cardigan
+toy poodle
+miniature poodle
+standard poodle
+Mexican hairless
+timber wolf
+white wolf
+red wolf
+coyote
+dingo
+dhole
+African hunting dog
+hyena
+red fox
+kit fox
+Arctic fox
+grey fox
+tabby
+tiger cat
+Persian cat
+Siamese cat
+Egyptian cat
+cougar
+lynx
+leopard
+snow leopard
+jaguar
+lion
+tiger
+cheetah
+brown bear
+American black bear
+ice bear
+sloth bear
+mongoose
+meerkat
+tiger beetle
+ladybug
+ground beetle
+long-horned beetle
+leaf beetle
+dung beetle
+rhinoceros beetle
+weevil
+fly
+bee
+ant
+grasshopper
+cricket
+walking stick
+cockroach
+mantis
+cicada
+leafhopper
+lacewing
+dragonfly
+damselfly
+admiral
+ringlet
+monarch
+cabbage butterfly
+sulphur butterfly
+lycaenid
+starfish
+sea urchin
+sea cucumber
+wood rabbit
+hare
+Angora
+hamster
+porcupine
+fox squirrel
+marmot
+beaver
+guinea pig
+sorrel
+zebra
+hog
+wild boar
+warthog
+hippopotamus
+ox
+water buffalo
+bison
+ram
+bighorn
+ibex
+hartebeest
+impala
+gazelle
+Arabian camel
+llama
+weasel
+mink
+polecat
+black-footed ferret
+otter
+skunk
+badger
+armadillo
+three-toed sloth
+orangutan
+gorilla
+chimpanzee
+gibbon
+siamang
+guenon
+patas
+baboon
+macaque
+langur
+colobus
+proboscis monkey
+marmoset
+capuchin
+howler monkey
+titi
+spider monkey
+squirrel monkey
+Madagascar cat
+indri
+Indian elephant
+African elephant
+lesser panda
+giant panda
+barracouta
+eel
+coho
+rock beauty
+anemone fish
+sturgeon
+gar
+lionfish
+puffer
+abacus
+abaya
+academic gown
+accordion
+acoustic guitar
+aircraft carrier
+airliner
+airship
+altar
+ambulance
+amphibian
+analog clock
+apiary
+apron
+ashcan
+assault rifle
+backpack
+bakery
+balance beam
+balloon
+ballpoint
+Band Aid
+banjo
+bannister
+barbell
+barber chair
+barbershop
+barn
+barometer
+barrel
+barrow
+baseball
+basketball
+bassinet
+bassoon
+bathing cap
+bath towel
+bathtub
+beach wagon
+beacon
+beaker
+bearskin
+beer bottle
+beer glass
+bell cote
+bib
+bicycle-built-for-two
+bikini
+binder
+binoculars
+birdhouse
+boathouse
+bobsled
+bolo tie
+bonnet
+bookcase
+bookshop
+bottlecap
+bow
+bow tie
+brass
+brassiere
+breakwater
+breastplate
+broom
+bucket
+buckle
+bulletproof vest
+bullet train
+butcher shop
+cab
+caldron
+candle
+cannon
+canoe
+can opener
+cardigan
+car mirror
+carousel
+carpenter's kit
+carton
+car wheel
+cash machine
+cassette
+cassette player
+castle
+catamaran
+CD player
+cello
+cellular telephone
+chain
+chainlink fence
+chain mail
+chain saw
+chest
+chiffonier
+chime
+china cabinet
+Christmas stocking
+church
+cinema
+cleaver
+cliff dwelling
+cloak
+clog
+cocktail shaker
+coffee mug
+coffeepot
+coil
+combination lock
+computer keyboard
+confectionery
+container ship
+convertible
+corkscrew
+cornet
+cowboy boot
+cowboy hat
+cradle
+crane
+crash helmet
+crate
+crib
+Crock Pot
+croquet ball
+crutch
+cuirass
+dam
+desk
+desktop computer
+dial telephone
+diaper
+digital clock
+digital watch
+dining table
+dishrag
+dishwasher
+disk brake
+dock
+dogsled
+dome
+doormat
+drilling platform
+drum
+drumstick
+dumbbell
+Dutch oven
+electric fan
+electric guitar
+electric locomotive
+entertainment center
+envelope
+espresso maker
+face powder
+feather boa
+file
+fireboat
+fire engine
+fire screen
+flagpole
+flute
+folding chair
+football helmet
+forklift
+fountain
+fountain pen
+four-poster
+freight car
+French horn
+frying pan
+fur coat
+garbage truck
+gasmask
+gas pump
+goblet
+go-kart
+golf ball
+golfcart
+gondola
+gong
+gown
+grand piano
+greenhouse
+grille
+grocery store
+guillotine
+hair slide
+hair spray
+half track
+hammer
+hamper
+hand blower
+hand-held computer
+handkerchief
+hard disc
+harmonica
+harp
+harvester
+hatchet
+holster
+home theater
+honeycomb
+hook
+hoopskirt
+horizontal bar
+horse cart
+hourglass
+iPod
+iron
+jack-o'-lantern
+jean
+jeep
+jersey
+jigsaw puzzle
+jinrikisha
+joystick
+kimono
+knee pad
+knot
+lab coat
+ladle
+lampshade
+laptop
+lawn mower
+lens cap
+letter opener
+library
+lifeboat
+lighter
+limousine
+liner
+lipstick
+Loafer
+lotion
+loudspeaker
+loupe
+lumbermill
+magnetic compass
+mailbag
+mailbox
+maillot
+maillot
+manhole cover
+maraca
+marimba
+mask
+matchstick
+maypole
+maze
+measuring cup
+medicine chest
+megalith
+microphone
+microwave
+military uniform
+milk can
+minibus
+miniskirt
+minivan
+missile
+mitten
+mixing bowl
+mobile home
+Model T
+modem
+monastery
+monitor
+moped
+mortar
+mortarboard
+mosque
+mosquito net
+motor scooter
+mountain bike
+mountain tent
+mouse
+mousetrap
+moving van
+muzzle
+nail
+neck brace
+necklace
+nipple
+notebook
+obelisk
+oboe
+ocarina
+odometer
+oil filter
+organ
+oscilloscope
+overskirt
+oxcart
+oxygen mask
+packet
+paddle
+paddlewheel
+padlock
+paintbrush
+pajama
+palace
+panpipe
+paper towel
+parachute
+parallel bars
+park bench
+parking meter
+passenger car
+patio
+pay-phone
+pedestal
+pencil box
+pencil sharpener
+perfume
+Petri dish
+photocopier
+pick
+pickelhaube
+picket fence
+pickup
+pier
+piggy bank
+pill bottle
+pillow
+ping-pong ball
+pinwheel
+pirate
+pitcher
+plane
+planetarium
+plastic bag
+plate rack
+plow
+plunger
+Polaroid camera
+pole
+police van
+poncho
+pool table
+pop bottle
+pot
+potter's wheel
+power drill
+prayer rug
+printer
+prison
+projectile
+projector
+puck
+punching bag
+purse
+quill
+quilt
+racer
+racket
+radiator
+radio
+radio telescope
+rain barrel
+recreational vehicle
+reel
+reflex camera
+refrigerator
+remote control
+restaurant
+revolver
+rifle
+rocking chair
+rotisserie
+rubber eraser
+rugby ball
+rule
+running shoe
+safe
+safety pin
+saltshaker
+sandal
+sarong
+sax
+scabbard
+scale
+school bus
+schooner
+scoreboard
+screen
+screw
+screwdriver
+seat belt
+sewing machine
+shield
+shoe shop
+shoji
+shopping basket
+shopping cart
+shovel
+shower cap
+shower curtain
+ski
+ski mask
+sleeping bag
+slide rule
+sliding door
+slot
+snorkel
+snowmobile
+snowplow
+soap dispenser
+soccer ball
+sock
+solar dish
+sombrero
+soup bowl
+space bar
+space heater
+space shuttle
+spatula
+speedboat
+spider web
+spindle
+sports car
+spotlight
+stage
+steam locomotive
+steel arch bridge
+steel drum
+stethoscope
+stole
+stone wall
+stopwatch
+stove
+strainer
+streetcar
+stretcher
+studio couch
+stupa
+submarine
+suit
+sundial
+sunglass
+sunglasses
+sunscreen
+suspension bridge
+swab
+sweatshirt
+swimming trunks
+swing
+switch
+syringe
+table lamp
+tank
+tape player
+teapot
+teddy
+television
+tennis ball
+thatch
+theater curtain
+thimble
+thresher
+throne
+tile roof
+toaster
+tobacco shop
+toilet seat
+torch
+totem pole
+tow truck
+toyshop
+tractor
+trailer truck
+tray
+trench coat
+tricycle
+trimaran
+tripod
+triumphal arch
+trolleybus
+trombone
+tub
+turnstile
+typewriter keyboard
+umbrella
+unicycle
+upright
+vacuum
+vase
+vault
+velvet
+vending machine
+vestment
+viaduct
+violin
+volleyball
+waffle iron
+wall clock
+wallet
+wardrobe
+warplane
+washbasin
+washer
+water bottle
+water jug
+water tower
+whiskey jug
+whistle
+wig
+window screen
+window shade
+Windsor tie
+wine bottle
+wing
+wok
+wooden spoon
+wool
+worm fence
+wreck
+yawl
+yurt
+web site
+comic book
+crossword puzzle
+street sign
+traffic light
+book jacket
+menu
+plate
+guacamole
+consomme
+hot pot
+trifle
+ice cream
+ice lolly
+French loaf
+bagel
+pretzel
+cheeseburger
+hotdog
+mashed potato
+head cabbage
+broccoli
+cauliflower
+zucchini
+spaghetti squash
+acorn squash
+butternut squash
+cucumber
+artichoke
+bell pepper
+cardoon
+mushroom
+Granny Smith
+strawberry
+orange
+lemon
+fig
+pineapple
+banana
+jackfruit
+custard apple
+pomegranate
+hay
+carbonara
+chocolate sauce
+dough
+meat loaf
+pizza
+potpie
+burrito
+red wine
+espresso
+cup
+eggnog
+alp
+bubble
+cliff
+coral reef
+geyser
+lakeside
+promontory
+sandbar
+seashore
+valley
+volcano
+ballplayer
+groom
+scuba diver
+rapeseed
+daisy
+yellow lady's slipper
+corn
+acorn
+hip
+buckeye
+coral fungus
+agaric
+gyromitra
+stinkhorn
+earthstar
+hen-of-the-woods
+bolete
+ear
+toilet tissue
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
index 74737a8b88..9b9fdffab5 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
@@ -296,7 +296,8 @@ public class Camera2BasicFragment extends Fragment
public void onActivityCreated(Bundle savedInstanceState) {
super.onActivityCreated(savedInstanceState);
try {
- classifier = new ImageClassifier(getActivity());
+ // create either a new ImageClassifierQuantizedMobileNet or an ImageClassifierFloatInception
+ classifier = new ImageClassifierQuantizedMobileNet(getActivity());
} catch (IOException e) {
Log.e(TAG, "Failed to initialize an image classifier.");
}
@@ -658,8 +659,7 @@ public class Camera2BasicFragment extends Fragment
showToast("Uninitialized Classifier or invalid context.");
return;
}
- Bitmap bitmap =
- textureView.getBitmap(ImageClassifier.DIM_IMG_SIZE_X, ImageClassifier.DIM_IMG_SIZE_Y);
+ Bitmap bitmap = textureView.getBitmap(classifier.getImageSizeX(), classifier.getImageSizeY());
String textToShow = classifier.classifyFrame(bitmap);
bitmap.recycle();
showToast(textToShow);
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
index e44c5ae6b4..2c91be9d62 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
@@ -37,17 +37,11 @@ import java.util.PriorityQueue;
import org.tensorflow.lite.Interpreter;
/** Classifies images with Tensorflow Lite. */
-public class ImageClassifier {
+public abstract class ImageClassifier {
/** Tag for the {@link Log}. */
private static final String TAG = "TfLiteCameraDemo";
- /** Name of the model file stored in Assets. */
- private static final String MODEL_PATH = "mobilenet_quant_v1_224.tflite";
-
- /** Name of the label file stored in Assets. */
- private static final String LABEL_PATH = "labels.txt";
-
/** Number of results to show in the UI. */
private static final int RESULTS_TO_SHOW = 3;
@@ -56,23 +50,18 @@ public class ImageClassifier {
private static final int DIM_PIXEL_SIZE = 3;
- static final int DIM_IMG_SIZE_X = 224;
- static final int DIM_IMG_SIZE_Y = 224;
-
/* Preallocated buffers for storing image data in. */
- private int[] intValues = new int[DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y];
+ private int[] intValues = new int[getImageSizeX() * getImageSizeY()];
/** An instance of the driver class to run model inference with Tensorflow Lite. */
- private Interpreter tflite;
+ protected Interpreter tflite;
/** Labels corresponding to the output of the vision model. */
private List<String> labelList;
/** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as inputs. */
- private ByteBuffer imgData = null;
+ protected ByteBuffer imgData = null;
- /** An array to hold inference results, to be feed into Tensorflow Lite as outputs. */
- private byte[][] labelProbArray = null;
/** multi-stage low pass filter * */
private float[][] filterLabelProbArray = null;
@@ -95,10 +84,13 @@ public class ImageClassifier {
labelList = loadLabelList(activity);
imgData =
ByteBuffer.allocateDirect(
- DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
+ DIM_BATCH_SIZE
+ * getImageSizeX()
+ * getImageSizeY()
+ * DIM_PIXEL_SIZE
+ * getNumBytesPerChannel());
imgData.order(ByteOrder.nativeOrder());
- labelProbArray = new byte[1][labelList.size()];
- filterLabelProbArray = new float[FILTER_STAGES][labelList.size()];
+ filterLabelProbArray = new float[FILTER_STAGES][getNumLabels()];
Log.d(TAG, "Created a Tensorflow Lite Image Classifier.");
}
@@ -111,7 +103,7 @@ public class ImageClassifier {
convertBitmapToByteBuffer(bitmap);
// Here's where the magic happens!!!
long startTime = SystemClock.uptimeMillis();
- tflite.run(imgData, labelProbArray);
+ runInference();
long endTime = SystemClock.uptimeMillis();
Log.d(TAG, "Timecost to run model inference: " + Long.toString(endTime - startTime));
@@ -125,12 +117,12 @@ public class ImageClassifier {
}
void applyFilter() {
- int numLabels = labelList.size();
+ int numLabels = getNumLabels();
// Low pass filter `labelProbArray` into the first stage of the filter.
for (int j = 0; j < numLabels; ++j) {
filterLabelProbArray[0][j] +=
- FILTER_FACTOR * (labelProbArray[0][j] - filterLabelProbArray[0][j]);
+ FILTER_FACTOR * (getProbability(j) - filterLabelProbArray[0][j]);
}
// Low pass filter each stage into the next.
for (int i = 1; i < FILTER_STAGES; ++i) {
@@ -142,7 +134,7 @@ public class ImageClassifier {
// Copy the last stage filter output back to `labelProbArray`.
for (int j = 0; j < numLabels; ++j) {
- labelProbArray[0][j] = (byte)filterLabelProbArray[FILTER_STAGES - 1][j];
+ setProbability(j, filterLabelProbArray[FILTER_STAGES - 1][j]);
}
}
@@ -156,7 +148,7 @@ public class ImageClassifier {
private List<String> loadLabelList(Activity activity) throws IOException {
List<String> labelList = new ArrayList<String>();
BufferedReader reader =
- new BufferedReader(new InputStreamReader(activity.getAssets().open(LABEL_PATH)));
+ new BufferedReader(new InputStreamReader(activity.getAssets().open(getLabelPath())));
String line;
while ((line = reader.readLine()) != null) {
labelList.add(line);
@@ -167,7 +159,7 @@ public class ImageClassifier {
/** Memory-map the model file in Assets. */
private MappedByteBuffer loadModelFile(Activity activity) throws IOException {
- AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_PATH);
+ AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(getModelPath());
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
@@ -185,12 +177,10 @@ public class ImageClassifier {
// Convert the image to floating point.
int pixel = 0;
long startTime = SystemClock.uptimeMillis();
- for (int i = 0; i < DIM_IMG_SIZE_X; ++i) {
- for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) {
+ for (int i = 0; i < getImageSizeX(); ++i) {
+ for (int j = 0; j < getImageSizeY(); ++j) {
final int val = intValues[pixel++];
- imgData.put((byte) ((val >> 16) & 0xFF));
- imgData.put((byte) ((val >> 8) & 0xFF));
- imgData.put((byte) (val & 0xFF));
+ addPixelValue(val);
}
}
long endTime = SystemClock.uptimeMillis();
@@ -199,9 +189,9 @@ public class ImageClassifier {
/** Prints top-K labels, to be shown in UI as the results. */
private String printTopKLabels() {
- for (int i = 0; i < labelList.size(); ++i) {
+ for (int i = 0; i < getNumLabels(); ++i) {
sortedLabels.add(
- new AbstractMap.SimpleEntry<>(labelList.get(i), (labelProbArray[0][i] & 0xff) / 255.0f));
+ new AbstractMap.SimpleEntry<>(labelList.get(i), getNormalizedProbability(i)));
if (sortedLabels.size() > RESULTS_TO_SHOW) {
sortedLabels.poll();
}
@@ -214,4 +204,89 @@ public class ImageClassifier {
}
return textToShow;
}
+
+ /**
+ * Get the name of the model file stored in Assets.
+ *
+ * @return
+ */
+ protected abstract String getModelPath();
+
+ /**
+ * Get the name of the label file stored in Assets.
+ *
+ * @return
+ */
+ protected abstract String getLabelPath();
+
+ /**
+ * Get the image size along the x axis.
+ *
+ * @return
+ */
+ protected abstract int getImageSizeX();
+
+ /**
+ * Get the image size along the y axis.
+ *
+ * @return
+ */
+ protected abstract int getImageSizeY();
+
+ /**
+ * Get the number of bytes that is used to store a single color channel value.
+ *
+ * @return
+ */
+ protected abstract int getNumBytesPerChannel();
+
+ /**
+ * Add pixelValue to byteBuffer.
+ *
+ * @param pixelValue
+ */
+ protected abstract void addPixelValue(int pixelValue);
+
+ /**
+ * Read the probability value for the specified label This is either the original value as it was
+ * read from the net's output or the updated value after the filter was applied.
+ *
+ * @param labelIndex
+ * @return
+ */
+ protected abstract float getProbability(int labelIndex);
+
+ /**
+ * Set the probability value for the specified label.
+ *
+ * @param labelIndex
+ * @param value
+ */
+ protected abstract void setProbability(int labelIndex, Number value);
+
+ /**
+ * Get the normalized probability value for the specified label. This is the final value as it
+ * will be shown to the user.
+ *
+ * @return
+ */
+ protected abstract float getNormalizedProbability(int labelIndex);
+
+ /**
+ * Run inference using the prepared input in {@link #imgData}. Afterwards, the result will be
+ * provided by getProbability().
+ *
+ * <p>This additional method is necessary, because we don't have a common base for different
+ * primitive data types.
+ */
+ protected abstract void runInference();
+
+ /**
+ * Get the total number of labels.
+ *
+ * @return
+ */
+ protected int getNumLabels() {
+ return labelList.size();
+ }
}
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierFloatInception.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierFloatInception.java
new file mode 100644
index 0000000000..3108422952
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierFloatInception.java
@@ -0,0 +1,103 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package com.example.android.tflitecamerademo;
+
+import android.app.Activity;
+import java.io.IOException;
+
+/**
+ * This classifier works with the Inception-v3 slim model. It applies floating point inference
+ * rather than using a quantized model.
+ */
+public class ImageClassifierFloatInception extends ImageClassifier {
+
+ /** The inception net requires additional normalization of the used input. */
+ private static final int IMAGE_MEAN = 128;
+
+ private static final float IMAGE_STD = 128.0f;
+
+ /**
+ * An array to hold inference results, to be feed into Tensorflow Lite as outputs. This isn't part
+ * of the super class, because we need a primitive array here.
+ */
+ private float[][] labelProbArray = null;
+
+ /**
+ * Initializes an {@code ImageClassifier}.
+ *
+ * @param activity
+ */
+ ImageClassifierFloatInception(Activity activity) throws IOException {
+ super(activity);
+ labelProbArray = new float[1][getNumLabels()];
+ }
+
+ @Override
+ protected String getModelPath() {
+ // you can download this file from
+ // https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_slim_2016_android_2017_11_10.zip
+ return "inceptionv3_slim_2016.tflite";
+ }
+
+ @Override
+ protected String getLabelPath() {
+ return "labels_imagenet_slim.txt";
+ }
+
+ @Override
+ protected int getImageSizeX() {
+ return 299;
+ }
+
+ @Override
+ protected int getImageSizeY() {
+ return 299;
+ }
+
+ @Override
+ protected int getNumBytesPerChannel() {
+ // a 32bit float value requires 4 bytes
+ return 4;
+ }
+
+ @Override
+ protected void addPixelValue(int pixelValue) {
+ imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ }
+
+ @Override
+ protected float getProbability(int labelIndex) {
+ return labelProbArray[0][labelIndex];
+ }
+
+ @Override
+ protected void setProbability(int labelIndex, Number value) {
+ labelProbArray[0][labelIndex] = value.floatValue();
+ }
+
+ @Override
+ protected float getNormalizedProbability(int labelIndex) {
+ // TODO the following value isn't in [0,1] yet, but may be greater. Why?
+ return getProbability(labelIndex);
+ }
+
+ @Override
+ protected void runInference() {
+ tflite.run(imgData, labelProbArray);
+ }
+}
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java
new file mode 100644
index 0000000000..5f341f0f5b
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java
@@ -0,0 +1,94 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package com.example.android.tflitecamerademo;
+
+import android.app.Activity;
+import java.io.IOException;
+
+/** This classifier works with the quantized MobileNet model. */
+public class ImageClassifierQuantizedMobileNet extends ImageClassifier {
+
+ /**
+ * An array to hold inference results, to be feed into Tensorflow Lite as outputs. This isn't part
+ * of the super class, because we need a primitive array here.
+ */
+ private byte[][] labelProbArray = null;
+
+ /**
+ * Initializes an {@code ImageClassifier}.
+ *
+ * @param activity
+ */
+ ImageClassifierQuantizedMobileNet(Activity activity) throws IOException {
+ super(activity);
+ labelProbArray = new byte[1][getNumLabels()];
+ }
+
+ @Override
+ protected String getModelPath() {
+ // you can download this file from
+ // https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip
+ return "mobilenet_quant_v1_224.tflite";
+ }
+
+ @Override
+ protected String getLabelPath() {
+ return "labels_mobilenet_quant_v1_224.txt";
+ }
+
+ @Override
+ protected int getImageSizeX() {
+ return 224;
+ }
+
+ @Override
+ protected int getImageSizeY() {
+ return 224;
+ }
+
+ @Override
+ protected int getNumBytesPerChannel() {
+ // the quantized model uses a single byte only
+ return 1;
+ }
+
+ @Override
+ protected void addPixelValue(int pixelValue) {
+ imgData.put((byte) ((pixelValue >> 16) & 0xFF));
+ imgData.put((byte) ((pixelValue >> 8) & 0xFF));
+ imgData.put((byte) (pixelValue & 0xFF));
+ }
+
+ @Override
+ protected float getProbability(int labelIndex) {
+ return labelProbArray[0][labelIndex];
+ }
+
+ @Override
+ protected void setProbability(int labelIndex, Number value) {
+ labelProbArray[0][labelIndex] = value.byteValue();
+ }
+
+ @Override
+ protected float getNormalizedProbability(int labelIndex) {
+ return (labelProbArray[0][labelIndex] & 0xff) / 255.0f;
+ }
+
+ @Override
+ protected void runInference() {
+ tflite.run(imgData, labelProbArray);
+ }
+}
diff --git a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py
index c3a57ba51b..2b9eee4ef7 100644
--- a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py
+++ b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py
@@ -50,16 +50,12 @@ def pairwise_distance(feature, squared=False):
pairwise_distances: 2-D Tensor of size [number of data, number of data].
"""
pairwise_distances_squared = math_ops.add(
+ math_ops.reduce_sum(math_ops.square(feature), axis=[1], keepdims=True),
math_ops.reduce_sum(
- math_ops.square(feature),
- axis=[1],
- keep_dims=True),
- math_ops.reduce_sum(
- math_ops.square(
- array_ops.transpose(feature)),
+ math_ops.square(array_ops.transpose(feature)),
axis=[0],
- keep_dims=True)) - 2.0 * math_ops.matmul(
- feature, array_ops.transpose(feature))
+ keepdims=True)) - 2.0 * math_ops.matmul(feature,
+ array_ops.transpose(feature))
# Deal with numerical inaccuracies. Set small negatives to zero.
pairwise_distances_squared = math_ops.maximum(pairwise_distances_squared, 0.0)
@@ -132,10 +128,10 @@ def masked_maximum(data, mask, dim=1):
masked_maximums: N-D `Tensor`.
The maximized dimension is of size 1 after the operation.
"""
- axis_minimums = math_ops.reduce_min(data, dim, keep_dims=True)
+ axis_minimums = math_ops.reduce_min(data, dim, keepdims=True)
masked_maximums = math_ops.reduce_max(
- math_ops.multiply(
- data - axis_minimums, mask), dim, keep_dims=True) + axis_minimums
+ math_ops.multiply(data - axis_minimums, mask), dim,
+ keepdims=True) + axis_minimums
return masked_maximums
@@ -151,10 +147,10 @@ def masked_minimum(data, mask, dim=1):
masked_minimums: N-D `Tensor`.
The minimized dimension is of size 1 after the operation.
"""
- axis_maximums = math_ops.reduce_max(data, dim, keep_dims=True)
+ axis_maximums = math_ops.reduce_max(data, dim, keepdims=True)
masked_minimums = math_ops.reduce_min(
- math_ops.multiply(
- data - axis_maximums, mask), dim, keep_dims=True) + axis_maximums
+ math_ops.multiply(data - axis_maximums, mask), dim,
+ keepdims=True) + axis_maximums
return masked_minimums
@@ -202,8 +198,7 @@ def triplet_semihard_loss(labels, embeddings, margin=1.0):
mask_final = array_ops.reshape(
math_ops.greater(
math_ops.reduce_sum(
- math_ops.cast(
- mask, dtype=dtypes.float32), 1, keep_dims=True),
+ math_ops.cast(mask, dtype=dtypes.float32), 1, keepdims=True),
0.0), [batch_size, batch_size])
mask_final = array_ops.transpose(mask_final)
@@ -290,7 +285,7 @@ def npairs_loss(labels, embeddings_anchor, embeddings_positive,
labels_remapped = math_ops.to_float(
math_ops.equal(labels, array_ops.transpose(labels)))
- labels_remapped /= math_ops.reduce_sum(labels_remapped, 1, keep_dims=True)
+ labels_remapped /= math_ops.reduce_sum(labels_remapped, 1, keepdims=True)
# Add the softmax loss.
xent_loss = nn.softmax_cross_entropy_with_logits(
@@ -395,7 +390,7 @@ def npairs_loss_multilabel(sparse_labels, embeddings_anchor,
multilabel_adjacency_matrix = _build_multilabel_adjacency(sparse_labels)
labels_remapped = math_ops.to_float(multilabel_adjacency_matrix)
- labels_remapped /= math_ops.reduce_sum(labels_remapped, 1, keep_dims=True)
+ labels_remapped /= math_ops.reduce_sum(labels_remapped, 1, keepdims=True)
# Add the softmax loss.
xent_loss = nn.softmax_cross_entropy_with_logits(
@@ -448,10 +443,10 @@ def lifted_struct_loss(labels, embeddings, margin=1.0):
# Safe maximum: Temporarily shift negative distances
# above zero before taking max.
# this is to take the max only among negatives.
- row_minimums = math_ops.reduce_min(diff, 1, keep_dims=True)
+ row_minimums = math_ops.reduce_min(diff, 1, keepdims=True)
row_negative_maximums = math_ops.reduce_max(
- math_ops.multiply(
- diff - row_minimums, mask), 1, keep_dims=True) + row_minimums
+ math_ops.multiply(diff - row_minimums, mask), 1,
+ keepdims=True) + row_minimums
# Compute the loss.
# Keep track of matrix of maximums where M_ij = max(m_i, m_j)
@@ -467,10 +462,11 @@ def lifted_struct_loss(labels, embeddings, margin=1.0):
array_ops.transpose(max_elements), [-1, 1])
loss_exp_left = array_ops.reshape(
- math_ops.reduce_sum(math_ops.multiply(
- math_ops.exp(
- diff_tiled - max_elements_vect),
- mask_tiled), 1, keep_dims=True), [batch_size, batch_size])
+ math_ops.reduce_sum(
+ math_ops.multiply(
+ math_ops.exp(diff_tiled - max_elements_vect), mask_tiled),
+ 1,
+ keepdims=True), [batch_size, batch_size])
loss_mat = max_elements + math_ops.log(
loss_exp_left + array_ops.transpose(loss_exp_left))
@@ -686,7 +682,7 @@ def _find_loss_augmented_facility_idx(pairwise_distances, labels, chosen_ids,
array_ops.reshape(pairwise_distances_candidate, [1, -1])
], 0),
axis=0,
- keep_dims=True), [num_candidates, -1]),
+ keepdims=True), [num_candidates, -1]),
axis=1)
nmi_scores = array_ops.zeros([num_candidates])
diff --git a/tensorflow/contrib/makefile/README.md b/tensorflow/contrib/makefile/README.md
index 6959ca344f..995230dfa8 100644
--- a/tensorflow/contrib/makefile/README.md
+++ b/tensorflow/contrib/makefile/README.md
@@ -130,6 +130,105 @@ adb shell '/data/local/tmp/benchmark \
For more details, see the [benchmark documentation](../../tools/benchmark).
+## CUDA support for Tegra devices running Android (Nvidia Shield TV, etc)
+
+With the release of TF 1.6 and JetPack for Android 3.2 (currently pending), you can now build a version of TensorFlow for compatible devices according to the following instructions which will receive the full benefits of GPU acceleration.
+
+#### Environment setup:
+
+First, download and install JetPack for Android version 3.2 or greater from [Nvidia](https://developers.nvidia.com). Note that as of the TF 1.6 release the JetPack for Android 3.2 release is still pending, and regular JetPack for L4T will not work.
+
+```bash
+git clone https://github.com/tensorflow/tensorflow.git
+cd tensorflow
+JETPACK=$HOME/JetPack_Android_3.2
+TEGRA_LIBS="$JETPACK/cuDNN/aarch64/cuda/lib64/libcudnn.so $JETPACK/cuda-9.0/extras/CUPTI/lib64/libcupti.so $JETPACK/cuda/targets/aarch64-linux-androideabi/lib64/libcufft.so"
+```
+
+#### Building all CUDA-enabled native binaries:
+This will build CUDA-enabled versions of libtensorflow_inference.so and the benchmark binary. (libtensorflow_demo.so will also be built incidentally, but it does not support CUDA)
+
+```bash
+NDK_ROOT=$JETPACK/android-ndk-r13b
+CC_PREFIX=ccache tensorflow/contrib/makefile/build_all_android.sh -s tensorflow/contrib/makefile/sub_makefiles/android/Makefile.in -t "libtensorflow_inference.so libtensorflow_demo.so all" -a tegra
+```
+(add -T on subsequent builds to skip protobuf downloading/building)
+
+
+#### Testing the CUDA-enabled benchmark via adb:
+Build binaries first as above, then run:
+
+```bash
+adb shell mkdir -p /data/local/tmp/lib64
+adb push $TEGRA_LIBS /data/local/tmp/lib64
+adb push tensorflow/contrib/makefile/gen/bin/android_arm64-v8a/benchmark /data/local/tmp
+wget https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk
+unzip tensorflow_demo.apk -d /tmp/tensorflow_demo
+adb push /tmp/tensorflow_demo/assets/*.pb /data/local/tmp
+adb shell "LD_LIBRARY_PATH=/data/local/tmp/lib64 /data/local/tmp/benchmark --graph=/data/local/tmp/tensorflow_inception_graph.pb"
+```
+
+#### Building the CUDA-enabled TensorFlow AAR with Bazel:
+Build the native binaries first as above. Then, build the aar and package the native libs by executing the following:
+```bash
+mkdir -p /tmp/tf/jni/arm64-v8a
+cp tensorflow/contrib/makefile/gen/lib/android_tegra/libtensorflow_*.so /tmp/tf/jni/arm64-v8a/
+cp $TEGRA_LIBS /tmp/tf/jni/arm64-v8a
+bazel build //tensorflow/contrib/android:android_tensorflow_inference_java.aar
+cp bazel-bin/tensorflow/contrib/android/android_tensorflow_inference_java.aar /tmp/tf/tensorflow.aar
+cd /tmp/tf
+chmod +w tensorflow.aar
+zip -ur tensorflow.aar $(find jni -name *.so)
+```
+
+#### Building the CUDA-enabled TensorFlow Android demo with Bazel:
+Build binaries first as above, then edit tensorflow/examples/android/BUILD and replace:
+```
+ srcs = [
+ ":libtensorflow_demo.so",
+ "//tensorflow/contrib/android:libtensorflow_inference.so",
+ ],
+```
+with:
+```
+srcs = glob(["libs/arm64-v8a/*.so"]),
+```
+
+Then run:
+```bash
+# Create dir for native libs
+mkdir -p tensorflow/examples/android/libs/arm64-v8a
+
+# Copy JetPack libs
+cp $TEGRA_LIBS tensorflow/examples/android/libs/arm64-v8a
+
+# Copy native TensorFlow libraries
+cp tensorflow/contrib/makefile/gen/lib/android_arm64-v8a/libtensorflow_*.so tensorflow/examples/android/libs/arm64-v8a/
+
+# Build APK
+bazel build -c opt --fat_apk_cpu=arm64-v8a tensorflow/android:tensorflow_demo
+
+# Install
+adb install -r -f bazel-bin/tensorflow/examples/android/tensorflow_demo.apk
+```
+
+#### Building the CUDA-enabled Android demo with gradle/Android Studio:
+
+Add tensorflow/examples/android as an Android project in Android Studio as normal.
+
+Edit build.gradle and:
+* set nativeBuildSystem = 'makefile'
+* set cpuType = 'arm64-v8a'
+* in "buildNativeMake", replace cpuType with 'tegra' (optional speedups like -T and ccache also work)
+* set the environment "NDK_ROOT" var to $JETPACK/android-ndk-r13b
+
+Click "build apk" to build.
+
+Install:
+```bash
+adb install -r -f tensorflow/examples/android/gradleBuild/outputs/apk/debug/android-debug.apk
+```
+
## iOS
_Note: To use this library in an iOS application, see related instructions in
diff --git a/tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh b/tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh
index 203ff4f890..421ddd210f 100755
--- a/tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh
+++ b/tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh
@@ -36,7 +36,7 @@ while getopts "bc:Eps" opt_name; do
b) BUILD_ONLY="true";;
c) TEST_COUNT="${OPTARG}";;
E) ENABLE_EXPERIMENTAL_HEXNN_OPS="true";;
- p) USE_PREBUILT_HEXAOGON_BINARIES="true";;
+ p) USE_PREBUILT_HEXAGON_BINARIES="true";;
s) SKIP_DOWNLOAD_IF_EXIST="true";;
*) usage;;
esac
@@ -49,7 +49,7 @@ if [[ -z "${NDK_ROOT}" ]]; then
exit 1
fi
-if [[ "${USE_PREBUILT_HEXAOGON_BINARIES}" != "true" &&
+if [[ "${USE_PREBUILT_HEXAGON_BINARIES}" != "true" &&
-z "${QUALCOMM_SDK}" ]]; then
echo "QUALCOMM_SDK is empty" 1>&2
usage
@@ -84,7 +84,7 @@ rm -rf "${GEN_DIR}"
mkdir -p "${GEN_LIBS_DIR}"
mkdir -p "${GEN_DOWNLOAD_DIR}"
-if [[ "${USE_PREBUILT_HEXAOGON_BINARIES}" == "true" ]]; then
+if [[ "${USE_PREBUILT_HEXAGON_BINARIES}" == "true" ]]; then
echo "Download prebuilt hexagon binaries"
if [[ "${BUILD_ONLY}" != "true" ]]; then
CONTROLLER_PUSH_DEST="/data/local/tmp"
diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
index f700717394..4eb4fbcd92 100644
--- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py
+++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
@@ -572,9 +572,8 @@ class LSTMBlockWrapper(base_layer.Layer):
def _gather_states(self, data, indices, batch_size):
"""Produce `out`, s.t. out(i, j) = data(indices(i), i, j)."""
- mod_indices = indices * batch_size + math_ops.range(batch_size)
- return array_ops.gather(
- array_ops.reshape(data, [-1, self.num_units]), mod_indices)
+ return array_ops.gather_nd(
+ data, array_ops.stack([indices, math_ops.range(batch_size)], axis=1))
class LSTMBlockFusedCell(LSTMBlockWrapper):
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index dce71c393a..a6c2d9cdbb 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -424,8 +424,9 @@ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell):
"W_O_diag", shape=[self._num_units], dtype=dtype)
# initialize the first freq state to be zero
- m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]), self._num_units],
- dtype)
+ m_prev_freq = array_ops.zeros(
+ [inputs.shape[0].value or inputs.get_shape()[0], self._num_units],
+ dtype)
for fq in range(len(freq_inputs)):
c_prev = array_ops.slice(state, [0, 2 * fq * self._num_units],
[-1, self._num_units])
diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc
index 64973ccccd..dfa12e873a 100644
--- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc
+++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc
@@ -80,12 +80,12 @@ class GatherTreeOp : public OpKernel {
max_sequence_lengths.shape().DebugString()));
Tensor* beams;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, step_ids_shape, &beams));
- typename TTypes<T, 3>::ConstTensor step_ids_t = step_ids.tensor<T, 3>();
- typename TTypes<T, 3>::ConstTensor parent_ids_t = parent_ids.tensor<T, 3>();
+ typename TTypes<T, 3>::ConstTensor step_ids_t(step_ids.tensor<T, 3>());
+ typename TTypes<T, 3>::ConstTensor parent_ids_t(parent_ids.tensor<T, 3>());
typename TTypes<int32>::ConstVec max_seq_lens_t =
max_sequence_lengths.vec<int32>();
- typename TTypes<T>::ConstScalar end_token_t = end_token.scalar<T>();
- typename TTypes<T, 3>::Tensor beams_t = beams->tensor<T, 3>();
+ typename TTypes<T>::ConstScalar end_token_t(end_token.scalar<T>());
+ typename TTypes<T, 3>::Tensor beams_t(beams->tensor<T, 3>());
const T end_token_value = end_token_t();
functor::GatherTree<Device, T>()(ctx, device, step_ids_t, parent_ids_t,
max_seq_lens_t, end_token_value, beams_t);
diff --git a/tensorflow/contrib/signal/python/ops/spectral_ops.py b/tensorflow/contrib/signal/python/ops/spectral_ops.py
index bca2e01d7b..a8b5deff6c 100644
--- a/tensorflow/contrib/signal/python/ops/spectral_ops.py
+++ b/tensorflow/contrib/signal/python/ops/spectral_ops.py
@@ -144,7 +144,7 @@ def inverse_stft_window_fn(frame_step,
overlaps = -(-frame_length // frame_step) # Ceiling division.
denom = array_ops.pad(denom, [(0, overlaps * frame_step - frame_length)])
denom = array_ops.reshape(denom, [overlaps, frame_step])
- denom = math_ops.reduce_sum(denom, 0, keep_dims=True)
+ denom = math_ops.reduce_sum(denom, 0, keepdims=True)
denom = array_ops.tile(denom, [overlaps, 1])
denom = array_ops.reshape(denom, [overlaps * frame_step])
diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py
index 7ab6805fac..c24bd04851 100644
--- a/tensorflow/contrib/slim/python/slim/evaluation_test.py
+++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py
@@ -29,6 +29,7 @@ from tensorflow.contrib.framework.python.ops import variables as variables_lib
from tensorflow.contrib.metrics.python.ops import metric_ops
from tensorflow.contrib.slim.python.slim import evaluation
from tensorflow.contrib.training.python.training import evaluation as evaluation_lib
+from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.debug.lib import debug_data
from tensorflow.python.debug.wrappers import hooks
from tensorflow.python.framework import constant_op
@@ -235,7 +236,7 @@ class SingleEvaluationTest(test.TestCase):
def _prepareCheckpoint(self, checkpoint_path):
init_op = control_flow_ops.group(variables.global_variables_initializer(),
variables.local_variables_initializer())
- saver = saver_lib.Saver()
+ saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1)
with self.test_session() as sess:
sess.run(init_op)
saver.save(sess, checkpoint_path)
diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD
index 58a7fa095d..1e4cc3f095 100644
--- a/tensorflow/contrib/tensor_forest/BUILD
+++ b/tensorflow/contrib/tensor_forest/BUILD
@@ -497,6 +497,7 @@ py_library(
":tensor_forest_v4_ops_py",
"//tensorflow/contrib/decision_trees/proto:generic_tree_model_py",
"//tensorflow/contrib/framework:framework_py",
+ "//tensorflow/contrib/tensor_forest/proto:fertile_stats_proto_py",
"//tensorflow/contrib/tensor_forest/proto:tensor_forest_params_proto_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index 28f571e1f0..65a0e903a7 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -1,5 +1,6 @@
# Description:
-# Wrap NVIDIA TensorRT (http://developer.nvidia.com/tensorrt) with tensorflow.
+# Wrap NVIDIA TensorRT (http://developer.nvidia.com/tensorrt) with tensorflow
+# and provide TensorRT operators and converter package.
# APIs are meant to change over time.
package(default_visibility = ["//tensorflow:__subpackages__"])
@@ -8,7 +9,19 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+ "tf_copts",
+ "tf_cuda_library",
+ "tf_custom_op_library",
+ "tf_custom_op_library_additional_deps",
+ "tf_gen_op_libs",
+ "tf_gen_op_wrapper_py",
+)
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
+load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
load(
"@local_config_tensorrt//:build_defs.bzl",
"if_tensorrt",
@@ -32,6 +45,195 @@ tf_cuda_cc_test(
]),
)
+tf_custom_op_library(
+ name = "python/ops/_trt_engine_op.so",
+ srcs = ["ops/trt_engine_op.cc"],
+ deps = [
+ ":trt_engine_op_kernel",
+ ":trt_shape_function",
+ "//tensorflow/core:lib_proto_parsing",
+ ] + if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]),
+)
+
+tf_cuda_library(
+ name = "trt_shape_function",
+ srcs = ["shape_fn/trt_shfn.cc"],
+ hdrs = ["shape_fn/trt_shfn.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":trt_logging",
+ ] + if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]) + tf_custom_op_library_additional_deps(),
+)
+
+cc_library(
+ name = "trt_engine_op_kernel",
+ srcs = ["kernels/trt_engine_op.cc"],
+ hdrs = ["kernels/trt_engine_op.h"],
+ copts = tf_copts(),
+ deps = [
+ ":trt_logging",
+ "//tensorflow/core:gpu_headers_lib",
+ "//tensorflow/core:lib_proto_parsing",
+ "//tensorflow/core:stream_executor_headers_lib",
+ ] + if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]) + tf_custom_op_library_additional_deps(),
+ # TODO(laigd)
+ alwayslink = 1, # buildozer: disable=alwayslink-with-hdrs
+)
+
+tf_gen_op_libs(
+ op_lib_names = ["trt_engine_op"],
+ deps = if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]),
+)
+
+tf_cuda_library(
+ name = "trt_logging",
+ srcs = ["log/trt_logger.cc"],
+ hdrs = ["log/trt_logger.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:lib_proto_parsing",
+ ] + if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]),
+)
+
+tf_gen_op_wrapper_py(
+ name = "trt_engine_op",
+ deps = [
+ ":trt_engine_op_op_lib",
+ ":trt_logging",
+ ":trt_shape_function",
+ ],
+)
+
+tf_custom_op_py_library(
+ name = "trt_engine_op_loader",
+ srcs = ["python/ops/trt_engine_op.py"],
+ dso = [
+ ":python/ops/_trt_engine_op.so",
+ ] + if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]),
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:resources",
+ ],
+)
+
+py_library(
+ name = "init_py",
+ srcs = [
+ "__init__.py",
+ "python/__init__.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":trt_convert_py",
+ ":trt_ops_py",
+ ],
+)
+
+py_library(
+ name = "trt_ops_py",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":trt_engine_op",
+ ":trt_engine_op_loader",
+ ],
+)
+
+py_library(
+ name = "trt_convert_py",
+ srcs = ["python/trt_convert.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":wrap_conversion",
+ ],
+)
+
+tf_py_wrap_cc(
+ name = "wrap_conversion",
+ srcs = ["trt_conversion.i"],
+ copts = tf_copts(),
+ deps = [
+ ":trt_conversion",
+ "//tensorflow/core:framework_lite",
+ "//util/python:python_headers",
+ ],
+)
+
+# Library for the node-level conversion portion of TensorRT operation creation
+tf_cuda_library(
+ name = "trt_conversion",
+ srcs = [
+ "convert/convert_graph.cc",
+ "convert/convert_nodes.cc",
+ ],
+ hdrs = [
+ "convert/convert_graph.h",
+ "convert/convert_nodes.h",
+ ],
+ deps = [
+ ":segment",
+ ":trt_logging",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_lite",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:devices",
+ "//tensorflow/core/grappler/clusters:virtual_cluster",
+ "//tensorflow/core/grappler/costs:graph_properties",
+ "//tensorflow/core/grappler/optimizers:constant_folding",
+ "//tensorflow/core/grappler/optimizers:layout_optimizer",
+ ] + if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]) + tf_custom_op_library_additional_deps(),
+)
+
+# Library for the segmenting portion of TensorRT operation creation
+cc_library(
+ name = "segment",
+ srcs = ["segment/segment.cc"],
+ hdrs = [
+ "segment/segment.h",
+ "segment/union_find.h",
+ ],
+ linkstatic = 1,
+ deps = [
+ "//tensorflow/core:graph",
+ "//tensorflow/core:lib_proto_parsing",
+ "//tensorflow/core:protos_all_cc",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+)
+
+tf_cc_test(
+ name = "segment_test",
+ size = "small",
+ srcs = ["segment/segment_test.cc"],
+ deps = [
+ ":segment",
+ "//tensorflow/c:c_api",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/tensorrt/README.md b/tensorflow/contrib/tensorrt/README.md
new file mode 100644
index 0000000000..dfcce0fd00
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/README.md
@@ -0,0 +1,40 @@
+Using TensorRT in TensorFlow
+============================
+
+This module provides necessary bindings and introduces TRT_engine_op
+operator that wraps a subgraph in TensorRT.
+
+Compilation
+-----------
+
+In order to compile the module, you need to have a local TensorRT
+installation (libnvinfer.so and respective include files). During the
+configuration step, TensorRT should be enabled and installation path
+should be set. If installed through package managers (deb,rpm),
+configure script should find the necessary components from the system
+automatically. If installed from tar packages, user has to set path to
+location where the library is installed during configuration.
+
+
+```
+bazel build --config=cuda --config=opt //tensorflow/tools/pip_package:build_pip_package
+bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/
+```
+
+After the installation of tensorflow package, TensorRT transformation
+will be available. An example use is shown below.
+
+```python
+import tensorflow as tf
+import tensorflow.contrib.tensorrt as trt
+#... create and train or load model
+gdef = sess.graph.as_graph_def()
+trt_gdef = trt.create_inference_graph(
+ gdef, #original graph_def
+ ["output"], #name of output node(s)
+ max_batch_size, #maximum batch size to run the inference
+ max_workspace_size_bytes) # max memory for TensorRT to use
+tf.reset_default_graph()
+tf.import_graph_def(graph_def=trt_gdef)
+#...... run inference
+```
diff --git a/tensorflow/contrib/tensorrt/__init__.py b/tensorflow/contrib/tensorrt/__init__.py
new file mode 100644
index 0000000000..fd551d70b4
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/__init__.py
@@ -0,0 +1,23 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Exposes the python wrapper for TensorRT graph transforms."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,wildcard-import
+from tensorflow.contrib.tensorrt.python import *
+# pylint: enable=unused-import,wildcard-import
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
new file mode 100644
index 0000000000..970f810473
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -0,0 +1,273 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
+
+#include <map>
+#include <set>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
+#include "tensorflow/contrib/tensorrt/segment/segment.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/devices.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/constant_folding.h"
+#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/device_properties.pb.h" // NOLINT
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "tensorrt/include/NvInfer.h"
+
+namespace tensorflow {
+namespace tensorrt {
+namespace convert {
+namespace {
+
+static bool IsTensorRTCandidate(const tensorflow::NodeDef& node_def) {
+ // LINT.IfChange
+ // TODO(jie): Segmentation shouldn't associated with op name.
+ // Split it into a registration for each kernel.
+ static const std::set<string> candidate_ops = {
+ "Identity", "Const", "Conv2D", "MaxPool", "BiasAdd", "Relu",
+ "Add", "Mul", "Sub", "Rsqrt", "Pad" // "Placeholder" ,"Mean"
+ };
+ // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h)
+ return candidate_ops.count(node_def.op());
+}
+
+void GetSubGraphIncomingEdges(const tensorflow::Graph& graph,
+ const std::set<int>& subgraph_node_ids,
+ tensorflow::EdgeSet* incoming_edges) {
+ for (int node_id : subgraph_node_ids) {
+ const tensorflow::Node* node = graph.FindNodeId(node_id);
+ for (const tensorflow::Edge* edge : node->in_edges()) {
+ if (!subgraph_node_ids.count(edge->src()->id()) &&
+ !edge->src()->IsSource()) {
+ incoming_edges->insert(edge);
+ }
+ }
+ }
+}
+
+void GetSubGraphOutgoingEdges(const tensorflow::Graph& graph,
+ const std::set<int>& subgraph_node_ids,
+ tensorflow::EdgeSet* outgoing_edges) {
+ for (int node_id : subgraph_node_ids) {
+ const tensorflow::Node* node = graph.FindNodeId(node_id);
+ for (const tensorflow::Edge* edge : node->out_edges()) {
+ if (!subgraph_node_ids.count(edge->dst()->id()) &&
+ !edge->dst()->IsSink()) {
+ outgoing_edges->insert(edge);
+ }
+ }
+ }
+}
+
+std::pair<string, int> ParseTensorName(string name, int default_idx = 0) {
+ int idx = default_idx;
+ size_t sep = name.find_last_of(':');
+ if (sep != string::npos) {
+ name = name.substr(0, sep);
+ idx = std::stoi(name.substr(sep + 1));
+ }
+ return std::make_pair(name, idx);
+}
+
+std::unordered_map<string, std::vector<int>> BuildTensorNameMap(
+ const std::vector<string>& tensor_names) {
+ std::unordered_map<string, std::vector<int>> result;
+ for (string const& tensor_name : tensor_names) {
+ string node_name;
+ int index;
+ std::tie(node_name, index) = ParseTensorName(tensor_name);
+ result[node_name].push_back(index);
+ }
+ return result;
+}
+
+tensorflow::Status ConvertSubGraphToTensorRT(
+ const std::vector<string>& output_names,
+ const std::set<int>& subgraph_node_ids,
+ size_t max_batch_size, // Max batch size that engine will be created for
+ // Max amount of memory that engine will be allowed to consume, in bytes
+ size_t max_workspace_size_bytes,
+ const tensorflow::grappler::GraphProperties& graph_properties,
+ tensorflow::Graph* graph) {
+ tensorflow::EdgeSet subgraph_incoming_edges;
+ GetSubGraphIncomingEdges(*graph, subgraph_node_ids, &subgraph_incoming_edges);
+
+ std::vector<std::pair<int, int>> subgraph_inputs;
+
+ // Collect inputs by looking for incoming edges
+ for (const tensorflow::Edge* edge : subgraph_incoming_edges) {
+ subgraph_inputs.push_back({edge->src()->id(), edge->src_output()});
+ }
+ std::set<std::pair<int, int>> subgraph_outputs_set;
+ // Collect outputs referenced from output_names
+ auto output_name_to_index_map = BuildTensorNameMap(output_names);
+ for (int node_id : subgraph_node_ids) {
+ tensorflow::Node* node = graph->FindNodeId(node_id);
+ if (output_name_to_index_map.count(node->name())) {
+ for (int index : output_name_to_index_map.at(node->name())) {
+ subgraph_outputs_set.insert({node_id, index});
+ }
+ }
+ }
+ // Collect outputs referenced from outgoing edges
+ tensorflow::EdgeSet subgraph_outgoing_edges;
+ GetSubGraphOutgoingEdges(*graph, subgraph_node_ids, &subgraph_outgoing_edges);
+ for (const tensorflow::Edge* edge : subgraph_outgoing_edges) {
+ subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()});
+ }
+ // Impose an ordering on the outputs
+ std::vector<std::pair<int, int>> subgraph_outputs(
+ subgraph_outputs_set.begin(), subgraph_outputs_set.end());
+ // Build TensorRT node and add it to the graph
+ tensorflow::NodeDef trt_node_def;
+ TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef(
+ *graph, subgraph_node_ids, subgraph_inputs, subgraph_outputs,
+ max_batch_size, max_workspace_size_bytes, graph_properties,
+ &trt_node_def));
+ tensorflow::Status status;
+ tensorflow::Node* trt_node = graph->AddNode(trt_node_def, &status);
+ TF_RETURN_IF_ERROR(status);
+
+ // Re-map outgoing edges to use the new TRT node instead of the orig subgraph
+ std::map<std::pair<int, int>, int> subgraph_edge_to_output_map;
+ for (size_t i = 0; i < subgraph_outputs.size(); ++i) {
+ subgraph_edge_to_output_map.insert({subgraph_outputs.at(i), i});
+ }
+ TF_RETURN_IF_ERROR(status);
+ for (const tensorflow::Edge* edge : subgraph_outgoing_edges) {
+ std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()};
+ int new_src_output = subgraph_edge_to_output_map.at(old_src);
+ TF_RETURN_IF_ERROR(graph->UpdateEdge(trt_node, new_src_output, edge->dst(),
+ edge->dst_input()));
+ }
+ // Remove the original subgraph
+ for (int node_id : subgraph_node_ids) {
+ tensorflow::Node* node = graph->FindNodeId(node_id);
+ // Don't remove the input placeholders
+ if (node->type_string() == "Placeholder") {
+ continue;
+ }
+ graph->RemoveNode(node);
+ }
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status BuildNodeMap(
+ const tensorflow::Graph& graph,
+ std::unordered_map<string, tensorflow::Node*>* node_map) {
+ for (auto* node : graph.op_nodes()) {
+ if (!node_map->insert({node->name(), node}).second) {
+ return tensorflow::errors::AlreadyExists(
+ "Node name is not unique in graph: " + node->name());
+ }
+ }
+ return tensorflow::Status::OK();
+}
+
+} // namespace
+
+tensorflow::Status ConvertGraphDefToTensorRT(
+ const tensorflow::GraphDef& graph_def,
+ const std::vector<string>& output_names, size_t max_batch_size,
+ size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def) {
+ // Optimization pass
+ tensorflow::grappler::GrapplerItem item;
+ item.fetch = output_names;
+ tensorflow::GraphDef gdef;
+
+ // Layout optimization
+ item.graph = graph_def;
+ tensorflow::grappler::LayoutOptimizer optimizer;
+ tensorflow::grappler::Cluster* cluster;
+
+ // Virtual cluster
+ tensorflow::DeviceProperties device_properties;
+ device_properties.set_type("GPU");
+ device_properties.mutable_environment()->insert({"architecture", "6"});
+ cluster =
+ new tensorflow::grappler::VirtualCluster({{"/GPU:0", device_properties}});
+
+ TF_RETURN_IF_ERROR(optimizer.Optimize(cluster, item, &gdef));
+
+ // Constant folding
+ item.graph = gdef;
+ tensorflow::grappler::ConstantFolding fold(nullptr);
+ TF_RETURN_IF_ERROR(fold.Optimize(nullptr, item, &gdef));
+
+ // AJ refactoring shape inference through grappler/GraphProperties.
+ tensorflow::grappler::GraphProperties static_graph_properties(item);
+ TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(false));
+
+ // Build full graph
+ tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
+ gdef.library());
+ tensorflow::Graph graph(flib);
+ TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
+ tensorflow::GraphConstructorOptions(), gdef, &graph));
+
+ // Segment the graph into subgraphs that can be converted to TensorRT
+ tensorflow::tensorrt::segment::SegmentOptions segment_options;
+
+ // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT)
+ for (auto node : output_names) {
+ segment_options.exclude_node_list.insert(node);
+ }
+
+ // TODO(sami): this should be passed as a knob!!!!
+ segment_options.minimum_segment_size = 2;
+ tensorflow::tensorrt::segment::SegmentNodesVector segments;
+ TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph(
+ gdef, IsTensorRTCandidate, segment_options, &segments));
+ if (segments.size() > 1) {
+ VLOG(0) << "MULTIPLE tensorrt candidate conversion: " << segments.size();
+ }
+ std::unordered_map<string, tensorflow::Node*> node_map;
+ TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
+ for (const std::set<string>& subgraph_node_names : segments) {
+ std::set<int> subgraph_node_ids;
+ for (const string& node_name : subgraph_node_names) {
+ subgraph_node_ids.insert(node_map.at(node_name)->id());
+ }
+ TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRT(
+ output_names, subgraph_node_ids, max_batch_size,
+ max_workspace_size_bytes, static_graph_properties, &graph));
+ }
+ graph.ToGraphDef(new_graph_def);
+ return tensorflow::Status::OK();
+}
+
+} // namespace convert
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h
new file mode 100644
index 0000000000..154ad3f2e8
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_
+
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+
+namespace tensorflow {
+namespace tensorrt {
+namespace convert {
+
+// max_batch_size: maximum batch size which can be used for inference for
+// optimization targets inference run with max batch size.
+// max_workspace_size_bytes: The upper bound of memory allowence for
+// engine building.
+tensorflow::Status ConvertGraphDefToTensorRT(
+ const tensorflow::GraphDef& graph_def,
+ const std::vector<string>& output_names, size_t max_batch_size,
+ size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def);
+
+} // namespace convert
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
new file mode 100644
index 0000000000..4003ba056d
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -0,0 +1,1601 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
+
+#include <algorithm>
+#include <list>
+#include <map>
+#include <memory>
+#include <set>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h" // NOLINT
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/tensor_coding.h"
+#include "tensorflow/core/platform/types.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+#include "tensorrt/include/NvInfer.h"
+
+// Check if the types are equal. Cast to int first so that failure log message
+// would work!
+#define CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2)
+
+namespace tensorflow {
+namespace tensorrt {
+namespace convert {
+
+namespace {
+
+inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype,
+ nvinfer1::DataType* trt_dtype) {
+ switch (tf_dtype) {
+ case tensorflow::DataType::DT_FLOAT:
+ *trt_dtype = nvinfer1::DataType::kFLOAT;
+ break;
+ case tensorflow::DataType::DT_INT8:
+ *trt_dtype = nvinfer1::DataType::kINT8;
+ break;
+ case tensorflow::DataType::DT_HALF:
+ *trt_dtype = nvinfer1::DataType::kHALF;
+ break;
+ default:
+ return tensorflow::errors::InvalidArgument("Unsupported data type");
+ }
+ return tensorflow::Status::OK();
+}
+
+inline nvinfer1::Dims GetTensorShape(const tensorflow::Tensor& tensor) {
+ nvinfer1::Dims dims;
+ dims.nbDims = tensor.dims();
+ for (int i = 0; i < dims.nbDims; i++) {
+ dims.d[i] = tensor.dim_size(i);
+ }
+ return dims;
+}
+
+inline int64_t GetShapeSize(nvinfer1::Dims shape) {
+ // Returns total number of elements in shape
+ int64_t count = 1;
+ for (int d = 0; d < shape.nbDims; ++d) {
+ count *= shape.d[d];
+ }
+ return count;
+}
+
+static std::vector<std::pair<int, int>> CreateSamePadding(
+ const nvinfer1::DimsHW& stride, const nvinfer1::DimsHW& kernel,
+ const std::vector<int64_t>& input_dims) {
+ std::vector<std::pair<int, int>> padding(input_dims.size());
+ CHECK_EQ((size_t)stride.nbDims, input_dims.size()); // TODO(jie): N+C? NC+?
+
+ for (size_t i = 0; i < input_dims.size(); ++i) {
+ // Formula to calculate the padding
+ int p = ((input_dims[i] - 1) / stride.d[i]) * stride.d[i] + kernel.d[i] -
+ input_dims[i];
+ p = (p > 0) ? p : 0;
+
+ // Right precedence padding, like in TensorFlow
+ int left = p / 2;
+ int right = p - left;
+
+ VLOG(2) << "PADDING_" << i << " pre: " << left << ", post: " << right
+ << "paras: " << input_dims[i] << ", " << stride.d[i] << ", "
+ << "kernel: " << kernel.d[i];
+ padding[i] = {left, right};
+ }
+ return padding;
+}
+
+class TRT_ShapedWeights {
+ public:
+ TRT_ShapedWeights(tensorflow::DataType type, const void* values,
+ nvinfer1::Dims shape)
+ : shape_(shape), type_(type), values_(values), empty_weight_flag_(false) {
+ // Note: this->shape.type[] is not used
+ }
+
+ explicit TRT_ShapedWeights(tensorflow::DataType type)
+ : shape_(), type_(type), values_(nullptr), empty_weight_flag_(true) {}
+
+ TRT_ShapedWeights(const TRT_ShapedWeights& rhs)
+ : shape_(rhs.shape_),
+ type_(rhs.type_),
+ values_(rhs.values_),
+ empty_weight_flag_(rhs.empty_weight_flag_) {}
+
+ int64_t count() const {
+ int64_t c = 1;
+ for (int i = 0; i < shape_.nbDims; i++) c *= shape_.d[i];
+ return c;
+ }
+
+ nvinfer1::Weights GetWeightsForTRT() const {
+ nvinfer1::DataType trt_type(nvinfer1::DataType::kFLOAT);
+ TF_CHECK_OK(ConvertDType(type_, &trt_type));
+ if (empty_weight_flag_) return nvinfer1::Weights{trt_type, nullptr, 0};
+
+ // Note: this->shape.type[] is not used
+ return nvinfer1::Weights{trt_type, GetValues(), GetShapeSize(shape_)};
+ }
+
+ const void* GetValues() const { return values_; }
+
+ void SetValues(const void* values) { values_ = values; }
+
+ size_t size_bytes() const {
+ int type_size = tensorflow::DataTypeSize(this->type_);
+ return this->count() * type_size;
+ }
+
+ // Default converter
+ operator nvinfer1::Weights() const { return GetWeightsForTRT(); }
+
+ nvinfer1::Dims shape_;
+ tensorflow::DataType type_;
+
+ private:
+ const void* values_;
+ bool empty_weight_flag_;
+};
+
+class TRT_TensorOrWeights {
+ public:
+ explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor)
+ : tensor_(tensor), weights_(DT_FLOAT), variant_(TRT_NODE_TENSOR) {}
+ explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights)
+ : tensor_(nullptr), weights_(weights), variant_(TRT_NODE_WEIGHTS) {}
+ TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs)
+ : tensor_(rhs.tensor_), weights_(rhs.weights_), variant_(rhs.variant_) {}
+ ~TRT_TensorOrWeights() {}
+
+ bool is_tensor() const { return variant_ == TRT_NODE_TENSOR; }
+ bool is_weights() const { return variant_ == TRT_NODE_WEIGHTS; }
+
+ nvinfer1::ITensor* tensor() {
+ CHECK_EQ(is_tensor(), true);
+ return tensor_;
+ }
+ const nvinfer1::ITensor* tensor() const {
+ CHECK_EQ(is_tensor(), true);
+ return tensor_;
+ }
+ TRT_ShapedWeights& weights() {
+ CHECK_EQ(is_weights(), true);
+ return weights_;
+ }
+ const TRT_ShapedWeights& weights() const {
+ CHECK_EQ(is_weights(), true);
+ return weights_;
+ }
+ nvinfer1::Dims shape() const {
+ if (is_tensor()) {
+ return tensor()->getDimensions();
+ } else {
+ return weights().shape_;
+ }
+ }
+
+ private:
+ nvinfer1::ITensor* tensor_;
+ TRT_ShapedWeights weights_;
+ enum { TRT_NODE_TENSOR, TRT_NODE_WEIGHTS } variant_;
+};
+
+class TFAttrs {
+ public:
+ explicit TFAttrs(const tensorflow::NodeDef& tf_node) {
+ for (const auto& attr : tf_node.attr()) {
+ attrs_.insert({attr.first, &attr.second});
+ }
+ }
+ bool count(string key) const { return attrs_.count(key); }
+ tensorflow::AttrValue const* at(string key) const {
+ if (!attrs_.count(key)) {
+ LOG(FATAL) << "Attribute not found: " << key;
+ }
+ return attrs_.at(key);
+ }
+ template <typename T>
+ T get(string key) const;
+ template <typename T>
+ T get(string key, const T& default_value) const {
+ return attrs_.count(key) ? this->get<T>(key) : default_value;
+ }
+
+ private:
+ typedef std::map<string, tensorflow::AttrValue const*> AttrMap;
+ AttrMap attrs_;
+};
+
+template <>
+string TFAttrs::get<string>(string key) const {
+ return this->at(key)->s();
+}
+
+template <>
+std::vector<int> TFAttrs::get<std::vector<int>>(string key) const {
+ auto attr = this->at(key)->list().i();
+ return std::vector<int>(attr.begin(), attr.end());
+}
+
+template <>
+nvinfer1::Dims TFAttrs::get<nvinfer1::Dims>(string key) const {
+ auto values = this->get<std::vector<int>>(key);
+ nvinfer1::Dims dims;
+ dims.nbDims = values.size();
+ std::copy(values.begin(), values.end(), dims.d);
+ // Note: No dimension type information is included
+ return dims;
+}
+
+template <>
+nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(string key) const {
+ nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT);
+ TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype));
+ return trt_dtype;
+}
+
+template <>
+tensorflow::DataType TFAttrs::get<tensorflow::DataType>(string key) const {
+ return this->at(key)->type();
+}
+
+template <typename T>
+void Reorder4(nvinfer1::DimsNCHW shape, const T* idata,
+ nvinfer1::DimsNCHW istrides, T* odata,
+ nvinfer1::DimsNCHW ostrides) {
+ for (int n = 0; n < shape.n(); ++n) {
+ for (int c = 0; c < shape.c(); ++c) {
+ for (int h = 0; h < shape.h(); ++h) {
+ for (int w = 0; w < shape.w(); ++w) {
+ odata[n * ostrides.n() + c * ostrides.c() + h * ostrides.h() +
+ w * ostrides.w()] = idata[n * istrides.n() + c * istrides.c() +
+ h * istrides.h() + w * istrides.w()];
+ }
+ }
+ }
+ }
+}
+
+void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
+ TRT_ShapedWeights* oweights) {
+ CHECK_EQ(iweights.type_, oweights->type_);
+ CHECK_EQ(iweights.size_bytes(), oweights->size_bytes());
+ int r = iweights.shape_.d[0];
+ int s = iweights.shape_.d[1];
+ int c = iweights.shape_.d[2];
+ int k = iweights.shape_.d[3];
+ oweights->shape_.d[0] = k;
+ oweights->shape_.d[1] = c;
+ oweights->shape_.d[2] = r;
+ oweights->shape_.d[3] = s;
+ nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k};
+ nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1};
+ switch (iweights.type_) {
+ case tensorflow::DataType::DT_FLOAT:
+ Reorder4({k, c, r, s}, static_cast<float const*>(iweights.GetValues()),
+ istrides,
+ static_cast<float*>(const_cast<void*>(oweights->GetValues())),
+ ostrides);
+ break;
+ default:
+ LOG(FATAL) << "!!!!!!!!!!!!!!!!!!!!!!!!broke!!!!!!!!!!!!";
+ }
+}
+
+struct InferDeleter {
+ template <typename T>
+ void operator()(T* obj) const {
+ if (obj) {
+ obj->destroy();
+ }
+ }
+};
+
+template <typename T>
+inline std::shared_ptr<T> infer_object(T* obj) {
+ return std::shared_ptr<T>(obj, InferDeleter());
+}
+
+// Logger for GIE info/warning/errors
+class Converter;
+
+using OpConverter =
+ std::function<tensorflow::Status(Converter&, const tensorflow::NodeDef&,
+ std::vector<TRT_TensorOrWeights> const&,
+ std::vector<TRT_TensorOrWeights>*)>;
+
+class Converter {
+ std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_;
+ std::unordered_map<string, OpConverter> op_registry_;
+ nvinfer1::INetworkDefinition* trt_network_;
+ std::list<std::vector<uint8_t>> temp_bufs_;
+
+ void register_op_converters();
+
+ std::vector<TRT_TensorOrWeights> get_inputs(
+ const tensorflow::NodeDef& node_def) {
+ std::vector<TRT_TensorOrWeights> inputs;
+ for (const auto& input_name : node_def.input()) {
+ VLOG(2) << "Retrieve input: " << input_name;
+ inputs.push_back(trt_tensors_.at(input_name));
+ }
+ return inputs;
+ }
+
+ public:
+ explicit Converter(nvinfer1::INetworkDefinition* trt_network)
+ : trt_network_(trt_network) {
+ this->register_op_converters();
+ }
+
+ TRT_ShapedWeights get_temp_weights(tensorflow::DataType type,
+ nvinfer1::Dims shape) {
+ TRT_ShapedWeights weights(type, nullptr, shape);
+ // TODO(jie): check weights size_bytes. 0 means type error
+ temp_bufs_.push_back(std::vector<uint8_t>(weights.size_bytes()));
+ weights.SetValues(temp_bufs_.back().data());
+ return weights;
+ }
+
+ TRT_ShapedWeights get_temp_weights_like(const TRT_ShapedWeights& weights) {
+ return this->get_temp_weights(weights.type_, weights.shape_);
+ }
+
+ tensorflow::Status convert_node(const tensorflow::NodeDef& node_def) {
+ std::vector<TRT_TensorOrWeights> inputs = this->get_inputs(node_def);
+ string op = node_def.op();
+ if (!op_registry_.count(op)) {
+ return tensorflow::errors::Unimplemented(
+ "No converter registered for op: " + op);
+ }
+ OpConverter op_converter = op_registry_.at(op);
+ std::vector<TRT_TensorOrWeights> outputs;
+ TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs));
+ for (size_t i = 0; i < outputs.size(); ++i) {
+ TRT_TensorOrWeights output = outputs.at(i);
+ // TODO(jie): tf protobuf seems to be omitting the :0 suffix
+ string output_name = node_def.name();
+ if (i != 0) output_name = output_name + ":" + std::to_string(i);
+ if (output.is_tensor()) {
+ output.tensor()->setName(output_name.c_str());
+ }
+ VLOG(2) << "Write out tensor: " << output_name;
+ if (!trt_tensors_.insert({output_name, output}).second) {
+ return tensorflow::errors::AlreadyExists(
+ "Output tensor already exists for op: " + op);
+ }
+ }
+ return tensorflow::Status::OK();
+ }
+
+ nvinfer1::INetworkDefinition* network() { return trt_network_; }
+
+ TRT_TensorOrWeights get_tensor(string name) {
+ if (!trt_tensors_.count(name)) {
+ return TRT_TensorOrWeights(nullptr);
+ }
+ return trt_tensors_.at(name);
+ }
+
+ bool insert_input_tensor(string name, nvinfer1::ITensor* tensor) {
+ return trt_tensors_.insert({name, TRT_TensorOrWeights(tensor)}).second;
+ }
+
+ nvinfer1::ITensor* TransposeTensor(nvinfer1::ITensor* input_tensor,
+ std::vector<int> order) {
+ auto dims = input_tensor->getDimensions();
+
+ // TODO(jie): change the return to status and properly exit
+ if (order.size() - 1 != size_t(dims.nbDims))
+ LOG(ERROR) << "Dimension does not match, fail gracefully";
+
+ nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input_tensor);
+ nvinfer1::Permutation permutation;
+ for (int32_t i = 0; i < dims.nbDims; ++i) {
+ permutation.order[i] = order[i + 1] - 1;
+ }
+ layer->setFirstTranspose(permutation);
+
+ nvinfer1::Dims reshape_dims;
+ reshape_dims.nbDims = dims.nbDims;
+ for (int32_t i = 0; i < reshape_dims.nbDims; ++i) {
+ reshape_dims.d[i] = 0;
+ reshape_dims.type[i] = dims.type[i];
+ }
+ layer->setReshapeDimensions(reshape_dims);
+ return layer->getOutput(0);
+ }
+};
+
+// ****************************************************************************
+// Constant folding functions
+// TODO(jie): once optimizer kicks in, we should have done constant folding
+// there.
+//*****************************************************************************/
+struct LambdaFactory {
+ enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB };
+ OP_CATEGORY op;
+
+ template <typename T>
+ std::function<T(T)> unary() {
+ switch (op) {
+ case OP_CATEGORY::RSQRT: {
+ VLOG(2) << "RSQRT GETS DONE";
+ return [](T t) -> T { return 1.0 / std::sqrt(t); };
+ }
+ case OP_CATEGORY::NEG:
+ return [](T t) -> T { return -t; };
+ default:
+ VLOG(2) << "Not supported op for unary: " << static_cast<int>(op);
+ return nullptr;
+ }
+ }
+
+ template <typename T>
+ std::function<T(T, T)> binary() {
+ switch (op) {
+ case OP_CATEGORY::ADD:
+ return [](T l, T r) -> T { return l + r; };
+ case OP_CATEGORY::SUB:
+ return [](T l, T r) -> T { return l - r; };
+ case OP_CATEGORY::MUL:
+ return [](T l, T r) -> T { return l * r; };
+ default:
+ LOG(WARNING) << "Not supported op for binary: " << static_cast<int>(op);
+ }
+ return [](T l, T r) -> T {
+ LOG(FATAL) << "Unsupported op type ";
+ return l;
+ };
+ }
+
+ template <typename T>
+ std::function<T(T)> broadcast_r(T val) {
+ VLOG(2) << "LAMBDA VAL : " << val;
+ switch (op) {
+ case OP_CATEGORY::ADD:
+ return [val](T l) -> T {
+ VLOG(2) << "LAMBDA VAL : " << val;
+ return l + val;
+ };
+ // Return [val](T l)-> T {return l+val;};
+ case OP_CATEGORY::SUB:
+ return [val](T l) -> T {
+ VLOG(2) << "LAMBDA VAL : " << val;
+ return l - val;
+ };
+ case OP_CATEGORY::MUL:
+ return [val](T l) -> T {
+ VLOG(2) << "LAMBDA VAL : " << val;
+ return l * val;
+ };
+ default:
+ LOG(WARNING) << "Not supported op for binary: " << static_cast<int>(op);
+ }
+ return [val](T l) -> T {
+ LOG(FATAL) << "Unsupported op type ";
+ return l;
+ };
+ }
+
+ template <typename T>
+ std::function<T(T)> broadcast_l(T val) {
+ VLOG(2) << "LAMBDA VAL : " << val;
+ switch (op) {
+ case OP_CATEGORY::ADD:
+ return [val](T l) -> T {
+ VLOG(2) << "LAMBDA VAL : " << val;
+ return val + l;
+ };
+ case OP_CATEGORY::SUB:
+ return [val](T l) -> T {
+ VLOG(2) << "LAMBDA VAL : " << val;
+ return val - l;
+ };
+ case OP_CATEGORY::MUL:
+ return [val](T l) -> T {
+ VLOG(2) << "LAMBDA VAL : " << val;
+ return val * l;
+ };
+ default:
+ LOG(ERROR) << "Not supported op for binary: " << static_cast<int>(op);
+ }
+ return [val](T l) -> T {
+ LOG(FATAL) << "Unsupported op type ";
+ return l;
+ };
+ }
+};
+
+tensorflow::Status UnaryCompute(const TRT_ShapedWeights& iweights,
+ TRT_ShapedWeights* oweights,
+ LambdaFactory unary_op) {
+ CHECK_EQ(iweights.type_, oweights->type_);
+ switch (iweights.type_) {
+ case tensorflow::DataType::DT_FLOAT: {
+ auto inp = static_cast<float const*>(iweights.GetValues());
+ auto oup = static_cast<float*>(const_cast<void*>(oweights->GetValues()));
+ std::transform(inp, inp + iweights.count(), oup, unary_op.unary<float>());
+ break;
+ }
+ default:
+ return tensorflow::errors::Unimplemented(
+ "Data type not supported: " +
+ tensorflow::DataTypeString(iweights.type_));
+ }
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status BinaryCompute(const TRT_ShapedWeights& iweights_l,
+ const TRT_ShapedWeights& iweights_r,
+ TRT_ShapedWeights* oweights,
+ LambdaFactory binary_op) {
+ // Assume iweights_l.type == iweight_r.type
+ CHECK_EQ(iweights_l.type_, oweights->type_);
+ CHECK_EQ(iweights_r.type_, oweights->type_);
+ VLOG(2) << "SANITY CHECK!";
+
+ switch (iweights_l.type_) {
+ case tensorflow::DataType::DT_FLOAT: {
+ auto inp_l = static_cast<const float*>(iweights_l.GetValues());
+ auto inp_r = static_cast<const float*>(iweights_r.GetValues());
+ auto oup = static_cast<float*>(const_cast<void*>(oweights->GetValues()));
+
+ if (iweights_l.count() != iweights_r.count()) {
+ // We only supports broadcast of RankZero
+ if (iweights_l.count() == 1) {
+ VLOG(2) << "I bet it is not working!" << (*inp_l);
+ std::transform(inp_r, inp_r + iweights_r.count(), oup,
+ binary_op.broadcast_l<float>(*inp_l));
+ } else if (iweights_r.count() == 1) {
+ VLOG(2) << "I bet it is not working!" << (*inp_r);
+ std::transform(inp_l, inp_l + iweights_l.count(), oup,
+ binary_op.broadcast_r<float>(*inp_r));
+ } else {
+ return tensorflow::errors::Unimplemented(
+ "Binary op with non-rankZero broadcast not supported");
+ }
+ } else {
+ std::transform(inp_l, inp_l + iweights_l.count(), inp_r, oup,
+ binary_op.binary<float>());
+ }
+ break;
+ }
+ default:
+ return tensorflow::errors::Unimplemented(
+ "Data type not supported: " +
+ tensorflow::DataTypeString(iweights_l.type_));
+ }
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConstantFoldUnary(
+ Converter& ctx, const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ TRT_ShapedWeights weights_input = inputs.at(0).weights();
+
+ // Allocate output weights
+ TRT_ShapedWeights weights_output = ctx.get_temp_weights_like(weights_input);
+
+ // FIXME assume type matches input weights
+ // Get trt type & shape
+ // Maybe this part has to be moved into the block of rsqrt later
+ // Check type consistency
+ CHECK_EQ(weights_input.type_,
+ TFAttrs(node_def).get<tensorflow::DataType>("T"));
+
+ // Maybe I should do a switch
+ LambdaFactory unary_op;
+ if (node_def.op() == "Rsqrt") {
+ // Compute rsqrt
+ unary_op.op = LambdaFactory::OP_CATEGORY::RSQRT;
+ auto ret = UnaryCompute(weights_input, &weights_output, unary_op);
+ // PAss the output
+ if (ret == tensorflow::Status::OK()) {
+ outputs->push_back(TRT_TensorOrWeights(weights_output));
+ }
+ return ret;
+ } else {
+ return tensorflow::errors::Unimplemented("Binary op not supported: " +
+ node_def.op());
+ }
+}
+
+// TODO(jie,ben) broadcast is needed yet not implemented
+// Let's get the simple stuff working first. Maybe we should fall bakc to TF
+// approach for constant folding
+tensorflow::Status ConstantFoldBinary(
+ Converter& ctx, const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ TRT_ShapedWeights weights_input_l = inputs.at(0).weights();
+ TRT_ShapedWeights weights_input_r = inputs.at(1).weights();
+
+ // Check type consistency
+ CHECK_EQ(weights_input_l.type_, weights_input_r.type_);
+
+ if (weights_input_l.shape_.nbDims != weights_input_r.shape_.nbDims)
+ return tensorflow::errors::Unimplemented(
+ "Binary op implicit broadcast not supported: " + node_def.op());
+
+ // TODO(jie): constant fold should really fall back to TF.
+ int nb_dims = weights_input_l.shape_.nbDims;
+ nvinfer1::Dims output_shape;
+ output_shape.nbDims = nb_dims;
+ VLOG(2) << "nb_dims: " << nb_dims
+ << ", the other: " << weights_input_r.shape_.nbDims;
+ for (int i = 0; i < nb_dims; i++) {
+ if (weights_input_l.shape_.d[i] == weights_input_r.shape_.d[i]) {
+ output_shape.d[i] = weights_input_l.shape_.d[i];
+ } else if (weights_input_l.shape_.d[i] == 1 ||
+ weights_input_r.shape_.d[i] == 1) {
+ output_shape.d[i] =
+ std::max(weights_input_l.shape_.d[i], weights_input_r.shape_.d[i]);
+ } else {
+ return tensorflow::errors::Unimplemented(
+ "Binary op with incompatible shape at, " + node_def.op());
+ }
+ VLOG(2) << "left: " << weights_input_l.shape_.d[i]
+ << "right: " << weights_input_r.shape_.d[i]
+ << "output: " << output_shape.d[i];
+ }
+
+ // FIXME assume type matches input weights
+ // Get trt type & shape
+ TFAttrs attrs(node_def);
+ // Maybe this part has to be moved into the block of rsqrt later
+ tensorflow::DataType dtype = attrs.get<tensorflow::DataType>("T");
+
+ // Allocate output weights
+ TRT_ShapedWeights weights_output = ctx.get_temp_weights(dtype, output_shape);
+
+ // Maybe I should do a switch
+ LambdaFactory binary_op;
+ if (node_def.op() == "Sub") {
+ binary_op.op = LambdaFactory::OP_CATEGORY::SUB;
+ } else if (node_def.op() == "Mul") {
+ binary_op.op = LambdaFactory::OP_CATEGORY::MUL;
+ } else if (node_def.op() == "Add") {
+ binary_op.op = LambdaFactory::OP_CATEGORY::ADD;
+ } else {
+ return tensorflow::errors::Unimplemented("Binary op not supported: " +
+ node_def.op());
+ }
+ auto ret = BinaryCompute(weights_input_l, weights_input_r, &weights_output,
+ binary_op);
+
+ // Pass the output
+ if (ret == tensorflow::Status::OK()) {
+ outputs->push_back(TRT_TensorOrWeights(weights_output));
+ }
+
+ return ret;
+}
+
+// TODO(jie): broadcast is needed yet not implemented.
+// Only implemented channel wise for the time being
+tensorflow::Status BinaryTensorOpWeight(
+ Converter& ctx, const tensorflow::NodeDef& node_def,
+ const nvinfer1::ITensor* tensor, TRT_ShapedWeights weights,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ // FIXME assume type matches input weights
+ // Get trt type & shape
+ // Maybe this part has to be moved into the block of rsqrt later
+
+ // Check type consistency
+ auto dtype = TFAttrs(node_def).get<nvinfer1::DataType>("T");
+ CHECK_EQ_TYPE(tensor->getType(), dtype); // Cast to int for error messages
+ nvinfer1::DataType ttype;
+ TF_CHECK_OK(ConvertDType(weights.type_, &ttype));
+ CHECK_EQ_TYPE(ttype, dtype); // Cast to int for error message
+
+ // Check scale mode
+ auto dims_w = weights.shape_;
+ auto dims_t = tensor->getDimensions();
+
+ // Default to channel-wise
+ auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
+
+ if (weights.count() == 1) {
+ VLOG(2) << "UNIFORM";
+ scale_mode = nvinfer1::ScaleMode::kUNIFORM;
+ } else {
+ // No broadcasting on Batch dimension;
+ assert(dims_w.d[0] == 1);
+
+ // Broadcasting on Channel dimension only allowed in kUNIFORM
+ assert(dims_w.d[1] == dims_t.d[0]);
+ assert(dims_w.nbDims == dims_t.nbDims);
+
+ // Default is element;
+ for (int i = 2; i < dims_w.nbDims; i++) {
+ if (dims_w.d[i] != dims_t.d[i - 1]) {
+ scale_mode = nvinfer1::ScaleMode::kCHANNEL;
+ break;
+ }
+ }
+ if (scale_mode == nvinfer1::ScaleMode::kELEMENTWISE) {
+ scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
+ for (int i = 2; i < dims_w.nbDims; i++) {
+ if (dims_w.d[i] != 1)
+ return tensorflow::errors::InvalidArgument(
+ "Weight shape not compatible at, " + node_def.name());
+ }
+ }
+ }
+
+ // Prepare weights
+ TRT_ShapedWeights shift_weights(weights.type_);
+ TRT_ShapedWeights scale_weights(weights.type_);
+ TRT_ShapedWeights power_weights(weights.type_);
+
+ // Maybe I should do a switch
+ if (node_def.op() == "Sub") {
+ TRT_ShapedWeights neg_weights = ctx.get_temp_weights_like(weights);
+ LambdaFactory unary_op;
+ unary_op.op = LambdaFactory::OP_CATEGORY::NEG;
+ TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op));
+ shift_weights = neg_weights;
+ } else if (node_def.op() == "Mul") {
+ scale_weights = weights;
+ } else if (node_def.op() == "Add") {
+ shift_weights = weights;
+ } else {
+ return tensorflow::errors::Unimplemented("Binary op not supported: " +
+ node_def.op());
+ }
+
+ nvinfer1::IScaleLayer* layer = ctx.network()->addScale(
+ *const_cast<nvinfer1::ITensor*>(tensor), scale_mode, shift_weights,
+ scale_weights, power_weights);
+
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+
+ // Pass the output
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status BinaryTensorOpTensor(
+ Converter& ctx, const tensorflow::NodeDef& node_def,
+ const nvinfer1::ITensor* tensor_l, const nvinfer1::ITensor* tensor_r,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ static const std::unordered_map<string, nvinfer1::ElementWiseOperation> ops{
+ {"Add", nvinfer1::ElementWiseOperation::kSUM},
+ {"Mul", nvinfer1::ElementWiseOperation::kPROD},
+ // {"max", nvinfer1::ElementWiseOperation::kMAX},
+ // {"min", nvinfer1::ElementWiseOperation::kMIN},
+ {"Sub", nvinfer1::ElementWiseOperation::kSUB},
+ {"Div", nvinfer1::ElementWiseOperation::kDIV},
+ };
+
+ // FIXME assume type matches input weights
+ // Get trt type & shape
+ TFAttrs attrs(node_def);
+ // Maybe this part has to be moved into the block of rsqrt later
+ nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("T");
+
+ // Check type consistency
+ CHECK_EQ_TYPE(tensor_l->getType(), dtype);
+ CHECK_EQ_TYPE(tensor_r->getType(), dtype);
+ auto op_pair = ops.find(node_def.op());
+ if (op_pair == ops.end())
+ return tensorflow::errors::Unimplemented(
+ "binary op: " + node_def.op() +
+ " not supported at: " + node_def.name());
+
+ nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise(
+ *const_cast<nvinfer1::ITensor*>(tensor_l),
+ *const_cast<nvinfer1::ITensor*>(tensor_r), op_pair->second);
+
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+
+ // Pass the output
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertPlaceholder(
+ Converter& ctx, const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ VLOG(2) << "Placeholder should have been replace already";
+ return tensorflow::errors::Unimplemented(", cannot convert Placeholder op");
+ // OK this make sense since we are supposed to replace it with input
+ TFAttrs attrs(node_def);
+ nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("dtype");
+ nvinfer1::Dims dims = attrs.get<nvinfer1::Dims>("shape");
+
+ dims.nbDims--;
+ for (int i = 0; i < dims.nbDims; i++) dims.d[i] = dims.d[i + 1];
+
+ nvinfer1::ITensor* output =
+ ctx.network()->addInput(node_def.name().c_str(), dtype, dims);
+ if (!output) {
+ return tensorflow::errors::InvalidArgument("Failed to create Input layer");
+ }
+ outputs->push_back(TRT_TensorOrWeights(output));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertConv2D(Converter& ctx,
+ const tensorflow::NodeDef& node_def,
+ const std::vector<TRT_TensorOrWeights>& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+ // TODO(jie): handle NHWC/NCHW transpose;
+ TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
+ TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_rsck);
+ ReorderRSCKToKCRS(weights_rsck, &weights);
+ TRT_ShapedWeights biases(weights.type_);
+ int noutput = weights.shape_.d[0];
+ nvinfer1::DimsHW kernel_size;
+ kernel_size.h() = weights.shape_.d[2];
+ kernel_size.w() = weights.shape_.d[3];
+ TFAttrs attrs(node_def);
+
+ int h_index = 2;
+ int w_index = 3;
+ auto data_format = attrs.get<string>("data_format");
+ if (data_format == "NHWC") {
+ tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
+ {0, 3, 1, 2});
+ h_index = 1;
+ w_index = 2;
+ // TODO(jie): transpose it
+ }
+
+ // TODO(jie): stride. (NHWC/NCHW)
+ auto tf_stride = attrs.get<std::vector<int>>("strides");
+ nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
+
+ auto tensor_dim = tensor->getDimensions();
+ std::vector<std::pair<int, int>> padding;
+ // TODO(jie): padding.
+ if (attrs.get<string>("padding") == "SAME") {
+ // This is NCHW tensor with no batch dimension.
+ // 1 -> h
+ // 2 -> w
+ padding = CreateSamePadding(
+ stride, kernel_size,
+ {static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])});
+ } else {
+ padding = {{0, 0}, {0, 0}};
+ }
+
+ if (padding[0].first != padding[0].second ||
+ padding[1].first != padding[1].second) {
+ // TODO(jie): handle asymmetric padding
+ VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second
+ << padding[1].first << padding[1].second;
+
+ auto dim_before = tensor->getDimensions();
+ VLOG(2) << "TENSOR before: " << dim_before.d[0] << ", " << dim_before.d[1]
+ << dim_before.d[2] << ", " << dim_before.d[3];
+ auto pad_layer = ctx.network()->addPadding(
+ *const_cast<nvinfer1::ITensor*>(tensor),
+ nvinfer1::DimsHW(padding[0].first, padding[1].first),
+ nvinfer1::DimsHW(padding[0].second, padding[1].second));
+ padding = {{0, 0}, {0, 0}};
+ tensor = pad_layer->getOutput(0);
+ auto dim_after = tensor->getDimensions();
+ VLOG(2) << "TENSOR after: " << dim_after.d[0] << ", " << dim_after.d[1]
+ << dim_after.d[2] << ", " << dim_after.d[3];
+ }
+
+ nvinfer1::IConvolutionLayer* layer =
+ ctx.network()->addConvolution(*const_cast<nvinfer1::ITensor*>(tensor),
+ noutput, kernel_size, weights, biases);
+
+ layer->setStride(stride);
+ layer->setPadding({padding[0].first, padding[1].first});
+ layer->setName(node_def.name().c_str());
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+
+ auto dim_after = output_tensor->getDimensions();
+ VLOG(2) << "TENSOR out: " << dim_after.d[0] << ", " << dim_after.d[1]
+ << dim_after.d[2] << ", " << dim_after.d[3];
+
+ if (data_format == "NHWC") {
+ // TODO(jie): transpose it back!
+ output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
+ } else {
+ VLOG(2) << "NCHW !!!!";
+ }
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertPool(Converter& ctx,
+ const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+ TFAttrs attrs(node_def);
+
+ int h_index = 2;
+ int w_index = 3;
+ auto data_format = attrs.get<string>("data_format");
+ if (data_format == "NHWC") {
+ h_index = 1;
+ w_index = 2;
+ tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
+ {0, 3, 1, 2});
+ } else {
+ VLOG(2) << "NCHW !!!!";
+ }
+ nvinfer1::PoolingType type;
+ // TODO(jie): support other pooling type
+ if (node_def.op() == "MaxPool")
+ type = nvinfer1::PoolingType::kMAX;
+ else
+ return tensorflow::errors::Unimplemented("Only supports Max pool");
+
+ // TODO(jie): NCHW
+ auto tf_stride = attrs.get<std::vector<int>>("strides");
+ nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
+
+ auto tf_kernel = attrs.get<std::vector<int>>("ksize");
+ nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]);
+
+ auto tensor_dim = tensor->getDimensions();
+ std::vector<std::pair<int, int>> padding;
+ // TODO(jie): padding.
+ if (attrs.get<string>("padding") == "SAME") {
+ // This is NCHW tensor with no batch dimension.
+ // 1 -> h
+ // 2 -> w
+ padding = CreateSamePadding(
+ stride, ksize,
+ {static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])});
+ } else if (attrs.get<string>("padding") == "VALID") {
+ // No padding for valid padding here
+ VLOG(2) << "No padding added for VALID padding in pool" << node_def.name();
+ padding = {{0, 0}, {0, 0}};
+ } else {
+ return tensorflow::errors::Unimplemented(
+ "Current MaxPool cannot support padding other than SAME");
+ }
+
+ if (padding[0].first != padding[0].second ||
+ padding[1].first != padding[1].second) {
+ // TODO(jie): handle asymmetric padding
+ VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second
+ << padding[1].first << padding[1].second;
+ auto pad_layer = ctx.network()->addPadding(
+ *const_cast<nvinfer1::ITensor*>(tensor),
+ nvinfer1::DimsHW(padding[0].first, padding[1].first),
+ nvinfer1::DimsHW(padding[0].second, padding[1].second));
+ padding = {{0, 0}, {0, 0}};
+ tensor = pad_layer->getOutput(0);
+ }
+
+ nvinfer1::IPoolingLayer* layer = ctx.network()->addPooling(
+ *const_cast<nvinfer1::ITensor*>(tensor), type, ksize);
+
+ layer->setStride(stride);
+ layer->setPadding({padding[0].first, padding[1].first});
+ layer->setName(node_def.name().c_str());
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+
+ if (data_format == "NHWC") {
+ // TODO(jie): transpose it back!
+ output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
+ } else {
+ VLOG(2) << "NCHW !!!!";
+ }
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertActivation(
+ Converter& ctx, const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+ nvinfer1::IActivationLayer* layer = ctx.network()->addActivation(
+ *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ActivationType::kRELU);
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertScale(Converter& ctx,
+ const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
+ !inputs.at(1).is_weights())
+ return tensorflow::errors::Unimplemented(
+ "Only supports tensor op weight for now, at " + node_def.name());
+ // Implement tensor binaryOp weight [channel wise] for now;
+ nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+
+ // TODO(jie): handle NHWC/NCHW transpose;
+ TRT_ShapedWeights weights = inputs.at(1).weights();
+ TRT_ShapedWeights empty_weights(weights.type_);
+
+ TFAttrs attrs(node_def);
+
+ // Transpose NHWC
+ auto data_format = attrs.get<string>("data_format");
+ if (data_format == "NHWC") {
+ tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
+ {0, 3, 1, 2});
+ // TODO(jie): transpose it
+ } else {
+ VLOG(2) << "NCHW !!!!";
+ }
+ nvinfer1::IScaleLayer* layer = ctx.network()->addScale(
+ *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ScaleMode::kCHANNEL,
+ weights, empty_weights, empty_weights);
+
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+ if (data_format == "NHWC") {
+ // TODO(jie): transpose it back!
+ output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
+ } else {
+ VLOG(2) << "NCHW !!!!";
+ }
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertConst(Converter& ctx,
+ const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ const auto& weights_tensor = node_def.attr().at("value").tensor();
+
+ // Get trt type & shape
+ TFAttrs attrs(node_def);
+ const tensorflow::DataType dtype = attrs.get<tensorflow::DataType>("dtype");
+
+ // Create shaped weights as output
+ tensorflow::Tensor tensor;
+ if (!tensor.FromProto(weights_tensor))
+ return tensorflow::errors::Internal("Cannot parse weight tensor proto: " +
+ node_def.name());
+
+ TRT_ShapedWeights weights(dtype);
+ if (!weights_tensor.float_val().empty()) {
+ VLOG(2) << "SCALAR!!!" << node_def.name();
+ nvinfer1::Dims scalar_shape;
+ if (tensor.dims() > 0) {
+ VLOG(2) << "Dimensions: " << tensor.dims();
+ weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(),
+ GetTensorShape(tensor));
+ } else {
+ VLOG(2) << "Dimensions: " << tensor.dims();
+ scalar_shape.nbDims = 1;
+ scalar_shape.d[0] = 1;
+ scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL;
+ for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; i++) {
+ scalar_shape.d[i] = 0;
+ scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL;
+ }
+ weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(),
+ scalar_shape);
+ }
+ } else if (!weights_tensor.tensor_content().empty()) {
+ VLOG(2) << "TENSOR!!!" << node_def.name();
+ const auto& content = weights_tensor.tensor_content();
+
+ weights = ctx.get_temp_weights(dtype, GetTensorShape(tensor));
+ if (content.size() > 0) {
+ const int dtype_size = tensorflow::DataTypeSize(dtype);
+ CHECK_EQ(0, content.size() % dtype_size)
+ << "Tensor content size (" << content.size()
+ << ") is not a multiple of " << dtype_size;
+ port::CopyToArray(
+ content, static_cast<char*>(const_cast<void*>(weights.GetValues())));
+ }
+ } else {
+ return tensorflow::errors::Unimplemented(
+ "Not supported constant type, at " + node_def.name());
+ }
+ // Pass the output
+ outputs->push_back(TRT_TensorOrWeights(weights));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertIdentity(
+ Converter& ctx, const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ outputs->push_back(inputs.at(0));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertBinary(Converter& ctx,
+ const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ if (inputs.size() != 2)
+ return tensorflow::errors::FailedPrecondition(
+ "Binary ops require two tensor input, at " + node_def.name());
+
+ if (inputs.at(0).is_weights() && inputs.at(1).is_weights())
+ return ConstantFoldBinary(ctx, node_def, inputs, outputs);
+
+ if (inputs.at(0).is_tensor() && inputs.at(1).is_weights())
+ return BinaryTensorOpWeight(ctx, node_def, inputs.at(0).tensor(),
+ inputs.at(1).weights(), outputs);
+
+ if (inputs.at(0).is_weights() && inputs.at(1).is_tensor())
+ return BinaryTensorOpWeight(ctx, node_def, inputs.at(1).tensor(),
+ inputs.at(0).weights(), outputs);
+
+ if (inputs.at(0).is_tensor() && inputs.at(1).is_tensor())
+ return BinaryTensorOpTensor(ctx, node_def, inputs.at(0).tensor(),
+ inputs.at(1).tensor(), outputs);
+
+ return tensorflow::errors::Unknown("Binary op input error, at " +
+ node_def.name());
+}
+
+tensorflow::Status ConvertUnary(Converter& ctx,
+ const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ if (inputs.size() != 1)
+ return tensorflow::errors::FailedPrecondition(
+ "Unary ops require single tensor input, at " + node_def.name());
+
+ if (inputs.at(0).is_weights())
+ return ConstantFoldUnary(ctx, node_def, inputs, outputs);
+ else if (inputs.at(0).is_tensor())
+ return tensorflow::errors::Unimplemented(
+ "Unary op for tensor not supported, at " + node_def.name());
+
+ return tensorflow::errors::Unknown("Binary op input error, at " +
+ node_def.name());
+}
+
+tensorflow::Status ConvertReduce(Converter& ctx,
+ const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
+ !inputs.at(1).is_weights())
+ return tensorflow::errors::InvalidArgument(
+ "Input expects tensor and weights, at" + node_def.name());
+
+ // Implement tensor binaryOp weight [channel wise] for now;
+ nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+ auto dims = tensor->getDimensions();
+ // Restore implicit batch dimension
+ int nb_dims = dims.nbDims + 1;
+
+ TRT_ShapedWeights index_list = inputs.at(1).weights();
+
+ TFAttrs attrs(node_def);
+ // TODO(jie): handle data type.
+ // Index type here is done through TF type, so I can leverage their
+ // EnumToDataType for my cast
+ auto index_type = attrs.get<tensorflow::DataType>("Tidx");
+
+ // Only expect to handle INT32 as attributes for now
+ if (index_type != tensorflow::DataType::DT_INT32)
+ return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32");
+ auto index_list_data =
+ static_cast<int*>(const_cast<void*>(index_list.GetValues()));
+
+ // Hack warning: have to fall back to pool layer since reduce is not in public
+ // TRT yet.
+ if (nb_dims != 4)
+ return tensorflow::errors::InvalidArgument(
+ "TRT only support reduce on 4 dimensional tensors, at" +
+ node_def.name());
+ if (index_list.count() > 2)
+ return tensorflow::errors::InvalidArgument(
+ "TRT cannot support reduce on more than 2 dimensions, at" +
+ node_def.name());
+
+ std::set<int> idx_set;
+ // We cannot operate on Channel. permutation flag used to transpose tensor
+ int permuted_index = -1;
+ for (int i = 0; i < index_list.count(); i++) {
+ if (index_list_data[i] == 0)
+ return tensorflow::errors::InvalidArgument("TRT cannot reduce at 0, at" +
+ node_def.name());
+ if (index_list_data[i] == 1) permuted_index = 1;
+ idx_set.emplace(index_list_data[i]);
+ }
+
+ std::vector<int> permutation_order(nb_dims);
+ nvinfer1::DimsHW pool_kernel;
+ if (permuted_index == 1) {
+ for (int i = 2; i < nb_dims; i++) {
+ if (idx_set.count(i)) {
+ permuted_index = i;
+ break;
+ }
+ }
+ for (int i = 0; i < nb_dims; i++) permutation_order[i] = i;
+
+ permutation_order[permuted_index] = 1;
+ permutation_order[1] = permuted_index;
+
+ // Apply permutation before extracting dimension for pool_kernel
+ tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
+ permutation_order);
+ }
+
+ // Apply permutation before extracting dimension for pool_kernel
+ pool_kernel.d[0] = (idx_set.count(2) || permuted_index == 2) ? dims.d[1] : 1;
+ pool_kernel.d[1] = (idx_set.count(3) || permuted_index == 3) ? dims.d[2] : 1;
+
+ nvinfer1::ITensor* output_tensor;
+
+ if (node_def.op() == "Mean") {
+ nvinfer1::IPoolingLayer* layer =
+ ctx.network()->addPooling(*const_cast<nvinfer1::ITensor*>(tensor),
+ nvinfer1::PoolingType::kAVERAGE, pool_kernel);
+ output_tensor = layer->getOutput(0);
+ } else {
+ return tensorflow::errors::Unimplemented(
+ "Op not supported " + node_def.op() + " , at " + node_def.name());
+ }
+ if (permuted_index != -1) {
+ // Apply permutation before extracting dimension for pool_kernel
+ output_tensor = ctx.TransposeTensor(
+ const_cast<nvinfer1::ITensor*>(output_tensor), permutation_order);
+ }
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertPad(Converter& ctx,
+ const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
+ !inputs.at(1).is_weights())
+ return tensorflow::errors::InvalidArgument(
+ "Input expects tensor and weights, at" + node_def.name());
+
+ // Implement tensor binaryOp weight [channel wise] for now;
+ nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+ auto dims = tensor->getDimensions();
+ // Restore implicit batch dimension
+ int nb_dims = dims.nbDims + 1;
+
+ TRT_ShapedWeights pads = inputs.at(1).weights();
+
+ TFAttrs attrs(node_def);
+ // Padding type here is done through TF type
+ // so I can leverage their EnumToDataType for my cast
+ auto padding_type = attrs.get<tensorflow::DataType>("Tpaddings");
+ // TODO(jie): handle data type conversion for TRT?
+
+ if (pads.shape_.d[0] != nb_dims || pads.shape_.d[1] != 2)
+ return tensorflow::errors::InvalidArgument(
+ "Pad only supports explicit padding on 4 dimensional tensor, at " +
+ node_def.name());
+
+ // Only expect to handle INT32 as attributes for now
+ if (padding_type != tensorflow::DataType::DT_INT32)
+ return tensorflow::errors::Unimplemented(
+ "Tpaddings supports only DT_INT32");
+ auto pad_data = static_cast<int*>(const_cast<void*>(pads.GetValues()));
+
+ std::vector<int32_t> pad_index;
+ for (int i = 0; i < nb_dims; i++) {
+ if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0)
+ pad_index.push_back(i);
+ }
+
+ // No padding at all, we should exit
+ if (pad_index.size() == 0) {
+ outputs->push_back(inputs.at(0));
+ return tensorflow::Status::OK();
+ }
+
+ // Only supports padding on less than 2 axis GIE-2579
+ if (pad_index.size() > 2)
+ return tensorflow::errors::InvalidArgument(
+ "Padding layer does not support padding on > 2");
+
+ // Padding on batch dimension is not supported
+ if (pad_index[0] == 0)
+ return tensorflow::errors::InvalidArgument(
+ "Padding layer does not support padding on batch dimension");
+
+ // Not doing the legit thing here. ignoring padding on dim 1 and 3;
+ // TODO(jie): implement pad as uff parser
+ if (pad_index.size() == 2 && pad_index[0] == 0 && pad_index[1] == 3)
+ return tensorflow::errors::Unimplemented(
+ "Padding layer does not support padding on dimension 1 and 3 yet");
+
+ bool legit_pad = true;
+ nvinfer1::DimsHW pre_padding(0, 0);
+ nvinfer1::DimsHW post_padding(0, 0);
+
+ std::vector<int32_t> permuted_pad_index(pad_index);
+ if (pad_index[0] == 1) {
+ legit_pad = false;
+ tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
+ {0, 3, 2, 1});
+ permuted_pad_index[0] = 3;
+ }
+
+ for (size_t i = 0; i < pad_index.size(); i++) {
+ int index = pad_index[i];
+ if (permuted_pad_index[i] == 2) {
+ pre_padding.h() = pad_data[index * 2];
+ post_padding.h() = pad_data[index * 2 + 1];
+ } else if (permuted_pad_index[i] == 3) {
+ pre_padding.w() = pad_data[index * 2];
+ post_padding.w() = pad_data[index * 2 + 1];
+ }
+ }
+
+ nvinfer1::IPaddingLayer* layer = ctx.network()->addPadding(
+ *const_cast<nvinfer1::ITensor*>(tensor), pre_padding, post_padding);
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+
+ if (!legit_pad)
+ output_tensor = ctx.TransposeTensor(
+ const_cast<nvinfer1::ITensor*>(output_tensor), {0, 3, 2, 1});
+
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
+void Converter::register_op_converters() {
+ // vgg_16 slim implementation
+ op_registry_["Placeholder"] = ConvertPlaceholder;
+ op_registry_["Conv2D"] = ConvertConv2D;
+ op_registry_["Relu"] = ConvertActivation;
+ op_registry_["MaxPool"] = ConvertPool;
+ // This could be really handled as ConvertBinary
+ op_registry_["BiasAdd"] = ConvertScale;
+ op_registry_["Const"] = ConvertConst;
+ // op_registry_["MatMul"] = ConvertFullyConnected; // Not used in vgg
+ // TODO(ben,jie): this is a temp hack.
+ op_registry_["Identity"] = ConvertIdentity; // Identity should be removed
+ // op_registry_["AvgPool"] = ConvertPool;
+
+ // resnet_50_v1 slim implementation
+ op_registry_["Add"] = ConvertBinary;
+ op_registry_["Mul"] = ConvertBinary;
+ op_registry_["Sub"] = ConvertBinary;
+ op_registry_["Rsqrt"] = ConvertUnary;
+ op_registry_["Mean"] = ConvertReduce;
+ op_registry_["Pad"] = ConvertPad;
+ // TODO(ben,jie): Add more ops
+}
+
+} // namespace
+
+tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
+ const tensorflow::Graph& graph, const std::set<int>& subgraph_node_ids,
+ const std::vector<std::pair<int, int>>& input_inds,
+ const std::vector<std::pair<int, int>>& output_inds, size_t max_batch_size,
+ size_t max_workspace_size_bytes,
+ const tensorflow::grappler::GraphProperties& graph_properties,
+ tensorflow::NodeDef* trt_node) {
+ // Visit nodes in reverse topological order and construct the TRT network.
+
+ // Toposort
+ std::vector<tensorflow::Node*> order_vec;
+ tensorflow::GetPostOrder(graph, &order_vec);
+ // Select just the subgraph
+ std::list<tensorflow::Node*> order;
+ for (tensorflow::Node* node : order_vec) {
+ if (subgraph_node_ids.count(node->id())) {
+ // We want topological order to contstruct the
+ // network layer by layer
+ order.push_front(node);
+ }
+ }
+ // Topological order is needed to build TRT network
+
+ tensorflow::tensorrt::Logger trt_logger;
+
+ auto trt_builder = infer_object(nvinfer1::createInferBuilder(trt_logger));
+ if (!trt_builder) {
+ return tensorflow::errors::Internal(
+ "Failed to create TensorRT builder object");
+ }
+
+ auto trt_network = infer_object(trt_builder->createNetwork());
+ if (!trt_network) {
+ return tensorflow::errors::Internal(
+ "Failed to create TensorRT network object");
+ }
+
+ // Build the network
+ Converter converter(trt_network.get());
+
+ std::vector<string> input_names;
+ std::vector<tensorflow::DataType> input_dtypes;
+ for (std::pair<int, int> const& input : input_inds) {
+ int node_id = input.first;
+ int output_idx = input.second;
+ tensorflow::Node* node = graph.FindNodeId(node_id);
+ auto node_name = node->name();
+ input_names.push_back(node_name); // Insert original node name without port
+ // TODO(jie): alternative :)
+ if (!graph_properties.HasOutputProperties(node_name))
+ return tensorflow::errors::Internal("Failed to find input node: " +
+ node_name);
+
+ auto op_info_vec = graph_properties.GetOutputProperties(node_name);
+ if (static_cast<int>(op_info_vec.size()) < output_idx)
+ return tensorflow::errors::Internal(
+ "Accessing output index of: " + std::to_string(output_idx) +
+ ", at node: " + node_name + " with output entry from shape_map: " +
+ std::to_string(op_info_vec.size()));
+
+ auto op_info = op_info_vec.at(output_idx);
+
+ tensorflow::DataType tf_dtype = op_info.dtype();
+ input_dtypes.push_back(tf_dtype);
+
+ nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT);
+ TF_CHECK_OK(ConvertDType(tf_dtype, &dtype));
+
+ VLOG(2) << "Accessing output index of: " << std::to_string(output_idx)
+ << ", at node: " << node_name
+ << " with output entry from shape_map: "
+ << std::to_string(op_info_vec.size());
+
+ // TODO(ben,jie): update TRT input format/dimension
+ nvinfer1::DimsCHW input_dim_pseudo_chw;
+ for (int i = 0; i < 3; i++) input_dim_pseudo_chw.d[i] = 1;
+
+ for (int i = 1; i < op_info.shape().dim_size(); i++) {
+ VLOG(2) << "dimension: " << i
+ << " , size: " << op_info.shape().dim(i).size();
+ input_dim_pseudo_chw.d[i - 1] = op_info.shape().dim(i).size();
+ }
+
+ // TODO(ben,jie): proper way to restore input tensor name?
+ auto input_tensor_name = node_name;
+ if (output_idx != 0)
+ input_tensor_name = node_name + ":" + std::to_string(output_idx);
+
+ nvinfer1::ITensor* input_tensor = converter.network()->addInput(
+ input_tensor_name.c_str(), dtype, input_dim_pseudo_chw);
+
+ if (!input_tensor)
+ return tensorflow::errors::InvalidArgument(
+ "Failed to create Input layer");
+ VLOG(2) << "Input tensor name :" << input_tensor_name;
+
+ if (!converter.insert_input_tensor(input_tensor_name, input_tensor))
+ return tensorflow::errors::AlreadyExists(
+ "Output tensor already exists for op: " + input_tensor_name);
+ }
+
+ VLOG(2) << "Finished sorting";
+
+ for (const tensorflow::Node* node : order) {
+ const tensorflow::NodeDef& node_def = node->def();
+ VLOG(2) << "Converting node: " << node_def.name() << " , " << node_def.op();
+ TF_RETURN_IF_ERROR(converter.convert_node(node_def));
+ }
+
+ VLOG(2) << "Finished conversion";
+
+ // Gather output metadata
+ std::vector<string> output_names;
+ std::vector<tensorflow::DataType> output_dtypes;
+ for (std::pair<int, int> const& output : output_inds) {
+ int node_id = output.first;
+ int output_idx = output.second;
+ tensorflow::Node* node = graph.FindNodeId(node_id);
+ string op_name = node->name();
+ string tensor_name = op_name;
+ if (output_idx != 0)
+ tensor_name = tensor_name + ":" + std::to_string(output_idx);
+ VLOG(2) << "Output tensor name: " << tensor_name;
+ output_names.push_back(tensor_name);
+ auto tensor_or_weights = converter.get_tensor(tensor_name);
+ if (!tensor_or_weights.is_tensor()) {
+ return tensorflow::errors::InvalidArgument(
+ "Output node is weights not tensor");
+ }
+ nvinfer1::ITensor* tensor = tensor_or_weights.tensor();
+ if (!tensor) {
+ return tensorflow::errors::NotFound("Output tensor not found: " +
+ tensor_name);
+ }
+ converter.network()->markOutput(*tensor);
+ tensorflow::DataType tf_dtype = node->output_type(output_idx);
+ output_dtypes.push_back(tf_dtype);
+ nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT;
+ TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype));
+ tensor->setType(trt_dtype);
+ }
+
+ VLOG(2) << "Finished output";
+ // TODO(jie): static_id is not thread safe.
+ static int static_id = 0;
+
+ // Build the engine
+ trt_builder->setMaxBatchSize(max_batch_size);
+ trt_builder->setMaxWorkspaceSize(max_workspace_size_bytes);
+ VLOG(0) << "Starting build engine " << static_id;
+ // TODO(ben,jie): half2 and int8 mode support
+ string engine_plan_string;
+ {
+ auto trt_engine =
+ infer_object(trt_builder->buildCudaEngine(*converter.network()));
+ VLOG(0) << "Built network";
+ auto engine_plan = infer_object(trt_engine->serialize());
+ VLOG(0) << "Serialized engine";
+ const char* engine_plan_data =
+ static_cast<const char*>(engine_plan->data());
+ engine_plan_string =
+ string(engine_plan_data, engine_plan_data + engine_plan->size());
+ }
+
+ VLOG(0) << "Finished engine";
+
+ // Build the TRT op
+ // TODO(sami,ben,jie): proper naming!
+ tensorflow::NodeDefBuilder op_builder(
+ tensorflow::strings::StrCat("my_trt_op", static_id++), "TRTEngineOp");
+ std::vector<tensorflow::NodeDefBuilder::NodeOut> income_edges;
+ for (size_t i = 0; i < input_names.size(); ++i) {
+ int output_idx = input_inds.at(i).second;
+ // We wired up the input here already, it is redundant to do it again in
+ // ConvertSubGraphToTensorRT(convert_graph.cc)
+ auto incoming_edge = tensorflow::NodeDefBuilder::NodeOut(
+ input_names.at(i), output_idx, input_dtypes.at(i));
+ income_edges.push_back(incoming_edge);
+ }
+ tensorflow::gtl::ArraySlice<tensorflow::NodeDefBuilder::NodeOut> input_list(
+ income_edges);
+ op_builder.Input(input_list);
+
+ VLOG(0) << "Finished op preparation";
+
+ auto status = op_builder.Attr("serialized_engine", engine_plan_string)
+ .Attr("input_nodes", input_names)
+ .Attr("output_nodes", output_names)
+ .Attr("OutT", output_dtypes)
+ .Finalize(trt_node);
+
+ VLOG(0) << status.ToString() << " finished op building";
+
+ return tensorflow::Status::OK();
+}
+
+} // namespace convert
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
new file mode 100644
index 0000000000..2e7fd19566
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
@@ -0,0 +1,52 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
+
+#include <set>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/lib/core/status.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+
+namespace tensorflow {
+namespace tensorrt {
+namespace convert {
+
+tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
+ const tensorflow::Graph& graph, const std::set<int>& subgraph_node_ids,
+ const std::vector<std::pair<int, int>>&
+ input_inds, // {node_id, output_idx}
+ const std::vector<std::pair<int, int>>&
+ output_inds, // {node_id, output_idx}
+ size_t max_batch_size, size_t max_workspace_size_bytes,
+ const tensorflow::grappler::GraphProperties& graph_prop,
+ tensorflow::NodeDef* trt_node);
+
+} // namespace convert
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
new file mode 100644
index 0000000000..8efdf63ebe
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -0,0 +1,140 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h"
+
+#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor.h"
+#include "tensorflow/core/platform/types.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "cuda/include/cuda_runtime_api.h"
+
+namespace tensorflow {
+namespace tensorrt {
+static ::tensorflow::tensorrt::Logger logger;
+
+TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) {
+ // read serialized_engine
+ string serialized_engine;
+ OP_REQUIRES_OK(context,
+ context->GetAttr("serialized_engine", &serialized_engine));
+
+ // register input output node name in trt_sub_graph
+ OP_REQUIRES_OK(context, context->GetAttr("input_nodes", &input_nodes_));
+ OP_REQUIRES_OK(context, context->GetAttr("output_nodes", &output_nodes_));
+
+ // TODO(samikama) runtime should be taken from a resourcemanager as well.
+ // Only engine should be in the op and context and runtime should be taken
+ // from resourcemanager
+ nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger);
+ trt_engine_ptr_.reset(infer->deserializeCudaEngine(
+ serialized_engine.c_str(), serialized_engine.size(), nullptr));
+
+ trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext());
+ // Runtime is safe to delete after engine creation
+ infer->destroy();
+}
+
+void TRTEngineOp::Compute(OpKernelContext* context) {
+ int num_binding = context->num_inputs() + context->num_outputs();
+ std::vector<void*> buffers(num_binding);
+
+ size_t binding_index;
+ int num_batch = 0;
+ bool valid = true;
+ for (int i = 0; i < context->num_inputs(); i++) {
+ // Grab the input tensor
+ binding_index = trt_engine_ptr_->getBindingIndex(input_nodes_[i].c_str());
+
+ const Tensor& input_tensor = context->input(i);
+ const TensorShape& input_shape = input_tensor.shape();
+ if (i == 0) {
+ num_batch = input_shape.dim_size(0);
+ } else if (num_batch != input_shape.dim_size(0)) {
+ valid = false;
+ break;
+ }
+ switch (trt_engine_ptr_->getBindingDataType(binding_index)) {
+ case nvinfer1::DataType::kFLOAT:
+ buffers[binding_index] = (void*)(input_tensor.flat<float>().data());
+ break;
+ case nvinfer1::DataType::kHALF:
+ LOG(FATAL) << "half size is not supported yet!";
+ break;
+ case nvinfer1::DataType::kINT8:
+ LOG(FATAL) << "int8 is not supported yet!";
+ break;
+ }
+ }
+
+ // Might want a different way to inform the user of batch size inconsistency
+ if (!valid) LOG(WARNING) << "input data inconsistent batch size";
+
+ for (int i = 0; i < static_cast<int>(output_nodes_.size()); i++) {
+ // This is bad that we have to reallocate output buffer every run.
+ // Create an output tensor
+ binding_index = trt_engine_ptr_->getBindingIndex(output_nodes_[i].c_str());
+ Tensor* output_tensor = nullptr;
+
+ TensorShape output_shape;
+ if (binding_index != -1) {
+ auto dims = trt_engine_ptr_->getBindingDimensions(binding_index);
+ std::vector<int> trt_shape(dims.nbDims + 1);
+ trt_shape[0] = num_batch;
+ for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j];
+ OP_REQUIRES_OK(context,
+ TensorShapeUtils::MakeShape(
+ trt_shape.data(), trt_shape.size(), &output_shape));
+ } else {
+ LOG(FATAL) << "output node not found, at " << output_nodes_[i];
+ break;
+ }
+
+ OP_REQUIRES_OK(context,
+ context->allocate_output(i, output_shape, &output_tensor));
+ switch (trt_engine_ptr_->getBindingDataType(binding_index)) {
+ case nvinfer1::DataType::kFLOAT:
+ buffers[binding_index] =
+ reinterpret_cast<void*>(output_tensor->flat<float>().data());
+ break;
+ case nvinfer1::DataType::kHALF:
+ LOG(FATAL) << "half size is not supported yet!";
+ break;
+ case nvinfer1::DataType::kINT8:
+ LOG(FATAL) << "int8 is not supported yet!";
+ break;
+ }
+ }
+ // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
+ const cudaStream_t* stream = CHECK_NOTNULL(
+ reinterpret_cast<const cudaStream_t*>(context->op_device_context()
+ ->stream()
+ ->implementation()
+ ->CudaStreamMemberHack()));
+
+ // execution handled by TF since we are getting stream from TF.
+ // it is safe for CPU pointer array (buffers) to go out of scope after enqueue
+ trt_execution_context_ptr_->enqueue(num_batch, &buffers[0], *stream, nullptr);
+}
+
+REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);
+
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
new file mode 100644
index 0000000000..0964b4b18a
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "cuda/include/cuda_runtime_api.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorrt/include/NvInfer.h"
+
+namespace tensorflow {
+namespace tensorrt {
+class Logger;
+
+class TRTEngineOp : public OpKernel {
+ public:
+ explicit TRTEngineOp(OpKernelConstruction* context);
+
+ void Compute(OpKernelContext* context) override;
+
+ private:
+ template <typename T>
+ struct Destroyer {
+ void operator()(T* d) { d->destroy(); }
+ };
+
+ template <typename T>
+ using destroyed_ptr = std::unique_ptr<T, Destroyer<T>>;
+ destroyed_ptr<nvinfer1::ICudaEngine> trt_engine_ptr_;
+ // TODO(samikama): context should go to a resource manager!
+ destroyed_ptr<nvinfer1::IExecutionContext> trt_execution_context_ptr_;
+
+ std::vector<string> input_nodes_;
+ std::vector<string> output_nodes_;
+};
+
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_
diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.cc b/tensorflow/contrib/tensorrt/log/trt_logger.cc
new file mode 100644
index 0000000000..7add8cb8b3
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/log/trt_logger.cc
@@ -0,0 +1,57 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace tensorrt {
+
+// Use TF logging for TensorRT informations
+void Logger::log(Severity severity, const char* msg) {
+ // Suppress info-level messages
+ switch (severity) {
+ case Severity::kINFO: { // Mark TRT info messages as debug!
+ VLOG(2) << msg;
+ break;
+ }
+ case Severity::kWARNING: {
+ LOG(WARNING) << msg;
+ break;
+ }
+ case Severity::kERROR: {
+ LOG(ERROR) << msg;
+ break;
+ }
+ case Severity::kINTERNAL_ERROR: {
+ LOG(FATAL) << msg;
+ break;
+ }
+ // This is useless for now. But would catch it in future if enum changes. It
+ // is always good to have default case!
+ default: {
+ LOG(FATAL) << name_ << "Got unknown severity level from TRT " << msg;
+ break;
+ }
+ }
+}
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
+#endif // GOOGLE_TENSORRT
diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.h b/tensorflow/contrib/tensorrt/log/trt_logger.h
new file mode 100644
index 0000000000..d71f66b933
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/log/trt_logger.h
@@ -0,0 +1,42 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_
+
+#include "tensorflow/core/platform/types.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "tensorrt/include/NvInfer.h"
+
+namespace tensorflow {
+namespace tensorrt {
+
+// Logger for GIE info/warning/errors
+class Logger : public nvinfer1::ILogger {
+ private:
+ void log(nvinfer1::ILogger::Severity severity, const char* msg) override;
+
+ string name_;
+};
+
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_
diff --git a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
new file mode 100644
index 0000000000..079d73f7be
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
@@ -0,0 +1,43 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+
+namespace tensorflow {
+
+namespace shape_inference {
+extern Status TRTEngineOpShapeInference(InferenceContext* c);
+}
+
+REGISTER_OP("TRTEngineOp")
+ .Attr("serialized_engine: string")
+ .Attr("input_nodes: list(string)")
+ .Attr("output_nodes: list(string)")
+ .Attr("InT: list({float32})")
+ .Attr("OutT: list({float32})")
+ .Input("in_tensor: InT")
+ .Output("out_tensor: OutT")
+ .SetShapeFn(shape_inference::TRTEngineOpShapeInference);
+
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/tensorrt/python/__init__.py b/tensorflow/contrib/tensorrt/python/__init__.py
new file mode 100644
index 0000000000..7e050a768c
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/python/__init__.py
@@ -0,0 +1,24 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Exposes the python wrapper for TensorRT graph transforms."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,line-too-long
+from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
+from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph
+# pylint: enable=unused-import,line-too-long
diff --git a/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py b/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py
new file mode 100644
index 0000000000..31a313182b
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py
@@ -0,0 +1,34 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Exposes the Python wrapper of TRTEngineOp."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import platform
+
+if platform.system() != "Windows":
+ # pylint: disable=wildcard-import,unused-import,g-import-not-at-top
+ from tensorflow.contrib.tensorrt.ops.gen_trt_engine_op import *
+
+ from tensorflow.contrib.util import loader
+ from tensorflow.python.platform import resource_loader
+ # pylint: enable=wildcard-import,unused-import,g-import-not-at-top
+
+ _trt_engine_op = loader.load_op_library(
+ resource_loader.get_path_to_datafile("_trt_engine_op.so"))
+else:
+ raise RuntimeError("Windows platforms are not supported")
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py
new file mode 100644
index 0000000000..9454862f85
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/python/trt_convert.py
@@ -0,0 +1,103 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Exposes the Python wrapper conversion to trt_graph."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,line-too-long
+import six as _six
+from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import errors_impl as _impl
+from tensorflow.python.framework import ops
+
+
+# TODO(skama): get outputs from session when implemented as c++
+# optimization pass
+def create_inference_graph(input_graph_def,
+ outputs,
+ max_batch_size=1,
+ max_workspace_size_bytes=2 << 20):
+ """Python wrapper for the TRT transormation.
+
+
+ Args:
+ input_graph_def: GraphDef object containing a model to be transformed.
+ outputs: List of tensors or node names for the model outputs.
+ max_batch_size: max size for the input batch
+ max_workspace_size_bytes: parameter to control memory allocation (in Bytes)
+
+ Returns:
+ New GraphDef with TRTEngineOps placed in graph replacing subgraphs.
+
+ Raises:
+ RuntimeError: if the returned status message is malformed.
+ """
+
+ def py2bytes(inp):
+ return inp
+
+ def py3bytes(inp):
+ return inp.encode("utf-8", errors="surrogateescape")
+
+ def py2string(inp):
+ return inp
+
+ def py3string(inp):
+ return inp.decode("utf-8")
+
+ if _six.PY2:
+ to_bytes = py2bytes
+ to_string = py2string
+ else:
+ to_bytes = py3bytes
+ to_string = py3string
+
+ out_names = []
+ for i in outputs:
+ if isinstance(i, ops.Tensor):
+ out_names.append(to_bytes(i.name))
+ else:
+ out_names.append(to_bytes(i))
+
+ input_graph_def_str = input_graph_def.SerializeToString()
+
+ # TODO(sami): Fix this when we can return status from C++ library
+ # There is a problem with the TF internal library setup that doesn't
+ # allow us to return a status object from C++. Thus we return a
+ # pair or strings where first one is encoded status and the second
+ # one is the transformed graphs protobuf string.
+ out = trt_convert(input_graph_def_str, out_names, max_batch_size,
+ max_workspace_size_bytes)
+ status = to_string(out[0])
+ output_graph_def_string = out[1]
+ del input_graph_def_str # Save some memory
+ if len(status) < 2:
+ raise _impl.UnknownError(None, None, status)
+ if status[:2] != "OK":
+ msg = status.split(";")
+ if len(msg) == 1:
+ raise RuntimeError("Status message is malformed {}".format(status))
+ # pylint: disable=protected-access
+ raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
+ int(msg[0]))
+ # pylint: enable=protected-access
+ output_graph_def = graph_pb2.GraphDef()
+ output_graph_def.ParseFromString(output_graph_def_string)
+ del output_graph_def_string # Save some memory
+ return output_graph_def
diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc
new file mode 100644
index 0000000000..6193f0b0a1
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/segment/segment.cc
@@ -0,0 +1,253 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/segment/segment.h"
+
+#include <set>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/tensorrt/segment/union_find.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace tensorrt {
+namespace segment {
+
+namespace {
+
+bool CanContractEdge(const tensorflow::Edge* edge,
+ const tensorflow::Graph& graph) {
+ const tensorflow::Node* src = edge->src();
+ const tensorflow::Node* dst = edge->dst();
+
+ // Can't contract edge if doing so would cause a cycle in the
+ // graph. So, if there is a directed path from 'src' to 'dst', other
+ // than 'edge' (or any other direct edge from 'src' to 'dst'), then
+ // combining 'src' and 'dst' will cause a cycle along that path.
+ //
+ // In practice, to avoid modifying the graph and to take advantage
+ // of existing graph functions, we perform an equivalent.
+ // 1. Get all nodes incoming to 'dst', excluding 'src'
+ // 2. Reverse DFS from those nodes
+ // 3. If reverse DFS reaches 'src' then we have a cycle
+ std::vector<tensorflow::Node*> dfs_start_nodes;
+ for (tensorflow::Node* node : dst->in_nodes()) {
+ if (node != src) {
+ dfs_start_nodes.push_back(node);
+ }
+ }
+
+ bool is_cycle = false;
+ if (!dfs_start_nodes.empty()) {
+ tensorflow::ReverseDFSFrom(graph, dfs_start_nodes, {},
+ [&is_cycle, src](tensorflow::Node* node) {
+ if (node == src) {
+ is_cycle = true;
+ }
+ });
+ }
+
+ return !is_cycle;
+}
+
+void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph,
+ std::vector<const tensorflow::Edge*>* remove_edges) {
+ // Transfer all inputs and outputs of 'dst' to 'src' except edges
+ // connecting the two.
+ tensorflow::Node* src = edge->src();
+ tensorflow::Node* dst = edge->dst();
+
+ // We can use '0' for input/output index because we don't need them
+ // to be accurate for the way we are using the graph.
+ std::vector<const tensorflow::Edge*> in_edges(dst->in_edges().begin(),
+ dst->in_edges().end());
+ for (const tensorflow::Edge* in_edge : in_edges) {
+ if (in_edge->src() != src) {
+ tensorflow::Edge* e = const_cast<tensorflow::Edge*>(in_edge);
+ if (e->src() == graph->source_node()) {
+ graph->AddEdge(e->src(), e->src_output(), src,
+ tensorflow::Graph::kControlSlot);
+ } else {
+ graph->AddEdge(e->src(), e->src_output(), src, 0 /* input index */);
+ }
+ }
+ }
+
+ std::vector<const tensorflow::Edge*> out_edges(dst->out_edges().begin(),
+ dst->out_edges().end());
+ for (const tensorflow::Edge* out_edge : out_edges) {
+ tensorflow::Edge* e = const_cast<tensorflow::Edge*>(out_edge);
+ if (e->dst() == graph->sink_node()) {
+ graph->AddEdge(src, tensorflow::Graph::kControlSlot, e->dst(),
+ e->dst_input());
+ } else {
+ graph->AddEdge(src, 0 /* output index */, e->dst(), e->dst_input());
+ }
+ }
+
+ // Return the edges that must be removed to disconnect 'dst' from
+ // the graph. We don't actually remove 'dst' since the caller holds
+ // references to all the nodes.
+ for (const auto& in_edge : dst->in_edges()) {
+ remove_edges->push_back(in_edge);
+ }
+ for (const auto& out_edge : dst->out_edges()) {
+ remove_edges->push_back(out_edge);
+ }
+}
+
+} // namespace
+
+tensorflow::Status SegmentGraph(
+ const tensorflow::GraphDef& gdef,
+ const std::function<bool(const tensorflow::NodeDef&)>& candidate_fn,
+ const SegmentOptions& options, SegmentNodesVector* segments) {
+ // Create a Graph representation of the GraphDef.
+ tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
+ gdef.library());
+ tensorflow::Graph graph(flib);
+ TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
+ tensorflow::GraphConstructorOptions(), gdef, &graph));
+
+ // tensorflow::DumpGraph("Pre-Segment", &graph);
+
+ // Use a union-find to collect the nodes that belong to the same
+ // segment. A node value of nullptr indicates that the node is not a
+ // candidate for TRT.
+ std::vector<UnionFind<tensorflow::Node*>> node_segments;
+ for (int i = 0; i < graph.num_node_ids(); ++i) {
+ tensorflow::Node* node = graph.FindNodeId(i);
+ if (options.exclude_node_list.count(node->name()) != 0 ||
+ !candidate_fn(node->def())) {
+ node = nullptr;
+ }
+ node_segments.emplace_back(node);
+ }
+
+ // The segmentation algorithm below visits nodes in reverse
+ // topological order and attempts to merge nodes along output
+ // edges. That means that subgraphs grow from the output-side of the
+ // network towards the inputs. In general this is not guaranteed to
+ // produce a globally optimal segmentation. In the future if we have
+ // a measure of how beneficial it is to include a given node in a
+ // TRT subgraph then we can revisit this algorithm to take advantage
+ // of that information.
+ std::vector<tensorflow::Node*> order;
+ tensorflow::GetPostOrder(graph, &order);
+
+ for (const tensorflow::Node* node : order) {
+ // All output nodes of 'node' have been visited...
+ VLOG(2) << "Trying node " << node->name();
+
+ // 'node' must be a TRT candidate...
+ if (node_segments[node->id()].Value() == nullptr) {
+ VLOG(2) << "... not a TRT candidate";
+ continue;
+ }
+
+ // Contract output edges to combine 'node' with output
+ // nodes. Iterate since combining two nodes may unblock other
+ // combining.
+ while (true) {
+ std::set<const tensorflow::Edge*> contract_edges;
+ for (const tensorflow::Edge* out_edge : node->out_edges()) {
+ VLOG(2) << "... out node " << out_edge->dst()->name();
+
+ // Out node must be TRT candidate...
+ if (node_segments[out_edge->dst()->id()].Value() == nullptr) {
+ VLOG(2) << "... ... not a TRT candidate";
+ continue;
+ }
+
+ if (CanContractEdge(out_edge, graph)) {
+ VLOG(2) << "... ... can contract";
+ contract_edges.insert(out_edge);
+ } else {
+ VLOG(2) << "... ... cannot contract, would form cycle";
+ }
+ }
+
+ if (contract_edges.empty()) {
+ break;
+ }
+
+ // Contract edges and collect the adjacent nodes into the same
+ // segment/subgraph.
+ while (!contract_edges.empty()) {
+ const tensorflow::Edge* contract_edge = *contract_edges.begin();
+ const tensorflow::Node* src = contract_edge->src();
+ const tensorflow::Node* dst = contract_edge->dst();
+
+ VLOG(2) << "Merge " << src->name() << " <- " << dst->name();
+ node_segments[src->id()].Merge(&node_segments[dst->id()]);
+
+ // Contracting the edge leaves disconnected graph edges.
+ // Remove these from the graph and from 'contract_edges' so we
+ // don't visit them again.
+ tensorflow::Edge* e = const_cast<tensorflow::Edge*>(contract_edge);
+ std::vector<const tensorflow::Edge*> remove_edges;
+ ContractEdge(e, &graph, &remove_edges);
+
+ for (const tensorflow::Edge* r : remove_edges) {
+ contract_edges.erase(r);
+ graph.RemoveEdge(r);
+ }
+ }
+ }
+ }
+
+ // Collect the segments/subgraphs. Each subgraph is represented by a
+ // set of the names of the nodes in that subgraph.
+ std::unordered_map<string, std::set<string>> sg_map;
+ for (auto& u : node_segments) {
+ if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) {
+ sg_map[u.ParentValue()->name()].insert(u.Value()->name());
+ }
+ }
+
+ // Convert the segments into the expected return format
+ for (const auto& itr : sg_map) {
+ const auto& segment_node_names = itr.second;
+ if (VLOG_IS_ON(1)) {
+ string s;
+ for (const auto& name : segment_node_names) {
+ s += " " + name;
+ }
+ VLOG(1) << "Segment " << segments->size() << ":" << s;
+ }
+
+ // Don't use small segments.
+ if (static_cast<int>(segment_node_names.size()) <
+ options.minimum_segment_size) {
+ VLOG(1) << "Segment " << segments->size() << " has only "
+ << segment_node_names.size() << " nodes, dropping";
+ continue;
+ }
+
+ segments->emplace_back(segment_node_names);
+ }
+
+ return tensorflow::Status::OK();
+}
+
+} // namespace segment
+} // namespace tensorrt
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h
new file mode 100644
index 0000000000..ee6e2b3ed2
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/segment/segment.h
@@ -0,0 +1,56 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_
+
+#include <set>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace tensorrt {
+namespace segment {
+
+using SegmentNodesVector = std::vector<std::set<string>>;
+
+struct SegmentOptions {
+ // Segment must contain at least this many nodes.
+ int minimum_segment_size = 2;
+ std::set<string> exclude_node_list;
+};
+
+// Get the subgraphs of a graph that can be handled by TensorRT.
+//
+// @param gdef The GraphDef describing the network
+// @param candidate_fn A function that returns true for a NodeDef if
+// that node can be handled by TensorRT.
+// @param segments Returns the TensorRT segments/subgraphs. Each entry
+// in the vector describes a subgraph by giving a set of the names of
+// all the NodeDefs in that subgraph.
+// @return the status.
+tensorflow::Status SegmentGraph(
+ const tensorflow::GraphDef& gdef,
+ const std::function<bool(const tensorflow::NodeDef&)>& candidate_fn,
+ const SegmentOptions& options, SegmentNodesVector* segments);
+
+} // namespace segment
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_
diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc
new file mode 100644
index 0000000000..74cbc5f2b3
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc
@@ -0,0 +1,367 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/segment/segment.h"
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace tensorrt {
+namespace segment {
+namespace test {
+
+class SegmentTest : public ::testing::Test {
+ public:
+ bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def);
+
+ TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name);
+ TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+ TF_Status* s, const char* name);
+
+ std::function<bool(const NodeDef&)> MakeCandidateFn(
+ const std::set<string>& node_names);
+
+ protected:
+ void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name,
+ TF_Operation** op);
+ void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+ TF_Status* s, const char* name, TF_Operation** op, bool check);
+
+ SegmentOptions default_options_;
+};
+
+bool SegmentTest::GetGraphDef(TF_Graph* graph,
+ tensorflow::GraphDef* graph_def) {
+ TF_Status* s = TF_NewStatus();
+ TF_Buffer* buffer = TF_NewBuffer();
+ TF_GraphToGraphDef(graph, buffer, s);
+ bool ret = TF_GetCode(s) == TF_OK;
+ EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ if (ret) ret = graph_def->ParseFromArray(buffer->data, buffer->length);
+ TF_DeleteBuffer(buffer);
+ TF_DeleteStatus(s);
+ return ret;
+}
+
+std::function<bool(const NodeDef&)> SegmentTest::MakeCandidateFn(
+ const std::set<string>& node_names) {
+ return [node_names](const NodeDef& node) -> bool {
+ return node_names.find(node.name()) != node_names.end();
+ };
+}
+
+void SegmentTest::PlaceholderHelper(TF_Graph* graph, TF_Status* s,
+ const char* name, TF_Operation** op) {
+ TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
+ TF_SetAttrType(desc, "dtype", TF_INT32);
+ *op = TF_FinishOperation(desc, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ ASSERT_NE(*op, nullptr);
+}
+
+TF_Operation* SegmentTest::Placeholder(TF_Graph* graph, TF_Status* s,
+ const char* name) {
+ TF_Operation* op;
+ PlaceholderHelper(graph, s, name, &op);
+ return op;
+}
+
+void SegmentTest::AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+ TF_Status* s, const char* name, TF_Operation** op,
+ bool check) {
+ TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
+ TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
+ TF_AddInputList(desc, add_inputs, 2);
+ *op = TF_FinishOperation(desc, s);
+ if (check) {
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ ASSERT_NE(*op, nullptr);
+ }
+}
+
+TF_Operation* SegmentTest::Add(TF_Operation* l, TF_Operation* r,
+ TF_Graph* graph, TF_Status* s,
+ const char* name) {
+ TF_Operation* op;
+ AddHelper(l, r, graph, s, name, &op, true);
+ return op;
+}
+
+TEST_F(SegmentTest, Empty) {
+ TF_Graph* graph = TF_NewGraph();
+
+ GraphDef graph_def;
+ ASSERT_TRUE(GetGraphDef(graph, &graph_def));
+
+ SegmentNodesVector segments;
+ ASSERT_EQ(
+ SegmentGraph(graph_def, MakeCandidateFn({}), default_options_, &segments),
+ tensorflow::Status::OK());
+
+ // Expect no segments/subgraphs.
+ EXPECT_TRUE(segments.empty());
+ TF_DeleteGraph(graph);
+}
+
+TEST_F(SegmentTest, Simple) {
+ TF_Status* s = TF_NewStatus();
+ TF_Graph* graph = TF_NewGraph();
+
+ // feed
+ // // ||
+ // add0 add1
+ // | | /
+ // | add2
+ // | / ||
+ // add3 add4
+ // | /
+ // <sink>
+ //
+ TF_Operation* feed = Placeholder(graph, s, "feed");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
+
+ TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
+ TF_Operation* add4 = Add(add2, add2, graph, s, "add4");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
+
+ GraphDef graph_def;
+ ASSERT_TRUE(GetGraphDef(graph, &graph_def));
+
+ SegmentNodesVector segments;
+ ASSERT_EQ(
+ SegmentGraph(graph_def,
+ MakeCandidateFn({"add0", "add1", "add2", "add3", "add4"}),
+ default_options_, &segments),
+ tensorflow::Status::OK());
+
+ // Expect all Add operations to be collapsed into a single segment
+ ASSERT_EQ(segments.size(), 1);
+ std::vector<string> expected{"add0", "add1", "add2", "add3", "add4"};
+ for (const auto& ex : expected) {
+ EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
+ << "Missing expected node " << ex;
+ }
+ TF_DeleteGraph(graph);
+ TF_DeleteStatus(s);
+}
+
+TEST_F(SegmentTest, AvoidCycle) {
+ TF_Status* s = TF_NewStatus();
+ TF_Graph* graph = TF_NewGraph();
+
+ // add2 is not a TRT candidate so add0/add3 cannot be formed as a
+ // subgraph
+ //
+ // feed
+ // // ||
+ // add0 add1
+ // | | /
+ // | add2
+ // | / ||
+ // add3 add4
+ // | /
+ // <sink>
+ //
+ TF_Operation* feed = Placeholder(graph, s, "feed");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
+
+ TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
+ TF_Operation* add4 = Add(add2, add2, graph, s, "add4");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
+
+ GraphDef graph_def;
+ ASSERT_TRUE(GetGraphDef(graph, &graph_def));
+
+ SegmentNodesVector segments;
+ ASSERT_EQ(
+ SegmentGraph(graph_def, MakeCandidateFn({"add0", "add1", "add3", "add4"}),
+ default_options_, &segments),
+ tensorflow::Status::OK());
+
+ // Expect no subgraphs
+ EXPECT_EQ(segments.size(), 0);
+ TF_DeleteGraph(graph);
+ TF_DeleteStatus(s);
+}
+
+TEST_F(SegmentTest, Multiple) {
+ TF_Status* s = TF_NewStatus();
+ TF_Graph* graph = TF_NewGraph();
+
+ // add5 is not a TRT candidate so two subgraphs should be formed
+ //
+ // feed
+ // // || ||
+ // add0 add1 add7
+ // | | / / ||
+ // | add2-----add5 add8
+ // | / | | | |
+ // add3 add4 add6
+ // | | /
+ // <sink>
+ //
+ TF_Operation* feed = Placeholder(graph, s, "feed");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
+
+ TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add7 = Add(feed, feed, graph, s, "add7");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add5 = Add(add2, add7, graph, s, "add5");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add8 = Add(add7, add7, graph, s, "add8");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
+ TF_Operation* add4 = Add(add2, add5, graph, s, "add4");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
+ TF_Operation* add6 = Add(add5, add8, graph, s, "add6");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("add6"), string(TF_OperationName(add6)));
+
+ GraphDef graph_def;
+ ASSERT_TRUE(GetGraphDef(graph, &graph_def));
+
+ SegmentNodesVector segments;
+ ASSERT_EQ(SegmentGraph(graph_def,
+ MakeCandidateFn({"add0", "add1", "add2", "add3",
+ "add4", "add6", "add7", "add8"}),
+ default_options_, &segments),
+ tensorflow::Status::OK());
+
+ // Expect two subgraphs
+ EXPECT_EQ(segments.size(), 2);
+
+ std::vector<string> expected0{"add0", "add1", "add2", "add3"};
+ for (const auto& ex : expected0) {
+ EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
+ << "Missing expected node " << ex;
+ }
+
+ std::vector<string> expected1{"add6", "add8"};
+ for (const auto& ex : expected1) {
+ EXPECT_TRUE(segments[1].find(ex) != segments[1].end())
+ << "Missing expected node " << ex;
+ }
+ TF_DeleteGraph(graph);
+ TF_DeleteStatus(s);
+}
+
+TEST_F(SegmentTest, BigIfElse) {
+ TF_Status* s = TF_NewStatus();
+ TF_Graph* graph = TF_NewGraph();
+
+ // add2 is not a TRT candidate
+ //
+ // feed
+ // ||
+ // add0
+ // // ||
+ // add1 add4
+ // || ||
+ // add2 add5
+ // || ||
+ // add3 add6
+ // || //
+ // add7
+ // ||
+ // <sink>
+ //
+ TF_Operation* feed = Placeholder(graph, s, "feed");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
+
+ TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add1 = Add(add0, add0, graph, s, "add1");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add2 = Add(add1, add1, graph, s, "add2");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add3 = Add(add2, add2, graph, s, "add3");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add4 = Add(add0, add0, graph, s, "add4");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add5 = Add(add4, add4, graph, s, "add5");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add6 = Add(add5, add5, graph, s, "add6");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add7 = Add(add3, add6, graph, s, "add7");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("add7"), string(TF_OperationName(add7)));
+
+ GraphDef graph_def;
+ ASSERT_TRUE(GetGraphDef(graph, &graph_def));
+
+ SegmentNodesVector segments;
+ ASSERT_EQ(SegmentGraph(graph_def,
+ MakeCandidateFn({"add0", "add1", "add3", "add4",
+ "add5", "add6", "add7"}),
+ default_options_, &segments),
+ tensorflow::Status::OK());
+
+ // Expect 2 subgraphs
+ EXPECT_EQ(segments.size(), 2);
+
+ std::vector<string> expected0{"add3", "add4", "add5", "add6", "add7"};
+ for (const auto& ex : expected0) {
+ EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
+ << "Missing expected node " << ex;
+ }
+
+ std::vector<string> expected1{"add0", "add1"};
+ for (const auto& ex : expected1) {
+ EXPECT_TRUE(segments[1].find(ex) != segments[1].end())
+ << "Missing expected node " << ex;
+ }
+ TF_DeleteGraph(graph);
+ TF_DeleteStatus(s);
+}
+
+} // namespace test
+} // namespace segment
+} // namespace tensorrt
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/segment/union_find.h b/tensorflow/contrib/tensorrt/segment/union_find.h
new file mode 100644
index 0000000000..1c64ebbb0a
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/segment/union_find.h
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_
+
+namespace tensorflow {
+namespace tensorrt {
+namespace segment {
+
+// Union-Find data structure.
+// Each cluster has an associated value; when merging clusters we can control
+// which value becomes the representative of the merged clusters. Values must be
+// copyable.
+template <typename T>
+class UnionFind {
+ public:
+ UnionFind() : size_(1), parent_(nullptr) {}
+ explicit UnionFind(const T& v) : size_(1), parent_(nullptr), value_(v) {}
+
+ // Returns the number of elements in a cluster.
+ int Size() { return FindRoot()->size_; }
+
+ // Merges this cluster with 'other'. This cluster's value becomes
+ // the value of the merged cluster; the value of 'other' is ignored.
+ void Merge(UnionFind* other);
+
+ // Each cluster has an associated value. Retrieves the value associated
+ // with this cluster.
+ T& ParentValue() { return FindRoot()->value_; }
+
+ // Get the original value of this node.
+ T& Value() { return value_; }
+
+ private:
+ // Finds the root element of the cluster. Performs path compression.
+ UnionFind* FindRoot();
+
+ int size_;
+ UnionFind* parent_;
+ T value_;
+};
+
+template <typename T>
+void UnionFind<T>::Merge(UnionFind* other) {
+ UnionFind<T>* a = FindRoot();
+ UnionFind<T>* b = other->FindRoot();
+ if (a == b) return;
+
+ b->parent_ = a;
+ a->size_ += b->size_;
+}
+
+template <typename T>
+UnionFind<T>* UnionFind<T>::FindRoot() {
+ if (!parent_) return this;
+ // Path compression: update intermediate nodes to point to the root of the
+ // equivalence class.
+ parent_ = parent_->FindRoot();
+ return parent_;
+}
+
+} // namespace segment
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_
diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
new file mode 100644
index 0000000000..8b475177bc
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
@@ -0,0 +1,89 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h"
+
+#include <string>
+#include <vector>
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorrt/include/NvInfer.h"
+
+namespace tensorflow {
+namespace shape_inference {
+
+tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) {
+ tensorflow::tensorrt::Logger logger;
+ string serialized_engine;
+ TF_RETURN_IF_ERROR(context->GetAttr("serialized_engine", &serialized_engine));
+ nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger);
+ nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine(
+ serialized_engine.c_str(), serialized_engine.size(), nullptr);
+
+ int num_batch = -1;
+ std::vector<::tensorflow::DataType> input_type;
+ TF_RETURN_IF_ERROR(context->GetAttr("InT", &input_type));
+ for (size_t i = 0; i < context->num_inputs(); i++) {
+ // Check if input shape is legit
+ auto input_shape = context->input(i);
+ for (int j = 0; j < context->Rank(input_shape); j++) {
+ auto dim_handler = context->Dim(input_shape, j);
+ if (j == 0) {
+ if (i == 0) {
+ num_batch = context->Value(dim_handler);
+ } else if (num_batch != context->Value(dim_handler)) {
+ // TODO(jie): TensorRT engine requires consistent batch between inputs
+ // tensors. Segmenter should be aware of this.
+ LOG(FATAL) << "TensorRT engine requires consistent batch size";
+ }
+ }
+ }
+ }
+
+ // Arrange input here
+ std::vector<string> input_nodes;
+ TF_RETURN_IF_ERROR(context->GetAttr("input_nodes", &input_nodes));
+
+ // Arrange output here
+ std::vector<string> output_nodes;
+ TF_RETURN_IF_ERROR(context->GetAttr("output_nodes", &output_nodes));
+ for (size_t i = 0; i < output_nodes.size(); i++) {
+ int binding_index = trt_engine->getBindingIndex(output_nodes[i].c_str());
+ ShapeHandle output_shape;
+ std::vector<DimensionHandle> dim_vec;
+ dim_vec.emplace_back(context->MakeDim(num_batch));
+ if (binding_index != -1) {
+ auto dims = trt_engine->getBindingDimensions(binding_index);
+ for (int j = 0; j < dims.nbDims; j++) {
+ dim_vec.emplace_back(context->MakeDim(dims.d[j]));
+ }
+ } else {
+ LOG(FATAL) << "TensorRT engine cannot find binding: " << output_nodes[i];
+ }
+ output_shape = context->MakeShape(dim_vec);
+ context->set_output(i, output_shape);
+ }
+
+ return Status::OK();
+}
+
+} // namespace shape_inference
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h
new file mode 100644
index 0000000000..4b50f66699
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h
@@ -0,0 +1,33 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace shape_inference {
+Status TRTEngineOpShapeInference(InferenceContext* c);
+} // namespace shape_inference
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_
diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py
new file mode 100644
index 0000000000..c78f6f2224
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py
@@ -0,0 +1,88 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Script to test TF-TensorRT integration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+# normally we should do import tensorflow as tf and then
+# tf.placeholder, tf.constant, tf.nn.conv2d etc but
+# it looks like internal builds don't like it so
+# importing every module individually
+
+from tensorflow.contrib import tensorrt as trt
+from tensorflow.core.protobuf import config_pb2 as cpb2
+from tensorflow.python.client import session as csess
+from tensorflow.python.framework import constant_op as cop
+from tensorflow.python.framework import dtypes as dtypes
+from tensorflow.python.framework import importer as importer
+from tensorflow.python.framework import ops as ops
+from tensorflow.python.ops import array_ops as aops
+from tensorflow.python.ops import nn as nn
+from tensorflow.python.ops import nn_ops as nn_ops
+
+
+def get_simple_graph_def():
+ """Create a simple graph and return its graph_def."""
+ g = ops.Graph()
+ with g.as_default():
+ a = aops.placeholder(
+ dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input")
+ e = cop.constant(
+ [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]],
+ name="weights",
+ dtype=dtypes.float32)
+ conv = nn.conv2d(
+ input=a, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv")
+ b = cop.constant(
+ [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtypes.float32)
+ t = nn.bias_add(conv, b, name="biasAdd")
+ relu = nn.relu(t, "relu")
+ idty = aops.identity(relu, "ID")
+ v = nn_ops.max_pool(
+ idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
+ aops.squeeze(v, name="output")
+ return g.as_graph_def()
+
+
+def run_graph(gdef, dumm_inp):
+ gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
+ ops.reset_default_graph()
+ g = ops.Graph()
+ with g.as_default():
+ inp, out = importer.import_graph_def(
+ graph_def=gdef, return_elements=["input", "output"])
+ inp = inp.outputs[0]
+ out = out.outputs[0]
+ with csess.Session(
+ config=cpb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess:
+ val = sess.run(out, {inp: dumm_inp})
+ return val
+
+
+if "__main__" in __name__:
+ inp_dims = (100, 24, 24, 2)
+ dummy_input = np.random.random_sample(inp_dims)
+ gdef = get_simple_graph_def()
+ # Get optimized graph
+ trt_graph = trt.create_inference_graph(gdef, ["output"], inp_dims[0])
+ o1 = run_graph(gdef, dummy_input)
+ o2 = run_graph(trt_graph, dummy_input)
+ o3 = run_graph(trt_graph, dummy_input)
+ assert np.array_equal(o1, o2)
+ assert np.array_equal(o3, o2) # sanity check
+ print("Pass")
diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i
new file mode 100644
index 0000000000..d679945d56
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/trt_conversion.i
@@ -0,0 +1,131 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+/* Wrap trt_conversion */
+%{
+#define SWIG_FILE_WITH_INIT
+%}
+%include "std_pair.i"
+%include "tensorflow/python/platform/base.i"
+
+%{
+PyObject* pair_helper(std::pair<string, string>* in) {
+ PyObject *first(nullptr), *second(nullptr), *tuple(nullptr);
+ first = PyBytes_FromStringAndSize(in->first.data(), in->first.length());
+ if (!first) {
+ if (!PyErr_Occurred()) {
+ PyErr_SetString(PyExc_TypeError, "Pair conversion first argument failed");
+ }
+ return NULL;
+ }
+ second = PyBytes_FromStringAndSize(in->second.data(), in->second.length());
+ if (!second) {
+ if (!PyErr_Occurred()) {
+ PyErr_SetString(PyExc_TypeError,
+ "Pair conversion second argument failed");
+ }
+ return NULL;
+ }
+ tuple = Py_BuildValue("(OO)", first, second);
+ if (!tuple) {
+ if (!PyErr_Occurred()) {
+ PyErr_SetString(PyExc_TypeError,
+ "Tuple creation from pair<string,string> failed!");
+ }
+ return NULL;
+ }
+ return tuple;
+}
+%}
+%typemap(out) std::pair<string, string> {
+ PyObject *tuple = pair_helper(&$1);
+ if (!tuple) SWIG_fail;
+ $result = tuple;
+}
+%{
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/stat_summarizer.h"
+#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
+%}
+
+%ignoreall
+%unignore tensorflow;
+%unignore trt_convert;
+
+%{
+std::pair<string, string> trt_convert(
+ string graph_def_string, // The serialized GraphDef string.
+ std::vector<string> output_names,
+ size_t max_batch_size,
+ size_t max_workspace_size_bytes
+ // Unfortunately we can't use TF_Status here since it
+ // is in c/c_api and brings in a lot of other libraries
+ // which in turn declare ops. These ops are included
+ // statically in our library and cause an abort when
+ // module is loaded due to double registration
+ // until Tensorflow properly exposes these headers
+ // we have to work around this by returning a string
+ // and converting it to exception on python side.
+ //,TF_Status* out_status) {
+) {
+#if GOOGLE_CUDA && GOOGLE_TENSORRT
+ string out_status;
+
+ tensorflow::GraphDef graph_def;
+ if (!graph_def.ParseFromString(graph_def_string)) {
+ out_status = "InvalidArgument;Couldn't interpret input as a GraphDef";
+ return std::pair<string, string>{out_status, ""};
+ }
+
+ if (!output_names.size()) {
+ out_status = "InvalidArgument;Size of the output_names vector is 0";
+ return std::pair<string, string>{out_status, ""};
+ // return "";
+ }
+ tensorflow::GraphDef outGraph;
+ tensorflow::Status conversion_status =
+ tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT(
+ graph_def, output_names, max_batch_size, max_workspace_size_bytes,
+ &outGraph);
+ if (!conversion_status.ok()) {
+ auto retCode = (int)conversion_status.code();
+ char buff[2000];
+ snprintf(buff, 2000, "%d;%s", retCode,
+ conversion_status.error_message().c_str());
+ out_status = buff;
+ return std::pair<string, string>{out_status, ""};
+ }
+ string result;
+ if (!outGraph.SerializeToString(&result)) {
+ out_status = "InvalidArgument;Couldn't serialize output as a GraphDef";
+ return std::pair<string, string>{out_status, ""};
+ }
+ out_status = "OK;All good!";
+ return std::pair<string, string>{out_status, result};
+#else
+ // Returns FAILED_PRECONDITION.
+ return std::pair<string, string>{"9;TensorRT is not enabled!", ""};
+#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
+}
+%}
+
+std::pair<string, string> trt_convert(string graph_def_string,
+ std::vector<string> output_names,
+ size_t max_batch_size,
+ size_t max_workspace_size_bytes);
+
+
+%unignoreall
diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py
index 76f1dd2a56..8d99835b64 100644
--- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py
+++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py
@@ -20,7 +20,7 @@ from __future__ import print_function
from setuptools import setup
-_VERSION = '1.6.0-rc0'
+_VERSION = '1.6.0-rc1'
CONSOLE_SCRIPTS = [
'capture_tpu_profile=cloud_tpu_profiler.main:run_main',