aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--RELEASE.md27
-rw-r--r--configure.py25
-rw-r--r--tensorflow/c/c_api_test.cc2
-rw-r--r--tensorflow/compiler/xla/literal_util.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h9
-rw-r--r--tensorflow/contrib/BUILD5
-rw-r--r--tensorflow/contrib/android/README.md5
-rw-r--r--tensorflow/contrib/boosted_trees/python/utils/losses.py4
-rw-r--r--tensorflow/contrib/cmake/CMakeLists.txt22
-rw-r--r--tensorflow/contrib/cmake/external/zlib.cmake108
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt3
-rw-r--r--tensorflow/contrib/crf/python/ops/crf.py12
-rw-r--r--tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py2
-rw-r--r--tensorflow/contrib/eager/python/g3doc/guide.md9
-rw-r--r--tensorflow/contrib/factorization/python/ops/clustering_ops.py4
-rw-r--r--tensorflow/contrib/framework/__init__.py6
-rw-r--r--tensorflow/contrib/framework/python/framework/graph_util.py12
-rw-r--r--tensorflow/contrib/image/kernels/bipartite_match_op.cc2
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py12
-rw-r--r--tensorflow/contrib/lite/README.md17
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_imagenet_slim.txt1001
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_mobilenet_quant_v1_224.txt1001
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java6
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java137
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierFloatInception.java103
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java94
-rw-r--r--tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py48
-rw-r--r--tensorflow/contrib/makefile/README.md99
-rwxr-xr-xtensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh6
-rw-r--r--tensorflow/contrib/rnn/python/ops/lstm_ops.py5
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py5
-rw-r--r--tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc8
-rw-r--r--tensorflow/contrib/signal/python/ops/spectral_ops.py2
-rw-r--r--tensorflow/contrib/slim/python/slim/evaluation_test.py3
-rw-r--r--tensorflow/contrib/tensor_forest/BUILD1
-rw-r--r--tensorflow/contrib/tensorrt/BUILD204
-rw-r--r--tensorflow/contrib/tensorrt/README.md40
-rw-r--r--tensorflow/contrib/tensorrt/__init__.py23
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc273
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.h47
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc1601
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.h52
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc140
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.h62
-rw-r--r--tensorflow/contrib/tensorrt/log/trt_logger.cc57
-rw-r--r--tensorflow/contrib/tensorrt/log/trt_logger.h42
-rw-r--r--tensorflow/contrib/tensorrt/ops/trt_engine_op.cc43
-rw-r--r--tensorflow/contrib/tensorrt/python/__init__.py24
-rw-r--r--tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py34
-rw-r--r--tensorflow/contrib/tensorrt/python/trt_convert.py103
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.cc253
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.h56
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment_test.cc367
-rw-r--r--tensorflow/contrib/tensorrt/segment/union_find.h79
-rw-r--r--tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc89
-rw-r--r--tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h33
-rw-r--r--tensorflow/contrib/tensorrt/test/test_tftrt.py88
-rw-r--r--tensorflow/contrib/tensorrt/trt_conversion.i131
-rw-r--r--tensorflow/contrib/tpu/profiler/pip_package/setup.py2
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_id.h2
-rw-r--r--tensorflow/core/common_runtime/mkl_cpu_allocator.h1
-rw-r--r--tensorflow/core/framework/tensor_shape.h3
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass.cc2
-rw-r--r--tensorflow/core/kernels/colorspace_op.cc2
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc8
-rw-r--r--tensorflow/core/kernels/mkl_batch_matmul_op.cc28
-rw-r--r--tensorflow/core/kernels/mkl_input_conversion_op.cc4
-rw-r--r--tensorflow/core/kernels/mkl_matmul_op.cc28
-rw-r--r--tensorflow/core/kernels/mkl_tfconv_op.h2
-rw-r--r--tensorflow/core/kernels/mkl_transpose_op.cc34
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op.cc5
-rw-r--r--tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/quantized_resize_bilinear_op.cc4
-rw-r--r--tensorflow/core/kernels/random_crop_op.cc4
-rw-r--r--tensorflow/core/kernels/resize_area_op.cc5
-rw-r--r--tensorflow/core/kernels/resize_bicubic_op.cc10
-rw-r--r--tensorflow/core/kernels/resize_bilinear_op.cc10
-rw-r--r--tensorflow/core/kernels/resize_nearest_neighbor_op.cc8
-rw-r--r--tensorflow/core/kernels/sample_distorted_bounding_box_op.cc6
-rw-r--r--tensorflow/core/kernels/slice_op.cc10
-rw-r--r--tensorflow/core/kernels/substr_op.cc20
-rw-r--r--tensorflow/core/kernels/xsmm_conv2d.cc12
-rw-r--r--tensorflow/core/lib/io/record_writer.cc2
-rw-r--r--tensorflow/core/ops/image_ops.cc8
-rw-r--r--tensorflow/core/platform/platform.h7
-rw-r--r--tensorflow/core/protobuf/config.proto2
-rw-r--r--tensorflow/core/public/version.h2
-rw-r--r--tensorflow/core/util/mkl_util.h10
-rw-r--r--tensorflow/docs_src/about/roadmap.md101
-rw-r--r--tensorflow/docs_src/about/uses.md8
-rw-r--r--tensorflow/docs_src/deploy/index.md2
-rw-r--r--tensorflow/docs_src/deploy/leftnav_files1
-rw-r--r--tensorflow/docs_src/deploy/s3.md40
-rw-r--r--tensorflow/docs_src/extend/add_filesys.md2
-rw-r--r--tensorflow/docs_src/install/install_c.md2
-rw-r--r--tensorflow/docs_src/install/install_go.md2
-rw-r--r--tensorflow/docs_src/install/install_java.md22
-rw-r--r--tensorflow/docs_src/install/install_linux.md23
-rw-r--r--tensorflow/docs_src/install/install_mac.md10
-rw-r--r--tensorflow/docs_src/install/install_sources.md16
-rw-r--r--tensorflow/docs_src/install/install_windows.md4
-rw-r--r--tensorflow/docs_src/mobile/mobile_intro.md2
-rw-r--r--tensorflow/docs_src/programmers_guide/low_level_intro.md6
-rw-r--r--tensorflow/docs_src/tutorials/layers.md2
-rw-r--r--tensorflow/examples/android/res/animator/color_animation.xml30
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java21
-rw-r--r--tensorflow/examples/get_started/regression/imports85.py11
-rw-r--r--tensorflow/examples/image_retraining/retrain.py55
-rw-r--r--tensorflow/examples/speech_commands/label_wav_dir.py136
-rw-r--r--tensorflow/examples/speech_commands/train.py6
-rw-r--r--tensorflow/examples/udacity/5_word2vec.ipynb2
-rw-r--r--tensorflow/python/estimator/estimator.py2
-rw-r--r--tensorflow/python/estimator/estimator_test.py15
-rw-r--r--tensorflow/python/framework/common_shapes.py2
-rw-r--r--tensorflow/python/framework/function_test.py2
-rw-r--r--tensorflow/python/kernel_tests/reduction_ops_test.py88
-rw-r--r--tensorflow/python/kernel_tests/reduction_ops_test_big.py18
-rw-r--r--tensorflow/python/layers/core.py9
-rw-r--r--tensorflow/python/layers/normalization.py8
-rw-r--r--tensorflow/python/layers/utils.py83
-rw-r--r--tensorflow/python/ops/clip_ops.py2
-rw-r--r--tensorflow/python/ops/control_flow_ops.py93
-rw-r--r--tensorflow/python/ops/control_flow_ops_test.py36
-rw-r--r--tensorflow/python/ops/data_flow_ops.py2
-rw-r--r--tensorflow/python/ops/distributions/multinomial.py2
-rw-r--r--tensorflow/python/ops/distributions/util.py11
-rw-r--r--tensorflow/python/ops/image_ops_impl.py183
-rw-r--r--tensorflow/python/ops/image_ops_test.py165
-rw-r--r--tensorflow/python/ops/losses/losses_impl.py25
-rw-r--r--tensorflow/python/ops/math_ops_test.py2
-rw-r--r--tensorflow/python/ops/nn_grad.py14
-rw-r--r--tensorflow/python/ops/nn_ops.py30
-rw-r--r--tensorflow/python/ops/nn_test.py25
-rw-r--r--tensorflow/python/profiler/option_builder.py2
-rw-r--r--tensorflow/python/tools/freeze_graph.py20
-rw-r--r--tensorflow/python/training/saver.py12
-rw-r--r--tensorflow/tensorflow.bzl63
-rwxr-xr-xtensorflow/tools/ci_build/install/install_bazel.sh2
-rw-r--r--tensorflow/tools/graph_transforms/BUILD1
-rw-r--r--tensorflow/tools/graph_transforms/README.md7
-rw-r--r--tensorflow/tools/graph_transforms/remove_control_dependencies.cc47
-rw-r--r--tensorflow/tools/graph_transforms/remove_nodes.cc12
-rw-r--r--tensorflow/tools/pip_package/BUILD5
-rw-r--r--tensorflow/tools/pip_package/setup.py8
-rw-r--r--third_party/gpus/cuda_configure.bzl19
-rw-r--r--third_party/tensorrt/BUILD.tpl34
-rw-r--r--third_party/tensorrt/LICENSE203
-rw-r--r--third_party/tensorrt/tensorrt_configure.bzl7
148 files changed, 7965 insertions, 690 deletions
diff --git a/RELEASE.md b/RELEASE.md
index 0720a8c639..6f54dee58f 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -21,7 +21,7 @@ newcomers.
* Other:
* Add `tf.contrib.distributions.Kumaraswamy`.
* `RetryingFileSystem::FlushCaches()` calls the base FileSystem's `FlushCaches()`.
- * Add auto_correlation to distributions.
+ * Add `auto_correlation` to distributions.
* Add `tf.contrib.distributions.Autoregressive`.
* Add SeparableConv1D layer.
* Add convolutional Flipout layers.
@@ -31,12 +31,12 @@ newcomers.
* Output variance over trees predictions for classifications tasks.
* For `pt` and `eval` commands, allow writing tensor values to filesystem as numpy files.
* gRPC: Propagate truncated errors (instead of returning gRPC internal error).
- * Augment parallel_interleave to support 2 kinds of prefetching.
+ * Augment `parallel_interleave` to support 2 kinds of prefetching.
* Improved XLA support for C64-related ops log, pow, atan2, tanh.
* Add probabilistic convolutional layers.
## API Changes
-* Introducing prepare_variance boolean with default setting to False for backward compatibility.
+* Introducing `prepare_variance` boolean with default setting to False for backward compatibility.
* Move `layers_dense_variational_impl.py` to `layers_dense_variational.py`.
## Known Bugs
@@ -96,27 +96,6 @@ Yoni Tsafir, yordun, Yuan (Terry) Tang, Yuxin Wu, zhengdi, Zhengsheng Wei, ç”°ä¼
* Starting from 1.6 release, our prebuilt binaries will use AVX instructions.
This may break TF on older CPUs.
-## Known Bugs
-* Using XLA:GPU with CUDA 9 and CUDA 9.1 results in garbage results and/or
- `CUDA_ILLEGAL_ADDRESS` failures.
-
- Google discovered in mid-December 2017 that the PTX-to-SASS compiler in CUDA 9
- and CUDA 9.1 sometimes does not properly compute the carry bit when
- decomposing 64-bit address calculations with large offsets (e.g. `load [x +
- large_constant]`) into 32-bit arithmetic in SASS.
-
- As a result, these versions of `ptxas` miscompile most XLA programs which use
- more than 4GB of temp memory. This results in garbage results and/or
- `CUDA_ERROR_ILLEGAL_ADDRESS` failures.
-
- A fix in CUDA 9.1.121 is expected in late February 2018. We do not expect a
- fix for CUDA 9.0.x. Until the fix is available, the only workaround is to
- [downgrade](https://developer.nvidia.com/cuda-toolkit-archive) to CUDA 8.0.x
- or disable XLA:GPU.
-
- TensorFlow will print a warning if you use XLA:GPU with a known-bad version of
- CUDA; see e00ba24c4038e7644da417ddc639169b6ea59122.
-
## Major Features And Improvements
* [Eager execution](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/eager)
preview version is now available.
diff --git a/configure.py b/configure.py
index 6b1fa7f1a8..9744f6ac81 100644
--- a/configure.py
+++ b/configure.py
@@ -445,7 +445,7 @@ def convert_version_to_int(version):
def check_bazel_version(min_version):
- """Check installed bezel version is at least min_version.
+ """Check installed bazel version is at least min_version.
Args:
min_version: string for minimum bazel version.
@@ -1078,12 +1078,22 @@ def set_tf_tensorrt_install_path(environ_cp):
break
# Reset and Retry
- print('Invalid path to TensorRT. None of the following files can be found:')
- print(trt_install_path)
- print(os.path.join(trt_install_path, 'lib'))
- print(os.path.join(trt_install_path, 'lib64'))
- if search_result:
- print(libnvinfer_path_from_ldconfig)
+ if possible_files:
+ print('TensorRT libraries found in one the following directories',
+ 'are not compatible with selected cuda and cudnn installations')
+ print(trt_install_path)
+ print(os.path.join(trt_install_path, 'lib'))
+ print(os.path.join(trt_install_path, 'lib64'))
+ if search_result:
+ print(libnvinfer_path_from_ldconfig)
+ else:
+ print(
+ 'Invalid path to TensorRT. None of the following files can be found:')
+ print(trt_install_path)
+ print(os.path.join(trt_install_path, 'lib'))
+ print(os.path.join(trt_install_path, 'lib64'))
+ if search_result:
+ print(libnvinfer_path_from_ldconfig)
else:
raise UserInputError('Invalid TF_TENSORRT setting was provided %d '
@@ -1481,7 +1491,6 @@ def main():
'more details.')
config_info_line('mkl', 'Build with MKL support.')
config_info_line('monolithic', 'Config for mostly static monolithic build.')
- config_info_line('tensorrt', 'Build with TensorRT support.')
if __name__ == '__main__':
main()
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index 69fe5bec51..028f146be3 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -2081,7 +2081,7 @@ TEST_F(CApiAttributesTest, Tensor) {
}
TEST_F(CApiAttributesTest, StringTensor) {
- // Create the string-Tensor "atttribute" value.
+ // Create the string-Tensor "attribute" value.
char encoded[] = {
0, 0, 0, 0, 0, 0, 0, 0, // array[uint64] offsets
1, // varint encoded string length
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index ed9d2a187a..823da43b5a 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -234,7 +234,8 @@ Status Literal::CopySliceFromInternal(
int64 src_index = linear_index(src_literal.shape(), src_indexes);
int64 dest_index = linear_index(shape(), dest_indexes);
- StridedCopy(data<NativeT>(), dest_index, stride_config.dest_stride,
+ // `this->` is needed to workaround MSVC bug: #16882
+ StridedCopy(this->data<NativeT>(), dest_index, stride_config.dest_stride,
src_literal.data<NativeT>(), src_index,
stride_config.source_stride, stride_config.minor_loop_size);
return true;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 1762d227be..c4fe132d1d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -589,12 +589,9 @@ class HloInstruction {
if (opcode() != other.opcode()) {
return false;
}
- auto eq_shapes = layout_sensitive
- ? [](const Shape& a,
- const Shape& b) { return ShapeUtil::Equal(a, b); }
- : [](const Shape& a, const Shape& b) {
- return ShapeUtil::Compatible(a, b);
- };
+ using EqShapeFuncType = bool (*)(const Shape&, const Shape&);
+ EqShapeFuncType eq_shapes =
+ layout_sensitive ? ShapeUtil::Equal : ShapeUtil::Compatible;
if (!eq_shapes(shape(), other.shape())) {
return false;
}
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 6b3343bb2f..bab37e8906 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -7,6 +7,7 @@ package(default_visibility = ["//tensorflow:__subpackages__"])
load("//third_party/mpi:mpi.bzl", "if_mpi")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
+load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt")
py_library(
name = "contrib_py",
@@ -107,7 +108,9 @@ py_library(
"//tensorflow/contrib/training:training_py",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:util",
- ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]),
+ ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_tensorrt([
+ "//tensorflow/contrib/tensorrt:init_py",
+ ]),
)
cc_library(
diff --git a/tensorflow/contrib/android/README.md b/tensorflow/contrib/android/README.md
index b8d73bf24c..db37bcf73d 100644
--- a/tensorflow/contrib/android/README.md
+++ b/tensorflow/contrib/android/README.md
@@ -81,6 +81,11 @@ For documentation on building a self-contained AAR file with cmake, see
[tensorflow/contrib/android/cmake](cmake).
+### Makefile
+
+For documentation on building native TF libraries with make, including a CUDA-enabled variant for devices like the Nvidia Shield TV, see [tensorflow/contrib/makefile/README.md](../makefile/README.md)
+
+
## AssetManagerFileSystem
This directory also contains a TensorFlow filesystem supporting the Android
diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses.py b/tensorflow/contrib/boosted_trees/python/utils/losses.py
index 1e8b3ac08a..ab7ac2aba6 100644
--- a/tensorflow/contrib/boosted_trees/python/utils/losses.py
+++ b/tensorflow/contrib/boosted_trees/python/utils/losses.py
@@ -78,7 +78,7 @@ def per_example_maxent_loss(labels, weights, logits, num_classes, eps=1e-15):
# Calculate softmax probabilities for each class.
unnormalized_probs = math_ops.exp(logits)
- normalizers = math_ops.reduce_sum(unnormalized_probs, 1, keep_dims=True)
+ normalizers = math_ops.reduce_sum(unnormalized_probs, 1, keepdims=True)
softmax_predictions = math_ops.divide(unnormalized_probs,
math_ops.add(normalizers, eps))
@@ -120,7 +120,7 @@ def per_example_squared_loss(labels, weights, predictions):
update_op: An update operation to update the loss's internal state.
"""
unweighted_loss = math_ops.reduce_sum(
- math_ops.square(predictions - labels), 1, keep_dims=True)
+ math_ops.square(predictions - labels), 1, keepdims=True)
return unweighted_loss * weights, control_flow_ops.no_op()
diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt
index 524946a9a5..23b31ae1dc 100644
--- a/tensorflow/contrib/cmake/CMakeLists.txt
+++ b/tensorflow/contrib/cmake/CMakeLists.txt
@@ -52,6 +52,7 @@ if (NOT WIN32)
# for targets that link ${CMAKE_THREAD_LIBS_INIT}.
find_package (Threads)
+ # Options for linking CUDA/CUDNN libraries
option(tensorflow_PATH_STATIC_LIB "Additional library search path for libcudnn_static.a, libnccl_static.a, libculibos.a" /usr/local/cuda/lib64/)
option(tensorflow_CUDNN_INCLUDE "cudnn.h header install path" /usr/include/)
if (NOT tensorflow_CUDNN_INCLUDE)
@@ -73,6 +74,14 @@ if (NOT WIN32)
# option's default value is OFF. Fill it with real default values
set(tensorflow_CUDA_LIBRARY_PATH /usr/local/cuda/lib64)
endif (NOT tensorflow_CUDA_LIBRARY_PATH)
+
+ # Options for linking other libraries
+ option(systemlib_ZLIB "Use the system installed library as shared objects instead of downloading ZLIB and statically linking to it: ZLIB" OFF)
+
+ option(systemlib_ALL "Turn on every possible systemlib_* options" OFF)
+ if (systemlib_ALL)
+ set (systmelib_ZLIB ON)
+ endif (systemlib_ALL)
endif()
if (WIN32)
@@ -188,8 +197,10 @@ if (tensorflow_BUILD_CC_TESTS)
include(googletest)
endif()
+add_definitions(${ADD_CFLAGS})
+link_directories(${ADD_LINK_DIRECTORY})
+
set(tensorflow_EXTERNAL_LIBRARIES
- ${zlib_STATIC_LIBRARIES}
${gif_STATIC_LIBRARIES}
${png_STATIC_LIBRARIES}
${jpeg_STATIC_LIBRARIES}
@@ -203,6 +214,15 @@ set(tensorflow_EXTERNAL_LIBRARIES
${re2_STATIC_LIBRARIES}
${sqlite_STATIC_LIBRARIES}
)
+
+if (systemlib_ZLIB)
+ set(tensorflow_EXTERNAL_LIBRARIES ${tensorflow_EXTERNAL_LIBRARIES}
+ ${ZLIB_LIBRARIES})
+else (systemlib_ZLIB)
+ set(tensorflow_EXTERNAL_LIBRARIES ${tensorflow_EXTERNAL_LIBRARIES}
+ ${zlib_STATIC_LIBRARIES})
+endif (systemlib_ZLIB)
+
set(tensorflow_EXTERNAL_DEPENDENCIES
zlib_copy_headers_to_destination
gif_copy_headers_to_destination
diff --git a/tensorflow/contrib/cmake/external/zlib.cmake b/tensorflow/contrib/cmake/external/zlib.cmake
index c5eb0cbcc7..116d423093 100644
--- a/tensorflow/contrib/cmake/external/zlib.cmake
+++ b/tensorflow/contrib/cmake/external/zlib.cmake
@@ -12,61 +12,75 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-include (ExternalProject)
+if (systemlib_ZLIB)
+ find_package(PkgConfig)
+ pkg_search_module(ZLIB REQUIRED zlib)
+ set(zlib_INCLUDE_DIR ${ZLIB_INCLUDE_DIRS})
+ set(ADD_LINK_DIRECTORY ${ADD_LINK_DIRECTORY} ${ZLIB_LIBRARY_DIRS})
+ set(ADD_CFLAGS ${ADD_CFLAGS} ${ZLIB_CFLAGS_OTHER})
-set(zlib_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/zlib_archive)
-set(ZLIB_URL https://github.com/madler/zlib)
-set(ZLIB_BUILD ${CMAKE_CURRENT_BINARY_DIR}/zlib/src/zlib)
-set(ZLIB_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/zlib/install)
-set(ZLIB_TAG 50893291621658f355bc5b4d450a8d06a563053d)
+ # To meet DEPENDS zlib from other projects.
+ # If we hit this line, zlib is already built and installed to the system.
+ add_custom_target(zlib)
+ add_custom_target(zlib_copy_headers_to_destination)
-if(WIN32)
- if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*")
- set(zlib_STATIC_LIBRARIES
- debug ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstaticd.lib
- optimized ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstatic.lib)
- else()
- if(CMAKE_BUILD_TYPE EQUAL Debug)
+else (systemlib_ZLIB)
+ include (ExternalProject)
+
+ set(zlib_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/zlib_archive)
+ set(ZLIB_URL https://github.com/madler/zlib)
+ set(ZLIB_BUILD ${CMAKE_CURRENT_BINARY_DIR}/zlib/src/zlib)
+ set(ZLIB_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/zlib/install)
+ set(ZLIB_TAG 50893291621658f355bc5b4d450a8d06a563053d)
+
+ if(WIN32)
+ if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*")
set(zlib_STATIC_LIBRARIES
- ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstaticd.lib)
+ debug ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstaticd.lib
+ optimized ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstatic.lib)
else()
- set(zlib_STATIC_LIBRARIES
- ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstatic.lib)
+ if(CMAKE_BUILD_TYPE EQUAL Debug)
+ set(zlib_STATIC_LIBRARIES
+ ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstaticd.lib)
+ else()
+ set(zlib_STATIC_LIBRARIES
+ ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstatic.lib)
+ endif()
endif()
+ else()
+ set(zlib_STATIC_LIBRARIES
+ ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/libz.a)
endif()
-else()
- set(zlib_STATIC_LIBRARIES
- ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/libz.a)
-endif()
-set(ZLIB_HEADERS
- "${ZLIB_INSTALL}/include/zconf.h"
- "${ZLIB_INSTALL}/include/zlib.h"
-)
+ set(ZLIB_HEADERS
+ "${ZLIB_INSTALL}/include/zconf.h"
+ "${ZLIB_INSTALL}/include/zlib.h"
+ )
-ExternalProject_Add(zlib
- PREFIX zlib
- GIT_REPOSITORY ${ZLIB_URL}
- GIT_TAG ${ZLIB_TAG}
- INSTALL_DIR ${ZLIB_INSTALL}
- BUILD_IN_SOURCE 1
- BUILD_BYPRODUCTS ${zlib_STATIC_LIBRARIES}
- DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
- CMAKE_CACHE_ARGS
- -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE}
- -DCMAKE_BUILD_TYPE:STRING=Release
- -DCMAKE_INSTALL_PREFIX:STRING=${ZLIB_INSTALL}
-)
+ ExternalProject_Add(zlib
+ PREFIX zlib
+ GIT_REPOSITORY ${ZLIB_URL}
+ GIT_TAG ${ZLIB_TAG}
+ INSTALL_DIR ${ZLIB_INSTALL}
+ BUILD_IN_SOURCE 1
+ BUILD_BYPRODUCTS ${zlib_STATIC_LIBRARIES}
+ DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
+ CMAKE_CACHE_ARGS
+ -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE}
+ -DCMAKE_BUILD_TYPE:STRING=Release
+ -DCMAKE_INSTALL_PREFIX:STRING=${ZLIB_INSTALL}
+ )
-# put zlib includes in the directory where they are expected
-add_custom_target(zlib_create_destination_dir
- COMMAND ${CMAKE_COMMAND} -E make_directory ${zlib_INCLUDE_DIR}
- DEPENDS zlib)
+ # put zlib includes in the directory where they are expected
+ add_custom_target(zlib_create_destination_dir
+ COMMAND ${CMAKE_COMMAND} -E make_directory ${zlib_INCLUDE_DIR}
+ DEPENDS zlib)
-add_custom_target(zlib_copy_headers_to_destination
- DEPENDS zlib_create_destination_dir)
+ add_custom_target(zlib_copy_headers_to_destination
+ DEPENDS zlib_create_destination_dir)
-foreach(header_file ${ZLIB_HEADERS})
- add_custom_command(TARGET zlib_copy_headers_to_destination PRE_BUILD
- COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${zlib_INCLUDE_DIR})
-endforeach()
+ foreach(header_file ${ZLIB_HEADERS})
+ add_custom_command(TARGET zlib_copy_headers_to_destination PRE_BUILD
+ COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${zlib_INCLUDE_DIR})
+ endforeach()
+endif (systemlib_ZLIB)
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index f55043c93d..bfe53c01b3 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -413,6 +413,9 @@ tensorflow/contrib/tensorboard
tensorflow/contrib/tensorboard/plugins
tensorflow/contrib/tensorboard/plugins/projector
tensorflow/contrib/tensorboard/plugins/trace
+# TODO(sami): Add cmake implementations.
+# tensorflow/contrib/tensorrt/python
+# tensorflow/contrib/tensorrt/python/ops
tensorflow/contrib/tensor_forest
tensorflow/contrib/tensor_forest/client
tensorflow/contrib/tensor_forest/hybrid
diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py
index faa78769b9..1233c8f251 100644
--- a/tensorflow/contrib/crf/python/ops/crf.py
+++ b/tensorflow/contrib/crf/python/ops/crf.py
@@ -105,8 +105,8 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths,
return utils.smart_cond(
pred=math_ops.equal(inputs.shape[1].value or array_ops.shape(inputs)[1],
1),
- fn1=_single_seq_fn,
- fn2=_multi_seq_fn)
+ true_fn=_single_seq_fn,
+ false_fn=_multi_seq_fn)
def crf_log_norm(inputs, sequence_lengths, transition_params):
@@ -511,7 +511,7 @@ def crf_decode(potentials, transition_params, sequence_length):
return decode_tags, best_score
return utils.smart_cond(
- pred=math_ops.equal(
- potentials.shape[1].value or array_ops.shape(potentials)[1], 1),
- fn1=_single_seq_fn,
- fn2=_multi_seq_fn)
+ pred=math_ops.equal(potentials.shape[1].value or
+ array_ops.shape(potentials)[1], 1),
+ true_fn=_single_seq_fn,
+ false_fn=_multi_seq_fn)
diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
index b6becfa9fc..2aa771a71e 100644
--- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
+++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
@@ -278,7 +278,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution):
* math_ops.log(self.temperature))
# compute the unnormalized density
log_softmax = nn_ops.log_softmax(logits_2d - x_2d * self._temperature_2d)
- log_unnorm_prob = math_ops.reduce_sum(log_softmax, [-1], keep_dims=False)
+ log_unnorm_prob = math_ops.reduce_sum(log_softmax, [-1], keepdims=False)
# combine unnormalized density with normalization constant
log_prob = log_norm_const + log_unnorm_prob
# Reshapes log_prob to be consistent with shape of user-supplied logits
diff --git a/tensorflow/contrib/eager/python/g3doc/guide.md b/tensorflow/contrib/eager/python/g3doc/guide.md
index 4724aa4aee..ebb05051f2 100644
--- a/tensorflow/contrib/eager/python/g3doc/guide.md
+++ b/tensorflow/contrib/eager/python/g3doc/guide.md
@@ -22,11 +22,10 @@ to models defined without using eager execution.
Eager execution is included in TensorFlow versions 1.5 and above.
Installation instructions at https://www.tensorflow.org/install/
-The contents of this guide are compatible with TensorFlow 1.5.
-However, if you run into bugs that are fixed in source but not the
-release, you may want to either either [building from
-source](https://www.tensorflow.org/install/install_sources)
-or the try latest nightly builds. The nightly builds are available as:
+The contents of this guide are compatible with TensorFlow 1.5. However, if you
+run into bugs that are fixed in source but not the release, you may want to
+either [build from source](https://www.tensorflow.org/install/install_sources)
+or try a nightly build. The nightly builds are available as:
- [`pip` packages](https://github.com/tensorflow/tensorflow/blob/master/README.md#installation) and
diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py
index 6d3acb2750..23137e0a97 100644
--- a/tensorflow/contrib/factorization/python/ops/clustering_ops.py
+++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py
@@ -192,11 +192,11 @@ class KMeans(object):
# Computes Euclidean distance. Note the first and third terms are
# broadcast additions.
squared_distance = (
- math_ops.reduce_sum(math_ops.square(inp), 1, keep_dims=True) -
+ math_ops.reduce_sum(math_ops.square(inp), 1, keepdims=True) -
2 * math_ops.matmul(inp, clusters, transpose_b=True) +
array_ops.transpose(
math_ops.reduce_sum(
- math_ops.square(clusters), 1, keep_dims=True)))
+ math_ops.square(clusters), 1, keepdims=True)))
output.append(squared_distance)
return output
diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py
index fb101c3653..deeb5bec79 100644
--- a/tensorflow/contrib/framework/__init__.py
+++ b/tensorflow/contrib/framework/__init__.py
@@ -85,6 +85,8 @@ See the @{$python/contrib.framework} guide.
@@py_func
@@sort
+@@get_placeholders
+
@@CriticalSection
@@BoundedTensorSpec
@@ -102,10 +104,10 @@ from tensorflow.contrib.framework.python.ops import *
from tensorflow.python.framework.ops import prepend_name_scope
from tensorflow.python.framework.ops import strip_name_scope
-
from tensorflow.python.framework.tensor_spec import BoundedTensorSpec
from tensorflow.python.framework.tensor_spec import TensorSpec
-
+from tensorflow.python.ops.control_flow_ops import smart_cond
+from tensorflow.python.ops.control_flow_ops import smart_constant_value
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = ['nest']
diff --git a/tensorflow/contrib/framework/python/framework/graph_util.py b/tensorflow/contrib/framework/python/framework/graph_util.py
index a18ff2320d..49eec3a3f1 100644
--- a/tensorflow/contrib/framework/python/framework/graph_util.py
+++ b/tensorflow/contrib/framework/python/framework/graph_util.py
@@ -133,6 +133,18 @@ def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes,
def get_placeholders(graph):
"""Get placeholders of a graph.
+ For example:
+
+ ```python
+ a = tf.placeholder(dtype=tf.float32, shape=[2, 2], name='a')
+ a = tf.placeholder(dtype=tf.int32, shape=[3, 2], name='b')
+
+ tf.contrib.framework.get_placeholders(tf.get_default_graph())
+ # Returns:
+ # [<tf.Tensor 'a:0' shape=(2, 2) dtype=float32>,
+ # <tf.Tensor 'b:0' shape=(3, 2) dtype=int32>]
+ ```
+
Args:
graph: A tf.Graph.
Returns:
diff --git a/tensorflow/contrib/image/kernels/bipartite_match_op.cc b/tensorflow/contrib/image/kernels/bipartite_match_op.cc
index 7d207c388b..726adb0777 100644
--- a/tensorflow/contrib/image/kernels/bipartite_match_op.cc
+++ b/tensorflow/contrib/image/kernels/bipartite_match_op.cc
@@ -85,7 +85,7 @@ class BipartiteMatchOp : public OpKernel {
context->allocate_output(1, TensorShape({num_input_columns}),
&column_to_row_match_indices));
- typename TTypes<float, 2>::ConstTensor distance_mat =
+ TTypes<float, 2>::ConstTensor distance_mat =
input_distance_mat.shaped<float, 2>(
{num_input_rows, num_input_columns});
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 45ddfbfc9f..b2ea75c7e1 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -517,8 +517,8 @@ def batch_norm(inputs,
then the batch normalization uses weighted mean and
variance. (This can be used to correct for bias in training
example selection.)
- fused: if `True`, use a faster, fused implementation if possible.
- If `None`, use the system recommended implementation.
+ fused: if `None` or `True`, use a faster, fused implementation if possible.
+ If `False`, use the system recommended implementation.
data_format: A string. `NHWC` (default) and `NCHW` are supported.
zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new
pair of variables 'moving_mean/biased' and 'moving_mean/local_step'.
@@ -778,7 +778,7 @@ def batch_norm(inputs,
else:
if data_format == DATA_FORMAT_NCHW:
mean, variance = nn.weighted_moments(
- inputs, moments_axes, batch_weights, keep_dims=True)
+ inputs, moments_axes, batch_weights, keepdims=True)
mean = array_ops.reshape(mean, [-1])
variance = array_ops.reshape(variance, [-1])
else:
@@ -2836,9 +2836,9 @@ def spatial_softmax(features,
softmax_attention = nn.softmax(features / temperature)
expected_x = math_ops.reduce_sum(
- pos_x * softmax_attention, [1], keep_dims=True)
+ pos_x * softmax_attention, [1], keepdims=True)
expected_y = math_ops.reduce_sum(
- pos_y * softmax_attention, [1], keep_dims=True)
+ pos_y * softmax_attention, [1], keepdims=True)
expected_xy = array_ops.concat([expected_x, expected_y], 1)
feature_keypoints = array_ops.reshape(expected_xy,
[-1, num_channels.value * 2])
@@ -3018,7 +3018,7 @@ def poincare_normalize(x, axis=1, epsilon=1e-5, name=None):
"""
with ops.name_scope(name, 'poincare_normalize', [x]) as name:
x = ops.convert_to_tensor(x, name='x')
- square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keep_dims=True)
+ square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims=True)
x_inv_norm = math_ops.rsqrt(square_sum)
x_inv_norm = math_ops.minimum((1. - epsilon) * x_inv_norm, 1.)
return math_ops.multiply(x, x_inv_norm, name=name)
diff --git a/tensorflow/contrib/lite/README.md b/tensorflow/contrib/lite/README.md
index 3e55d2a496..00e93d2c4f 100644
--- a/tensorflow/contrib/lite/README.md
+++ b/tensorflow/contrib/lite/README.md
@@ -6,7 +6,7 @@ TensorFlow Lite uses many techniques for achieving low latency like optimizing t
![image](g3doc/TFLite-Architecture.jpg)
# Getting Started with an Android Demo App
-This section contains an example application using TensorFlow Lite for Android devices. The demo is a sample camera app that classifies images continuously using a quantized Mobilenet model. A device running Android 5.0 ( API 21) or higher is required to run the demo.
+This section contains an example application using TensorFlow Lite for Android devices. The demo is a sample camera app that classifies images continuously using either a quantized Mobilenet model or a floating point Inception-v3 model. A device running Android 5.0 ( API 21) or higher is required to run the demo.
There are 3 ways to get the demo app to your device
- Download the prebuilt binary or
@@ -29,9 +29,16 @@ The simplest way to compile the demo app, and try out changes to the project cod
- Make sure the Android SDK version is greater than 26 and NDK version is greater than 14 (in the Android Studio Settings).
- Import the `tensorflow/contrib/lite/java/demo` directory as a new Android Studio project.
- Click through installing all the Gradle extensions it requests.
- - Download the quantized Mobilenet TensorFlow Lite model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip)
- - unzip and copy mobilenet_quant_v1_224.tflite to the assets directory:
- `tensorflow/contrib/lite/java/demo/app/src/main/assets/`
+ - Either
+ - Download the quantized Mobilenet TensorFlow Lite model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip)
+ - unzip and copy mobilenet_quant_v1_224.tflite to the assets directory:
+ `tensorflow/contrib/lite/java/demo/app/src/main/assets/`
+ - Or download the floating point Inception-v3 model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_slim_2016_android_2017_11_10.zip)
+ - unzip and copy inceptionv3_non_slim_2015.tflite to the assets directory
+ - change the chosen classifier in [Camera2BasicFragment.java](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java) from
+ `classifier = new ImageClassifierQuantizedMobileNet(getActivity());`
+ to
+ `classifier = new ImageClassifierFloatInception(getActivity());`
- Build and run the demo app
## Building TensorFlow Lite and the demo app from source
@@ -84,7 +91,7 @@ Currently, we only support building the Android demo app within a Python 2
environment (due to a Bazel bug).
### More about the demo
-The demo is resizing each camera image frame to (224 width * 224 height) to match the quantized Mobilenet model being used. The resized image is converted into a ByteBuffer row by row of size 1 * 224 * 224 * 3 bytes, where 1 is the number of images in a batch 224 * 224 is the width and height of the image 3 bytes represents three colors of a pixel. This demo uses the TensorFlow Lite Java inference API for models which take a single input and provide a single output. This outputs a two-dimensional array, with the first dimension being the category index and the second dimension being the confidence of classification. The Mobilenet model has 1001 unique categories and the app sorts the probabilities of all the categories and displays the top three. The Mobilenet quantized model is bundled within the assets directory of the app.
+The demo is resizing each camera image frame to (224 width * 224 height) to match the quantized Mobilenet model being used (229 * 229 for Inception-v3). The resized image is converted into a ByteBuffer row by row of size 1 * 224 * 224 * 3 bytes, where 1 is the number of images in a batch. 224 * 224 (299 * 299) is the width and height of the image. 3 bytes represents three colors of a pixel. This demo uses the TensorFlow Lite Java inference API for models which take a single input and provide a single output. This outputs a two-dimensional array, with the first dimension being the category index and the second dimension being the confidence of classification. Both models have 1001 unique categories and the app sorts the probabilities of all the categories and displays the top three. The model file must be downloaded and bundled within the assets directory of the app.
# iOS Demo App
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_imagenet_slim.txt b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_imagenet_slim.txt
new file mode 100644
index 0000000000..572eccf900
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_imagenet_slim.txt
@@ -0,0 +1,1001 @@
+dummy
+tench
+goldfish
+great white shark
+tiger shark
+hammerhead
+electric ray
+stingray
+cock
+hen
+ostrich
+brambling
+goldfinch
+house finch
+junco
+indigo bunting
+robin
+bulbul
+jay
+magpie
+chickadee
+water ouzel
+kite
+bald eagle
+vulture
+great grey owl
+European fire salamander
+common newt
+eft
+spotted salamander
+axolotl
+bullfrog
+tree frog
+tailed frog
+loggerhead
+leatherback turtle
+mud turtle
+terrapin
+box turtle
+banded gecko
+common iguana
+American chameleon
+whiptail
+agama
+frilled lizard
+alligator lizard
+Gila monster
+green lizard
+African chameleon
+Komodo dragon
+African crocodile
+American alligator
+triceratops
+thunder snake
+ringneck snake
+hognose snake
+green snake
+king snake
+garter snake
+water snake
+vine snake
+night snake
+boa constrictor
+rock python
+Indian cobra
+green mamba
+sea snake
+horned viper
+diamondback
+sidewinder
+trilobite
+harvestman
+scorpion
+black and gold garden spider
+barn spider
+garden spider
+black widow
+tarantula
+wolf spider
+tick
+centipede
+black grouse
+ptarmigan
+ruffed grouse
+prairie chicken
+peacock
+quail
+partridge
+African grey
+macaw
+sulphur-crested cockatoo
+lorikeet
+coucal
+bee eater
+hornbill
+hummingbird
+jacamar
+toucan
+drake
+red-breasted merganser
+goose
+black swan
+tusker
+echidna
+platypus
+wallaby
+koala
+wombat
+jellyfish
+sea anemone
+brain coral
+flatworm
+nematode
+conch
+snail
+slug
+sea slug
+chiton
+chambered nautilus
+Dungeness crab
+rock crab
+fiddler crab
+king crab
+American lobster
+spiny lobster
+crayfish
+hermit crab
+isopod
+white stork
+black stork
+spoonbill
+flamingo
+little blue heron
+American egret
+bittern
+crane
+limpkin
+European gallinule
+American coot
+bustard
+ruddy turnstone
+red-backed sandpiper
+redshank
+dowitcher
+oystercatcher
+pelican
+king penguin
+albatross
+grey whale
+killer whale
+dugong
+sea lion
+Chihuahua
+Japanese spaniel
+Maltese dog
+Pekinese
+Shih-Tzu
+Blenheim spaniel
+papillon
+toy terrier
+Rhodesian ridgeback
+Afghan hound
+basset
+beagle
+bloodhound
+bluetick
+black-and-tan coonhound
+Walker hound
+English foxhound
+redbone
+borzoi
+Irish wolfhound
+Italian greyhound
+whippet
+Ibizan hound
+Norwegian elkhound
+otterhound
+Saluki
+Scottish deerhound
+Weimaraner
+Staffordshire bullterrier
+American Staffordshire terrier
+Bedlington terrier
+Border terrier
+Kerry blue terrier
+Irish terrier
+Norfolk terrier
+Norwich terrier
+Yorkshire terrier
+wire-haired fox terrier
+Lakeland terrier
+Sealyham terrier
+Airedale
+cairn
+Australian terrier
+Dandie Dinmont
+Boston bull
+miniature schnauzer
+giant schnauzer
+standard schnauzer
+Scotch terrier
+Tibetan terrier
+silky terrier
+soft-coated wheaten terrier
+West Highland white terrier
+Lhasa
+flat-coated retriever
+curly-coated retriever
+golden retriever
+Labrador retriever
+Chesapeake Bay retriever
+German short-haired pointer
+vizsla
+English setter
+Irish setter
+Gordon setter
+Brittany spaniel
+clumber
+English springer
+Welsh springer spaniel
+cocker spaniel
+Sussex spaniel
+Irish water spaniel
+kuvasz
+schipperke
+groenendael
+malinois
+briard
+kelpie
+komondor
+Old English sheepdog
+Shetland sheepdog
+collie
+Border collie
+Bouvier des Flandres
+Rottweiler
+German shepherd
+Doberman
+miniature pinscher
+Greater Swiss Mountain dog
+Bernese mountain dog
+Appenzeller
+EntleBucher
+boxer
+bull mastiff
+Tibetan mastiff
+French bulldog
+Great Dane
+Saint Bernard
+Eskimo dog
+malamute
+Siberian husky
+dalmatian
+affenpinscher
+basenji
+pug
+Leonberg
+Newfoundland
+Great Pyrenees
+Samoyed
+Pomeranian
+chow
+keeshond
+Brabancon griffon
+Pembroke
+Cardigan
+toy poodle
+miniature poodle
+standard poodle
+Mexican hairless
+timber wolf
+white wolf
+red wolf
+coyote
+dingo
+dhole
+African hunting dog
+hyena
+red fox
+kit fox
+Arctic fox
+grey fox
+tabby
+tiger cat
+Persian cat
+Siamese cat
+Egyptian cat
+cougar
+lynx
+leopard
+snow leopard
+jaguar
+lion
+tiger
+cheetah
+brown bear
+American black bear
+ice bear
+sloth bear
+mongoose
+meerkat
+tiger beetle
+ladybug
+ground beetle
+long-horned beetle
+leaf beetle
+dung beetle
+rhinoceros beetle
+weevil
+fly
+bee
+ant
+grasshopper
+cricket
+walking stick
+cockroach
+mantis
+cicada
+leafhopper
+lacewing
+dragonfly
+damselfly
+admiral
+ringlet
+monarch
+cabbage butterfly
+sulphur butterfly
+lycaenid
+starfish
+sea urchin
+sea cucumber
+wood rabbit
+hare
+Angora
+hamster
+porcupine
+fox squirrel
+marmot
+beaver
+guinea pig
+sorrel
+zebra
+hog
+wild boar
+warthog
+hippopotamus
+ox
+water buffalo
+bison
+ram
+bighorn
+ibex
+hartebeest
+impala
+gazelle
+Arabian camel
+llama
+weasel
+mink
+polecat
+black-footed ferret
+otter
+skunk
+badger
+armadillo
+three-toed sloth
+orangutan
+gorilla
+chimpanzee
+gibbon
+siamang
+guenon
+patas
+baboon
+macaque
+langur
+colobus
+proboscis monkey
+marmoset
+capuchin
+howler monkey
+titi
+spider monkey
+squirrel monkey
+Madagascar cat
+indri
+Indian elephant
+African elephant
+lesser panda
+giant panda
+barracouta
+eel
+coho
+rock beauty
+anemone fish
+sturgeon
+gar
+lionfish
+puffer
+abacus
+abaya
+academic gown
+accordion
+acoustic guitar
+aircraft carrier
+airliner
+airship
+altar
+ambulance
+amphibian
+analog clock
+apiary
+apron
+ashcan
+assault rifle
+backpack
+bakery
+balance beam
+balloon
+ballpoint
+Band Aid
+banjo
+bannister
+barbell
+barber chair
+barbershop
+barn
+barometer
+barrel
+barrow
+baseball
+basketball
+bassinet
+bassoon
+bathing cap
+bath towel
+bathtub
+beach wagon
+beacon
+beaker
+bearskin
+beer bottle
+beer glass
+bell cote
+bib
+bicycle-built-for-two
+bikini
+binder
+binoculars
+birdhouse
+boathouse
+bobsled
+bolo tie
+bonnet
+bookcase
+bookshop
+bottlecap
+bow
+bow tie
+brass
+brassiere
+breakwater
+breastplate
+broom
+bucket
+buckle
+bulletproof vest
+bullet train
+butcher shop
+cab
+caldron
+candle
+cannon
+canoe
+can opener
+cardigan
+car mirror
+carousel
+carpenter's kit
+carton
+car wheel
+cash machine
+cassette
+cassette player
+castle
+catamaran
+CD player
+cello
+cellular telephone
+chain
+chainlink fence
+chain mail
+chain saw
+chest
+chiffonier
+chime
+china cabinet
+Christmas stocking
+church
+cinema
+cleaver
+cliff dwelling
+cloak
+clog
+cocktail shaker
+coffee mug
+coffeepot
+coil
+combination lock
+computer keyboard
+confectionery
+container ship
+convertible
+corkscrew
+cornet
+cowboy boot
+cowboy hat
+cradle
+crane
+crash helmet
+crate
+crib
+Crock Pot
+croquet ball
+crutch
+cuirass
+dam
+desk
+desktop computer
+dial telephone
+diaper
+digital clock
+digital watch
+dining table
+dishrag
+dishwasher
+disk brake
+dock
+dogsled
+dome
+doormat
+drilling platform
+drum
+drumstick
+dumbbell
+Dutch oven
+electric fan
+electric guitar
+electric locomotive
+entertainment center
+envelope
+espresso maker
+face powder
+feather boa
+file
+fireboat
+fire engine
+fire screen
+flagpole
+flute
+folding chair
+football helmet
+forklift
+fountain
+fountain pen
+four-poster
+freight car
+French horn
+frying pan
+fur coat
+garbage truck
+gasmask
+gas pump
+goblet
+go-kart
+golf ball
+golfcart
+gondola
+gong
+gown
+grand piano
+greenhouse
+grille
+grocery store
+guillotine
+hair slide
+hair spray
+half track
+hammer
+hamper
+hand blower
+hand-held computer
+handkerchief
+hard disc
+harmonica
+harp
+harvester
+hatchet
+holster
+home theater
+honeycomb
+hook
+hoopskirt
+horizontal bar
+horse cart
+hourglass
+iPod
+iron
+jack-o'-lantern
+jean
+jeep
+jersey
+jigsaw puzzle
+jinrikisha
+joystick
+kimono
+knee pad
+knot
+lab coat
+ladle
+lampshade
+laptop
+lawn mower
+lens cap
+letter opener
+library
+lifeboat
+lighter
+limousine
+liner
+lipstick
+Loafer
+lotion
+loudspeaker
+loupe
+lumbermill
+magnetic compass
+mailbag
+mailbox
+maillot
+maillot
+manhole cover
+maraca
+marimba
+mask
+matchstick
+maypole
+maze
+measuring cup
+medicine chest
+megalith
+microphone
+microwave
+military uniform
+milk can
+minibus
+miniskirt
+minivan
+missile
+mitten
+mixing bowl
+mobile home
+Model T
+modem
+monastery
+monitor
+moped
+mortar
+mortarboard
+mosque
+mosquito net
+motor scooter
+mountain bike
+mountain tent
+mouse
+mousetrap
+moving van
+muzzle
+nail
+neck brace
+necklace
+nipple
+notebook
+obelisk
+oboe
+ocarina
+odometer
+oil filter
+organ
+oscilloscope
+overskirt
+oxcart
+oxygen mask
+packet
+paddle
+paddlewheel
+padlock
+paintbrush
+pajama
+palace
+panpipe
+paper towel
+parachute
+parallel bars
+park bench
+parking meter
+passenger car
+patio
+pay-phone
+pedestal
+pencil box
+pencil sharpener
+perfume
+Petri dish
+photocopier
+pick
+pickelhaube
+picket fence
+pickup
+pier
+piggy bank
+pill bottle
+pillow
+ping-pong ball
+pinwheel
+pirate
+pitcher
+plane
+planetarium
+plastic bag
+plate rack
+plow
+plunger
+Polaroid camera
+pole
+police van
+poncho
+pool table
+pop bottle
+pot
+potter's wheel
+power drill
+prayer rug
+printer
+prison
+projectile
+projector
+puck
+punching bag
+purse
+quill
+quilt
+racer
+racket
+radiator
+radio
+radio telescope
+rain barrel
+recreational vehicle
+reel
+reflex camera
+refrigerator
+remote control
+restaurant
+revolver
+rifle
+rocking chair
+rotisserie
+rubber eraser
+rugby ball
+rule
+running shoe
+safe
+safety pin
+saltshaker
+sandal
+sarong
+sax
+scabbard
+scale
+school bus
+schooner
+scoreboard
+screen
+screw
+screwdriver
+seat belt
+sewing machine
+shield
+shoe shop
+shoji
+shopping basket
+shopping cart
+shovel
+shower cap
+shower curtain
+ski
+ski mask
+sleeping bag
+slide rule
+sliding door
+slot
+snorkel
+snowmobile
+snowplow
+soap dispenser
+soccer ball
+sock
+solar dish
+sombrero
+soup bowl
+space bar
+space heater
+space shuttle
+spatula
+speedboat
+spider web
+spindle
+sports car
+spotlight
+stage
+steam locomotive
+steel arch bridge
+steel drum
+stethoscope
+stole
+stone wall
+stopwatch
+stove
+strainer
+streetcar
+stretcher
+studio couch
+stupa
+submarine
+suit
+sundial
+sunglass
+sunglasses
+sunscreen
+suspension bridge
+swab
+sweatshirt
+swimming trunks
+swing
+switch
+syringe
+table lamp
+tank
+tape player
+teapot
+teddy
+television
+tennis ball
+thatch
+theater curtain
+thimble
+thresher
+throne
+tile roof
+toaster
+tobacco shop
+toilet seat
+torch
+totem pole
+tow truck
+toyshop
+tractor
+trailer truck
+tray
+trench coat
+tricycle
+trimaran
+tripod
+triumphal arch
+trolleybus
+trombone
+tub
+turnstile
+typewriter keyboard
+umbrella
+unicycle
+upright
+vacuum
+vase
+vault
+velvet
+vending machine
+vestment
+viaduct
+violin
+volleyball
+waffle iron
+wall clock
+wallet
+wardrobe
+warplane
+washbasin
+washer
+water bottle
+water jug
+water tower
+whiskey jug
+whistle
+wig
+window screen
+window shade
+Windsor tie
+wine bottle
+wing
+wok
+wooden spoon
+wool
+worm fence
+wreck
+yawl
+yurt
+web site
+comic book
+crossword puzzle
+street sign
+traffic light
+book jacket
+menu
+plate
+guacamole
+consomme
+hot pot
+trifle
+ice cream
+ice lolly
+French loaf
+bagel
+pretzel
+cheeseburger
+hotdog
+mashed potato
+head cabbage
+broccoli
+cauliflower
+zucchini
+spaghetti squash
+acorn squash
+butternut squash
+cucumber
+artichoke
+bell pepper
+cardoon
+mushroom
+Granny Smith
+strawberry
+orange
+lemon
+fig
+pineapple
+banana
+jackfruit
+custard apple
+pomegranate
+hay
+carbonara
+chocolate sauce
+dough
+meat loaf
+pizza
+potpie
+burrito
+red wine
+espresso
+cup
+eggnog
+alp
+bubble
+cliff
+coral reef
+geyser
+lakeside
+promontory
+sandbar
+seashore
+valley
+volcano
+ballplayer
+groom
+scuba diver
+rapeseed
+daisy
+yellow lady's slipper
+corn
+acorn
+hip
+buckeye
+coral fungus
+agaric
+gyromitra
+stinkhorn
+earthstar
+hen-of-the-woods
+bolete
+ear
+toilet tissue
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_mobilenet_quant_v1_224.txt b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_mobilenet_quant_v1_224.txt
new file mode 100644
index 0000000000..fe811239d8
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_mobilenet_quant_v1_224.txt
@@ -0,0 +1,1001 @@
+background
+tench
+goldfish
+great white shark
+tiger shark
+hammerhead
+electric ray
+stingray
+cock
+hen
+ostrich
+brambling
+goldfinch
+house finch
+junco
+indigo bunting
+robin
+bulbul
+jay
+magpie
+chickadee
+water ouzel
+kite
+bald eagle
+vulture
+great grey owl
+European fire salamander
+common newt
+eft
+spotted salamander
+axolotl
+bullfrog
+tree frog
+tailed frog
+loggerhead
+leatherback turtle
+mud turtle
+terrapin
+box turtle
+banded gecko
+common iguana
+American chameleon
+whiptail
+agama
+frilled lizard
+alligator lizard
+Gila monster
+green lizard
+African chameleon
+Komodo dragon
+African crocodile
+American alligator
+triceratops
+thunder snake
+ringneck snake
+hognose snake
+green snake
+king snake
+garter snake
+water snake
+vine snake
+night snake
+boa constrictor
+rock python
+Indian cobra
+green mamba
+sea snake
+horned viper
+diamondback
+sidewinder
+trilobite
+harvestman
+scorpion
+black and gold garden spider
+barn spider
+garden spider
+black widow
+tarantula
+wolf spider
+tick
+centipede
+black grouse
+ptarmigan
+ruffed grouse
+prairie chicken
+peacock
+quail
+partridge
+African grey
+macaw
+sulphur-crested cockatoo
+lorikeet
+coucal
+bee eater
+hornbill
+hummingbird
+jacamar
+toucan
+drake
+red-breasted merganser
+goose
+black swan
+tusker
+echidna
+platypus
+wallaby
+koala
+wombat
+jellyfish
+sea anemone
+brain coral
+flatworm
+nematode
+conch
+snail
+slug
+sea slug
+chiton
+chambered nautilus
+Dungeness crab
+rock crab
+fiddler crab
+king crab
+American lobster
+spiny lobster
+crayfish
+hermit crab
+isopod
+white stork
+black stork
+spoonbill
+flamingo
+little blue heron
+American egret
+bittern
+crane
+limpkin
+European gallinule
+American coot
+bustard
+ruddy turnstone
+red-backed sandpiper
+redshank
+dowitcher
+oystercatcher
+pelican
+king penguin
+albatross
+grey whale
+killer whale
+dugong
+sea lion
+Chihuahua
+Japanese spaniel
+Maltese dog
+Pekinese
+Shih-Tzu
+Blenheim spaniel
+papillon
+toy terrier
+Rhodesian ridgeback
+Afghan hound
+basset
+beagle
+bloodhound
+bluetick
+black-and-tan coonhound
+Walker hound
+English foxhound
+redbone
+borzoi
+Irish wolfhound
+Italian greyhound
+whippet
+Ibizan hound
+Norwegian elkhound
+otterhound
+Saluki
+Scottish deerhound
+Weimaraner
+Staffordshire bullterrier
+American Staffordshire terrier
+Bedlington terrier
+Border terrier
+Kerry blue terrier
+Irish terrier
+Norfolk terrier
+Norwich terrier
+Yorkshire terrier
+wire-haired fox terrier
+Lakeland terrier
+Sealyham terrier
+Airedale
+cairn
+Australian terrier
+Dandie Dinmont
+Boston bull
+miniature schnauzer
+giant schnauzer
+standard schnauzer
+Scotch terrier
+Tibetan terrier
+silky terrier
+soft-coated wheaten terrier
+West Highland white terrier
+Lhasa
+flat-coated retriever
+curly-coated retriever
+golden retriever
+Labrador retriever
+Chesapeake Bay retriever
+German short-haired pointer
+vizsla
+English setter
+Irish setter
+Gordon setter
+Brittany spaniel
+clumber
+English springer
+Welsh springer spaniel
+cocker spaniel
+Sussex spaniel
+Irish water spaniel
+kuvasz
+schipperke
+groenendael
+malinois
+briard
+kelpie
+komondor
+Old English sheepdog
+Shetland sheepdog
+collie
+Border collie
+Bouvier des Flandres
+Rottweiler
+German shepherd
+Doberman
+miniature pinscher
+Greater Swiss Mountain dog
+Bernese mountain dog
+Appenzeller
+EntleBucher
+boxer
+bull mastiff
+Tibetan mastiff
+French bulldog
+Great Dane
+Saint Bernard
+Eskimo dog
+malamute
+Siberian husky
+dalmatian
+affenpinscher
+basenji
+pug
+Leonberg
+Newfoundland
+Great Pyrenees
+Samoyed
+Pomeranian
+chow
+keeshond
+Brabancon griffon
+Pembroke
+Cardigan
+toy poodle
+miniature poodle
+standard poodle
+Mexican hairless
+timber wolf
+white wolf
+red wolf
+coyote
+dingo
+dhole
+African hunting dog
+hyena
+red fox
+kit fox
+Arctic fox
+grey fox
+tabby
+tiger cat
+Persian cat
+Siamese cat
+Egyptian cat
+cougar
+lynx
+leopard
+snow leopard
+jaguar
+lion
+tiger
+cheetah
+brown bear
+American black bear
+ice bear
+sloth bear
+mongoose
+meerkat
+tiger beetle
+ladybug
+ground beetle
+long-horned beetle
+leaf beetle
+dung beetle
+rhinoceros beetle
+weevil
+fly
+bee
+ant
+grasshopper
+cricket
+walking stick
+cockroach
+mantis
+cicada
+leafhopper
+lacewing
+dragonfly
+damselfly
+admiral
+ringlet
+monarch
+cabbage butterfly
+sulphur butterfly
+lycaenid
+starfish
+sea urchin
+sea cucumber
+wood rabbit
+hare
+Angora
+hamster
+porcupine
+fox squirrel
+marmot
+beaver
+guinea pig
+sorrel
+zebra
+hog
+wild boar
+warthog
+hippopotamus
+ox
+water buffalo
+bison
+ram
+bighorn
+ibex
+hartebeest
+impala
+gazelle
+Arabian camel
+llama
+weasel
+mink
+polecat
+black-footed ferret
+otter
+skunk
+badger
+armadillo
+three-toed sloth
+orangutan
+gorilla
+chimpanzee
+gibbon
+siamang
+guenon
+patas
+baboon
+macaque
+langur
+colobus
+proboscis monkey
+marmoset
+capuchin
+howler monkey
+titi
+spider monkey
+squirrel monkey
+Madagascar cat
+indri
+Indian elephant
+African elephant
+lesser panda
+giant panda
+barracouta
+eel
+coho
+rock beauty
+anemone fish
+sturgeon
+gar
+lionfish
+puffer
+abacus
+abaya
+academic gown
+accordion
+acoustic guitar
+aircraft carrier
+airliner
+airship
+altar
+ambulance
+amphibian
+analog clock
+apiary
+apron
+ashcan
+assault rifle
+backpack
+bakery
+balance beam
+balloon
+ballpoint
+Band Aid
+banjo
+bannister
+barbell
+barber chair
+barbershop
+barn
+barometer
+barrel
+barrow
+baseball
+basketball
+bassinet
+bassoon
+bathing cap
+bath towel
+bathtub
+beach wagon
+beacon
+beaker
+bearskin
+beer bottle
+beer glass
+bell cote
+bib
+bicycle-built-for-two
+bikini
+binder
+binoculars
+birdhouse
+boathouse
+bobsled
+bolo tie
+bonnet
+bookcase
+bookshop
+bottlecap
+bow
+bow tie
+brass
+brassiere
+breakwater
+breastplate
+broom
+bucket
+buckle
+bulletproof vest
+bullet train
+butcher shop
+cab
+caldron
+candle
+cannon
+canoe
+can opener
+cardigan
+car mirror
+carousel
+carpenter's kit
+carton
+car wheel
+cash machine
+cassette
+cassette player
+castle
+catamaran
+CD player
+cello
+cellular telephone
+chain
+chainlink fence
+chain mail
+chain saw
+chest
+chiffonier
+chime
+china cabinet
+Christmas stocking
+church
+cinema
+cleaver
+cliff dwelling
+cloak
+clog
+cocktail shaker
+coffee mug
+coffeepot
+coil
+combination lock
+computer keyboard
+confectionery
+container ship
+convertible
+corkscrew
+cornet
+cowboy boot
+cowboy hat
+cradle
+crane
+crash helmet
+crate
+crib
+Crock Pot
+croquet ball
+crutch
+cuirass
+dam
+desk
+desktop computer
+dial telephone
+diaper
+digital clock
+digital watch
+dining table
+dishrag
+dishwasher
+disk brake
+dock
+dogsled
+dome
+doormat
+drilling platform
+drum
+drumstick
+dumbbell
+Dutch oven
+electric fan
+electric guitar
+electric locomotive
+entertainment center
+envelope
+espresso maker
+face powder
+feather boa
+file
+fireboat
+fire engine
+fire screen
+flagpole
+flute
+folding chair
+football helmet
+forklift
+fountain
+fountain pen
+four-poster
+freight car
+French horn
+frying pan
+fur coat
+garbage truck
+gasmask
+gas pump
+goblet
+go-kart
+golf ball
+golfcart
+gondola
+gong
+gown
+grand piano
+greenhouse
+grille
+grocery store
+guillotine
+hair slide
+hair spray
+half track
+hammer
+hamper
+hand blower
+hand-held computer
+handkerchief
+hard disc
+harmonica
+harp
+harvester
+hatchet
+holster
+home theater
+honeycomb
+hook
+hoopskirt
+horizontal bar
+horse cart
+hourglass
+iPod
+iron
+jack-o'-lantern
+jean
+jeep
+jersey
+jigsaw puzzle
+jinrikisha
+joystick
+kimono
+knee pad
+knot
+lab coat
+ladle
+lampshade
+laptop
+lawn mower
+lens cap
+letter opener
+library
+lifeboat
+lighter
+limousine
+liner
+lipstick
+Loafer
+lotion
+loudspeaker
+loupe
+lumbermill
+magnetic compass
+mailbag
+mailbox
+maillot
+maillot
+manhole cover
+maraca
+marimba
+mask
+matchstick
+maypole
+maze
+measuring cup
+medicine chest
+megalith
+microphone
+microwave
+military uniform
+milk can
+minibus
+miniskirt
+minivan
+missile
+mitten
+mixing bowl
+mobile home
+Model T
+modem
+monastery
+monitor
+moped
+mortar
+mortarboard
+mosque
+mosquito net
+motor scooter
+mountain bike
+mountain tent
+mouse
+mousetrap
+moving van
+muzzle
+nail
+neck brace
+necklace
+nipple
+notebook
+obelisk
+oboe
+ocarina
+odometer
+oil filter
+organ
+oscilloscope
+overskirt
+oxcart
+oxygen mask
+packet
+paddle
+paddlewheel
+padlock
+paintbrush
+pajama
+palace
+panpipe
+paper towel
+parachute
+parallel bars
+park bench
+parking meter
+passenger car
+patio
+pay-phone
+pedestal
+pencil box
+pencil sharpener
+perfume
+Petri dish
+photocopier
+pick
+pickelhaube
+picket fence
+pickup
+pier
+piggy bank
+pill bottle
+pillow
+ping-pong ball
+pinwheel
+pirate
+pitcher
+plane
+planetarium
+plastic bag
+plate rack
+plow
+plunger
+Polaroid camera
+pole
+police van
+poncho
+pool table
+pop bottle
+pot
+potter's wheel
+power drill
+prayer rug
+printer
+prison
+projectile
+projector
+puck
+punching bag
+purse
+quill
+quilt
+racer
+racket
+radiator
+radio
+radio telescope
+rain barrel
+recreational vehicle
+reel
+reflex camera
+refrigerator
+remote control
+restaurant
+revolver
+rifle
+rocking chair
+rotisserie
+rubber eraser
+rugby ball
+rule
+running shoe
+safe
+safety pin
+saltshaker
+sandal
+sarong
+sax
+scabbard
+scale
+school bus
+schooner
+scoreboard
+screen
+screw
+screwdriver
+seat belt
+sewing machine
+shield
+shoe shop
+shoji
+shopping basket
+shopping cart
+shovel
+shower cap
+shower curtain
+ski
+ski mask
+sleeping bag
+slide rule
+sliding door
+slot
+snorkel
+snowmobile
+snowplow
+soap dispenser
+soccer ball
+sock
+solar dish
+sombrero
+soup bowl
+space bar
+space heater
+space shuttle
+spatula
+speedboat
+spider web
+spindle
+sports car
+spotlight
+stage
+steam locomotive
+steel arch bridge
+steel drum
+stethoscope
+stole
+stone wall
+stopwatch
+stove
+strainer
+streetcar
+stretcher
+studio couch
+stupa
+submarine
+suit
+sundial
+sunglass
+sunglasses
+sunscreen
+suspension bridge
+swab
+sweatshirt
+swimming trunks
+swing
+switch
+syringe
+table lamp
+tank
+tape player
+teapot
+teddy
+television
+tennis ball
+thatch
+theater curtain
+thimble
+thresher
+throne
+tile roof
+toaster
+tobacco shop
+toilet seat
+torch
+totem pole
+tow truck
+toyshop
+tractor
+trailer truck
+tray
+trench coat
+tricycle
+trimaran
+tripod
+triumphal arch
+trolleybus
+trombone
+tub
+turnstile
+typewriter keyboard
+umbrella
+unicycle
+upright
+vacuum
+vase
+vault
+velvet
+vending machine
+vestment
+viaduct
+violin
+volleyball
+waffle iron
+wall clock
+wallet
+wardrobe
+warplane
+washbasin
+washer
+water bottle
+water jug
+water tower
+whiskey jug
+whistle
+wig
+window screen
+window shade
+Windsor tie
+wine bottle
+wing
+wok
+wooden spoon
+wool
+worm fence
+wreck
+yawl
+yurt
+web site
+comic book
+crossword puzzle
+street sign
+traffic light
+book jacket
+menu
+plate
+guacamole
+consomme
+hot pot
+trifle
+ice cream
+ice lolly
+French loaf
+bagel
+pretzel
+cheeseburger
+hotdog
+mashed potato
+head cabbage
+broccoli
+cauliflower
+zucchini
+spaghetti squash
+acorn squash
+butternut squash
+cucumber
+artichoke
+bell pepper
+cardoon
+mushroom
+Granny Smith
+strawberry
+orange
+lemon
+fig
+pineapple
+banana
+jackfruit
+custard apple
+pomegranate
+hay
+carbonara
+chocolate sauce
+dough
+meat loaf
+pizza
+potpie
+burrito
+red wine
+espresso
+cup
+eggnog
+alp
+bubble
+cliff
+coral reef
+geyser
+lakeside
+promontory
+sandbar
+seashore
+valley
+volcano
+ballplayer
+groom
+scuba diver
+rapeseed
+daisy
+yellow lady's slipper
+corn
+acorn
+hip
+buckeye
+coral fungus
+agaric
+gyromitra
+stinkhorn
+earthstar
+hen-of-the-woods
+bolete
+ear
+toilet tissue
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
index 74737a8b88..9b9fdffab5 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
@@ -296,7 +296,8 @@ public class Camera2BasicFragment extends Fragment
public void onActivityCreated(Bundle savedInstanceState) {
super.onActivityCreated(savedInstanceState);
try {
- classifier = new ImageClassifier(getActivity());
+ // create either a new ImageClassifierQuantizedMobileNet or an ImageClassifierFloatInception
+ classifier = new ImageClassifierQuantizedMobileNet(getActivity());
} catch (IOException e) {
Log.e(TAG, "Failed to initialize an image classifier.");
}
@@ -658,8 +659,7 @@ public class Camera2BasicFragment extends Fragment
showToast("Uninitialized Classifier or invalid context.");
return;
}
- Bitmap bitmap =
- textureView.getBitmap(ImageClassifier.DIM_IMG_SIZE_X, ImageClassifier.DIM_IMG_SIZE_Y);
+ Bitmap bitmap = textureView.getBitmap(classifier.getImageSizeX(), classifier.getImageSizeY());
String textToShow = classifier.classifyFrame(bitmap);
bitmap.recycle();
showToast(textToShow);
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
index e44c5ae6b4..2c91be9d62 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
@@ -37,17 +37,11 @@ import java.util.PriorityQueue;
import org.tensorflow.lite.Interpreter;
/** Classifies images with Tensorflow Lite. */
-public class ImageClassifier {
+public abstract class ImageClassifier {
/** Tag for the {@link Log}. */
private static final String TAG = "TfLiteCameraDemo";
- /** Name of the model file stored in Assets. */
- private static final String MODEL_PATH = "mobilenet_quant_v1_224.tflite";
-
- /** Name of the label file stored in Assets. */
- private static final String LABEL_PATH = "labels.txt";
-
/** Number of results to show in the UI. */
private static final int RESULTS_TO_SHOW = 3;
@@ -56,23 +50,18 @@ public class ImageClassifier {
private static final int DIM_PIXEL_SIZE = 3;
- static final int DIM_IMG_SIZE_X = 224;
- static final int DIM_IMG_SIZE_Y = 224;
-
/* Preallocated buffers for storing image data in. */
- private int[] intValues = new int[DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y];
+ private int[] intValues = new int[getImageSizeX() * getImageSizeY()];
/** An instance of the driver class to run model inference with Tensorflow Lite. */
- private Interpreter tflite;
+ protected Interpreter tflite;
/** Labels corresponding to the output of the vision model. */
private List<String> labelList;
/** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as inputs. */
- private ByteBuffer imgData = null;
+ protected ByteBuffer imgData = null;
- /** An array to hold inference results, to be feed into Tensorflow Lite as outputs. */
- private byte[][] labelProbArray = null;
/** multi-stage low pass filter * */
private float[][] filterLabelProbArray = null;
@@ -95,10 +84,13 @@ public class ImageClassifier {
labelList = loadLabelList(activity);
imgData =
ByteBuffer.allocateDirect(
- DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
+ DIM_BATCH_SIZE
+ * getImageSizeX()
+ * getImageSizeY()
+ * DIM_PIXEL_SIZE
+ * getNumBytesPerChannel());
imgData.order(ByteOrder.nativeOrder());
- labelProbArray = new byte[1][labelList.size()];
- filterLabelProbArray = new float[FILTER_STAGES][labelList.size()];
+ filterLabelProbArray = new float[FILTER_STAGES][getNumLabels()];
Log.d(TAG, "Created a Tensorflow Lite Image Classifier.");
}
@@ -111,7 +103,7 @@ public class ImageClassifier {
convertBitmapToByteBuffer(bitmap);
// Here's where the magic happens!!!
long startTime = SystemClock.uptimeMillis();
- tflite.run(imgData, labelProbArray);
+ runInference();
long endTime = SystemClock.uptimeMillis();
Log.d(TAG, "Timecost to run model inference: " + Long.toString(endTime - startTime));
@@ -125,12 +117,12 @@ public class ImageClassifier {
}
void applyFilter() {
- int numLabels = labelList.size();
+ int numLabels = getNumLabels();
// Low pass filter `labelProbArray` into the first stage of the filter.
for (int j = 0; j < numLabels; ++j) {
filterLabelProbArray[0][j] +=
- FILTER_FACTOR * (labelProbArray[0][j] - filterLabelProbArray[0][j]);
+ FILTER_FACTOR * (getProbability(j) - filterLabelProbArray[0][j]);
}
// Low pass filter each stage into the next.
for (int i = 1; i < FILTER_STAGES; ++i) {
@@ -142,7 +134,7 @@ public class ImageClassifier {
// Copy the last stage filter output back to `labelProbArray`.
for (int j = 0; j < numLabels; ++j) {
- labelProbArray[0][j] = (byte)filterLabelProbArray[FILTER_STAGES - 1][j];
+ setProbability(j, filterLabelProbArray[FILTER_STAGES - 1][j]);
}
}
@@ -156,7 +148,7 @@ public class ImageClassifier {
private List<String> loadLabelList(Activity activity) throws IOException {
List<String> labelList = new ArrayList<String>();
BufferedReader reader =
- new BufferedReader(new InputStreamReader(activity.getAssets().open(LABEL_PATH)));
+ new BufferedReader(new InputStreamReader(activity.getAssets().open(getLabelPath())));
String line;
while ((line = reader.readLine()) != null) {
labelList.add(line);
@@ -167,7 +159,7 @@ public class ImageClassifier {
/** Memory-map the model file in Assets. */
private MappedByteBuffer loadModelFile(Activity activity) throws IOException {
- AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_PATH);
+ AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(getModelPath());
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
@@ -185,12 +177,10 @@ public class ImageClassifier {
// Convert the image to floating point.
int pixel = 0;
long startTime = SystemClock.uptimeMillis();
- for (int i = 0; i < DIM_IMG_SIZE_X; ++i) {
- for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) {
+ for (int i = 0; i < getImageSizeX(); ++i) {
+ for (int j = 0; j < getImageSizeY(); ++j) {
final int val = intValues[pixel++];
- imgData.put((byte) ((val >> 16) & 0xFF));
- imgData.put((byte) ((val >> 8) & 0xFF));
- imgData.put((byte) (val & 0xFF));
+ addPixelValue(val);
}
}
long endTime = SystemClock.uptimeMillis();
@@ -199,9 +189,9 @@ public class ImageClassifier {
/** Prints top-K labels, to be shown in UI as the results. */
private String printTopKLabels() {
- for (int i = 0; i < labelList.size(); ++i) {
+ for (int i = 0; i < getNumLabels(); ++i) {
sortedLabels.add(
- new AbstractMap.SimpleEntry<>(labelList.get(i), (labelProbArray[0][i] & 0xff) / 255.0f));
+ new AbstractMap.SimpleEntry<>(labelList.get(i), getNormalizedProbability(i)));
if (sortedLabels.size() > RESULTS_TO_SHOW) {
sortedLabels.poll();
}
@@ -214,4 +204,89 @@ public class ImageClassifier {
}
return textToShow;
}
+
+ /**
+ * Get the name of the model file stored in Assets.
+ *
+ * @return
+ */
+ protected abstract String getModelPath();
+
+ /**
+ * Get the name of the label file stored in Assets.
+ *
+ * @return
+ */
+ protected abstract String getLabelPath();
+
+ /**
+ * Get the image size along the x axis.
+ *
+ * @return
+ */
+ protected abstract int getImageSizeX();
+
+ /**
+ * Get the image size along the y axis.
+ *
+ * @return
+ */
+ protected abstract int getImageSizeY();
+
+ /**
+ * Get the number of bytes that is used to store a single color channel value.
+ *
+ * @return
+ */
+ protected abstract int getNumBytesPerChannel();
+
+ /**
+ * Add pixelValue to byteBuffer.
+ *
+ * @param pixelValue
+ */
+ protected abstract void addPixelValue(int pixelValue);
+
+ /**
+ * Read the probability value for the specified label This is either the original value as it was
+ * read from the net's output or the updated value after the filter was applied.
+ *
+ * @param labelIndex
+ * @return
+ */
+ protected abstract float getProbability(int labelIndex);
+
+ /**
+ * Set the probability value for the specified label.
+ *
+ * @param labelIndex
+ * @param value
+ */
+ protected abstract void setProbability(int labelIndex, Number value);
+
+ /**
+ * Get the normalized probability value for the specified label. This is the final value as it
+ * will be shown to the user.
+ *
+ * @return
+ */
+ protected abstract float getNormalizedProbability(int labelIndex);
+
+ /**
+ * Run inference using the prepared input in {@link #imgData}. Afterwards, the result will be
+ * provided by getProbability().
+ *
+ * <p>This additional method is necessary, because we don't have a common base for different
+ * primitive data types.
+ */
+ protected abstract void runInference();
+
+ /**
+ * Get the total number of labels.
+ *
+ * @return
+ */
+ protected int getNumLabels() {
+ return labelList.size();
+ }
}
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierFloatInception.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierFloatInception.java
new file mode 100644
index 0000000000..3108422952
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierFloatInception.java
@@ -0,0 +1,103 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package com.example.android.tflitecamerademo;
+
+import android.app.Activity;
+import java.io.IOException;
+
+/**
+ * This classifier works with the Inception-v3 slim model. It applies floating point inference
+ * rather than using a quantized model.
+ */
+public class ImageClassifierFloatInception extends ImageClassifier {
+
+ /** The inception net requires additional normalization of the used input. */
+ private static final int IMAGE_MEAN = 128;
+
+ private static final float IMAGE_STD = 128.0f;
+
+ /**
+ * An array to hold inference results, to be feed into Tensorflow Lite as outputs. This isn't part
+ * of the super class, because we need a primitive array here.
+ */
+ private float[][] labelProbArray = null;
+
+ /**
+ * Initializes an {@code ImageClassifier}.
+ *
+ * @param activity
+ */
+ ImageClassifierFloatInception(Activity activity) throws IOException {
+ super(activity);
+ labelProbArray = new float[1][getNumLabels()];
+ }
+
+ @Override
+ protected String getModelPath() {
+ // you can download this file from
+ // https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_slim_2016_android_2017_11_10.zip
+ return "inceptionv3_slim_2016.tflite";
+ }
+
+ @Override
+ protected String getLabelPath() {
+ return "labels_imagenet_slim.txt";
+ }
+
+ @Override
+ protected int getImageSizeX() {
+ return 299;
+ }
+
+ @Override
+ protected int getImageSizeY() {
+ return 299;
+ }
+
+ @Override
+ protected int getNumBytesPerChannel() {
+ // a 32bit float value requires 4 bytes
+ return 4;
+ }
+
+ @Override
+ protected void addPixelValue(int pixelValue) {
+ imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ }
+
+ @Override
+ protected float getProbability(int labelIndex) {
+ return labelProbArray[0][labelIndex];
+ }
+
+ @Override
+ protected void setProbability(int labelIndex, Number value) {
+ labelProbArray[0][labelIndex] = value.floatValue();
+ }
+
+ @Override
+ protected float getNormalizedProbability(int labelIndex) {
+ // TODO the following value isn't in [0,1] yet, but may be greater. Why?
+ return getProbability(labelIndex);
+ }
+
+ @Override
+ protected void runInference() {
+ tflite.run(imgData, labelProbArray);
+ }
+}
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java
new file mode 100644
index 0000000000..5f341f0f5b
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java
@@ -0,0 +1,94 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package com.example.android.tflitecamerademo;
+
+import android.app.Activity;
+import java.io.IOException;
+
+/** This classifier works with the quantized MobileNet model. */
+public class ImageClassifierQuantizedMobileNet extends ImageClassifier {
+
+ /**
+ * An array to hold inference results, to be feed into Tensorflow Lite as outputs. This isn't part
+ * of the super class, because we need a primitive array here.
+ */
+ private byte[][] labelProbArray = null;
+
+ /**
+ * Initializes an {@code ImageClassifier}.
+ *
+ * @param activity
+ */
+ ImageClassifierQuantizedMobileNet(Activity activity) throws IOException {
+ super(activity);
+ labelProbArray = new byte[1][getNumLabels()];
+ }
+
+ @Override
+ protected String getModelPath() {
+ // you can download this file from
+ // https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip
+ return "mobilenet_quant_v1_224.tflite";
+ }
+
+ @Override
+ protected String getLabelPath() {
+ return "labels_mobilenet_quant_v1_224.txt";
+ }
+
+ @Override
+ protected int getImageSizeX() {
+ return 224;
+ }
+
+ @Override
+ protected int getImageSizeY() {
+ return 224;
+ }
+
+ @Override
+ protected int getNumBytesPerChannel() {
+ // the quantized model uses a single byte only
+ return 1;
+ }
+
+ @Override
+ protected void addPixelValue(int pixelValue) {
+ imgData.put((byte) ((pixelValue >> 16) & 0xFF));
+ imgData.put((byte) ((pixelValue >> 8) & 0xFF));
+ imgData.put((byte) (pixelValue & 0xFF));
+ }
+
+ @Override
+ protected float getProbability(int labelIndex) {
+ return labelProbArray[0][labelIndex];
+ }
+
+ @Override
+ protected void setProbability(int labelIndex, Number value) {
+ labelProbArray[0][labelIndex] = value.byteValue();
+ }
+
+ @Override
+ protected float getNormalizedProbability(int labelIndex) {
+ return (labelProbArray[0][labelIndex] & 0xff) / 255.0f;
+ }
+
+ @Override
+ protected void runInference() {
+ tflite.run(imgData, labelProbArray);
+ }
+}
diff --git a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py
index c3a57ba51b..2b9eee4ef7 100644
--- a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py
+++ b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py
@@ -50,16 +50,12 @@ def pairwise_distance(feature, squared=False):
pairwise_distances: 2-D Tensor of size [number of data, number of data].
"""
pairwise_distances_squared = math_ops.add(
+ math_ops.reduce_sum(math_ops.square(feature), axis=[1], keepdims=True),
math_ops.reduce_sum(
- math_ops.square(feature),
- axis=[1],
- keep_dims=True),
- math_ops.reduce_sum(
- math_ops.square(
- array_ops.transpose(feature)),
+ math_ops.square(array_ops.transpose(feature)),
axis=[0],
- keep_dims=True)) - 2.0 * math_ops.matmul(
- feature, array_ops.transpose(feature))
+ keepdims=True)) - 2.0 * math_ops.matmul(feature,
+ array_ops.transpose(feature))
# Deal with numerical inaccuracies. Set small negatives to zero.
pairwise_distances_squared = math_ops.maximum(pairwise_distances_squared, 0.0)
@@ -132,10 +128,10 @@ def masked_maximum(data, mask, dim=1):
masked_maximums: N-D `Tensor`.
The maximized dimension is of size 1 after the operation.
"""
- axis_minimums = math_ops.reduce_min(data, dim, keep_dims=True)
+ axis_minimums = math_ops.reduce_min(data, dim, keepdims=True)
masked_maximums = math_ops.reduce_max(
- math_ops.multiply(
- data - axis_minimums, mask), dim, keep_dims=True) + axis_minimums
+ math_ops.multiply(data - axis_minimums, mask), dim,
+ keepdims=True) + axis_minimums
return masked_maximums
@@ -151,10 +147,10 @@ def masked_minimum(data, mask, dim=1):
masked_minimums: N-D `Tensor`.
The minimized dimension is of size 1 after the operation.
"""
- axis_maximums = math_ops.reduce_max(data, dim, keep_dims=True)
+ axis_maximums = math_ops.reduce_max(data, dim, keepdims=True)
masked_minimums = math_ops.reduce_min(
- math_ops.multiply(
- data - axis_maximums, mask), dim, keep_dims=True) + axis_maximums
+ math_ops.multiply(data - axis_maximums, mask), dim,
+ keepdims=True) + axis_maximums
return masked_minimums
@@ -202,8 +198,7 @@ def triplet_semihard_loss(labels, embeddings, margin=1.0):
mask_final = array_ops.reshape(
math_ops.greater(
math_ops.reduce_sum(
- math_ops.cast(
- mask, dtype=dtypes.float32), 1, keep_dims=True),
+ math_ops.cast(mask, dtype=dtypes.float32), 1, keepdims=True),
0.0), [batch_size, batch_size])
mask_final = array_ops.transpose(mask_final)
@@ -290,7 +285,7 @@ def npairs_loss(labels, embeddings_anchor, embeddings_positive,
labels_remapped = math_ops.to_float(
math_ops.equal(labels, array_ops.transpose(labels)))
- labels_remapped /= math_ops.reduce_sum(labels_remapped, 1, keep_dims=True)
+ labels_remapped /= math_ops.reduce_sum(labels_remapped, 1, keepdims=True)
# Add the softmax loss.
xent_loss = nn.softmax_cross_entropy_with_logits(
@@ -395,7 +390,7 @@ def npairs_loss_multilabel(sparse_labels, embeddings_anchor,
multilabel_adjacency_matrix = _build_multilabel_adjacency(sparse_labels)
labels_remapped = math_ops.to_float(multilabel_adjacency_matrix)
- labels_remapped /= math_ops.reduce_sum(labels_remapped, 1, keep_dims=True)
+ labels_remapped /= math_ops.reduce_sum(labels_remapped, 1, keepdims=True)
# Add the softmax loss.
xent_loss = nn.softmax_cross_entropy_with_logits(
@@ -448,10 +443,10 @@ def lifted_struct_loss(labels, embeddings, margin=1.0):
# Safe maximum: Temporarily shift negative distances
# above zero before taking max.
# this is to take the max only among negatives.
- row_minimums = math_ops.reduce_min(diff, 1, keep_dims=True)
+ row_minimums = math_ops.reduce_min(diff, 1, keepdims=True)
row_negative_maximums = math_ops.reduce_max(
- math_ops.multiply(
- diff - row_minimums, mask), 1, keep_dims=True) + row_minimums
+ math_ops.multiply(diff - row_minimums, mask), 1,
+ keepdims=True) + row_minimums
# Compute the loss.
# Keep track of matrix of maximums where M_ij = max(m_i, m_j)
@@ -467,10 +462,11 @@ def lifted_struct_loss(labels, embeddings, margin=1.0):
array_ops.transpose(max_elements), [-1, 1])
loss_exp_left = array_ops.reshape(
- math_ops.reduce_sum(math_ops.multiply(
- math_ops.exp(
- diff_tiled - max_elements_vect),
- mask_tiled), 1, keep_dims=True), [batch_size, batch_size])
+ math_ops.reduce_sum(
+ math_ops.multiply(
+ math_ops.exp(diff_tiled - max_elements_vect), mask_tiled),
+ 1,
+ keepdims=True), [batch_size, batch_size])
loss_mat = max_elements + math_ops.log(
loss_exp_left + array_ops.transpose(loss_exp_left))
@@ -686,7 +682,7 @@ def _find_loss_augmented_facility_idx(pairwise_distances, labels, chosen_ids,
array_ops.reshape(pairwise_distances_candidate, [1, -1])
], 0),
axis=0,
- keep_dims=True), [num_candidates, -1]),
+ keepdims=True), [num_candidates, -1]),
axis=1)
nmi_scores = array_ops.zeros([num_candidates])
diff --git a/tensorflow/contrib/makefile/README.md b/tensorflow/contrib/makefile/README.md
index 6959ca344f..995230dfa8 100644
--- a/tensorflow/contrib/makefile/README.md
+++ b/tensorflow/contrib/makefile/README.md
@@ -130,6 +130,105 @@ adb shell '/data/local/tmp/benchmark \
For more details, see the [benchmark documentation](../../tools/benchmark).
+## CUDA support for Tegra devices running Android (Nvidia Shield TV, etc)
+
+With the release of TF 1.6 and JetPack for Android 3.2 (currently pending), you can now build a version of TensorFlow for compatible devices according to the following instructions which will receive the full benefits of GPU acceleration.
+
+#### Environment setup:
+
+First, download and install JetPack for Android version 3.2 or greater from [Nvidia](https://developers.nvidia.com). Note that as of the TF 1.6 release the JetPack for Android 3.2 release is still pending, and regular JetPack for L4T will not work.
+
+```bash
+git clone https://github.com/tensorflow/tensorflow.git
+cd tensorflow
+JETPACK=$HOME/JetPack_Android_3.2
+TEGRA_LIBS="$JETPACK/cuDNN/aarch64/cuda/lib64/libcudnn.so $JETPACK/cuda-9.0/extras/CUPTI/lib64/libcupti.so $JETPACK/cuda/targets/aarch64-linux-androideabi/lib64/libcufft.so"
+```
+
+#### Building all CUDA-enabled native binaries:
+This will build CUDA-enabled versions of libtensorflow_inference.so and the benchmark binary. (libtensorflow_demo.so will also be built incidentally, but it does not support CUDA)
+
+```bash
+NDK_ROOT=$JETPACK/android-ndk-r13b
+CC_PREFIX=ccache tensorflow/contrib/makefile/build_all_android.sh -s tensorflow/contrib/makefile/sub_makefiles/android/Makefile.in -t "libtensorflow_inference.so libtensorflow_demo.so all" -a tegra
+```
+(add -T on subsequent builds to skip protobuf downloading/building)
+
+
+#### Testing the CUDA-enabled benchmark via adb:
+Build binaries first as above, then run:
+
+```bash
+adb shell mkdir -p /data/local/tmp/lib64
+adb push $TEGRA_LIBS /data/local/tmp/lib64
+adb push tensorflow/contrib/makefile/gen/bin/android_arm64-v8a/benchmark /data/local/tmp
+wget https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk
+unzip tensorflow_demo.apk -d /tmp/tensorflow_demo
+adb push /tmp/tensorflow_demo/assets/*.pb /data/local/tmp
+adb shell "LD_LIBRARY_PATH=/data/local/tmp/lib64 /data/local/tmp/benchmark --graph=/data/local/tmp/tensorflow_inception_graph.pb"
+```
+
+#### Building the CUDA-enabled TensorFlow AAR with Bazel:
+Build the native binaries first as above. Then, build the aar and package the native libs by executing the following:
+```bash
+mkdir -p /tmp/tf/jni/arm64-v8a
+cp tensorflow/contrib/makefile/gen/lib/android_tegra/libtensorflow_*.so /tmp/tf/jni/arm64-v8a/
+cp $TEGRA_LIBS /tmp/tf/jni/arm64-v8a
+bazel build //tensorflow/contrib/android:android_tensorflow_inference_java.aar
+cp bazel-bin/tensorflow/contrib/android/android_tensorflow_inference_java.aar /tmp/tf/tensorflow.aar
+cd /tmp/tf
+chmod +w tensorflow.aar
+zip -ur tensorflow.aar $(find jni -name *.so)
+```
+
+#### Building the CUDA-enabled TensorFlow Android demo with Bazel:
+Build binaries first as above, then edit tensorflow/examples/android/BUILD and replace:
+```
+ srcs = [
+ ":libtensorflow_demo.so",
+ "//tensorflow/contrib/android:libtensorflow_inference.so",
+ ],
+```
+with:
+```
+srcs = glob(["libs/arm64-v8a/*.so"]),
+```
+
+Then run:
+```bash
+# Create dir for native libs
+mkdir -p tensorflow/examples/android/libs/arm64-v8a
+
+# Copy JetPack libs
+cp $TEGRA_LIBS tensorflow/examples/android/libs/arm64-v8a
+
+# Copy native TensorFlow libraries
+cp tensorflow/contrib/makefile/gen/lib/android_arm64-v8a/libtensorflow_*.so tensorflow/examples/android/libs/arm64-v8a/
+
+# Build APK
+bazel build -c opt --fat_apk_cpu=arm64-v8a tensorflow/android:tensorflow_demo
+
+# Install
+adb install -r -f bazel-bin/tensorflow/examples/android/tensorflow_demo.apk
+```
+
+#### Building the CUDA-enabled Android demo with gradle/Android Studio:
+
+Add tensorflow/examples/android as an Android project in Android Studio as normal.
+
+Edit build.gradle and:
+* set nativeBuildSystem = 'makefile'
+* set cpuType = 'arm64-v8a'
+* in "buildNativeMake", replace cpuType with 'tegra' (optional speedups like -T and ccache also work)
+* set the environment "NDK_ROOT" var to $JETPACK/android-ndk-r13b
+
+Click "build apk" to build.
+
+Install:
+```bash
+adb install -r -f tensorflow/examples/android/gradleBuild/outputs/apk/debug/android-debug.apk
+```
+
## iOS
_Note: To use this library in an iOS application, see related instructions in
diff --git a/tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh b/tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh
index 203ff4f890..421ddd210f 100755
--- a/tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh
+++ b/tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh
@@ -36,7 +36,7 @@ while getopts "bc:Eps" opt_name; do
b) BUILD_ONLY="true";;
c) TEST_COUNT="${OPTARG}";;
E) ENABLE_EXPERIMENTAL_HEXNN_OPS="true";;
- p) USE_PREBUILT_HEXAOGON_BINARIES="true";;
+ p) USE_PREBUILT_HEXAGON_BINARIES="true";;
s) SKIP_DOWNLOAD_IF_EXIST="true";;
*) usage;;
esac
@@ -49,7 +49,7 @@ if [[ -z "${NDK_ROOT}" ]]; then
exit 1
fi
-if [[ "${USE_PREBUILT_HEXAOGON_BINARIES}" != "true" &&
+if [[ "${USE_PREBUILT_HEXAGON_BINARIES}" != "true" &&
-z "${QUALCOMM_SDK}" ]]; then
echo "QUALCOMM_SDK is empty" 1>&2
usage
@@ -84,7 +84,7 @@ rm -rf "${GEN_DIR}"
mkdir -p "${GEN_LIBS_DIR}"
mkdir -p "${GEN_DOWNLOAD_DIR}"
-if [[ "${USE_PREBUILT_HEXAOGON_BINARIES}" == "true" ]]; then
+if [[ "${USE_PREBUILT_HEXAGON_BINARIES}" == "true" ]]; then
echo "Download prebuilt hexagon binaries"
if [[ "${BUILD_ONLY}" != "true" ]]; then
CONTROLLER_PUSH_DEST="/data/local/tmp"
diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
index f700717394..4eb4fbcd92 100644
--- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py
+++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
@@ -572,9 +572,8 @@ class LSTMBlockWrapper(base_layer.Layer):
def _gather_states(self, data, indices, batch_size):
"""Produce `out`, s.t. out(i, j) = data(indices(i), i, j)."""
- mod_indices = indices * batch_size + math_ops.range(batch_size)
- return array_ops.gather(
- array_ops.reshape(data, [-1, self.num_units]), mod_indices)
+ return array_ops.gather_nd(
+ data, array_ops.stack([indices, math_ops.range(batch_size)], axis=1))
class LSTMBlockFusedCell(LSTMBlockWrapper):
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index dce71c393a..a6c2d9cdbb 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -424,8 +424,9 @@ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell):
"W_O_diag", shape=[self._num_units], dtype=dtype)
# initialize the first freq state to be zero
- m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]), self._num_units],
- dtype)
+ m_prev_freq = array_ops.zeros(
+ [inputs.shape[0].value or inputs.get_shape()[0], self._num_units],
+ dtype)
for fq in range(len(freq_inputs)):
c_prev = array_ops.slice(state, [0, 2 * fq * self._num_units],
[-1, self._num_units])
diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc
index 64973ccccd..dfa12e873a 100644
--- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc
+++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc
@@ -80,12 +80,12 @@ class GatherTreeOp : public OpKernel {
max_sequence_lengths.shape().DebugString()));
Tensor* beams;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, step_ids_shape, &beams));
- typename TTypes<T, 3>::ConstTensor step_ids_t = step_ids.tensor<T, 3>();
- typename TTypes<T, 3>::ConstTensor parent_ids_t = parent_ids.tensor<T, 3>();
+ typename TTypes<T, 3>::ConstTensor step_ids_t(step_ids.tensor<T, 3>());
+ typename TTypes<T, 3>::ConstTensor parent_ids_t(parent_ids.tensor<T, 3>());
typename TTypes<int32>::ConstVec max_seq_lens_t =
max_sequence_lengths.vec<int32>();
- typename TTypes<T>::ConstScalar end_token_t = end_token.scalar<T>();
- typename TTypes<T, 3>::Tensor beams_t = beams->tensor<T, 3>();
+ typename TTypes<T>::ConstScalar end_token_t(end_token.scalar<T>());
+ typename TTypes<T, 3>::Tensor beams_t(beams->tensor<T, 3>());
const T end_token_value = end_token_t();
functor::GatherTree<Device, T>()(ctx, device, step_ids_t, parent_ids_t,
max_seq_lens_t, end_token_value, beams_t);
diff --git a/tensorflow/contrib/signal/python/ops/spectral_ops.py b/tensorflow/contrib/signal/python/ops/spectral_ops.py
index bca2e01d7b..a8b5deff6c 100644
--- a/tensorflow/contrib/signal/python/ops/spectral_ops.py
+++ b/tensorflow/contrib/signal/python/ops/spectral_ops.py
@@ -144,7 +144,7 @@ def inverse_stft_window_fn(frame_step,
overlaps = -(-frame_length // frame_step) # Ceiling division.
denom = array_ops.pad(denom, [(0, overlaps * frame_step - frame_length)])
denom = array_ops.reshape(denom, [overlaps, frame_step])
- denom = math_ops.reduce_sum(denom, 0, keep_dims=True)
+ denom = math_ops.reduce_sum(denom, 0, keepdims=True)
denom = array_ops.tile(denom, [overlaps, 1])
denom = array_ops.reshape(denom, [overlaps * frame_step])
diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py
index 7ab6805fac..c24bd04851 100644
--- a/tensorflow/contrib/slim/python/slim/evaluation_test.py
+++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py
@@ -29,6 +29,7 @@ from tensorflow.contrib.framework.python.ops import variables as variables_lib
from tensorflow.contrib.metrics.python.ops import metric_ops
from tensorflow.contrib.slim.python.slim import evaluation
from tensorflow.contrib.training.python.training import evaluation as evaluation_lib
+from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.debug.lib import debug_data
from tensorflow.python.debug.wrappers import hooks
from tensorflow.python.framework import constant_op
@@ -235,7 +236,7 @@ class SingleEvaluationTest(test.TestCase):
def _prepareCheckpoint(self, checkpoint_path):
init_op = control_flow_ops.group(variables.global_variables_initializer(),
variables.local_variables_initializer())
- saver = saver_lib.Saver()
+ saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1)
with self.test_session() as sess:
sess.run(init_op)
saver.save(sess, checkpoint_path)
diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD
index 58a7fa095d..1e4cc3f095 100644
--- a/tensorflow/contrib/tensor_forest/BUILD
+++ b/tensorflow/contrib/tensor_forest/BUILD
@@ -497,6 +497,7 @@ py_library(
":tensor_forest_v4_ops_py",
"//tensorflow/contrib/decision_trees/proto:generic_tree_model_py",
"//tensorflow/contrib/framework:framework_py",
+ "//tensorflow/contrib/tensor_forest/proto:fertile_stats_proto_py",
"//tensorflow/contrib/tensor_forest/proto:tensor_forest_params_proto_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index 28f571e1f0..65a0e903a7 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -1,5 +1,6 @@
# Description:
-# Wrap NVIDIA TensorRT (http://developer.nvidia.com/tensorrt) with tensorflow.
+# Wrap NVIDIA TensorRT (http://developer.nvidia.com/tensorrt) with tensorflow
+# and provide TensorRT operators and converter package.
# APIs are meant to change over time.
package(default_visibility = ["//tensorflow:__subpackages__"])
@@ -8,7 +9,19 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+ "tf_copts",
+ "tf_cuda_library",
+ "tf_custom_op_library",
+ "tf_custom_op_library_additional_deps",
+ "tf_gen_op_libs",
+ "tf_gen_op_wrapper_py",
+)
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
+load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
load(
"@local_config_tensorrt//:build_defs.bzl",
"if_tensorrt",
@@ -32,6 +45,195 @@ tf_cuda_cc_test(
]),
)
+tf_custom_op_library(
+ name = "python/ops/_trt_engine_op.so",
+ srcs = ["ops/trt_engine_op.cc"],
+ deps = [
+ ":trt_engine_op_kernel",
+ ":trt_shape_function",
+ "//tensorflow/core:lib_proto_parsing",
+ ] + if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]),
+)
+
+tf_cuda_library(
+ name = "trt_shape_function",
+ srcs = ["shape_fn/trt_shfn.cc"],
+ hdrs = ["shape_fn/trt_shfn.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":trt_logging",
+ ] + if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]) + tf_custom_op_library_additional_deps(),
+)
+
+cc_library(
+ name = "trt_engine_op_kernel",
+ srcs = ["kernels/trt_engine_op.cc"],
+ hdrs = ["kernels/trt_engine_op.h"],
+ copts = tf_copts(),
+ deps = [
+ ":trt_logging",
+ "//tensorflow/core:gpu_headers_lib",
+ "//tensorflow/core:lib_proto_parsing",
+ "//tensorflow/core:stream_executor_headers_lib",
+ ] + if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]) + tf_custom_op_library_additional_deps(),
+ # TODO(laigd)
+ alwayslink = 1, # buildozer: disable=alwayslink-with-hdrs
+)
+
+tf_gen_op_libs(
+ op_lib_names = ["trt_engine_op"],
+ deps = if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]),
+)
+
+tf_cuda_library(
+ name = "trt_logging",
+ srcs = ["log/trt_logger.cc"],
+ hdrs = ["log/trt_logger.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:lib_proto_parsing",
+ ] + if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]),
+)
+
+tf_gen_op_wrapper_py(
+ name = "trt_engine_op",
+ deps = [
+ ":trt_engine_op_op_lib",
+ ":trt_logging",
+ ":trt_shape_function",
+ ],
+)
+
+tf_custom_op_py_library(
+ name = "trt_engine_op_loader",
+ srcs = ["python/ops/trt_engine_op.py"],
+ dso = [
+ ":python/ops/_trt_engine_op.so",
+ ] + if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]),
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:resources",
+ ],
+)
+
+py_library(
+ name = "init_py",
+ srcs = [
+ "__init__.py",
+ "python/__init__.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":trt_convert_py",
+ ":trt_ops_py",
+ ],
+)
+
+py_library(
+ name = "trt_ops_py",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":trt_engine_op",
+ ":trt_engine_op_loader",
+ ],
+)
+
+py_library(
+ name = "trt_convert_py",
+ srcs = ["python/trt_convert.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":wrap_conversion",
+ ],
+)
+
+tf_py_wrap_cc(
+ name = "wrap_conversion",
+ srcs = ["trt_conversion.i"],
+ copts = tf_copts(),
+ deps = [
+ ":trt_conversion",
+ "//tensorflow/core:framework_lite",
+ "//util/python:python_headers",
+ ],
+)
+
+# Library for the node-level conversion portion of TensorRT operation creation
+tf_cuda_library(
+ name = "trt_conversion",
+ srcs = [
+ "convert/convert_graph.cc",
+ "convert/convert_nodes.cc",
+ ],
+ hdrs = [
+ "convert/convert_graph.h",
+ "convert/convert_nodes.h",
+ ],
+ deps = [
+ ":segment",
+ ":trt_logging",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_lite",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:devices",
+ "//tensorflow/core/grappler/clusters:virtual_cluster",
+ "//tensorflow/core/grappler/costs:graph_properties",
+ "//tensorflow/core/grappler/optimizers:constant_folding",
+ "//tensorflow/core/grappler/optimizers:layout_optimizer",
+ ] + if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]) + tf_custom_op_library_additional_deps(),
+)
+
+# Library for the segmenting portion of TensorRT operation creation
+cc_library(
+ name = "segment",
+ srcs = ["segment/segment.cc"],
+ hdrs = [
+ "segment/segment.h",
+ "segment/union_find.h",
+ ],
+ linkstatic = 1,
+ deps = [
+ "//tensorflow/core:graph",
+ "//tensorflow/core:lib_proto_parsing",
+ "//tensorflow/core:protos_all_cc",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+)
+
+tf_cc_test(
+ name = "segment_test",
+ size = "small",
+ srcs = ["segment/segment_test.cc"],
+ deps = [
+ ":segment",
+ "//tensorflow/c:c_api",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/tensorrt/README.md b/tensorflow/contrib/tensorrt/README.md
new file mode 100644
index 0000000000..dfcce0fd00
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/README.md
@@ -0,0 +1,40 @@
+Using TensorRT in TensorFlow
+============================
+
+This module provides necessary bindings and introduces TRT_engine_op
+operator that wraps a subgraph in TensorRT.
+
+Compilation
+-----------
+
+In order to compile the module, you need to have a local TensorRT
+installation (libnvinfer.so and respective include files). During the
+configuration step, TensorRT should be enabled and installation path
+should be set. If installed through package managers (deb,rpm),
+configure script should find the necessary components from the system
+automatically. If installed from tar packages, user has to set path to
+location where the library is installed during configuration.
+
+
+```
+bazel build --config=cuda --config=opt //tensorflow/tools/pip_package:build_pip_package
+bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/
+```
+
+After the installation of tensorflow package, TensorRT transformation
+will be available. An example use is shown below.
+
+```python
+import tensorflow as tf
+import tensorflow.contrib.tensorrt as trt
+#... create and train or load model
+gdef = sess.graph.as_graph_def()
+trt_gdef = trt.create_inference_graph(
+ gdef, #original graph_def
+ ["output"], #name of output node(s)
+ max_batch_size, #maximum batch size to run the inference
+ max_workspace_size_bytes) # max memory for TensorRT to use
+tf.reset_default_graph()
+tf.import_graph_def(graph_def=trt_gdef)
+#...... run inference
+```
diff --git a/tensorflow/contrib/tensorrt/__init__.py b/tensorflow/contrib/tensorrt/__init__.py
new file mode 100644
index 0000000000..fd551d70b4
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/__init__.py
@@ -0,0 +1,23 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Exposes the python wrapper for TensorRT graph transforms."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,wildcard-import
+from tensorflow.contrib.tensorrt.python import *
+# pylint: enable=unused-import,wildcard-import
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
new file mode 100644
index 0000000000..970f810473
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -0,0 +1,273 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
+
+#include <map>
+#include <set>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
+#include "tensorflow/contrib/tensorrt/segment/segment.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/devices.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/constant_folding.h"
+#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/device_properties.pb.h" // NOLINT
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "tensorrt/include/NvInfer.h"
+
+namespace tensorflow {
+namespace tensorrt {
+namespace convert {
+namespace {
+
+static bool IsTensorRTCandidate(const tensorflow::NodeDef& node_def) {
+ // LINT.IfChange
+ // TODO(jie): Segmentation shouldn't associated with op name.
+ // Split it into a registration for each kernel.
+ static const std::set<string> candidate_ops = {
+ "Identity", "Const", "Conv2D", "MaxPool", "BiasAdd", "Relu",
+ "Add", "Mul", "Sub", "Rsqrt", "Pad" // "Placeholder" ,"Mean"
+ };
+ // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h)
+ return candidate_ops.count(node_def.op());
+}
+
+void GetSubGraphIncomingEdges(const tensorflow::Graph& graph,
+ const std::set<int>& subgraph_node_ids,
+ tensorflow::EdgeSet* incoming_edges) {
+ for (int node_id : subgraph_node_ids) {
+ const tensorflow::Node* node = graph.FindNodeId(node_id);
+ for (const tensorflow::Edge* edge : node->in_edges()) {
+ if (!subgraph_node_ids.count(edge->src()->id()) &&
+ !edge->src()->IsSource()) {
+ incoming_edges->insert(edge);
+ }
+ }
+ }
+}
+
+void GetSubGraphOutgoingEdges(const tensorflow::Graph& graph,
+ const std::set<int>& subgraph_node_ids,
+ tensorflow::EdgeSet* outgoing_edges) {
+ for (int node_id : subgraph_node_ids) {
+ const tensorflow::Node* node = graph.FindNodeId(node_id);
+ for (const tensorflow::Edge* edge : node->out_edges()) {
+ if (!subgraph_node_ids.count(edge->dst()->id()) &&
+ !edge->dst()->IsSink()) {
+ outgoing_edges->insert(edge);
+ }
+ }
+ }
+}
+
+std::pair<string, int> ParseTensorName(string name, int default_idx = 0) {
+ int idx = default_idx;
+ size_t sep = name.find_last_of(':');
+ if (sep != string::npos) {
+ name = name.substr(0, sep);
+ idx = std::stoi(name.substr(sep + 1));
+ }
+ return std::make_pair(name, idx);
+}
+
+std::unordered_map<string, std::vector<int>> BuildTensorNameMap(
+ const std::vector<string>& tensor_names) {
+ std::unordered_map<string, std::vector<int>> result;
+ for (string const& tensor_name : tensor_names) {
+ string node_name;
+ int index;
+ std::tie(node_name, index) = ParseTensorName(tensor_name);
+ result[node_name].push_back(index);
+ }
+ return result;
+}
+
+tensorflow::Status ConvertSubGraphToTensorRT(
+ const std::vector<string>& output_names,
+ const std::set<int>& subgraph_node_ids,
+ size_t max_batch_size, // Max batch size that engine will be created for
+ // Max amount of memory that engine will be allowed to consume, in bytes
+ size_t max_workspace_size_bytes,
+ const tensorflow::grappler::GraphProperties& graph_properties,
+ tensorflow::Graph* graph) {
+ tensorflow::EdgeSet subgraph_incoming_edges;
+ GetSubGraphIncomingEdges(*graph, subgraph_node_ids, &subgraph_incoming_edges);
+
+ std::vector<std::pair<int, int>> subgraph_inputs;
+
+ // Collect inputs by looking for incoming edges
+ for (const tensorflow::Edge* edge : subgraph_incoming_edges) {
+ subgraph_inputs.push_back({edge->src()->id(), edge->src_output()});
+ }
+ std::set<std::pair<int, int>> subgraph_outputs_set;
+ // Collect outputs referenced from output_names
+ auto output_name_to_index_map = BuildTensorNameMap(output_names);
+ for (int node_id : subgraph_node_ids) {
+ tensorflow::Node* node = graph->FindNodeId(node_id);
+ if (output_name_to_index_map.count(node->name())) {
+ for (int index : output_name_to_index_map.at(node->name())) {
+ subgraph_outputs_set.insert({node_id, index});
+ }
+ }
+ }
+ // Collect outputs referenced from outgoing edges
+ tensorflow::EdgeSet subgraph_outgoing_edges;
+ GetSubGraphOutgoingEdges(*graph, subgraph_node_ids, &subgraph_outgoing_edges);
+ for (const tensorflow::Edge* edge : subgraph_outgoing_edges) {
+ subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()});
+ }
+ // Impose an ordering on the outputs
+ std::vector<std::pair<int, int>> subgraph_outputs(
+ subgraph_outputs_set.begin(), subgraph_outputs_set.end());
+ // Build TensorRT node and add it to the graph
+ tensorflow::NodeDef trt_node_def;
+ TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef(
+ *graph, subgraph_node_ids, subgraph_inputs, subgraph_outputs,
+ max_batch_size, max_workspace_size_bytes, graph_properties,
+ &trt_node_def));
+ tensorflow::Status status;
+ tensorflow::Node* trt_node = graph->AddNode(trt_node_def, &status);
+ TF_RETURN_IF_ERROR(status);
+
+ // Re-map outgoing edges to use the new TRT node instead of the orig subgraph
+ std::map<std::pair<int, int>, int> subgraph_edge_to_output_map;
+ for (size_t i = 0; i < subgraph_outputs.size(); ++i) {
+ subgraph_edge_to_output_map.insert({subgraph_outputs.at(i), i});
+ }
+ TF_RETURN_IF_ERROR(status);
+ for (const tensorflow::Edge* edge : subgraph_outgoing_edges) {
+ std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()};
+ int new_src_output = subgraph_edge_to_output_map.at(old_src);
+ TF_RETURN_IF_ERROR(graph->UpdateEdge(trt_node, new_src_output, edge->dst(),
+ edge->dst_input()));
+ }
+ // Remove the original subgraph
+ for (int node_id : subgraph_node_ids) {
+ tensorflow::Node* node = graph->FindNodeId(node_id);
+ // Don't remove the input placeholders
+ if (node->type_string() == "Placeholder") {
+ continue;
+ }
+ graph->RemoveNode(node);
+ }
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status BuildNodeMap(
+ const tensorflow::Graph& graph,
+ std::unordered_map<string, tensorflow::Node*>* node_map) {
+ for (auto* node : graph.op_nodes()) {
+ if (!node_map->insert({node->name(), node}).second) {
+ return tensorflow::errors::AlreadyExists(
+ "Node name is not unique in graph: " + node->name());
+ }
+ }
+ return tensorflow::Status::OK();
+}
+
+} // namespace
+
+tensorflow::Status ConvertGraphDefToTensorRT(
+ const tensorflow::GraphDef& graph_def,
+ const std::vector<string>& output_names, size_t max_batch_size,
+ size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def) {
+ // Optimization pass
+ tensorflow::grappler::GrapplerItem item;
+ item.fetch = output_names;
+ tensorflow::GraphDef gdef;
+
+ // Layout optimization
+ item.graph = graph_def;
+ tensorflow::grappler::LayoutOptimizer optimizer;
+ tensorflow::grappler::Cluster* cluster;
+
+ // Virtual cluster
+ tensorflow::DeviceProperties device_properties;
+ device_properties.set_type("GPU");
+ device_properties.mutable_environment()->insert({"architecture", "6"});
+ cluster =
+ new tensorflow::grappler::VirtualCluster({{"/GPU:0", device_properties}});
+
+ TF_RETURN_IF_ERROR(optimizer.Optimize(cluster, item, &gdef));
+
+ // Constant folding
+ item.graph = gdef;
+ tensorflow::grappler::ConstantFolding fold(nullptr);
+ TF_RETURN_IF_ERROR(fold.Optimize(nullptr, item, &gdef));
+
+ // AJ refactoring shape inference through grappler/GraphProperties.
+ tensorflow::grappler::GraphProperties static_graph_properties(item);
+ TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(false));
+
+ // Build full graph
+ tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
+ gdef.library());
+ tensorflow::Graph graph(flib);
+ TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
+ tensorflow::GraphConstructorOptions(), gdef, &graph));
+
+ // Segment the graph into subgraphs that can be converted to TensorRT
+ tensorflow::tensorrt::segment::SegmentOptions segment_options;
+
+ // TODO(ben,jie,sami): exclude output nodes (DISCUSS IT)
+ for (auto node : output_names) {
+ segment_options.exclude_node_list.insert(node);
+ }
+
+ // TODO(sami): this should be passed as a knob!!!!
+ segment_options.minimum_segment_size = 2;
+ tensorflow::tensorrt::segment::SegmentNodesVector segments;
+ TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph(
+ gdef, IsTensorRTCandidate, segment_options, &segments));
+ if (segments.size() > 1) {
+ VLOG(0) << "MULTIPLE tensorrt candidate conversion: " << segments.size();
+ }
+ std::unordered_map<string, tensorflow::Node*> node_map;
+ TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
+ for (const std::set<string>& subgraph_node_names : segments) {
+ std::set<int> subgraph_node_ids;
+ for (const string& node_name : subgraph_node_names) {
+ subgraph_node_ids.insert(node_map.at(node_name)->id());
+ }
+ TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRT(
+ output_names, subgraph_node_ids, max_batch_size,
+ max_workspace_size_bytes, static_graph_properties, &graph));
+ }
+ graph.ToGraphDef(new_graph_def);
+ return tensorflow::Status::OK();
+}
+
+} // namespace convert
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h
new file mode 100644
index 0000000000..154ad3f2e8
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_
+
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+
+namespace tensorflow {
+namespace tensorrt {
+namespace convert {
+
+// max_batch_size: maximum batch size which can be used for inference for
+// optimization targets inference run with max batch size.
+// max_workspace_size_bytes: The upper bound of memory allowence for
+// engine building.
+tensorflow::Status ConvertGraphDefToTensorRT(
+ const tensorflow::GraphDef& graph_def,
+ const std::vector<string>& output_names, size_t max_batch_size,
+ size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def);
+
+} // namespace convert
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_GRAPH_H_
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
new file mode 100644
index 0000000000..4003ba056d
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -0,0 +1,1601 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
+
+#include <algorithm>
+#include <list>
+#include <map>
+#include <memory>
+#include <set>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h" // NOLINT
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/tensor_coding.h"
+#include "tensorflow/core/platform/types.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+#include "tensorrt/include/NvInfer.h"
+
+// Check if the types are equal. Cast to int first so that failure log message
+// would work!
+#define CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2)
+
+namespace tensorflow {
+namespace tensorrt {
+namespace convert {
+
+namespace {
+
+inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype,
+ nvinfer1::DataType* trt_dtype) {
+ switch (tf_dtype) {
+ case tensorflow::DataType::DT_FLOAT:
+ *trt_dtype = nvinfer1::DataType::kFLOAT;
+ break;
+ case tensorflow::DataType::DT_INT8:
+ *trt_dtype = nvinfer1::DataType::kINT8;
+ break;
+ case tensorflow::DataType::DT_HALF:
+ *trt_dtype = nvinfer1::DataType::kHALF;
+ break;
+ default:
+ return tensorflow::errors::InvalidArgument("Unsupported data type");
+ }
+ return tensorflow::Status::OK();
+}
+
+inline nvinfer1::Dims GetTensorShape(const tensorflow::Tensor& tensor) {
+ nvinfer1::Dims dims;
+ dims.nbDims = tensor.dims();
+ for (int i = 0; i < dims.nbDims; i++) {
+ dims.d[i] = tensor.dim_size(i);
+ }
+ return dims;
+}
+
+inline int64_t GetShapeSize(nvinfer1::Dims shape) {
+ // Returns total number of elements in shape
+ int64_t count = 1;
+ for (int d = 0; d < shape.nbDims; ++d) {
+ count *= shape.d[d];
+ }
+ return count;
+}
+
+static std::vector<std::pair<int, int>> CreateSamePadding(
+ const nvinfer1::DimsHW& stride, const nvinfer1::DimsHW& kernel,
+ const std::vector<int64_t>& input_dims) {
+ std::vector<std::pair<int, int>> padding(input_dims.size());
+ CHECK_EQ((size_t)stride.nbDims, input_dims.size()); // TODO(jie): N+C? NC+?
+
+ for (size_t i = 0; i < input_dims.size(); ++i) {
+ // Formula to calculate the padding
+ int p = ((input_dims[i] - 1) / stride.d[i]) * stride.d[i] + kernel.d[i] -
+ input_dims[i];
+ p = (p > 0) ? p : 0;
+
+ // Right precedence padding, like in TensorFlow
+ int left = p / 2;
+ int right = p - left;
+
+ VLOG(2) << "PADDING_" << i << " pre: " << left << ", post: " << right
+ << "paras: " << input_dims[i] << ", " << stride.d[i] << ", "
+ << "kernel: " << kernel.d[i];
+ padding[i] = {left, right};
+ }
+ return padding;
+}
+
+class TRT_ShapedWeights {
+ public:
+ TRT_ShapedWeights(tensorflow::DataType type, const void* values,
+ nvinfer1::Dims shape)
+ : shape_(shape), type_(type), values_(values), empty_weight_flag_(false) {
+ // Note: this->shape.type[] is not used
+ }
+
+ explicit TRT_ShapedWeights(tensorflow::DataType type)
+ : shape_(), type_(type), values_(nullptr), empty_weight_flag_(true) {}
+
+ TRT_ShapedWeights(const TRT_ShapedWeights& rhs)
+ : shape_(rhs.shape_),
+ type_(rhs.type_),
+ values_(rhs.values_),
+ empty_weight_flag_(rhs.empty_weight_flag_) {}
+
+ int64_t count() const {
+ int64_t c = 1;
+ for (int i = 0; i < shape_.nbDims; i++) c *= shape_.d[i];
+ return c;
+ }
+
+ nvinfer1::Weights GetWeightsForTRT() const {
+ nvinfer1::DataType trt_type(nvinfer1::DataType::kFLOAT);
+ TF_CHECK_OK(ConvertDType(type_, &trt_type));
+ if (empty_weight_flag_) return nvinfer1::Weights{trt_type, nullptr, 0};
+
+ // Note: this->shape.type[] is not used
+ return nvinfer1::Weights{trt_type, GetValues(), GetShapeSize(shape_)};
+ }
+
+ const void* GetValues() const { return values_; }
+
+ void SetValues(const void* values) { values_ = values; }
+
+ size_t size_bytes() const {
+ int type_size = tensorflow::DataTypeSize(this->type_);
+ return this->count() * type_size;
+ }
+
+ // Default converter
+ operator nvinfer1::Weights() const { return GetWeightsForTRT(); }
+
+ nvinfer1::Dims shape_;
+ tensorflow::DataType type_;
+
+ private:
+ const void* values_;
+ bool empty_weight_flag_;
+};
+
+class TRT_TensorOrWeights {
+ public:
+ explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor)
+ : tensor_(tensor), weights_(DT_FLOAT), variant_(TRT_NODE_TENSOR) {}
+ explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights)
+ : tensor_(nullptr), weights_(weights), variant_(TRT_NODE_WEIGHTS) {}
+ TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs)
+ : tensor_(rhs.tensor_), weights_(rhs.weights_), variant_(rhs.variant_) {}
+ ~TRT_TensorOrWeights() {}
+
+ bool is_tensor() const { return variant_ == TRT_NODE_TENSOR; }
+ bool is_weights() const { return variant_ == TRT_NODE_WEIGHTS; }
+
+ nvinfer1::ITensor* tensor() {
+ CHECK_EQ(is_tensor(), true);
+ return tensor_;
+ }
+ const nvinfer1::ITensor* tensor() const {
+ CHECK_EQ(is_tensor(), true);
+ return tensor_;
+ }
+ TRT_ShapedWeights& weights() {
+ CHECK_EQ(is_weights(), true);
+ return weights_;
+ }
+ const TRT_ShapedWeights& weights() const {
+ CHECK_EQ(is_weights(), true);
+ return weights_;
+ }
+ nvinfer1::Dims shape() const {
+ if (is_tensor()) {
+ return tensor()->getDimensions();
+ } else {
+ return weights().shape_;
+ }
+ }
+
+ private:
+ nvinfer1::ITensor* tensor_;
+ TRT_ShapedWeights weights_;
+ enum { TRT_NODE_TENSOR, TRT_NODE_WEIGHTS } variant_;
+};
+
+class TFAttrs {
+ public:
+ explicit TFAttrs(const tensorflow::NodeDef& tf_node) {
+ for (const auto& attr : tf_node.attr()) {
+ attrs_.insert({attr.first, &attr.second});
+ }
+ }
+ bool count(string key) const { return attrs_.count(key); }
+ tensorflow::AttrValue const* at(string key) const {
+ if (!attrs_.count(key)) {
+ LOG(FATAL) << "Attribute not found: " << key;
+ }
+ return attrs_.at(key);
+ }
+ template <typename T>
+ T get(string key) const;
+ template <typename T>
+ T get(string key, const T& default_value) const {
+ return attrs_.count(key) ? this->get<T>(key) : default_value;
+ }
+
+ private:
+ typedef std::map<string, tensorflow::AttrValue const*> AttrMap;
+ AttrMap attrs_;
+};
+
+template <>
+string TFAttrs::get<string>(string key) const {
+ return this->at(key)->s();
+}
+
+template <>
+std::vector<int> TFAttrs::get<std::vector<int>>(string key) const {
+ auto attr = this->at(key)->list().i();
+ return std::vector<int>(attr.begin(), attr.end());
+}
+
+template <>
+nvinfer1::Dims TFAttrs::get<nvinfer1::Dims>(string key) const {
+ auto values = this->get<std::vector<int>>(key);
+ nvinfer1::Dims dims;
+ dims.nbDims = values.size();
+ std::copy(values.begin(), values.end(), dims.d);
+ // Note: No dimension type information is included
+ return dims;
+}
+
+template <>
+nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(string key) const {
+ nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT);
+ TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype));
+ return trt_dtype;
+}
+
+template <>
+tensorflow::DataType TFAttrs::get<tensorflow::DataType>(string key) const {
+ return this->at(key)->type();
+}
+
+template <typename T>
+void Reorder4(nvinfer1::DimsNCHW shape, const T* idata,
+ nvinfer1::DimsNCHW istrides, T* odata,
+ nvinfer1::DimsNCHW ostrides) {
+ for (int n = 0; n < shape.n(); ++n) {
+ for (int c = 0; c < shape.c(); ++c) {
+ for (int h = 0; h < shape.h(); ++h) {
+ for (int w = 0; w < shape.w(); ++w) {
+ odata[n * ostrides.n() + c * ostrides.c() + h * ostrides.h() +
+ w * ostrides.w()] = idata[n * istrides.n() + c * istrides.c() +
+ h * istrides.h() + w * istrides.w()];
+ }
+ }
+ }
+ }
+}
+
+void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
+ TRT_ShapedWeights* oweights) {
+ CHECK_EQ(iweights.type_, oweights->type_);
+ CHECK_EQ(iweights.size_bytes(), oweights->size_bytes());
+ int r = iweights.shape_.d[0];
+ int s = iweights.shape_.d[1];
+ int c = iweights.shape_.d[2];
+ int k = iweights.shape_.d[3];
+ oweights->shape_.d[0] = k;
+ oweights->shape_.d[1] = c;
+ oweights->shape_.d[2] = r;
+ oweights->shape_.d[3] = s;
+ nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k};
+ nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1};
+ switch (iweights.type_) {
+ case tensorflow::DataType::DT_FLOAT:
+ Reorder4({k, c, r, s}, static_cast<float const*>(iweights.GetValues()),
+ istrides,
+ static_cast<float*>(const_cast<void*>(oweights->GetValues())),
+ ostrides);
+ break;
+ default:
+ LOG(FATAL) << "!!!!!!!!!!!!!!!!!!!!!!!!broke!!!!!!!!!!!!";
+ }
+}
+
+struct InferDeleter {
+ template <typename T>
+ void operator()(T* obj) const {
+ if (obj) {
+ obj->destroy();
+ }
+ }
+};
+
+template <typename T>
+inline std::shared_ptr<T> infer_object(T* obj) {
+ return std::shared_ptr<T>(obj, InferDeleter());
+}
+
+// Logger for GIE info/warning/errors
+class Converter;
+
+using OpConverter =
+ std::function<tensorflow::Status(Converter&, const tensorflow::NodeDef&,
+ std::vector<TRT_TensorOrWeights> const&,
+ std::vector<TRT_TensorOrWeights>*)>;
+
+class Converter {
+ std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_;
+ std::unordered_map<string, OpConverter> op_registry_;
+ nvinfer1::INetworkDefinition* trt_network_;
+ std::list<std::vector<uint8_t>> temp_bufs_;
+
+ void register_op_converters();
+
+ std::vector<TRT_TensorOrWeights> get_inputs(
+ const tensorflow::NodeDef& node_def) {
+ std::vector<TRT_TensorOrWeights> inputs;
+ for (const auto& input_name : node_def.input()) {
+ VLOG(2) << "Retrieve input: " << input_name;
+ inputs.push_back(trt_tensors_.at(input_name));
+ }
+ return inputs;
+ }
+
+ public:
+ explicit Converter(nvinfer1::INetworkDefinition* trt_network)
+ : trt_network_(trt_network) {
+ this->register_op_converters();
+ }
+
+ TRT_ShapedWeights get_temp_weights(tensorflow::DataType type,
+ nvinfer1::Dims shape) {
+ TRT_ShapedWeights weights(type, nullptr, shape);
+ // TODO(jie): check weights size_bytes. 0 means type error
+ temp_bufs_.push_back(std::vector<uint8_t>(weights.size_bytes()));
+ weights.SetValues(temp_bufs_.back().data());
+ return weights;
+ }
+
+ TRT_ShapedWeights get_temp_weights_like(const TRT_ShapedWeights& weights) {
+ return this->get_temp_weights(weights.type_, weights.shape_);
+ }
+
+ tensorflow::Status convert_node(const tensorflow::NodeDef& node_def) {
+ std::vector<TRT_TensorOrWeights> inputs = this->get_inputs(node_def);
+ string op = node_def.op();
+ if (!op_registry_.count(op)) {
+ return tensorflow::errors::Unimplemented(
+ "No converter registered for op: " + op);
+ }
+ OpConverter op_converter = op_registry_.at(op);
+ std::vector<TRT_TensorOrWeights> outputs;
+ TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs));
+ for (size_t i = 0; i < outputs.size(); ++i) {
+ TRT_TensorOrWeights output = outputs.at(i);
+ // TODO(jie): tf protobuf seems to be omitting the :0 suffix
+ string output_name = node_def.name();
+ if (i != 0) output_name = output_name + ":" + std::to_string(i);
+ if (output.is_tensor()) {
+ output.tensor()->setName(output_name.c_str());
+ }
+ VLOG(2) << "Write out tensor: " << output_name;
+ if (!trt_tensors_.insert({output_name, output}).second) {
+ return tensorflow::errors::AlreadyExists(
+ "Output tensor already exists for op: " + op);
+ }
+ }
+ return tensorflow::Status::OK();
+ }
+
+ nvinfer1::INetworkDefinition* network() { return trt_network_; }
+
+ TRT_TensorOrWeights get_tensor(string name) {
+ if (!trt_tensors_.count(name)) {
+ return TRT_TensorOrWeights(nullptr);
+ }
+ return trt_tensors_.at(name);
+ }
+
+ bool insert_input_tensor(string name, nvinfer1::ITensor* tensor) {
+ return trt_tensors_.insert({name, TRT_TensorOrWeights(tensor)}).second;
+ }
+
+ nvinfer1::ITensor* TransposeTensor(nvinfer1::ITensor* input_tensor,
+ std::vector<int> order) {
+ auto dims = input_tensor->getDimensions();
+
+ // TODO(jie): change the return to status and properly exit
+ if (order.size() - 1 != size_t(dims.nbDims))
+ LOG(ERROR) << "Dimension does not match, fail gracefully";
+
+ nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input_tensor);
+ nvinfer1::Permutation permutation;
+ for (int32_t i = 0; i < dims.nbDims; ++i) {
+ permutation.order[i] = order[i + 1] - 1;
+ }
+ layer->setFirstTranspose(permutation);
+
+ nvinfer1::Dims reshape_dims;
+ reshape_dims.nbDims = dims.nbDims;
+ for (int32_t i = 0; i < reshape_dims.nbDims; ++i) {
+ reshape_dims.d[i] = 0;
+ reshape_dims.type[i] = dims.type[i];
+ }
+ layer->setReshapeDimensions(reshape_dims);
+ return layer->getOutput(0);
+ }
+};
+
+// ****************************************************************************
+// Constant folding functions
+// TODO(jie): once optimizer kicks in, we should have done constant folding
+// there.
+//*****************************************************************************/
+struct LambdaFactory {
+ enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB };
+ OP_CATEGORY op;
+
+ template <typename T>
+ std::function<T(T)> unary() {
+ switch (op) {
+ case OP_CATEGORY::RSQRT: {
+ VLOG(2) << "RSQRT GETS DONE";
+ return [](T t) -> T { return 1.0 / std::sqrt(t); };
+ }
+ case OP_CATEGORY::NEG:
+ return [](T t) -> T { return -t; };
+ default:
+ VLOG(2) << "Not supported op for unary: " << static_cast<int>(op);
+ return nullptr;
+ }
+ }
+
+ template <typename T>
+ std::function<T(T, T)> binary() {
+ switch (op) {
+ case OP_CATEGORY::ADD:
+ return [](T l, T r) -> T { return l + r; };
+ case OP_CATEGORY::SUB:
+ return [](T l, T r) -> T { return l - r; };
+ case OP_CATEGORY::MUL:
+ return [](T l, T r) -> T { return l * r; };
+ default:
+ LOG(WARNING) << "Not supported op for binary: " << static_cast<int>(op);
+ }
+ return [](T l, T r) -> T {
+ LOG(FATAL) << "Unsupported op type ";
+ return l;
+ };
+ }
+
+ template <typename T>
+ std::function<T(T)> broadcast_r(T val) {
+ VLOG(2) << "LAMBDA VAL : " << val;
+ switch (op) {
+ case OP_CATEGORY::ADD:
+ return [val](T l) -> T {
+ VLOG(2) << "LAMBDA VAL : " << val;
+ return l + val;
+ };
+ // Return [val](T l)-> T {return l+val;};
+ case OP_CATEGORY::SUB:
+ return [val](T l) -> T {
+ VLOG(2) << "LAMBDA VAL : " << val;
+ return l - val;
+ };
+ case OP_CATEGORY::MUL:
+ return [val](T l) -> T {
+ VLOG(2) << "LAMBDA VAL : " << val;
+ return l * val;
+ };
+ default:
+ LOG(WARNING) << "Not supported op for binary: " << static_cast<int>(op);
+ }
+ return [val](T l) -> T {
+ LOG(FATAL) << "Unsupported op type ";
+ return l;
+ };
+ }
+
+ template <typename T>
+ std::function<T(T)> broadcast_l(T val) {
+ VLOG(2) << "LAMBDA VAL : " << val;
+ switch (op) {
+ case OP_CATEGORY::ADD:
+ return [val](T l) -> T {
+ VLOG(2) << "LAMBDA VAL : " << val;
+ return val + l;
+ };
+ case OP_CATEGORY::SUB:
+ return [val](T l) -> T {
+ VLOG(2) << "LAMBDA VAL : " << val;
+ return val - l;
+ };
+ case OP_CATEGORY::MUL:
+ return [val](T l) -> T {
+ VLOG(2) << "LAMBDA VAL : " << val;
+ return val * l;
+ };
+ default:
+ LOG(ERROR) << "Not supported op for binary: " << static_cast<int>(op);
+ }
+ return [val](T l) -> T {
+ LOG(FATAL) << "Unsupported op type ";
+ return l;
+ };
+ }
+};
+
+tensorflow::Status UnaryCompute(const TRT_ShapedWeights& iweights,
+ TRT_ShapedWeights* oweights,
+ LambdaFactory unary_op) {
+ CHECK_EQ(iweights.type_, oweights->type_);
+ switch (iweights.type_) {
+ case tensorflow::DataType::DT_FLOAT: {
+ auto inp = static_cast<float const*>(iweights.GetValues());
+ auto oup = static_cast<float*>(const_cast<void*>(oweights->GetValues()));
+ std::transform(inp, inp + iweights.count(), oup, unary_op.unary<float>());
+ break;
+ }
+ default:
+ return tensorflow::errors::Unimplemented(
+ "Data type not supported: " +
+ tensorflow::DataTypeString(iweights.type_));
+ }
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status BinaryCompute(const TRT_ShapedWeights& iweights_l,
+ const TRT_ShapedWeights& iweights_r,
+ TRT_ShapedWeights* oweights,
+ LambdaFactory binary_op) {
+ // Assume iweights_l.type == iweight_r.type
+ CHECK_EQ(iweights_l.type_, oweights->type_);
+ CHECK_EQ(iweights_r.type_, oweights->type_);
+ VLOG(2) << "SANITY CHECK!";
+
+ switch (iweights_l.type_) {
+ case tensorflow::DataType::DT_FLOAT: {
+ auto inp_l = static_cast<const float*>(iweights_l.GetValues());
+ auto inp_r = static_cast<const float*>(iweights_r.GetValues());
+ auto oup = static_cast<float*>(const_cast<void*>(oweights->GetValues()));
+
+ if (iweights_l.count() != iweights_r.count()) {
+ // We only supports broadcast of RankZero
+ if (iweights_l.count() == 1) {
+ VLOG(2) << "I bet it is not working!" << (*inp_l);
+ std::transform(inp_r, inp_r + iweights_r.count(), oup,
+ binary_op.broadcast_l<float>(*inp_l));
+ } else if (iweights_r.count() == 1) {
+ VLOG(2) << "I bet it is not working!" << (*inp_r);
+ std::transform(inp_l, inp_l + iweights_l.count(), oup,
+ binary_op.broadcast_r<float>(*inp_r));
+ } else {
+ return tensorflow::errors::Unimplemented(
+ "Binary op with non-rankZero broadcast not supported");
+ }
+ } else {
+ std::transform(inp_l, inp_l + iweights_l.count(), inp_r, oup,
+ binary_op.binary<float>());
+ }
+ break;
+ }
+ default:
+ return tensorflow::errors::Unimplemented(
+ "Data type not supported: " +
+ tensorflow::DataTypeString(iweights_l.type_));
+ }
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConstantFoldUnary(
+ Converter& ctx, const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ TRT_ShapedWeights weights_input = inputs.at(0).weights();
+
+ // Allocate output weights
+ TRT_ShapedWeights weights_output = ctx.get_temp_weights_like(weights_input);
+
+ // FIXME assume type matches input weights
+ // Get trt type & shape
+ // Maybe this part has to be moved into the block of rsqrt later
+ // Check type consistency
+ CHECK_EQ(weights_input.type_,
+ TFAttrs(node_def).get<tensorflow::DataType>("T"));
+
+ // Maybe I should do a switch
+ LambdaFactory unary_op;
+ if (node_def.op() == "Rsqrt") {
+ // Compute rsqrt
+ unary_op.op = LambdaFactory::OP_CATEGORY::RSQRT;
+ auto ret = UnaryCompute(weights_input, &weights_output, unary_op);
+ // PAss the output
+ if (ret == tensorflow::Status::OK()) {
+ outputs->push_back(TRT_TensorOrWeights(weights_output));
+ }
+ return ret;
+ } else {
+ return tensorflow::errors::Unimplemented("Binary op not supported: " +
+ node_def.op());
+ }
+}
+
+// TODO(jie,ben) broadcast is needed yet not implemented
+// Let's get the simple stuff working first. Maybe we should fall bakc to TF
+// approach for constant folding
+tensorflow::Status ConstantFoldBinary(
+ Converter& ctx, const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ TRT_ShapedWeights weights_input_l = inputs.at(0).weights();
+ TRT_ShapedWeights weights_input_r = inputs.at(1).weights();
+
+ // Check type consistency
+ CHECK_EQ(weights_input_l.type_, weights_input_r.type_);
+
+ if (weights_input_l.shape_.nbDims != weights_input_r.shape_.nbDims)
+ return tensorflow::errors::Unimplemented(
+ "Binary op implicit broadcast not supported: " + node_def.op());
+
+ // TODO(jie): constant fold should really fall back to TF.
+ int nb_dims = weights_input_l.shape_.nbDims;
+ nvinfer1::Dims output_shape;
+ output_shape.nbDims = nb_dims;
+ VLOG(2) << "nb_dims: " << nb_dims
+ << ", the other: " << weights_input_r.shape_.nbDims;
+ for (int i = 0; i < nb_dims; i++) {
+ if (weights_input_l.shape_.d[i] == weights_input_r.shape_.d[i]) {
+ output_shape.d[i] = weights_input_l.shape_.d[i];
+ } else if (weights_input_l.shape_.d[i] == 1 ||
+ weights_input_r.shape_.d[i] == 1) {
+ output_shape.d[i] =
+ std::max(weights_input_l.shape_.d[i], weights_input_r.shape_.d[i]);
+ } else {
+ return tensorflow::errors::Unimplemented(
+ "Binary op with incompatible shape at, " + node_def.op());
+ }
+ VLOG(2) << "left: " << weights_input_l.shape_.d[i]
+ << "right: " << weights_input_r.shape_.d[i]
+ << "output: " << output_shape.d[i];
+ }
+
+ // FIXME assume type matches input weights
+ // Get trt type & shape
+ TFAttrs attrs(node_def);
+ // Maybe this part has to be moved into the block of rsqrt later
+ tensorflow::DataType dtype = attrs.get<tensorflow::DataType>("T");
+
+ // Allocate output weights
+ TRT_ShapedWeights weights_output = ctx.get_temp_weights(dtype, output_shape);
+
+ // Maybe I should do a switch
+ LambdaFactory binary_op;
+ if (node_def.op() == "Sub") {
+ binary_op.op = LambdaFactory::OP_CATEGORY::SUB;
+ } else if (node_def.op() == "Mul") {
+ binary_op.op = LambdaFactory::OP_CATEGORY::MUL;
+ } else if (node_def.op() == "Add") {
+ binary_op.op = LambdaFactory::OP_CATEGORY::ADD;
+ } else {
+ return tensorflow::errors::Unimplemented("Binary op not supported: " +
+ node_def.op());
+ }
+ auto ret = BinaryCompute(weights_input_l, weights_input_r, &weights_output,
+ binary_op);
+
+ // Pass the output
+ if (ret == tensorflow::Status::OK()) {
+ outputs->push_back(TRT_TensorOrWeights(weights_output));
+ }
+
+ return ret;
+}
+
+// TODO(jie): broadcast is needed yet not implemented.
+// Only implemented channel wise for the time being
+tensorflow::Status BinaryTensorOpWeight(
+ Converter& ctx, const tensorflow::NodeDef& node_def,
+ const nvinfer1::ITensor* tensor, TRT_ShapedWeights weights,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ // FIXME assume type matches input weights
+ // Get trt type & shape
+ // Maybe this part has to be moved into the block of rsqrt later
+
+ // Check type consistency
+ auto dtype = TFAttrs(node_def).get<nvinfer1::DataType>("T");
+ CHECK_EQ_TYPE(tensor->getType(), dtype); // Cast to int for error messages
+ nvinfer1::DataType ttype;
+ TF_CHECK_OK(ConvertDType(weights.type_, &ttype));
+ CHECK_EQ_TYPE(ttype, dtype); // Cast to int for error message
+
+ // Check scale mode
+ auto dims_w = weights.shape_;
+ auto dims_t = tensor->getDimensions();
+
+ // Default to channel-wise
+ auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
+
+ if (weights.count() == 1) {
+ VLOG(2) << "UNIFORM";
+ scale_mode = nvinfer1::ScaleMode::kUNIFORM;
+ } else {
+ // No broadcasting on Batch dimension;
+ assert(dims_w.d[0] == 1);
+
+ // Broadcasting on Channel dimension only allowed in kUNIFORM
+ assert(dims_w.d[1] == dims_t.d[0]);
+ assert(dims_w.nbDims == dims_t.nbDims);
+
+ // Default is element;
+ for (int i = 2; i < dims_w.nbDims; i++) {
+ if (dims_w.d[i] != dims_t.d[i - 1]) {
+ scale_mode = nvinfer1::ScaleMode::kCHANNEL;
+ break;
+ }
+ }
+ if (scale_mode == nvinfer1::ScaleMode::kELEMENTWISE) {
+ scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
+ for (int i = 2; i < dims_w.nbDims; i++) {
+ if (dims_w.d[i] != 1)
+ return tensorflow::errors::InvalidArgument(
+ "Weight shape not compatible at, " + node_def.name());
+ }
+ }
+ }
+
+ // Prepare weights
+ TRT_ShapedWeights shift_weights(weights.type_);
+ TRT_ShapedWeights scale_weights(weights.type_);
+ TRT_ShapedWeights power_weights(weights.type_);
+
+ // Maybe I should do a switch
+ if (node_def.op() == "Sub") {
+ TRT_ShapedWeights neg_weights = ctx.get_temp_weights_like(weights);
+ LambdaFactory unary_op;
+ unary_op.op = LambdaFactory::OP_CATEGORY::NEG;
+ TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op));
+ shift_weights = neg_weights;
+ } else if (node_def.op() == "Mul") {
+ scale_weights = weights;
+ } else if (node_def.op() == "Add") {
+ shift_weights = weights;
+ } else {
+ return tensorflow::errors::Unimplemented("Binary op not supported: " +
+ node_def.op());
+ }
+
+ nvinfer1::IScaleLayer* layer = ctx.network()->addScale(
+ *const_cast<nvinfer1::ITensor*>(tensor), scale_mode, shift_weights,
+ scale_weights, power_weights);
+
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+
+ // Pass the output
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status BinaryTensorOpTensor(
+ Converter& ctx, const tensorflow::NodeDef& node_def,
+ const nvinfer1::ITensor* tensor_l, const nvinfer1::ITensor* tensor_r,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ static const std::unordered_map<string, nvinfer1::ElementWiseOperation> ops{
+ {"Add", nvinfer1::ElementWiseOperation::kSUM},
+ {"Mul", nvinfer1::ElementWiseOperation::kPROD},
+ // {"max", nvinfer1::ElementWiseOperation::kMAX},
+ // {"min", nvinfer1::ElementWiseOperation::kMIN},
+ {"Sub", nvinfer1::ElementWiseOperation::kSUB},
+ {"Div", nvinfer1::ElementWiseOperation::kDIV},
+ };
+
+ // FIXME assume type matches input weights
+ // Get trt type & shape
+ TFAttrs attrs(node_def);
+ // Maybe this part has to be moved into the block of rsqrt later
+ nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("T");
+
+ // Check type consistency
+ CHECK_EQ_TYPE(tensor_l->getType(), dtype);
+ CHECK_EQ_TYPE(tensor_r->getType(), dtype);
+ auto op_pair = ops.find(node_def.op());
+ if (op_pair == ops.end())
+ return tensorflow::errors::Unimplemented(
+ "binary op: " + node_def.op() +
+ " not supported at: " + node_def.name());
+
+ nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise(
+ *const_cast<nvinfer1::ITensor*>(tensor_l),
+ *const_cast<nvinfer1::ITensor*>(tensor_r), op_pair->second);
+
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+
+ // Pass the output
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertPlaceholder(
+ Converter& ctx, const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ VLOG(2) << "Placeholder should have been replace already";
+ return tensorflow::errors::Unimplemented(", cannot convert Placeholder op");
+ // OK this make sense since we are supposed to replace it with input
+ TFAttrs attrs(node_def);
+ nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("dtype");
+ nvinfer1::Dims dims = attrs.get<nvinfer1::Dims>("shape");
+
+ dims.nbDims--;
+ for (int i = 0; i < dims.nbDims; i++) dims.d[i] = dims.d[i + 1];
+
+ nvinfer1::ITensor* output =
+ ctx.network()->addInput(node_def.name().c_str(), dtype, dims);
+ if (!output) {
+ return tensorflow::errors::InvalidArgument("Failed to create Input layer");
+ }
+ outputs->push_back(TRT_TensorOrWeights(output));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertConv2D(Converter& ctx,
+ const tensorflow::NodeDef& node_def,
+ const std::vector<TRT_TensorOrWeights>& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+ // TODO(jie): handle NHWC/NCHW transpose;
+ TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
+ TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_rsck);
+ ReorderRSCKToKCRS(weights_rsck, &weights);
+ TRT_ShapedWeights biases(weights.type_);
+ int noutput = weights.shape_.d[0];
+ nvinfer1::DimsHW kernel_size;
+ kernel_size.h() = weights.shape_.d[2];
+ kernel_size.w() = weights.shape_.d[3];
+ TFAttrs attrs(node_def);
+
+ int h_index = 2;
+ int w_index = 3;
+ auto data_format = attrs.get<string>("data_format");
+ if (data_format == "NHWC") {
+ tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
+ {0, 3, 1, 2});
+ h_index = 1;
+ w_index = 2;
+ // TODO(jie): transpose it
+ }
+
+ // TODO(jie): stride. (NHWC/NCHW)
+ auto tf_stride = attrs.get<std::vector<int>>("strides");
+ nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
+
+ auto tensor_dim = tensor->getDimensions();
+ std::vector<std::pair<int, int>> padding;
+ // TODO(jie): padding.
+ if (attrs.get<string>("padding") == "SAME") {
+ // This is NCHW tensor with no batch dimension.
+ // 1 -> h
+ // 2 -> w
+ padding = CreateSamePadding(
+ stride, kernel_size,
+ {static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])});
+ } else {
+ padding = {{0, 0}, {0, 0}};
+ }
+
+ if (padding[0].first != padding[0].second ||
+ padding[1].first != padding[1].second) {
+ // TODO(jie): handle asymmetric padding
+ VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second
+ << padding[1].first << padding[1].second;
+
+ auto dim_before = tensor->getDimensions();
+ VLOG(2) << "TENSOR before: " << dim_before.d[0] << ", " << dim_before.d[1]
+ << dim_before.d[2] << ", " << dim_before.d[3];
+ auto pad_layer = ctx.network()->addPadding(
+ *const_cast<nvinfer1::ITensor*>(tensor),
+ nvinfer1::DimsHW(padding[0].first, padding[1].first),
+ nvinfer1::DimsHW(padding[0].second, padding[1].second));
+ padding = {{0, 0}, {0, 0}};
+ tensor = pad_layer->getOutput(0);
+ auto dim_after = tensor->getDimensions();
+ VLOG(2) << "TENSOR after: " << dim_after.d[0] << ", " << dim_after.d[1]
+ << dim_after.d[2] << ", " << dim_after.d[3];
+ }
+
+ nvinfer1::IConvolutionLayer* layer =
+ ctx.network()->addConvolution(*const_cast<nvinfer1::ITensor*>(tensor),
+ noutput, kernel_size, weights, biases);
+
+ layer->setStride(stride);
+ layer->setPadding({padding[0].first, padding[1].first});
+ layer->setName(node_def.name().c_str());
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+
+ auto dim_after = output_tensor->getDimensions();
+ VLOG(2) << "TENSOR out: " << dim_after.d[0] << ", " << dim_after.d[1]
+ << dim_after.d[2] << ", " << dim_after.d[3];
+
+ if (data_format == "NHWC") {
+ // TODO(jie): transpose it back!
+ output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
+ } else {
+ VLOG(2) << "NCHW !!!!";
+ }
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertPool(Converter& ctx,
+ const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+ TFAttrs attrs(node_def);
+
+ int h_index = 2;
+ int w_index = 3;
+ auto data_format = attrs.get<string>("data_format");
+ if (data_format == "NHWC") {
+ h_index = 1;
+ w_index = 2;
+ tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
+ {0, 3, 1, 2});
+ } else {
+ VLOG(2) << "NCHW !!!!";
+ }
+ nvinfer1::PoolingType type;
+ // TODO(jie): support other pooling type
+ if (node_def.op() == "MaxPool")
+ type = nvinfer1::PoolingType::kMAX;
+ else
+ return tensorflow::errors::Unimplemented("Only supports Max pool");
+
+ // TODO(jie): NCHW
+ auto tf_stride = attrs.get<std::vector<int>>("strides");
+ nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
+
+ auto tf_kernel = attrs.get<std::vector<int>>("ksize");
+ nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]);
+
+ auto tensor_dim = tensor->getDimensions();
+ std::vector<std::pair<int, int>> padding;
+ // TODO(jie): padding.
+ if (attrs.get<string>("padding") == "SAME") {
+ // This is NCHW tensor with no batch dimension.
+ // 1 -> h
+ // 2 -> w
+ padding = CreateSamePadding(
+ stride, ksize,
+ {static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])});
+ } else if (attrs.get<string>("padding") == "VALID") {
+ // No padding for valid padding here
+ VLOG(2) << "No padding added for VALID padding in pool" << node_def.name();
+ padding = {{0, 0}, {0, 0}};
+ } else {
+ return tensorflow::errors::Unimplemented(
+ "Current MaxPool cannot support padding other than SAME");
+ }
+
+ if (padding[0].first != padding[0].second ||
+ padding[1].first != padding[1].second) {
+ // TODO(jie): handle asymmetric padding
+ VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second
+ << padding[1].first << padding[1].second;
+ auto pad_layer = ctx.network()->addPadding(
+ *const_cast<nvinfer1::ITensor*>(tensor),
+ nvinfer1::DimsHW(padding[0].first, padding[1].first),
+ nvinfer1::DimsHW(padding[0].second, padding[1].second));
+ padding = {{0, 0}, {0, 0}};
+ tensor = pad_layer->getOutput(0);
+ }
+
+ nvinfer1::IPoolingLayer* layer = ctx.network()->addPooling(
+ *const_cast<nvinfer1::ITensor*>(tensor), type, ksize);
+
+ layer->setStride(stride);
+ layer->setPadding({padding[0].first, padding[1].first});
+ layer->setName(node_def.name().c_str());
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+
+ if (data_format == "NHWC") {
+ // TODO(jie): transpose it back!
+ output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
+ } else {
+ VLOG(2) << "NCHW !!!!";
+ }
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertActivation(
+ Converter& ctx, const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+ nvinfer1::IActivationLayer* layer = ctx.network()->addActivation(
+ *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ActivationType::kRELU);
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertScale(Converter& ctx,
+ const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
+ !inputs.at(1).is_weights())
+ return tensorflow::errors::Unimplemented(
+ "Only supports tensor op weight for now, at " + node_def.name());
+ // Implement tensor binaryOp weight [channel wise] for now;
+ nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+
+ // TODO(jie): handle NHWC/NCHW transpose;
+ TRT_ShapedWeights weights = inputs.at(1).weights();
+ TRT_ShapedWeights empty_weights(weights.type_);
+
+ TFAttrs attrs(node_def);
+
+ // Transpose NHWC
+ auto data_format = attrs.get<string>("data_format");
+ if (data_format == "NHWC") {
+ tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
+ {0, 3, 1, 2});
+ // TODO(jie): transpose it
+ } else {
+ VLOG(2) << "NCHW !!!!";
+ }
+ nvinfer1::IScaleLayer* layer = ctx.network()->addScale(
+ *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ScaleMode::kCHANNEL,
+ weights, empty_weights, empty_weights);
+
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+ if (data_format == "NHWC") {
+ // TODO(jie): transpose it back!
+ output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
+ } else {
+ VLOG(2) << "NCHW !!!!";
+ }
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertConst(Converter& ctx,
+ const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ const auto& weights_tensor = node_def.attr().at("value").tensor();
+
+ // Get trt type & shape
+ TFAttrs attrs(node_def);
+ const tensorflow::DataType dtype = attrs.get<tensorflow::DataType>("dtype");
+
+ // Create shaped weights as output
+ tensorflow::Tensor tensor;
+ if (!tensor.FromProto(weights_tensor))
+ return tensorflow::errors::Internal("Cannot parse weight tensor proto: " +
+ node_def.name());
+
+ TRT_ShapedWeights weights(dtype);
+ if (!weights_tensor.float_val().empty()) {
+ VLOG(2) << "SCALAR!!!" << node_def.name();
+ nvinfer1::Dims scalar_shape;
+ if (tensor.dims() > 0) {
+ VLOG(2) << "Dimensions: " << tensor.dims();
+ weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(),
+ GetTensorShape(tensor));
+ } else {
+ VLOG(2) << "Dimensions: " << tensor.dims();
+ scalar_shape.nbDims = 1;
+ scalar_shape.d[0] = 1;
+ scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL;
+ for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; i++) {
+ scalar_shape.d[i] = 0;
+ scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL;
+ }
+ weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(),
+ scalar_shape);
+ }
+ } else if (!weights_tensor.tensor_content().empty()) {
+ VLOG(2) << "TENSOR!!!" << node_def.name();
+ const auto& content = weights_tensor.tensor_content();
+
+ weights = ctx.get_temp_weights(dtype, GetTensorShape(tensor));
+ if (content.size() > 0) {
+ const int dtype_size = tensorflow::DataTypeSize(dtype);
+ CHECK_EQ(0, content.size() % dtype_size)
+ << "Tensor content size (" << content.size()
+ << ") is not a multiple of " << dtype_size;
+ port::CopyToArray(
+ content, static_cast<char*>(const_cast<void*>(weights.GetValues())));
+ }
+ } else {
+ return tensorflow::errors::Unimplemented(
+ "Not supported constant type, at " + node_def.name());
+ }
+ // Pass the output
+ outputs->push_back(TRT_TensorOrWeights(weights));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertIdentity(
+ Converter& ctx, const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ outputs->push_back(inputs.at(0));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertBinary(Converter& ctx,
+ const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ if (inputs.size() != 2)
+ return tensorflow::errors::FailedPrecondition(
+ "Binary ops require two tensor input, at " + node_def.name());
+
+ if (inputs.at(0).is_weights() && inputs.at(1).is_weights())
+ return ConstantFoldBinary(ctx, node_def, inputs, outputs);
+
+ if (inputs.at(0).is_tensor() && inputs.at(1).is_weights())
+ return BinaryTensorOpWeight(ctx, node_def, inputs.at(0).tensor(),
+ inputs.at(1).weights(), outputs);
+
+ if (inputs.at(0).is_weights() && inputs.at(1).is_tensor())
+ return BinaryTensorOpWeight(ctx, node_def, inputs.at(1).tensor(),
+ inputs.at(0).weights(), outputs);
+
+ if (inputs.at(0).is_tensor() && inputs.at(1).is_tensor())
+ return BinaryTensorOpTensor(ctx, node_def, inputs.at(0).tensor(),
+ inputs.at(1).tensor(), outputs);
+
+ return tensorflow::errors::Unknown("Binary op input error, at " +
+ node_def.name());
+}
+
+tensorflow::Status ConvertUnary(Converter& ctx,
+ const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ if (inputs.size() != 1)
+ return tensorflow::errors::FailedPrecondition(
+ "Unary ops require single tensor input, at " + node_def.name());
+
+ if (inputs.at(0).is_weights())
+ return ConstantFoldUnary(ctx, node_def, inputs, outputs);
+ else if (inputs.at(0).is_tensor())
+ return tensorflow::errors::Unimplemented(
+ "Unary op for tensor not supported, at " + node_def.name());
+
+ return tensorflow::errors::Unknown("Binary op input error, at " +
+ node_def.name());
+}
+
+tensorflow::Status ConvertReduce(Converter& ctx,
+ const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
+ !inputs.at(1).is_weights())
+ return tensorflow::errors::InvalidArgument(
+ "Input expects tensor and weights, at" + node_def.name());
+
+ // Implement tensor binaryOp weight [channel wise] for now;
+ nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+ auto dims = tensor->getDimensions();
+ // Restore implicit batch dimension
+ int nb_dims = dims.nbDims + 1;
+
+ TRT_ShapedWeights index_list = inputs.at(1).weights();
+
+ TFAttrs attrs(node_def);
+ // TODO(jie): handle data type.
+ // Index type here is done through TF type, so I can leverage their
+ // EnumToDataType for my cast
+ auto index_type = attrs.get<tensorflow::DataType>("Tidx");
+
+ // Only expect to handle INT32 as attributes for now
+ if (index_type != tensorflow::DataType::DT_INT32)
+ return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32");
+ auto index_list_data =
+ static_cast<int*>(const_cast<void*>(index_list.GetValues()));
+
+ // Hack warning: have to fall back to pool layer since reduce is not in public
+ // TRT yet.
+ if (nb_dims != 4)
+ return tensorflow::errors::InvalidArgument(
+ "TRT only support reduce on 4 dimensional tensors, at" +
+ node_def.name());
+ if (index_list.count() > 2)
+ return tensorflow::errors::InvalidArgument(
+ "TRT cannot support reduce on more than 2 dimensions, at" +
+ node_def.name());
+
+ std::set<int> idx_set;
+ // We cannot operate on Channel. permutation flag used to transpose tensor
+ int permuted_index = -1;
+ for (int i = 0; i < index_list.count(); i++) {
+ if (index_list_data[i] == 0)
+ return tensorflow::errors::InvalidArgument("TRT cannot reduce at 0, at" +
+ node_def.name());
+ if (index_list_data[i] == 1) permuted_index = 1;
+ idx_set.emplace(index_list_data[i]);
+ }
+
+ std::vector<int> permutation_order(nb_dims);
+ nvinfer1::DimsHW pool_kernel;
+ if (permuted_index == 1) {
+ for (int i = 2; i < nb_dims; i++) {
+ if (idx_set.count(i)) {
+ permuted_index = i;
+ break;
+ }
+ }
+ for (int i = 0; i < nb_dims; i++) permutation_order[i] = i;
+
+ permutation_order[permuted_index] = 1;
+ permutation_order[1] = permuted_index;
+
+ // Apply permutation before extracting dimension for pool_kernel
+ tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
+ permutation_order);
+ }
+
+ // Apply permutation before extracting dimension for pool_kernel
+ pool_kernel.d[0] = (idx_set.count(2) || permuted_index == 2) ? dims.d[1] : 1;
+ pool_kernel.d[1] = (idx_set.count(3) || permuted_index == 3) ? dims.d[2] : 1;
+
+ nvinfer1::ITensor* output_tensor;
+
+ if (node_def.op() == "Mean") {
+ nvinfer1::IPoolingLayer* layer =
+ ctx.network()->addPooling(*const_cast<nvinfer1::ITensor*>(tensor),
+ nvinfer1::PoolingType::kAVERAGE, pool_kernel);
+ output_tensor = layer->getOutput(0);
+ } else {
+ return tensorflow::errors::Unimplemented(
+ "Op not supported " + node_def.op() + " , at " + node_def.name());
+ }
+ if (permuted_index != -1) {
+ // Apply permutation before extracting dimension for pool_kernel
+ output_tensor = ctx.TransposeTensor(
+ const_cast<nvinfer1::ITensor*>(output_tensor), permutation_order);
+ }
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertPad(Converter& ctx,
+ const tensorflow::NodeDef& node_def,
+ std::vector<TRT_TensorOrWeights> const& inputs,
+ std::vector<TRT_TensorOrWeights>* outputs) {
+ if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
+ !inputs.at(1).is_weights())
+ return tensorflow::errors::InvalidArgument(
+ "Input expects tensor and weights, at" + node_def.name());
+
+ // Implement tensor binaryOp weight [channel wise] for now;
+ nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
+ auto dims = tensor->getDimensions();
+ // Restore implicit batch dimension
+ int nb_dims = dims.nbDims + 1;
+
+ TRT_ShapedWeights pads = inputs.at(1).weights();
+
+ TFAttrs attrs(node_def);
+ // Padding type here is done through TF type
+ // so I can leverage their EnumToDataType for my cast
+ auto padding_type = attrs.get<tensorflow::DataType>("Tpaddings");
+ // TODO(jie): handle data type conversion for TRT?
+
+ if (pads.shape_.d[0] != nb_dims || pads.shape_.d[1] != 2)
+ return tensorflow::errors::InvalidArgument(
+ "Pad only supports explicit padding on 4 dimensional tensor, at " +
+ node_def.name());
+
+ // Only expect to handle INT32 as attributes for now
+ if (padding_type != tensorflow::DataType::DT_INT32)
+ return tensorflow::errors::Unimplemented(
+ "Tpaddings supports only DT_INT32");
+ auto pad_data = static_cast<int*>(const_cast<void*>(pads.GetValues()));
+
+ std::vector<int32_t> pad_index;
+ for (int i = 0; i < nb_dims; i++) {
+ if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0)
+ pad_index.push_back(i);
+ }
+
+ // No padding at all, we should exit
+ if (pad_index.size() == 0) {
+ outputs->push_back(inputs.at(0));
+ return tensorflow::Status::OK();
+ }
+
+ // Only supports padding on less than 2 axis GIE-2579
+ if (pad_index.size() > 2)
+ return tensorflow::errors::InvalidArgument(
+ "Padding layer does not support padding on > 2");
+
+ // Padding on batch dimension is not supported
+ if (pad_index[0] == 0)
+ return tensorflow::errors::InvalidArgument(
+ "Padding layer does not support padding on batch dimension");
+
+ // Not doing the legit thing here. ignoring padding on dim 1 and 3;
+ // TODO(jie): implement pad as uff parser
+ if (pad_index.size() == 2 && pad_index[0] == 0 && pad_index[1] == 3)
+ return tensorflow::errors::Unimplemented(
+ "Padding layer does not support padding on dimension 1 and 3 yet");
+
+ bool legit_pad = true;
+ nvinfer1::DimsHW pre_padding(0, 0);
+ nvinfer1::DimsHW post_padding(0, 0);
+
+ std::vector<int32_t> permuted_pad_index(pad_index);
+ if (pad_index[0] == 1) {
+ legit_pad = false;
+ tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
+ {0, 3, 2, 1});
+ permuted_pad_index[0] = 3;
+ }
+
+ for (size_t i = 0; i < pad_index.size(); i++) {
+ int index = pad_index[i];
+ if (permuted_pad_index[i] == 2) {
+ pre_padding.h() = pad_data[index * 2];
+ post_padding.h() = pad_data[index * 2 + 1];
+ } else if (permuted_pad_index[i] == 3) {
+ pre_padding.w() = pad_data[index * 2];
+ post_padding.w() = pad_data[index * 2 + 1];
+ }
+ }
+
+ nvinfer1::IPaddingLayer* layer = ctx.network()->addPadding(
+ *const_cast<nvinfer1::ITensor*>(tensor), pre_padding, post_padding);
+ nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+
+ if (!legit_pad)
+ output_tensor = ctx.TransposeTensor(
+ const_cast<nvinfer1::ITensor*>(output_tensor), {0, 3, 2, 1});
+
+ outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return tensorflow::Status::OK();
+}
+
+void Converter::register_op_converters() {
+ // vgg_16 slim implementation
+ op_registry_["Placeholder"] = ConvertPlaceholder;
+ op_registry_["Conv2D"] = ConvertConv2D;
+ op_registry_["Relu"] = ConvertActivation;
+ op_registry_["MaxPool"] = ConvertPool;
+ // This could be really handled as ConvertBinary
+ op_registry_["BiasAdd"] = ConvertScale;
+ op_registry_["Const"] = ConvertConst;
+ // op_registry_["MatMul"] = ConvertFullyConnected; // Not used in vgg
+ // TODO(ben,jie): this is a temp hack.
+ op_registry_["Identity"] = ConvertIdentity; // Identity should be removed
+ // op_registry_["AvgPool"] = ConvertPool;
+
+ // resnet_50_v1 slim implementation
+ op_registry_["Add"] = ConvertBinary;
+ op_registry_["Mul"] = ConvertBinary;
+ op_registry_["Sub"] = ConvertBinary;
+ op_registry_["Rsqrt"] = ConvertUnary;
+ op_registry_["Mean"] = ConvertReduce;
+ op_registry_["Pad"] = ConvertPad;
+ // TODO(ben,jie): Add more ops
+}
+
+} // namespace
+
+tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
+ const tensorflow::Graph& graph, const std::set<int>& subgraph_node_ids,
+ const std::vector<std::pair<int, int>>& input_inds,
+ const std::vector<std::pair<int, int>>& output_inds, size_t max_batch_size,
+ size_t max_workspace_size_bytes,
+ const tensorflow::grappler::GraphProperties& graph_properties,
+ tensorflow::NodeDef* trt_node) {
+ // Visit nodes in reverse topological order and construct the TRT network.
+
+ // Toposort
+ std::vector<tensorflow::Node*> order_vec;
+ tensorflow::GetPostOrder(graph, &order_vec);
+ // Select just the subgraph
+ std::list<tensorflow::Node*> order;
+ for (tensorflow::Node* node : order_vec) {
+ if (subgraph_node_ids.count(node->id())) {
+ // We want topological order to contstruct the
+ // network layer by layer
+ order.push_front(node);
+ }
+ }
+ // Topological order is needed to build TRT network
+
+ tensorflow::tensorrt::Logger trt_logger;
+
+ auto trt_builder = infer_object(nvinfer1::createInferBuilder(trt_logger));
+ if (!trt_builder) {
+ return tensorflow::errors::Internal(
+ "Failed to create TensorRT builder object");
+ }
+
+ auto trt_network = infer_object(trt_builder->createNetwork());
+ if (!trt_network) {
+ return tensorflow::errors::Internal(
+ "Failed to create TensorRT network object");
+ }
+
+ // Build the network
+ Converter converter(trt_network.get());
+
+ std::vector<string> input_names;
+ std::vector<tensorflow::DataType> input_dtypes;
+ for (std::pair<int, int> const& input : input_inds) {
+ int node_id = input.first;
+ int output_idx = input.second;
+ tensorflow::Node* node = graph.FindNodeId(node_id);
+ auto node_name = node->name();
+ input_names.push_back(node_name); // Insert original node name without port
+ // TODO(jie): alternative :)
+ if (!graph_properties.HasOutputProperties(node_name))
+ return tensorflow::errors::Internal("Failed to find input node: " +
+ node_name);
+
+ auto op_info_vec = graph_properties.GetOutputProperties(node_name);
+ if (static_cast<int>(op_info_vec.size()) < output_idx)
+ return tensorflow::errors::Internal(
+ "Accessing output index of: " + std::to_string(output_idx) +
+ ", at node: " + node_name + " with output entry from shape_map: " +
+ std::to_string(op_info_vec.size()));
+
+ auto op_info = op_info_vec.at(output_idx);
+
+ tensorflow::DataType tf_dtype = op_info.dtype();
+ input_dtypes.push_back(tf_dtype);
+
+ nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT);
+ TF_CHECK_OK(ConvertDType(tf_dtype, &dtype));
+
+ VLOG(2) << "Accessing output index of: " << std::to_string(output_idx)
+ << ", at node: " << node_name
+ << " with output entry from shape_map: "
+ << std::to_string(op_info_vec.size());
+
+ // TODO(ben,jie): update TRT input format/dimension
+ nvinfer1::DimsCHW input_dim_pseudo_chw;
+ for (int i = 0; i < 3; i++) input_dim_pseudo_chw.d[i] = 1;
+
+ for (int i = 1; i < op_info.shape().dim_size(); i++) {
+ VLOG(2) << "dimension: " << i
+ << " , size: " << op_info.shape().dim(i).size();
+ input_dim_pseudo_chw.d[i - 1] = op_info.shape().dim(i).size();
+ }
+
+ // TODO(ben,jie): proper way to restore input tensor name?
+ auto input_tensor_name = node_name;
+ if (output_idx != 0)
+ input_tensor_name = node_name + ":" + std::to_string(output_idx);
+
+ nvinfer1::ITensor* input_tensor = converter.network()->addInput(
+ input_tensor_name.c_str(), dtype, input_dim_pseudo_chw);
+
+ if (!input_tensor)
+ return tensorflow::errors::InvalidArgument(
+ "Failed to create Input layer");
+ VLOG(2) << "Input tensor name :" << input_tensor_name;
+
+ if (!converter.insert_input_tensor(input_tensor_name, input_tensor))
+ return tensorflow::errors::AlreadyExists(
+ "Output tensor already exists for op: " + input_tensor_name);
+ }
+
+ VLOG(2) << "Finished sorting";
+
+ for (const tensorflow::Node* node : order) {
+ const tensorflow::NodeDef& node_def = node->def();
+ VLOG(2) << "Converting node: " << node_def.name() << " , " << node_def.op();
+ TF_RETURN_IF_ERROR(converter.convert_node(node_def));
+ }
+
+ VLOG(2) << "Finished conversion";
+
+ // Gather output metadata
+ std::vector<string> output_names;
+ std::vector<tensorflow::DataType> output_dtypes;
+ for (std::pair<int, int> const& output : output_inds) {
+ int node_id = output.first;
+ int output_idx = output.second;
+ tensorflow::Node* node = graph.FindNodeId(node_id);
+ string op_name = node->name();
+ string tensor_name = op_name;
+ if (output_idx != 0)
+ tensor_name = tensor_name + ":" + std::to_string(output_idx);
+ VLOG(2) << "Output tensor name: " << tensor_name;
+ output_names.push_back(tensor_name);
+ auto tensor_or_weights = converter.get_tensor(tensor_name);
+ if (!tensor_or_weights.is_tensor()) {
+ return tensorflow::errors::InvalidArgument(
+ "Output node is weights not tensor");
+ }
+ nvinfer1::ITensor* tensor = tensor_or_weights.tensor();
+ if (!tensor) {
+ return tensorflow::errors::NotFound("Output tensor not found: " +
+ tensor_name);
+ }
+ converter.network()->markOutput(*tensor);
+ tensorflow::DataType tf_dtype = node->output_type(output_idx);
+ output_dtypes.push_back(tf_dtype);
+ nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT;
+ TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype));
+ tensor->setType(trt_dtype);
+ }
+
+ VLOG(2) << "Finished output";
+ // TODO(jie): static_id is not thread safe.
+ static int static_id = 0;
+
+ // Build the engine
+ trt_builder->setMaxBatchSize(max_batch_size);
+ trt_builder->setMaxWorkspaceSize(max_workspace_size_bytes);
+ VLOG(0) << "Starting build engine " << static_id;
+ // TODO(ben,jie): half2 and int8 mode support
+ string engine_plan_string;
+ {
+ auto trt_engine =
+ infer_object(trt_builder->buildCudaEngine(*converter.network()));
+ VLOG(0) << "Built network";
+ auto engine_plan = infer_object(trt_engine->serialize());
+ VLOG(0) << "Serialized engine";
+ const char* engine_plan_data =
+ static_cast<const char*>(engine_plan->data());
+ engine_plan_string =
+ string(engine_plan_data, engine_plan_data + engine_plan->size());
+ }
+
+ VLOG(0) << "Finished engine";
+
+ // Build the TRT op
+ // TODO(sami,ben,jie): proper naming!
+ tensorflow::NodeDefBuilder op_builder(
+ tensorflow::strings::StrCat("my_trt_op", static_id++), "TRTEngineOp");
+ std::vector<tensorflow::NodeDefBuilder::NodeOut> income_edges;
+ for (size_t i = 0; i < input_names.size(); ++i) {
+ int output_idx = input_inds.at(i).second;
+ // We wired up the input here already, it is redundant to do it again in
+ // ConvertSubGraphToTensorRT(convert_graph.cc)
+ auto incoming_edge = tensorflow::NodeDefBuilder::NodeOut(
+ input_names.at(i), output_idx, input_dtypes.at(i));
+ income_edges.push_back(incoming_edge);
+ }
+ tensorflow::gtl::ArraySlice<tensorflow::NodeDefBuilder::NodeOut> input_list(
+ income_edges);
+ op_builder.Input(input_list);
+
+ VLOG(0) << "Finished op preparation";
+
+ auto status = op_builder.Attr("serialized_engine", engine_plan_string)
+ .Attr("input_nodes", input_names)
+ .Attr("output_nodes", output_names)
+ .Attr("OutT", output_dtypes)
+ .Finalize(trt_node);
+
+ VLOG(0) << status.ToString() << " finished op building";
+
+ return tensorflow::Status::OK();
+}
+
+} // namespace convert
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
new file mode 100644
index 0000000000..2e7fd19566
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
@@ -0,0 +1,52 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
+
+#include <set>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/lib/core/status.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+
+namespace tensorflow {
+namespace tensorrt {
+namespace convert {
+
+tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
+ const tensorflow::Graph& graph, const std::set<int>& subgraph_node_ids,
+ const std::vector<std::pair<int, int>>&
+ input_inds, // {node_id, output_idx}
+ const std::vector<std::pair<int, int>>&
+ output_inds, // {node_id, output_idx}
+ size_t max_batch_size, size_t max_workspace_size_bytes,
+ const tensorflow::grappler::GraphProperties& graph_prop,
+ tensorflow::NodeDef* trt_node);
+
+} // namespace convert
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_CONVERT_NODES_H_
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
new file mode 100644
index 0000000000..8efdf63ebe
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -0,0 +1,140 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h"
+
+#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor.h"
+#include "tensorflow/core/platform/types.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "cuda/include/cuda_runtime_api.h"
+
+namespace tensorflow {
+namespace tensorrt {
+static ::tensorflow::tensorrt::Logger logger;
+
+TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) {
+ // read serialized_engine
+ string serialized_engine;
+ OP_REQUIRES_OK(context,
+ context->GetAttr("serialized_engine", &serialized_engine));
+
+ // register input output node name in trt_sub_graph
+ OP_REQUIRES_OK(context, context->GetAttr("input_nodes", &input_nodes_));
+ OP_REQUIRES_OK(context, context->GetAttr("output_nodes", &output_nodes_));
+
+ // TODO(samikama) runtime should be taken from a resourcemanager as well.
+ // Only engine should be in the op and context and runtime should be taken
+ // from resourcemanager
+ nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger);
+ trt_engine_ptr_.reset(infer->deserializeCudaEngine(
+ serialized_engine.c_str(), serialized_engine.size(), nullptr));
+
+ trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext());
+ // Runtime is safe to delete after engine creation
+ infer->destroy();
+}
+
+void TRTEngineOp::Compute(OpKernelContext* context) {
+ int num_binding = context->num_inputs() + context->num_outputs();
+ std::vector<void*> buffers(num_binding);
+
+ size_t binding_index;
+ int num_batch = 0;
+ bool valid = true;
+ for (int i = 0; i < context->num_inputs(); i++) {
+ // Grab the input tensor
+ binding_index = trt_engine_ptr_->getBindingIndex(input_nodes_[i].c_str());
+
+ const Tensor& input_tensor = context->input(i);
+ const TensorShape& input_shape = input_tensor.shape();
+ if (i == 0) {
+ num_batch = input_shape.dim_size(0);
+ } else if (num_batch != input_shape.dim_size(0)) {
+ valid = false;
+ break;
+ }
+ switch (trt_engine_ptr_->getBindingDataType(binding_index)) {
+ case nvinfer1::DataType::kFLOAT:
+ buffers[binding_index] = (void*)(input_tensor.flat<float>().data());
+ break;
+ case nvinfer1::DataType::kHALF:
+ LOG(FATAL) << "half size is not supported yet!";
+ break;
+ case nvinfer1::DataType::kINT8:
+ LOG(FATAL) << "int8 is not supported yet!";
+ break;
+ }
+ }
+
+ // Might want a different way to inform the user of batch size inconsistency
+ if (!valid) LOG(WARNING) << "input data inconsistent batch size";
+
+ for (int i = 0; i < static_cast<int>(output_nodes_.size()); i++) {
+ // This is bad that we have to reallocate output buffer every run.
+ // Create an output tensor
+ binding_index = trt_engine_ptr_->getBindingIndex(output_nodes_[i].c_str());
+ Tensor* output_tensor = nullptr;
+
+ TensorShape output_shape;
+ if (binding_index != -1) {
+ auto dims = trt_engine_ptr_->getBindingDimensions(binding_index);
+ std::vector<int> trt_shape(dims.nbDims + 1);
+ trt_shape[0] = num_batch;
+ for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j];
+ OP_REQUIRES_OK(context,
+ TensorShapeUtils::MakeShape(
+ trt_shape.data(), trt_shape.size(), &output_shape));
+ } else {
+ LOG(FATAL) << "output node not found, at " << output_nodes_[i];
+ break;
+ }
+
+ OP_REQUIRES_OK(context,
+ context->allocate_output(i, output_shape, &output_tensor));
+ switch (trt_engine_ptr_->getBindingDataType(binding_index)) {
+ case nvinfer1::DataType::kFLOAT:
+ buffers[binding_index] =
+ reinterpret_cast<void*>(output_tensor->flat<float>().data());
+ break;
+ case nvinfer1::DataType::kHALF:
+ LOG(FATAL) << "half size is not supported yet!";
+ break;
+ case nvinfer1::DataType::kINT8:
+ LOG(FATAL) << "int8 is not supported yet!";
+ break;
+ }
+ }
+ // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
+ const cudaStream_t* stream = CHECK_NOTNULL(
+ reinterpret_cast<const cudaStream_t*>(context->op_device_context()
+ ->stream()
+ ->implementation()
+ ->CudaStreamMemberHack()));
+
+ // execution handled by TF since we are getting stream from TF.
+ // it is safe for CPU pointer array (buffers) to go out of scope after enqueue
+ trt_execution_context_ptr_->enqueue(num_batch, &buffers[0], *stream, nullptr);
+}
+
+REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);
+
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
new file mode 100644
index 0000000000..0964b4b18a
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "cuda/include/cuda_runtime_api.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorrt/include/NvInfer.h"
+
+namespace tensorflow {
+namespace tensorrt {
+class Logger;
+
+class TRTEngineOp : public OpKernel {
+ public:
+ explicit TRTEngineOp(OpKernelConstruction* context);
+
+ void Compute(OpKernelContext* context) override;
+
+ private:
+ template <typename T>
+ struct Destroyer {
+ void operator()(T* d) { d->destroy(); }
+ };
+
+ template <typename T>
+ using destroyed_ptr = std::unique_ptr<T, Destroyer<T>>;
+ destroyed_ptr<nvinfer1::ICudaEngine> trt_engine_ptr_;
+ // TODO(samikama): context should go to a resource manager!
+ destroyed_ptr<nvinfer1::IExecutionContext> trt_execution_context_ptr_;
+
+ std::vector<string> input_nodes_;
+ std::vector<string> output_nodes_;
+};
+
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_
diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.cc b/tensorflow/contrib/tensorrt/log/trt_logger.cc
new file mode 100644
index 0000000000..7add8cb8b3
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/log/trt_logger.cc
@@ -0,0 +1,57 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace tensorrt {
+
+// Use TF logging for TensorRT informations
+void Logger::log(Severity severity, const char* msg) {
+ // Suppress info-level messages
+ switch (severity) {
+ case Severity::kINFO: { // Mark TRT info messages as debug!
+ VLOG(2) << msg;
+ break;
+ }
+ case Severity::kWARNING: {
+ LOG(WARNING) << msg;
+ break;
+ }
+ case Severity::kERROR: {
+ LOG(ERROR) << msg;
+ break;
+ }
+ case Severity::kINTERNAL_ERROR: {
+ LOG(FATAL) << msg;
+ break;
+ }
+ // This is useless for now. But would catch it in future if enum changes. It
+ // is always good to have default case!
+ default: {
+ LOG(FATAL) << name_ << "Got unknown severity level from TRT " << msg;
+ break;
+ }
+ }
+}
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
+#endif // GOOGLE_TENSORRT
diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.h b/tensorflow/contrib/tensorrt/log/trt_logger.h
new file mode 100644
index 0000000000..d71f66b933
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/log/trt_logger.h
@@ -0,0 +1,42 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_
+
+#include "tensorflow/core/platform/types.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "tensorrt/include/NvInfer.h"
+
+namespace tensorflow {
+namespace tensorrt {
+
+// Logger for GIE info/warning/errors
+class Logger : public nvinfer1::ILogger {
+ private:
+ void log(nvinfer1::ILogger::Severity severity, const char* msg) override;
+
+ string name_;
+};
+
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_
diff --git a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
new file mode 100644
index 0000000000..079d73f7be
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
@@ -0,0 +1,43 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+
+namespace tensorflow {
+
+namespace shape_inference {
+extern Status TRTEngineOpShapeInference(InferenceContext* c);
+}
+
+REGISTER_OP("TRTEngineOp")
+ .Attr("serialized_engine: string")
+ .Attr("input_nodes: list(string)")
+ .Attr("output_nodes: list(string)")
+ .Attr("InT: list({float32})")
+ .Attr("OutT: list({float32})")
+ .Input("in_tensor: InT")
+ .Output("out_tensor: OutT")
+ .SetShapeFn(shape_inference::TRTEngineOpShapeInference);
+
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/tensorrt/python/__init__.py b/tensorflow/contrib/tensorrt/python/__init__.py
new file mode 100644
index 0000000000..7e050a768c
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/python/__init__.py
@@ -0,0 +1,24 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Exposes the python wrapper for TensorRT graph transforms."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,line-too-long
+from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
+from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph
+# pylint: enable=unused-import,line-too-long
diff --git a/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py b/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py
new file mode 100644
index 0000000000..31a313182b
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/python/ops/trt_engine_op.py
@@ -0,0 +1,34 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Exposes the Python wrapper of TRTEngineOp."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import platform
+
+if platform.system() != "Windows":
+ # pylint: disable=wildcard-import,unused-import,g-import-not-at-top
+ from tensorflow.contrib.tensorrt.ops.gen_trt_engine_op import *
+
+ from tensorflow.contrib.util import loader
+ from tensorflow.python.platform import resource_loader
+ # pylint: enable=wildcard-import,unused-import,g-import-not-at-top
+
+ _trt_engine_op = loader.load_op_library(
+ resource_loader.get_path_to_datafile("_trt_engine_op.so"))
+else:
+ raise RuntimeError("Windows platforms are not supported")
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py
new file mode 100644
index 0000000000..9454862f85
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/python/trt_convert.py
@@ -0,0 +1,103 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Exposes the Python wrapper conversion to trt_graph."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,line-too-long
+import six as _six
+from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import errors_impl as _impl
+from tensorflow.python.framework import ops
+
+
+# TODO(skama): get outputs from session when implemented as c++
+# optimization pass
+def create_inference_graph(input_graph_def,
+ outputs,
+ max_batch_size=1,
+ max_workspace_size_bytes=2 << 20):
+ """Python wrapper for the TRT transormation.
+
+
+ Args:
+ input_graph_def: GraphDef object containing a model to be transformed.
+ outputs: List of tensors or node names for the model outputs.
+ max_batch_size: max size for the input batch
+ max_workspace_size_bytes: parameter to control memory allocation (in Bytes)
+
+ Returns:
+ New GraphDef with TRTEngineOps placed in graph replacing subgraphs.
+
+ Raises:
+ RuntimeError: if the returned status message is malformed.
+ """
+
+ def py2bytes(inp):
+ return inp
+
+ def py3bytes(inp):
+ return inp.encode("utf-8", errors="surrogateescape")
+
+ def py2string(inp):
+ return inp
+
+ def py3string(inp):
+ return inp.decode("utf-8")
+
+ if _six.PY2:
+ to_bytes = py2bytes
+ to_string = py2string
+ else:
+ to_bytes = py3bytes
+ to_string = py3string
+
+ out_names = []
+ for i in outputs:
+ if isinstance(i, ops.Tensor):
+ out_names.append(to_bytes(i.name))
+ else:
+ out_names.append(to_bytes(i))
+
+ input_graph_def_str = input_graph_def.SerializeToString()
+
+ # TODO(sami): Fix this when we can return status from C++ library
+ # There is a problem with the TF internal library setup that doesn't
+ # allow us to return a status object from C++. Thus we return a
+ # pair or strings where first one is encoded status and the second
+ # one is the transformed graphs protobuf string.
+ out = trt_convert(input_graph_def_str, out_names, max_batch_size,
+ max_workspace_size_bytes)
+ status = to_string(out[0])
+ output_graph_def_string = out[1]
+ del input_graph_def_str # Save some memory
+ if len(status) < 2:
+ raise _impl.UnknownError(None, None, status)
+ if status[:2] != "OK":
+ msg = status.split(";")
+ if len(msg) == 1:
+ raise RuntimeError("Status message is malformed {}".format(status))
+ # pylint: disable=protected-access
+ raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
+ int(msg[0]))
+ # pylint: enable=protected-access
+ output_graph_def = graph_pb2.GraphDef()
+ output_graph_def.ParseFromString(output_graph_def_string)
+ del output_graph_def_string # Save some memory
+ return output_graph_def
diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc
new file mode 100644
index 0000000000..6193f0b0a1
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/segment/segment.cc
@@ -0,0 +1,253 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/segment/segment.h"
+
+#include <set>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/tensorrt/segment/union_find.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace tensorrt {
+namespace segment {
+
+namespace {
+
+bool CanContractEdge(const tensorflow::Edge* edge,
+ const tensorflow::Graph& graph) {
+ const tensorflow::Node* src = edge->src();
+ const tensorflow::Node* dst = edge->dst();
+
+ // Can't contract edge if doing so would cause a cycle in the
+ // graph. So, if there is a directed path from 'src' to 'dst', other
+ // than 'edge' (or any other direct edge from 'src' to 'dst'), then
+ // combining 'src' and 'dst' will cause a cycle along that path.
+ //
+ // In practice, to avoid modifying the graph and to take advantage
+ // of existing graph functions, we perform an equivalent.
+ // 1. Get all nodes incoming to 'dst', excluding 'src'
+ // 2. Reverse DFS from those nodes
+ // 3. If reverse DFS reaches 'src' then we have a cycle
+ std::vector<tensorflow::Node*> dfs_start_nodes;
+ for (tensorflow::Node* node : dst->in_nodes()) {
+ if (node != src) {
+ dfs_start_nodes.push_back(node);
+ }
+ }
+
+ bool is_cycle = false;
+ if (!dfs_start_nodes.empty()) {
+ tensorflow::ReverseDFSFrom(graph, dfs_start_nodes, {},
+ [&is_cycle, src](tensorflow::Node* node) {
+ if (node == src) {
+ is_cycle = true;
+ }
+ });
+ }
+
+ return !is_cycle;
+}
+
+void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph,
+ std::vector<const tensorflow::Edge*>* remove_edges) {
+ // Transfer all inputs and outputs of 'dst' to 'src' except edges
+ // connecting the two.
+ tensorflow::Node* src = edge->src();
+ tensorflow::Node* dst = edge->dst();
+
+ // We can use '0' for input/output index because we don't need them
+ // to be accurate for the way we are using the graph.
+ std::vector<const tensorflow::Edge*> in_edges(dst->in_edges().begin(),
+ dst->in_edges().end());
+ for (const tensorflow::Edge* in_edge : in_edges) {
+ if (in_edge->src() != src) {
+ tensorflow::Edge* e = const_cast<tensorflow::Edge*>(in_edge);
+ if (e->src() == graph->source_node()) {
+ graph->AddEdge(e->src(), e->src_output(), src,
+ tensorflow::Graph::kControlSlot);
+ } else {
+ graph->AddEdge(e->src(), e->src_output(), src, 0 /* input index */);
+ }
+ }
+ }
+
+ std::vector<const tensorflow::Edge*> out_edges(dst->out_edges().begin(),
+ dst->out_edges().end());
+ for (const tensorflow::Edge* out_edge : out_edges) {
+ tensorflow::Edge* e = const_cast<tensorflow::Edge*>(out_edge);
+ if (e->dst() == graph->sink_node()) {
+ graph->AddEdge(src, tensorflow::Graph::kControlSlot, e->dst(),
+ e->dst_input());
+ } else {
+ graph->AddEdge(src, 0 /* output index */, e->dst(), e->dst_input());
+ }
+ }
+
+ // Return the edges that must be removed to disconnect 'dst' from
+ // the graph. We don't actually remove 'dst' since the caller holds
+ // references to all the nodes.
+ for (const auto& in_edge : dst->in_edges()) {
+ remove_edges->push_back(in_edge);
+ }
+ for (const auto& out_edge : dst->out_edges()) {
+ remove_edges->push_back(out_edge);
+ }
+}
+
+} // namespace
+
+tensorflow::Status SegmentGraph(
+ const tensorflow::GraphDef& gdef,
+ const std::function<bool(const tensorflow::NodeDef&)>& candidate_fn,
+ const SegmentOptions& options, SegmentNodesVector* segments) {
+ // Create a Graph representation of the GraphDef.
+ tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
+ gdef.library());
+ tensorflow::Graph graph(flib);
+ TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
+ tensorflow::GraphConstructorOptions(), gdef, &graph));
+
+ // tensorflow::DumpGraph("Pre-Segment", &graph);
+
+ // Use a union-find to collect the nodes that belong to the same
+ // segment. A node value of nullptr indicates that the node is not a
+ // candidate for TRT.
+ std::vector<UnionFind<tensorflow::Node*>> node_segments;
+ for (int i = 0; i < graph.num_node_ids(); ++i) {
+ tensorflow::Node* node = graph.FindNodeId(i);
+ if (options.exclude_node_list.count(node->name()) != 0 ||
+ !candidate_fn(node->def())) {
+ node = nullptr;
+ }
+ node_segments.emplace_back(node);
+ }
+
+ // The segmentation algorithm below visits nodes in reverse
+ // topological order and attempts to merge nodes along output
+ // edges. That means that subgraphs grow from the output-side of the
+ // network towards the inputs. In general this is not guaranteed to
+ // produce a globally optimal segmentation. In the future if we have
+ // a measure of how beneficial it is to include a given node in a
+ // TRT subgraph then we can revisit this algorithm to take advantage
+ // of that information.
+ std::vector<tensorflow::Node*> order;
+ tensorflow::GetPostOrder(graph, &order);
+
+ for (const tensorflow::Node* node : order) {
+ // All output nodes of 'node' have been visited...
+ VLOG(2) << "Trying node " << node->name();
+
+ // 'node' must be a TRT candidate...
+ if (node_segments[node->id()].Value() == nullptr) {
+ VLOG(2) << "... not a TRT candidate";
+ continue;
+ }
+
+ // Contract output edges to combine 'node' with output
+ // nodes. Iterate since combining two nodes may unblock other
+ // combining.
+ while (true) {
+ std::set<const tensorflow::Edge*> contract_edges;
+ for (const tensorflow::Edge* out_edge : node->out_edges()) {
+ VLOG(2) << "... out node " << out_edge->dst()->name();
+
+ // Out node must be TRT candidate...
+ if (node_segments[out_edge->dst()->id()].Value() == nullptr) {
+ VLOG(2) << "... ... not a TRT candidate";
+ continue;
+ }
+
+ if (CanContractEdge(out_edge, graph)) {
+ VLOG(2) << "... ... can contract";
+ contract_edges.insert(out_edge);
+ } else {
+ VLOG(2) << "... ... cannot contract, would form cycle";
+ }
+ }
+
+ if (contract_edges.empty()) {
+ break;
+ }
+
+ // Contract edges and collect the adjacent nodes into the same
+ // segment/subgraph.
+ while (!contract_edges.empty()) {
+ const tensorflow::Edge* contract_edge = *contract_edges.begin();
+ const tensorflow::Node* src = contract_edge->src();
+ const tensorflow::Node* dst = contract_edge->dst();
+
+ VLOG(2) << "Merge " << src->name() << " <- " << dst->name();
+ node_segments[src->id()].Merge(&node_segments[dst->id()]);
+
+ // Contracting the edge leaves disconnected graph edges.
+ // Remove these from the graph and from 'contract_edges' so we
+ // don't visit them again.
+ tensorflow::Edge* e = const_cast<tensorflow::Edge*>(contract_edge);
+ std::vector<const tensorflow::Edge*> remove_edges;
+ ContractEdge(e, &graph, &remove_edges);
+
+ for (const tensorflow::Edge* r : remove_edges) {
+ contract_edges.erase(r);
+ graph.RemoveEdge(r);
+ }
+ }
+ }
+ }
+
+ // Collect the segments/subgraphs. Each subgraph is represented by a
+ // set of the names of the nodes in that subgraph.
+ std::unordered_map<string, std::set<string>> sg_map;
+ for (auto& u : node_segments) {
+ if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) {
+ sg_map[u.ParentValue()->name()].insert(u.Value()->name());
+ }
+ }
+
+ // Convert the segments into the expected return format
+ for (const auto& itr : sg_map) {
+ const auto& segment_node_names = itr.second;
+ if (VLOG_IS_ON(1)) {
+ string s;
+ for (const auto& name : segment_node_names) {
+ s += " " + name;
+ }
+ VLOG(1) << "Segment " << segments->size() << ":" << s;
+ }
+
+ // Don't use small segments.
+ if (static_cast<int>(segment_node_names.size()) <
+ options.minimum_segment_size) {
+ VLOG(1) << "Segment " << segments->size() << " has only "
+ << segment_node_names.size() << " nodes, dropping";
+ continue;
+ }
+
+ segments->emplace_back(segment_node_names);
+ }
+
+ return tensorflow::Status::OK();
+}
+
+} // namespace segment
+} // namespace tensorrt
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h
new file mode 100644
index 0000000000..ee6e2b3ed2
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/segment/segment.h
@@ -0,0 +1,56 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_
+
+#include <set>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace tensorrt {
+namespace segment {
+
+using SegmentNodesVector = std::vector<std::set<string>>;
+
+struct SegmentOptions {
+ // Segment must contain at least this many nodes.
+ int minimum_segment_size = 2;
+ std::set<string> exclude_node_list;
+};
+
+// Get the subgraphs of a graph that can be handled by TensorRT.
+//
+// @param gdef The GraphDef describing the network
+// @param candidate_fn A function that returns true for a NodeDef if
+// that node can be handled by TensorRT.
+// @param segments Returns the TensorRT segments/subgraphs. Each entry
+// in the vector describes a subgraph by giving a set of the names of
+// all the NodeDefs in that subgraph.
+// @return the status.
+tensorflow::Status SegmentGraph(
+ const tensorflow::GraphDef& gdef,
+ const std::function<bool(const tensorflow::NodeDef&)>& candidate_fn,
+ const SegmentOptions& options, SegmentNodesVector* segments);
+
+} // namespace segment
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_SEGMENT_H_
diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc
new file mode 100644
index 0000000000..74cbc5f2b3
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc
@@ -0,0 +1,367 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/segment/segment.h"
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace tensorrt {
+namespace segment {
+namespace test {
+
+class SegmentTest : public ::testing::Test {
+ public:
+ bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def);
+
+ TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name);
+ TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+ TF_Status* s, const char* name);
+
+ std::function<bool(const NodeDef&)> MakeCandidateFn(
+ const std::set<string>& node_names);
+
+ protected:
+ void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name,
+ TF_Operation** op);
+ void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+ TF_Status* s, const char* name, TF_Operation** op, bool check);
+
+ SegmentOptions default_options_;
+};
+
+bool SegmentTest::GetGraphDef(TF_Graph* graph,
+ tensorflow::GraphDef* graph_def) {
+ TF_Status* s = TF_NewStatus();
+ TF_Buffer* buffer = TF_NewBuffer();
+ TF_GraphToGraphDef(graph, buffer, s);
+ bool ret = TF_GetCode(s) == TF_OK;
+ EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ if (ret) ret = graph_def->ParseFromArray(buffer->data, buffer->length);
+ TF_DeleteBuffer(buffer);
+ TF_DeleteStatus(s);
+ return ret;
+}
+
+std::function<bool(const NodeDef&)> SegmentTest::MakeCandidateFn(
+ const std::set<string>& node_names) {
+ return [node_names](const NodeDef& node) -> bool {
+ return node_names.find(node.name()) != node_names.end();
+ };
+}
+
+void SegmentTest::PlaceholderHelper(TF_Graph* graph, TF_Status* s,
+ const char* name, TF_Operation** op) {
+ TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
+ TF_SetAttrType(desc, "dtype", TF_INT32);
+ *op = TF_FinishOperation(desc, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ ASSERT_NE(*op, nullptr);
+}
+
+TF_Operation* SegmentTest::Placeholder(TF_Graph* graph, TF_Status* s,
+ const char* name) {
+ TF_Operation* op;
+ PlaceholderHelper(graph, s, name, &op);
+ return op;
+}
+
+void SegmentTest::AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+ TF_Status* s, const char* name, TF_Operation** op,
+ bool check) {
+ TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
+ TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
+ TF_AddInputList(desc, add_inputs, 2);
+ *op = TF_FinishOperation(desc, s);
+ if (check) {
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ ASSERT_NE(*op, nullptr);
+ }
+}
+
+TF_Operation* SegmentTest::Add(TF_Operation* l, TF_Operation* r,
+ TF_Graph* graph, TF_Status* s,
+ const char* name) {
+ TF_Operation* op;
+ AddHelper(l, r, graph, s, name, &op, true);
+ return op;
+}
+
+TEST_F(SegmentTest, Empty) {
+ TF_Graph* graph = TF_NewGraph();
+
+ GraphDef graph_def;
+ ASSERT_TRUE(GetGraphDef(graph, &graph_def));
+
+ SegmentNodesVector segments;
+ ASSERT_EQ(
+ SegmentGraph(graph_def, MakeCandidateFn({}), default_options_, &segments),
+ tensorflow::Status::OK());
+
+ // Expect no segments/subgraphs.
+ EXPECT_TRUE(segments.empty());
+ TF_DeleteGraph(graph);
+}
+
+TEST_F(SegmentTest, Simple) {
+ TF_Status* s = TF_NewStatus();
+ TF_Graph* graph = TF_NewGraph();
+
+ // feed
+ // // ||
+ // add0 add1
+ // | | /
+ // | add2
+ // | / ||
+ // add3 add4
+ // | /
+ // <sink>
+ //
+ TF_Operation* feed = Placeholder(graph, s, "feed");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
+
+ TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
+ TF_Operation* add4 = Add(add2, add2, graph, s, "add4");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
+
+ GraphDef graph_def;
+ ASSERT_TRUE(GetGraphDef(graph, &graph_def));
+
+ SegmentNodesVector segments;
+ ASSERT_EQ(
+ SegmentGraph(graph_def,
+ MakeCandidateFn({"add0", "add1", "add2", "add3", "add4"}),
+ default_options_, &segments),
+ tensorflow::Status::OK());
+
+ // Expect all Add operations to be collapsed into a single segment
+ ASSERT_EQ(segments.size(), 1);
+ std::vector<string> expected{"add0", "add1", "add2", "add3", "add4"};
+ for (const auto& ex : expected) {
+ EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
+ << "Missing expected node " << ex;
+ }
+ TF_DeleteGraph(graph);
+ TF_DeleteStatus(s);
+}
+
+TEST_F(SegmentTest, AvoidCycle) {
+ TF_Status* s = TF_NewStatus();
+ TF_Graph* graph = TF_NewGraph();
+
+ // add2 is not a TRT candidate so add0/add3 cannot be formed as a
+ // subgraph
+ //
+ // feed
+ // // ||
+ // add0 add1
+ // | | /
+ // | add2
+ // | / ||
+ // add3 add4
+ // | /
+ // <sink>
+ //
+ TF_Operation* feed = Placeholder(graph, s, "feed");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
+
+ TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
+ TF_Operation* add4 = Add(add2, add2, graph, s, "add4");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
+
+ GraphDef graph_def;
+ ASSERT_TRUE(GetGraphDef(graph, &graph_def));
+
+ SegmentNodesVector segments;
+ ASSERT_EQ(
+ SegmentGraph(graph_def, MakeCandidateFn({"add0", "add1", "add3", "add4"}),
+ default_options_, &segments),
+ tensorflow::Status::OK());
+
+ // Expect no subgraphs
+ EXPECT_EQ(segments.size(), 0);
+ TF_DeleteGraph(graph);
+ TF_DeleteStatus(s);
+}
+
+TEST_F(SegmentTest, Multiple) {
+ TF_Status* s = TF_NewStatus();
+ TF_Graph* graph = TF_NewGraph();
+
+ // add5 is not a TRT candidate so two subgraphs should be formed
+ //
+ // feed
+ // // || ||
+ // add0 add1 add7
+ // | | / / ||
+ // | add2-----add5 add8
+ // | / | | | |
+ // add3 add4 add6
+ // | | /
+ // <sink>
+ //
+ TF_Operation* feed = Placeholder(graph, s, "feed");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
+
+ TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add7 = Add(feed, feed, graph, s, "add7");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add5 = Add(add2, add7, graph, s, "add5");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add8 = Add(add7, add7, graph, s, "add8");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
+ TF_Operation* add4 = Add(add2, add5, graph, s, "add4");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
+ TF_Operation* add6 = Add(add5, add8, graph, s, "add6");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("add6"), string(TF_OperationName(add6)));
+
+ GraphDef graph_def;
+ ASSERT_TRUE(GetGraphDef(graph, &graph_def));
+
+ SegmentNodesVector segments;
+ ASSERT_EQ(SegmentGraph(graph_def,
+ MakeCandidateFn({"add0", "add1", "add2", "add3",
+ "add4", "add6", "add7", "add8"}),
+ default_options_, &segments),
+ tensorflow::Status::OK());
+
+ // Expect two subgraphs
+ EXPECT_EQ(segments.size(), 2);
+
+ std::vector<string> expected0{"add0", "add1", "add2", "add3"};
+ for (const auto& ex : expected0) {
+ EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
+ << "Missing expected node " << ex;
+ }
+
+ std::vector<string> expected1{"add6", "add8"};
+ for (const auto& ex : expected1) {
+ EXPECT_TRUE(segments[1].find(ex) != segments[1].end())
+ << "Missing expected node " << ex;
+ }
+ TF_DeleteGraph(graph);
+ TF_DeleteStatus(s);
+}
+
+TEST_F(SegmentTest, BigIfElse) {
+ TF_Status* s = TF_NewStatus();
+ TF_Graph* graph = TF_NewGraph();
+
+ // add2 is not a TRT candidate
+ //
+ // feed
+ // ||
+ // add0
+ // // ||
+ // add1 add4
+ // || ||
+ // add2 add5
+ // || ||
+ // add3 add6
+ // || //
+ // add7
+ // ||
+ // <sink>
+ //
+ TF_Operation* feed = Placeholder(graph, s, "feed");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
+
+ TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add1 = Add(add0, add0, graph, s, "add1");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add2 = Add(add1, add1, graph, s, "add2");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add3 = Add(add2, add2, graph, s, "add3");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add4 = Add(add0, add0, graph, s, "add4");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add5 = Add(add4, add4, graph, s, "add5");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add6 = Add(add5, add5, graph, s, "add6");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_Operation* add7 = Add(add3, add6, graph, s, "add7");
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ EXPECT_EQ(string("add7"), string(TF_OperationName(add7)));
+
+ GraphDef graph_def;
+ ASSERT_TRUE(GetGraphDef(graph, &graph_def));
+
+ SegmentNodesVector segments;
+ ASSERT_EQ(SegmentGraph(graph_def,
+ MakeCandidateFn({"add0", "add1", "add3", "add4",
+ "add5", "add6", "add7"}),
+ default_options_, &segments),
+ tensorflow::Status::OK());
+
+ // Expect 2 subgraphs
+ EXPECT_EQ(segments.size(), 2);
+
+ std::vector<string> expected0{"add3", "add4", "add5", "add6", "add7"};
+ for (const auto& ex : expected0) {
+ EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
+ << "Missing expected node " << ex;
+ }
+
+ std::vector<string> expected1{"add0", "add1"};
+ for (const auto& ex : expected1) {
+ EXPECT_TRUE(segments[1].find(ex) != segments[1].end())
+ << "Missing expected node " << ex;
+ }
+ TF_DeleteGraph(graph);
+ TF_DeleteStatus(s);
+}
+
+} // namespace test
+} // namespace segment
+} // namespace tensorrt
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/segment/union_find.h b/tensorflow/contrib/tensorrt/segment/union_find.h
new file mode 100644
index 0000000000..1c64ebbb0a
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/segment/union_find.h
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_
+
+namespace tensorflow {
+namespace tensorrt {
+namespace segment {
+
+// Union-Find data structure.
+// Each cluster has an associated value; when merging clusters we can control
+// which value becomes the representative of the merged clusters. Values must be
+// copyable.
+template <typename T>
+class UnionFind {
+ public:
+ UnionFind() : size_(1), parent_(nullptr) {}
+ explicit UnionFind(const T& v) : size_(1), parent_(nullptr), value_(v) {}
+
+ // Returns the number of elements in a cluster.
+ int Size() { return FindRoot()->size_; }
+
+ // Merges this cluster with 'other'. This cluster's value becomes
+ // the value of the merged cluster; the value of 'other' is ignored.
+ void Merge(UnionFind* other);
+
+ // Each cluster has an associated value. Retrieves the value associated
+ // with this cluster.
+ T& ParentValue() { return FindRoot()->value_; }
+
+ // Get the original value of this node.
+ T& Value() { return value_; }
+
+ private:
+ // Finds the root element of the cluster. Performs path compression.
+ UnionFind* FindRoot();
+
+ int size_;
+ UnionFind* parent_;
+ T value_;
+};
+
+template <typename T>
+void UnionFind<T>::Merge(UnionFind* other) {
+ UnionFind<T>* a = FindRoot();
+ UnionFind<T>* b = other->FindRoot();
+ if (a == b) return;
+
+ b->parent_ = a;
+ a->size_ += b->size_;
+}
+
+template <typename T>
+UnionFind<T>* UnionFind<T>::FindRoot() {
+ if (!parent_) return this;
+ // Path compression: update intermediate nodes to point to the root of the
+ // equivalence class.
+ parent_ = parent_->FindRoot();
+ return parent_;
+}
+
+} // namespace segment
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_SEGMENT_UNION_FIND_H_
diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
new file mode 100644
index 0000000000..8b475177bc
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
@@ -0,0 +1,89 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h"
+
+#include <string>
+#include <vector>
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorrt/include/NvInfer.h"
+
+namespace tensorflow {
+namespace shape_inference {
+
+tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) {
+ tensorflow::tensorrt::Logger logger;
+ string serialized_engine;
+ TF_RETURN_IF_ERROR(context->GetAttr("serialized_engine", &serialized_engine));
+ nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger);
+ nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine(
+ serialized_engine.c_str(), serialized_engine.size(), nullptr);
+
+ int num_batch = -1;
+ std::vector<::tensorflow::DataType> input_type;
+ TF_RETURN_IF_ERROR(context->GetAttr("InT", &input_type));
+ for (size_t i = 0; i < context->num_inputs(); i++) {
+ // Check if input shape is legit
+ auto input_shape = context->input(i);
+ for (int j = 0; j < context->Rank(input_shape); j++) {
+ auto dim_handler = context->Dim(input_shape, j);
+ if (j == 0) {
+ if (i == 0) {
+ num_batch = context->Value(dim_handler);
+ } else if (num_batch != context->Value(dim_handler)) {
+ // TODO(jie): TensorRT engine requires consistent batch between inputs
+ // tensors. Segmenter should be aware of this.
+ LOG(FATAL) << "TensorRT engine requires consistent batch size";
+ }
+ }
+ }
+ }
+
+ // Arrange input here
+ std::vector<string> input_nodes;
+ TF_RETURN_IF_ERROR(context->GetAttr("input_nodes", &input_nodes));
+
+ // Arrange output here
+ std::vector<string> output_nodes;
+ TF_RETURN_IF_ERROR(context->GetAttr("output_nodes", &output_nodes));
+ for (size_t i = 0; i < output_nodes.size(); i++) {
+ int binding_index = trt_engine->getBindingIndex(output_nodes[i].c_str());
+ ShapeHandle output_shape;
+ std::vector<DimensionHandle> dim_vec;
+ dim_vec.emplace_back(context->MakeDim(num_batch));
+ if (binding_index != -1) {
+ auto dims = trt_engine->getBindingDimensions(binding_index);
+ for (int j = 0; j < dims.nbDims; j++) {
+ dim_vec.emplace_back(context->MakeDim(dims.d[j]));
+ }
+ } else {
+ LOG(FATAL) << "TensorRT engine cannot find binding: " << output_nodes[i];
+ }
+ output_shape = context->MakeShape(dim_vec);
+ context->set_output(i, output_shape);
+ }
+
+ return Status::OK();
+}
+
+} // namespace shape_inference
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h
new file mode 100644
index 0000000000..4b50f66699
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h
@@ -0,0 +1,33 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace shape_inference {
+Status TRTEngineOpShapeInference(InferenceContext* c);
+} // namespace shape_inference
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_SHAPE_FN_TRT_SHFN_H_
diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py
new file mode 100644
index 0000000000..c78f6f2224
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py
@@ -0,0 +1,88 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Script to test TF-TensorRT integration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+# normally we should do import tensorflow as tf and then
+# tf.placeholder, tf.constant, tf.nn.conv2d etc but
+# it looks like internal builds don't like it so
+# importing every module individually
+
+from tensorflow.contrib import tensorrt as trt
+from tensorflow.core.protobuf import config_pb2 as cpb2
+from tensorflow.python.client import session as csess
+from tensorflow.python.framework import constant_op as cop
+from tensorflow.python.framework import dtypes as dtypes
+from tensorflow.python.framework import importer as importer
+from tensorflow.python.framework import ops as ops
+from tensorflow.python.ops import array_ops as aops
+from tensorflow.python.ops import nn as nn
+from tensorflow.python.ops import nn_ops as nn_ops
+
+
+def get_simple_graph_def():
+ """Create a simple graph and return its graph_def."""
+ g = ops.Graph()
+ with g.as_default():
+ a = aops.placeholder(
+ dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input")
+ e = cop.constant(
+ [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]],
+ name="weights",
+ dtype=dtypes.float32)
+ conv = nn.conv2d(
+ input=a, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv")
+ b = cop.constant(
+ [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtypes.float32)
+ t = nn.bias_add(conv, b, name="biasAdd")
+ relu = nn.relu(t, "relu")
+ idty = aops.identity(relu, "ID")
+ v = nn_ops.max_pool(
+ idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
+ aops.squeeze(v, name="output")
+ return g.as_graph_def()
+
+
+def run_graph(gdef, dumm_inp):
+ gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
+ ops.reset_default_graph()
+ g = ops.Graph()
+ with g.as_default():
+ inp, out = importer.import_graph_def(
+ graph_def=gdef, return_elements=["input", "output"])
+ inp = inp.outputs[0]
+ out = out.outputs[0]
+ with csess.Session(
+ config=cpb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess:
+ val = sess.run(out, {inp: dumm_inp})
+ return val
+
+
+if "__main__" in __name__:
+ inp_dims = (100, 24, 24, 2)
+ dummy_input = np.random.random_sample(inp_dims)
+ gdef = get_simple_graph_def()
+ # Get optimized graph
+ trt_graph = trt.create_inference_graph(gdef, ["output"], inp_dims[0])
+ o1 = run_graph(gdef, dummy_input)
+ o2 = run_graph(trt_graph, dummy_input)
+ o3 = run_graph(trt_graph, dummy_input)
+ assert np.array_equal(o1, o2)
+ assert np.array_equal(o3, o2) # sanity check
+ print("Pass")
diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i
new file mode 100644
index 0000000000..d679945d56
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/trt_conversion.i
@@ -0,0 +1,131 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+/* Wrap trt_conversion */
+%{
+#define SWIG_FILE_WITH_INIT
+%}
+%include "std_pair.i"
+%include "tensorflow/python/platform/base.i"
+
+%{
+PyObject* pair_helper(std::pair<string, string>* in) {
+ PyObject *first(nullptr), *second(nullptr), *tuple(nullptr);
+ first = PyBytes_FromStringAndSize(in->first.data(), in->first.length());
+ if (!first) {
+ if (!PyErr_Occurred()) {
+ PyErr_SetString(PyExc_TypeError, "Pair conversion first argument failed");
+ }
+ return NULL;
+ }
+ second = PyBytes_FromStringAndSize(in->second.data(), in->second.length());
+ if (!second) {
+ if (!PyErr_Occurred()) {
+ PyErr_SetString(PyExc_TypeError,
+ "Pair conversion second argument failed");
+ }
+ return NULL;
+ }
+ tuple = Py_BuildValue("(OO)", first, second);
+ if (!tuple) {
+ if (!PyErr_Occurred()) {
+ PyErr_SetString(PyExc_TypeError,
+ "Tuple creation from pair<string,string> failed!");
+ }
+ return NULL;
+ }
+ return tuple;
+}
+%}
+%typemap(out) std::pair<string, string> {
+ PyObject *tuple = pair_helper(&$1);
+ if (!tuple) SWIG_fail;
+ $result = tuple;
+}
+%{
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/stat_summarizer.h"
+#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
+%}
+
+%ignoreall
+%unignore tensorflow;
+%unignore trt_convert;
+
+%{
+std::pair<string, string> trt_convert(
+ string graph_def_string, // The serialized GraphDef string.
+ std::vector<string> output_names,
+ size_t max_batch_size,
+ size_t max_workspace_size_bytes
+ // Unfortunately we can't use TF_Status here since it
+ // is in c/c_api and brings in a lot of other libraries
+ // which in turn declare ops. These ops are included
+ // statically in our library and cause an abort when
+ // module is loaded due to double registration
+ // until Tensorflow properly exposes these headers
+ // we have to work around this by returning a string
+ // and converting it to exception on python side.
+ //,TF_Status* out_status) {
+) {
+#if GOOGLE_CUDA && GOOGLE_TENSORRT
+ string out_status;
+
+ tensorflow::GraphDef graph_def;
+ if (!graph_def.ParseFromString(graph_def_string)) {
+ out_status = "InvalidArgument;Couldn't interpret input as a GraphDef";
+ return std::pair<string, string>{out_status, ""};
+ }
+
+ if (!output_names.size()) {
+ out_status = "InvalidArgument;Size of the output_names vector is 0";
+ return std::pair<string, string>{out_status, ""};
+ // return "";
+ }
+ tensorflow::GraphDef outGraph;
+ tensorflow::Status conversion_status =
+ tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT(
+ graph_def, output_names, max_batch_size, max_workspace_size_bytes,
+ &outGraph);
+ if (!conversion_status.ok()) {
+ auto retCode = (int)conversion_status.code();
+ char buff[2000];
+ snprintf(buff, 2000, "%d;%s", retCode,
+ conversion_status.error_message().c_str());
+ out_status = buff;
+ return std::pair<string, string>{out_status, ""};
+ }
+ string result;
+ if (!outGraph.SerializeToString(&result)) {
+ out_status = "InvalidArgument;Couldn't serialize output as a GraphDef";
+ return std::pair<string, string>{out_status, ""};
+ }
+ out_status = "OK;All good!";
+ return std::pair<string, string>{out_status, result};
+#else
+ // Returns FAILED_PRECONDITION.
+ return std::pair<string, string>{"9;TensorRT is not enabled!", ""};
+#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
+}
+%}
+
+std::pair<string, string> trt_convert(string graph_def_string,
+ std::vector<string> output_names,
+ size_t max_batch_size,
+ size_t max_workspace_size_bytes);
+
+
+%unignoreall
diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py
index 76f1dd2a56..8d99835b64 100644
--- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py
+++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py
@@ -20,7 +20,7 @@ from __future__ import print_function
from setuptools import setup
-_VERSION = '1.6.0-rc0'
+_VERSION = '1.6.0-rc1'
CONSOLE_SCRIPTS = [
'capture_tpu_profile=cloud_tpu_profiler.main:run_main',
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id.h b/tensorflow/core/common_runtime/gpu/gpu_id.h
index 4e9c4abce1..2a6caea296 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_id.h
@@ -40,7 +40,7 @@ namespace tensorflow {
// a BaseGPUDevice. Note that the configuration allows us to create multiple
// BaseGPUDevice per GPU hardware in order to use multi CUDA streams on the
// hardware, so the mapping between TF GPU id and CUDA GPU id is not a 1:1
-// mappping, see the example below.
+// mapping, see the example below.
//
// For example, assuming that in the machine we have GPU device with index 0, 1,
// 2 and 3 (physical GPU id). Setting "CUDA_VISIBLE_DEVICES=1,2,3" will create
diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
index 77eeb56b19..fb092424bf 100644
--- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h
+++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
@@ -21,7 +21,6 @@ limitations under the License.
#ifdef INTEL_MKL
-#include <unistd.h>
#include <cstdlib>
#include <string>
#include "tensorflow/core/common_runtime/bfc_allocator.h"
diff --git a/tensorflow/core/framework/tensor_shape.h b/tensorflow/core/framework/tensor_shape.h
index adb41b81c6..fe2ba375aa 100644
--- a/tensorflow/core/framework/tensor_shape.h
+++ b/tensorflow/core/framework/tensor_shape.h
@@ -191,9 +191,6 @@ class TensorShapeBase : public TensorShapeRep {
/// Appends all the dimensions from `shape`.
void AppendShape(const TensorShapeBase& shape);
- // Maximum number of dimensions in a tensor.
- static constexpr int MaxDimensions() { return 254; }
-
/// \brief Insert a dimension somewhere in the `TensorShape`.
/// REQUIRES: `0 <= d <= dims()`
/// REQUIRES: `size >= 0`
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc
index 5343e6802d..e9ced4d2b6 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc
@@ -222,7 +222,7 @@ Status MklToTfConversionPass::InsertInputConversionNode(
BaseType(n->input_type(0)));
// Check ordering of edges
- for (uint i = 0; i < 4; i++) {
+ for (uint32 i = 0; i < 4; i++) {
CHECK_EQ((edges[i]->dst_input() == i), true);
}
diff --git a/tensorflow/core/kernels/colorspace_op.cc b/tensorflow/core/kernels/colorspace_op.cc
index 9cc2e67bbe..f4402a245d 100644
--- a/tensorflow/core/kernels/colorspace_op.cc
+++ b/tensorflow/core/kernels/colorspace_op.cc
@@ -71,7 +71,7 @@ class RGBToHSVOp : public OpKernel {
TensorShape({input_data.dimension(0)}),
&trange));
- typename TTypes<T, 1>::Tensor range = trange.tensor<T, 1>();
+ typename TTypes<T, 1>::Tensor range(trange.tensor<T, 1>());
functor::RGBToHSV<Device, T>()(context->eigen_device<Device>(), input_data,
range, output_data);
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
index 2b3b7184dc..94989089ec 100644
--- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
@@ -24,12 +24,12 @@ limitations under the License.
#include "tensorflow/core/util/cuda_kernel_helper.h"
#include "tensorflow/core/util/tensor_format.h"
-#if !defined(_MSC_VER)
-#define UNROLL _Pragma("unroll")
-#define NOUNROLL _Pragma("nounroll")
-#else
+#if defined(_MSC_VER) && !defined(__clang__)
#define UNROLL
#define NOUNROLL
+#else
+#define UNROLL _Pragma("unroll")
+#define NOUNROLL _Pragma("nounroll")
#endif
namespace tensorflow {
diff --git a/tensorflow/core/kernels/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
index d9713075be..723b445a75 100644
--- a/tensorflow/core/kernels/mkl_batch_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
@@ -29,7 +29,6 @@ limitations under the License.
#include <vector>
#include "mkl_cblas.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -41,9 +40,6 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
-#define MKL_Complex8 tensorflow::complex64
-#define MKL_Complex16 tensorflow::complex128
-
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -180,16 +176,16 @@ class BatchMatMulMkl : public OpKernel {
void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
const bool TransB, const MKL_INT *M_Array,
const MKL_INT *N_Array, const MKL_INT *K_Array,
- const MKL_Complex8 **A_Array, const MKL_INT *lda_Array,
- const MKL_Complex8 **B_Array, const MKL_INT *ldb_Array,
- MKL_Complex8 **C_Array, const MKL_INT *ldc_Array,
+ const complex64 **A_Array, const MKL_INT *lda_Array,
+ const complex64 **B_Array, const MKL_INT *ldb_Array,
+ complex64 **C_Array, const MKL_INT *ldc_Array,
const MKL_INT group_count, const MKL_INT *group_size) {
std::vector<CBLAS_TRANSPOSE> TransA_array(
group_size[0], TransA ? CblasConjTrans : CblasNoTrans);
std::vector<CBLAS_TRANSPOSE> TransB_array(
group_size[0], TransB ? CblasConjTrans : CblasNoTrans);
- std::vector<MKL_Complex8> alpha_Array(group_size[0], {1.0f, 0.0f});
- std::vector<MKL_Complex8> beta_Array(group_size[0], {0.0f, 0.0f});
+ std::vector<complex64> alpha_Array(group_size[0], {1.0f, 0.0f});
+ std::vector<complex64> beta_Array(group_size[0], {0.0f, 0.0f});
cblas_cgemm_batch(
Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array,
static_cast<const void *>(&alpha_Array[0]),
@@ -202,18 +198,16 @@ class BatchMatMulMkl : public OpKernel {
void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
const bool TransB, const MKL_INT *M_Array,
const MKL_INT *N_Array, const MKL_INT *K_Array,
- const MKL_Complex16 **A_Array,
- const MKL_INT *lda_Array,
- const MKL_Complex16 **B_Array,
- const MKL_INT *ldb_Array, MKL_Complex16 **C_Array,
- const MKL_INT *ldc_Array, const MKL_INT group_count,
- const MKL_INT *group_size) {
+ const complex128 **A_Array, const MKL_INT *lda_Array,
+ const complex128 **B_Array, const MKL_INT *ldb_Array,
+ complex128 **C_Array, const MKL_INT *ldc_Array,
+ const MKL_INT group_count, const MKL_INT *group_size) {
std::vector<CBLAS_TRANSPOSE> TransA_array(
group_size[0], TransA ? CblasConjTrans : CblasNoTrans);
std::vector<CBLAS_TRANSPOSE> TransB_array(
group_size[0], TransB ? CblasConjTrans : CblasNoTrans);
- std::vector<MKL_Complex16> alpha_Array(group_size[0], {1.0f, 0.0f});
- std::vector<MKL_Complex16> beta_Array(group_size[0], {0.0f, 0.0f});
+ std::vector<complex128> alpha_Array(group_size[0], {1.0f, 0.0f});
+ std::vector<complex128> beta_Array(group_size[0], {0.0f, 0.0f});
cblas_zgemm_batch(
Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array,
static_cast<const void *>(&alpha_Array[0]),
diff --git a/tensorflow/core/kernels/mkl_input_conversion_op.cc b/tensorflow/core/kernels/mkl_input_conversion_op.cc
index 5a8799ae93..e9a2376b54 100644
--- a/tensorflow/core/kernels/mkl_input_conversion_op.cc
+++ b/tensorflow/core/kernels/mkl_input_conversion_op.cc
@@ -145,8 +145,8 @@ class MklInputConversionOp : public OpKernel {
const MklShape* mkl_shape;
const Tensor* tf_tensor;
MklShape* tf_mkl_shape;
- uint mkl_tensor_index;
- uint tf_tensor_index;
+ uint32 mkl_tensor_index;
+ uint32 tf_tensor_index;
if (input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) {
mkl_tensor = &input_tensor_0;
mkl_shape = &input_shape_0;
diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc
index 47598f443f..dfa6cecc9b 100644
--- a/tensorflow/core/kernels/mkl_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_matmul_op.cc
@@ -170,32 +170,32 @@ class MklMatMulOp : public OpKernel {
// Matrix-Matrix Multiplication with Complex64 (std::complex<float>) tensors.
// For detailed info about parameters, look at FP32 function description.
void MklBlasGemm(bool transa, bool transb, const int m, const int n,
- const int k, const std::complex<float>* a, const int lda,
- const std::complex<float>* b, const int ldb,
- std::complex<float>* c, int const ldc) {
+ const int k, const complex64* a, const int lda,
+ const complex64* b, const int ldb, complex64* c,
+ int const ldc) {
const MKL_Complex8 alpha = {1.0f, 0.0f};
const MKL_Complex8 beta = {0.0f, 0.0f};
cblas_cgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
- transb ? CblasTrans : CblasNoTrans, m, n, k,
- static_cast<const void*>(&alpha), static_cast<const void*>(a),
- lda, static_cast<const void*>(b), ldb,
- static_cast<const void*>(&beta), static_cast<void*>(c), ldc);
+ transb ? CblasTrans : CblasNoTrans, m, n, k, &alpha,
+ reinterpret_cast<const MKL_Complex8*>(a), lda,
+ reinterpret_cast<const MKL_Complex8*>(b), ldb, &beta,
+ reinterpret_cast<MKL_Complex8*>(c), ldc);
}
// Matrix-Matrix Multiplication with Complex128 (std::complex<double>)
// tensors. For detailed info about parameters, look at FP32 function
// description.
void MklBlasGemm(bool transa, bool transb, const int m, const int n,
- const int k, const std::complex<double>* a, const int lda,
- const std::complex<double>* b, const int ldb,
- std::complex<double>* c, const int ldc) {
+ const int k, const complex128* a, const int lda,
+ const complex128* b, const int ldb, complex128* c,
+ const int ldc) {
const MKL_Complex16 alpha = {1.0, 0.0};
const MKL_Complex16 beta = {0.0, 0.0};
cblas_zgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
- transb ? CblasTrans : CblasNoTrans, m, n, k,
- static_cast<const void*>(&alpha), static_cast<const void*>(a),
- lda, static_cast<const void*>(b), ldb,
- static_cast<const void*>(&beta), static_cast<void*>(c), ldc);
+ transb ? CblasTrans : CblasNoTrans, m, n, k, &alpha,
+ reinterpret_cast<const MKL_Complex16*>(a), lda,
+ reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta,
+ reinterpret_cast<MKL_Complex16*>(c), ldc);
}
};
diff --git a/tensorflow/core/kernels/mkl_tfconv_op.h b/tensorflow/core/kernels/mkl_tfconv_op.h
index 5fafa14b5d..ddea9e281b 100644
--- a/tensorflow/core/kernels/mkl_tfconv_op.h
+++ b/tensorflow/core/kernels/mkl_tfconv_op.h
@@ -128,7 +128,7 @@ class MklToTfOp : public OpKernel {
#else
static void ConvertMklToTf(OpKernel* op_kernel, OpKernelContext* context,
string data_format_str, DataType op_data_type,
- bool has_avx512f, uint input_number) {
+ bool has_avx512f, uint32 input_number) {
// Check that input tensor is in MKL format.
const Tensor& input_tensor = MklGetInput(context, input_number);
MklShape input_shape;
diff --git a/tensorflow/core/kernels/mkl_transpose_op.cc b/tensorflow/core/kernels/mkl_transpose_op.cc
index 764d4c9400..3f07b317c4 100644
--- a/tensorflow/core/kernels/mkl_transpose_op.cc
+++ b/tensorflow/core/kernels/mkl_transpose_op.cc
@@ -18,9 +18,6 @@ limitations under the License.
#ifdef INTEL_MKL
#define EIGEN_USE_THREADS
-#include "tensorflow/core/framework/numeric_types.h"
-#define MKL_Complex8 tensorflow::complex64
-#define MKL_Complex16 tensorflow::complex128
#include "mkl_trans.h"
#include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/kernels/transpose_op.h"
@@ -62,10 +59,37 @@ Status MKLTranspose2D(const char trans, const Tensor& in, Tensor* out);
INSTANTIATE(float, s)
INSTANTIATE(double, d)
-INSTANTIATE(complex64, c)
-INSTANTIATE(complex128, z)
+
#undef INSTANTIATE
+template <>
+Status MKLTranspose2D<complex64>(const char trans, const Tensor& in,
+ Tensor* out) {
+ const MKL_Complex8 alpha = {1.0f, 0.0f};
+ mkl_comatcopy(
+ 'R', trans, in.dim_size(0), in.dim_size(1), alpha,
+ reinterpret_cast<const MKL_Complex8*>(in.flat<complex64>().data()),
+ in.dim_size(1),
+ reinterpret_cast<MKL_Complex8*>(
+ const_cast<complex64*>(out->flat<complex64>().data())),
+ in.dim_size(0));
+ return Status::OK();
+}
+
+template <>
+Status MKLTranspose2D<complex128>(const char trans, const Tensor& in,
+ Tensor* out) {
+ const MKL_Complex16 alpha = {1.0, 0.0};
+ mkl_zomatcopy(
+ 'R', trans, in.dim_size(0), in.dim_size(1), alpha,
+ reinterpret_cast<const MKL_Complex16*>(in.flat<complex128>().data()),
+ in.dim_size(1),
+ reinterpret_cast<MKL_Complex16*>(
+ const_cast<complex128*>(out->flat<complex128>().data())),
+ in.dim_size(0));
+ return Status::OK();
+}
+
static const char kMKLTranspose = 'T';
static const char kMKLConjugateTranspose = 'C';
diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc
index 5d28b87e6b..903b898d0a 100644
--- a/tensorflow/core/kernels/non_max_suppression_op.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op.cc
@@ -105,7 +105,7 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes,
}
const int output_size = std::min(max_output_size.scalar<int>()(), num_boxes);
- typename TTypes<float, 2>::ConstTensor boxes_data = boxes.tensor<float, 2>();
+ TTypes<float, 2>::ConstTensor boxes_data = boxes.tensor<float, 2>();
std::vector<float> scores_data(num_boxes);
std::copy_n(scores.flat<float>().data(), num_boxes, scores_data.begin());
@@ -138,8 +138,7 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes,
Tensor* output = nullptr;
TensorShape output_shape({static_cast<int>(selected.size())});
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
- typename TTypes<int, 1>::Tensor selected_indices_data =
- output->tensor<int, 1>();
+ TTypes<int, 1>::Tensor selected_indices_data = output->tensor<int, 1>();
std::copy_n(selected.begin(), selected.size(), selected_indices_data.data());
}
diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc b/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc
index ddfeb1bb79..661d47d925 100644
--- a/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc
@@ -29,7 +29,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random_distributions.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
-#ifdef COMPILER_MSVC
+#if defined(_MSC_VER) && !defined(__clang__)
// msvc does not support unroll. One could try the loop pragma but we need to
// take a closer look if this generates better code in this case. For now let
// the compiler take care of it.
diff --git a/tensorflow/core/kernels/quantized_resize_bilinear_op.cc b/tensorflow/core/kernels/quantized_resize_bilinear_op.cc
index fb2faede2f..9a1dcd0d49 100644
--- a/tensorflow/core/kernels/quantized_resize_bilinear_op.cc
+++ b/tensorflow/core/kernels/quantized_resize_bilinear_op.cc
@@ -697,8 +697,8 @@ class QuantizedResizeBilinearOp : public OpKernel {
// Return if the output is empty.
if (st.output->NumElements() == 0) return;
- typename TTypes<T, 4>::ConstTensor image_data = input.tensor<T, 4>();
- typename TTypes<T, 4>::Tensor output_data = st.output->tensor<T, 4>();
+ typename TTypes<T, 4>::ConstTensor image_data(input.tensor<T, 4>());
+ typename TTypes<T, 4>::Tensor output_data(st.output->tensor<T, 4>());
ResizeBilinear<T>(image_data, st.height_scale, st.width_scale, in_min,
in_max, &output_data);
diff --git a/tensorflow/core/kernels/random_crop_op.cc b/tensorflow/core/kernels/random_crop_op.cc
index 554909760a..b89bda4769 100644
--- a/tensorflow/core/kernels/random_crop_op.cc
+++ b/tensorflow/core/kernels/random_crop_op.cc
@@ -92,8 +92,8 @@ class RandomCropOp : public OpKernel {
// TODO(shlens): Do this more efficiently with memcpy once padding is
// available for smaller images.
- typename TTypes<T, 3>::ConstTensor input_data = input.tensor<T, 3>();
- typename TTypes<T, 3>::Tensor output_data = output->tensor<T, 3>();
+ typename TTypes<T, 3>::ConstTensor input_data(input.tensor<T, 3>());
+ typename TTypes<T, 3>::Tensor output_data(output->tensor<T, 3>());
for (int y = 0; y < target_height; ++y) {
for (int x = 0; x < target_width; ++x) {
diff --git a/tensorflow/core/kernels/resize_area_op.cc b/tensorflow/core/kernels/resize_area_op.cc
index ada50dfb70..98b8a0df28 100644
--- a/tensorflow/core/kernels/resize_area_op.cc
+++ b/tensorflow/core/kernels/resize_area_op.cc
@@ -149,7 +149,7 @@ class ResizeAreaOp : public OpKernel {
if (!context->status().ok()) return;
- typename TTypes<T, 4>::ConstTensor input_data = input.tensor<T, 4>();
+ typename TTypes<T, 4>::ConstTensor input_data(input.tensor<T, 4>());
// Precompute values used when iterating over x coordinates within a row.
// Note that it may be useful to cache x_interps for a given
@@ -190,8 +190,7 @@ class ResizeAreaOp : public OpKernel {
void ComputeLoop(const ImageResizerState& st,
const std::vector<CachedInterpolation>& x_interps,
typename TTypes<T, 4>::ConstTensor input_data) {
- typename TTypes<float, 4>::Tensor output_data =
- st.output->tensor<float, 4>();
+ TTypes<float, 4>::Tensor output_data = st.output->tensor<float, 4>();
// When using this algorithm for downsizing, the target pixel value is the
// weighted average of all the source pixels. The weight is determined by
diff --git a/tensorflow/core/kernels/resize_bicubic_op.cc b/tensorflow/core/kernels/resize_bicubic_op.cc
index 86e61bbcef..65014b6c44 100644
--- a/tensorflow/core/kernels/resize_bicubic_op.cc
+++ b/tensorflow/core/kernels/resize_bicubic_op.cc
@@ -480,9 +480,8 @@ class ResizeBicubicOp : public OpKernel {
if (!context->status().ok()) return;
- typename TTypes<T, 4>::ConstTensor input_data = input.tensor<T, 4>();
- typename TTypes<float, 4>::Tensor output_data =
- st.output->tensor<float, 4>();
+ typename TTypes<T, 4>::ConstTensor input_data(input.tensor<T, 4>());
+ TTypes<float, 4>::Tensor output_data = st.output->tensor<float, 4>();
interpolate_with_caching<T>(input_data, st, output_data);
}
@@ -510,9 +509,8 @@ class ResizeBicubicOpGrad : public OpKernel {
if (!context->status().ok()) return;
- typename TTypes<float, 4>::ConstTensor input_grad =
- input.tensor<float, 4>();
- typename TTypes<T, 4>::Tensor output_grad = st.output->tensor<T, 4>();
+ TTypes<float, 4>::ConstTensor input_grad = input.tensor<float, 4>();
+ typename TTypes<T, 4>::Tensor output_grad(st.output->tensor<T, 4>());
ResizeBicubicGrad<T>(input_grad, st, output_grad);
}
diff --git a/tensorflow/core/kernels/resize_bilinear_op.cc b/tensorflow/core/kernels/resize_bilinear_op.cc
index d9cb993a4b..dde59e8e74 100644
--- a/tensorflow/core/kernels/resize_bilinear_op.cc
+++ b/tensorflow/core/kernels/resize_bilinear_op.cc
@@ -51,9 +51,8 @@ class ResizeBilinearOp : public OpKernel {
// Return if the output is empty.
if (st.output->NumElements() == 0) return;
- typename TTypes<T, 4>::ConstTensor image_data = input.tensor<T, 4>();
- typename TTypes<float, 4>::Tensor output_data =
- st.output->tensor<float, 4>();
+ typename TTypes<T, 4>::ConstTensor image_data(input.tensor<T, 4>());
+ TTypes<float, 4>::Tensor output_data = st.output->tensor<float, 4>();
functor::ResizeBilinear<Device, T>()(context->eigen_device<Device>(),
image_data, st.height_scale,
@@ -258,9 +257,8 @@ class ResizeBilinearOpGrad : public OpKernel {
if (!context->status().ok()) return;
- typename TTypes<float, 4>::ConstTensor input_grad =
- input.tensor<float, 4>();
- typename TTypes<T, 4>::Tensor output_grad = st.output->tensor<T, 4>();
+ TTypes<float, 4>::ConstTensor input_grad = input.tensor<float, 4>();
+ typename TTypes<T, 4>::Tensor output_grad(st.output->tensor<T, 4>());
functor::ResizeBilinearGrad<Device, T>()(context->eigen_device<Device>(),
input_grad, st.height_scale,
diff --git a/tensorflow/core/kernels/resize_nearest_neighbor_op.cc b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc
index bfd29b7ec8..8ec526c2b2 100644
--- a/tensorflow/core/kernels/resize_nearest_neighbor_op.cc
+++ b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc
@@ -56,8 +56,8 @@ class ResizeNearestNeighborOp : public OpKernel {
// Return if the output is empty.
if (st.output->NumElements() == 0) return;
- typename TTypes<T, 4>::ConstTensor input_data = input.tensor<T, 4>();
- typename TTypes<T, 4>::Tensor output_data = st.output->tensor<T, 4>();
+ typename TTypes<T, 4>::ConstTensor input_data(input.tensor<T, 4>());
+ typename TTypes<T, 4>::Tensor output_data(st.output->tensor<T, 4>());
bool status;
if (align_corners_) {
@@ -162,8 +162,8 @@ class ResizeNearestNeighborOpGrad : public OpKernel {
// Return if the output is empty.
if (output->NumElements() == 0) return;
- typename TTypes<T, 4>::ConstTensor input_data = input.tensor<T, 4>();
- typename TTypes<T, 4>::Tensor output_data = output->tensor<T, 4>();
+ typename TTypes<T, 4>::ConstTensor input_data(input.tensor<T, 4>());
+ typename TTypes<T, 4>::Tensor output_data(output->tensor<T, 4>());
const float height_scale =
CalculateResizeScale(out_height, in_height, align_corners_);
diff --git a/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc b/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc
index 44a817a5c7..c0fde8042e 100644
--- a/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc
+++ b/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc
@@ -387,9 +387,9 @@ class SampleDistortedBoundingBoxV2Op : public OpKernel {
OP_REQUIRES_OK(
context, context->allocate_output(2, TensorShape({1, 1, 4}), &bboxes));
- typename TTypes<T, 1>::Tensor begin_data = begin->tensor<T, 1>();
- typename TTypes<T, 1>::Tensor size_data = size->tensor<T, 1>();
- typename TTypes<float, 3>::Tensor bboxes_data = bboxes->tensor<float, 3>();
+ typename TTypes<T, 1>::Tensor begin_data(begin->tensor<T, 1>());
+ typename TTypes<T, 1>::Tensor size_data(size->tensor<T, 1>());
+ TTypes<float, 3>::Tensor bboxes_data = bboxes->tensor<float, 3>();
begin_data(0) = T(offset_height);
size_data(0) = T(target_height);
diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc
index 79369fd4a9..77594479cb 100644
--- a/tensorflow/core/kernels/slice_op.cc
+++ b/tensorflow/core/kernels/slice_op.cc
@@ -358,11 +358,11 @@ class MklSliceOp : public OpKernel {
/* data format = NCHW */
#pragma omp parallel for
- for (size_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) {
+ for (ssize_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) {
T* ip = in_buf + (d0 * in_strides[0]);
T* op = op_buf + ((d0 - begin[0]) * out_strides[0]);
#pragma omp parallel for
- for (size_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) {
+ for (ssize_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) {
T* ip1 = ip + (d1 * in_strides[1]);
T* op1 = op + ((d1 - begin[1]) * out_strides[1]);
// For NCHW, H and W will be contiguous. So we can copy
@@ -376,15 +376,15 @@ class MklSliceOp : public OpKernel {
/* data_format = NHWC */
#pragma omp parallel for
- for (size_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) {
+ for (ssize_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) {
T* ip = in_buf + (d0 * in_strides[0]);
T* op = op_buf + ((d0 - begin[0]) * out_strides[0]);
#pragma omp parallel for
- for (size_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) {
+ for (ssize_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) {
T* ip1 = ip + (d1 * in_strides[1]);
T* op1 = op + ((d1 - begin[1]) * out_strides[1]);
#pragma omp parallel for
- for (size_t d2 = begin[2]; d2 < begin[2] + size[2]; d2++) {
+ for (ssize_t d2 = begin[2]; d2 < begin[2] + size[2]; d2++) {
T* ip2 = ip1 + (d2 * in_strides[2]);
T* ip3 = ip2 + begin[3];
T* op2 = op1 + ((d2 - begin[2]) * out_strides[2]);
diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc
index e29f67297f..22e45918a0 100644
--- a/tensorflow/core/kernels/substr_op.cc
+++ b/tensorflow/core/kernels/substr_op.cc
@@ -115,7 +115,7 @@ class SubstrOp : public OpKernel {
Tensor input_buffer;
OP_REQUIRES_OK(context, context->allocate_temp(
DT_STRING, output_shape, &input_buffer));
- typename TTypes<string, 1>::Tensor input_bcast =
+ TTypes<string, 1>::Tensor input_bcast =
input_buffer.shaped<string, 1>(bcast.result_shape());
input_bcast =
input.broadcast(BCast::ToIndexArray<1>(bcast.x_bcast()));
@@ -125,8 +125,8 @@ class SubstrOp : public OpKernel {
OP_REQUIRES_OK(context,
context->allocate_temp(DataTypeToEnum<T>::v(),
output_shape, &pos_buffer));
- typename TTypes<T, 1>::Tensor pos_bcast =
- pos_buffer.shaped<T, 1>(bcast.result_shape());
+ typename TTypes<T, 1>::Tensor pos_bcast(
+ pos_buffer.shaped<T, 1>(bcast.result_shape()));
pos_bcast =
pos_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast()));
@@ -135,8 +135,8 @@ class SubstrOp : public OpKernel {
OP_REQUIRES_OK(context,
context->allocate_temp(DataTypeToEnum<T>::v(),
output_shape, &len_buffer));
- typename TTypes<T, 1>::Tensor len_bcast =
- len_buffer.shaped<T, 1>(bcast.result_shape());
+ typename TTypes<T, 1>::Tensor len_bcast(
+ len_buffer.shaped<T, 1>(bcast.result_shape()));
len_bcast =
len_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast()));
@@ -164,7 +164,7 @@ class SubstrOp : public OpKernel {
Tensor input_buffer;
OP_REQUIRES_OK(context, context->allocate_temp(
DT_STRING, output_shape, &input_buffer));
- typename TTypes<string, 2>::Tensor input_bcast =
+ TTypes<string, 2>::Tensor input_bcast =
input_buffer.shaped<string, 2>(bcast.result_shape());
input_bcast =
input.broadcast(BCast::ToIndexArray<2>(bcast.x_bcast()));
@@ -174,8 +174,8 @@ class SubstrOp : public OpKernel {
OP_REQUIRES_OK(context,
context->allocate_temp(DataTypeToEnum<T>::v(),
output_shape, &pos_buffer));
- typename TTypes<T, 2>::Tensor pos_bcast =
- pos_buffer.shaped<T, 2>(bcast.result_shape());
+ typename TTypes<T, 2>::Tensor pos_bcast(
+ pos_buffer.shaped<T, 2>(bcast.result_shape()));
pos_bcast =
pos_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast()));
@@ -184,8 +184,8 @@ class SubstrOp : public OpKernel {
OP_REQUIRES_OK(context,
context->allocate_temp(DataTypeToEnum<T>::v(),
output_shape, &len_buffer));
- typename TTypes<T, 2>::Tensor len_bcast =
- len_buffer.shaped<T, 2>(bcast.result_shape());
+ typename TTypes<T, 2>::Tensor len_bcast(
+ len_buffer.shaped<T, 2>(bcast.result_shape()));
len_bcast =
len_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast()));
diff --git a/tensorflow/core/kernels/xsmm_conv2d.cc b/tensorflow/core/kernels/xsmm_conv2d.cc
index 601704c8a7..ba03357cc6 100644
--- a/tensorflow/core/kernels/xsmm_conv2d.cc
+++ b/tensorflow/core/kernels/xsmm_conv2d.cc
@@ -27,9 +27,6 @@ void dummy_xsmm_conv2d_ensure_file_is_not_empty();
#include <stdlib.h>
#include <cstring>
-#if 0
-#include <omp.h>
-#endif
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
@@ -360,7 +357,6 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
l_tick6 = libxsmm_timer_tick();
#endif
-#if 1
BlockingCounter counter(num_threads);
for (int i = 0; i < num_threads; ++i) {
@@ -371,14 +367,6 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
});
}
counter.Wait();
-#else
-#pragma omp parallel
- {
- chk_libxsmm_err(
- libxsmm_dnn_execute_st(libxsmm_handle, kind, 0, omp_get_thread_num()),
- "Worker");
- }
-#endif
#if defined(LIBXSMM_DETAILED_TIMING)
l_tick7 = libxsmm_timer_tick();
diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc
index 3657243c5d..ebc5648269 100644
--- a/tensorflow/core/lib/io/record_writer.cc
+++ b/tensorflow/core/lib/io/record_writer.cc
@@ -49,7 +49,7 @@ RecordWriterOptions RecordWriterOptions::CreateRecordWriterOptions(
#endif // IS_SLIM_BUILD
} else if (compression_type != compression::kNone) {
LOG(ERROR) << "Unsupported compression_type:" << compression_type
- << ". No comprression will be used.";
+ << ". No compression will be used.";
}
return options;
}
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index 4c05b274fe..c3b08e067a 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -619,6 +619,10 @@ REGISTER_OP("NonMaxSuppression")
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
// The boxes is a 2-D float Tensor of shape [num_boxes, 4].
DimensionHandle unused;
+ // The boxes[0] and scores[0] are both num_boxes.
+ TF_RETURN_IF_ERROR(
+ c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
+ // The boxes[1] is 4.
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
c->set_output(0, c->Vector(c->UnknownDim()));
@@ -643,6 +647,10 @@ REGISTER_OP("NonMaxSuppressionV2")
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
// The boxes is a 2-D float Tensor of shape [num_boxes, 4].
DimensionHandle unused;
+ // The boxes[0] and scores[0] are both num_boxes.
+ TF_RETURN_IF_ERROR(
+ c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
+ // The boxes[1] is 4.
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
c->set_output(0, c->Vector(c->UnknownDim()));
diff --git a/tensorflow/core/platform/platform.h b/tensorflow/core/platform/platform.h
index 12120c4ab9..0481b36871 100644
--- a/tensorflow/core/platform/platform.h
+++ b/tensorflow/core/platform/platform.h
@@ -43,10 +43,11 @@ limitations under the License.
#elif defined(__arm__)
#define PLATFORM_POSIX
-// Require an outside macro to tell us if we're building for Raspberry Pi.
-#if !defined(RASPBERRY_PI)
+// Require an outside macro to tell us if we're building for Raspberry Pi or
+// another ARM device that's not a mobile platform.
+#if !defined(RASPBERRY_PI) && !defined(ARM_NON_MOBILE)
#define IS_MOBILE_PLATFORM
-#endif // !defined(RASPBERRY_PI)
+#endif // !defined(RASPBERRY_PI) && !defined(ARM_NON_MOBILE)
#else
// If no platform specified, use:
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index ccab69b9c0..3606c5f127 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -387,7 +387,7 @@ message RunOptions {
// EXPERIMENTAL. Options used to initialize DebuggerState, if enabled.
DebugOptions debug_options = 6;
- // When enabled, causes tensor alllocation information to be included in
+ // When enabled, causes tensor allocation information to be included in
// the error message when the Run() call fails because the allocator ran
// out of memory (OOM).
//
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index 50bfa91267..7405e01e14 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -24,7 +24,7 @@ limitations under the License.
// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
// "-beta", "-rc", "-rc.1")
-#define TF_VERSION_SUFFIX "-rc0"
+#define TF_VERSION_SUFFIX "-rc1"
#define TF_STR_HELPER(x) #x
#define TF_STR(x) TF_STR_HELPER(x)
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index db4c5c35e3..34db96075d 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -1112,9 +1112,11 @@ inline void ForwardMklTensorInToOutWithMklShape(OpKernelContext* context,
// Forward the MKL shape ONLY (used in elementwise and other ops where
// we call the eigen implementation and MKL shape is not used)
inline void ForwardMklMetaDataInToOut(OpKernelContext* context,
- uint idx_data_in, uint idx_data_out) {
- uint idx_meta_in = GetTensorMetaDataIndex(idx_data_in, context->num_inputs());
- uint idx_meta_out =
+ uint32 idx_data_in,
+ uint32_t idx_data_out) {
+ uint32 idx_meta_in =
+ GetTensorMetaDataIndex(idx_data_in, context->num_inputs());
+ uint32 idx_meta_out =
GetTensorMetaDataIndex(idx_data_out, context->num_outputs());
if (IsRefType(context->input_dtype(idx_data_in))) {
@@ -1126,7 +1128,7 @@ inline void ForwardMklMetaDataInToOut(OpKernelContext* context,
// Set a dummy MKL shape (called when the output is in TF format)
inline void SetDummyMklShapeOutput(OpKernelContext* context,
- uint idx_data_out) {
+ uint32 idx_data_out) {
MklShape mkl_shape_output;
mkl_shape_output.SetMklTensor(false);
AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output);
diff --git a/tensorflow/docs_src/about/roadmap.md b/tensorflow/docs_src/about/roadmap.md
index 3ee825ed40..1f934acab6 100644
--- a/tensorflow/docs_src/about/roadmap.md
+++ b/tensorflow/docs_src/about/roadmap.md
@@ -1,37 +1,86 @@
# Roadmap
-**Last updated: January 23, 2017**
+**Last updated: Feb 15, 2018**
-TensorFlow is a fast moving project. In order for the community to better
-understand what the near future will bring, this document shares what we are
-working on internally. Many of these features were requested by the community,
-and we welcome
-[contributions](https://github.com/tensorflow/tensorflow/labels/stat%3Acontributions%20welcome).
+TensorFlow is a rapidly moving, community supported project. This document is intended
+to provide guidance about priorities and focus areas of the core set of TensorFlow
+developers and about functionality that can be expected in the upcoming releases of
+TensorFlow. Many of these areas are driven by community use cases, and we welcome
+further
+[contributions](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md)
+to TensorFlow.
-The features on this list are targeted for the next few months. At this point,
-we do not have timelines for these features.
+The features below do not have concrete release dates. However, the majority can be
+expected in the next one to two releases.
-### Improve non-Python language support
+### APIs
+#### High Level APIs:
+* Easy multi-GPU utilization with Estimators
+* Easy-to-use high-level pre-made estimators for Gradient Boosted Trees, Time Series, and other models
-* Support for adding gradient computation for graphs constructed in other
- languages (C++, Java, Go etc.)
+#### Eager Execution:
+* Efficient utilization of multiple GPUs
+* Distributed training (multi-machine)
+* Performance improvements
+* Simpler export to a GraphDef/SavedModel
-### Making TensorFlow easier to use
-* High-level APIs
-* Well-maintained models showing best practices
+#### Keras API:
+* Better integration with tf.data (ability to call `model.fit` with data tensors)
+* Full support for Eager Execution (both Eager support for the regular Keras API, and ability
+to create Keras models Eager- style via Model subclassing)
+* Better distribution/multi-GPU support and TPU support (including a smoother model-to-estimator workflow)
-### Performance
-* Speed and memory benchmarks
-* Distributed full model benchmarks
-* Performance and memory usage improvements
+#### Official Models:
+* A set of
+[reference models](https://github.com/tensorflow/models/tree/master/official)
+across image recognition, speech, object detection, and
+ translation that demonstrate best practices and serve as a starting point for
+ high-performance model development.
+
+#### Contrib:
+* Deprecation notices added to parts of tf.contrib where preferred implementations exist outside of tf.contrib.
+* As much as possible, large projects inside tf.contrib moved to separate repositories.
+* The tf.contrib module will eventually be discontinued in its current form, experimental development will in future happen in other repositories.
-### Core Features
-* Automatic op placement ([#2126](https://github.com/tensorflow/tensorflow/issues/2126))
-* Support for graph-level functions
+
+#### Probabilistic Reasoning and Statistical Analysis:
+* Rich set of tools for probabilistic and statistical analysis in tf.distributions
+ and tf.probability. These include new samplers, layers, optimizers, losses, and structured models
+* Statistical tools for hypothesis testing, convergence diagnostics, and sample statistics
+* Edward 2.0: High-level API for probabilistic programming
### Platforms
-* OpenCL support ([#22](https://github.com/tensorflow/tensorflow/issues/22))
+#### TensorFlow Lite:
+* Increased coverage of supported ops in TensorFlow Lite
+* Easier conversion of a trained TensorFlow graph for use on TensorFlow Lite
+* Support for GPU acceleration in TensorFlow Lite (iOS and Android)
+* Support for hardware accelerators via Android NeuralNets API
+* Improved CPU performance by quantization and other network optimizations (eg. pruning, distillation)
+* Increased support for devices beyond Android and iOS (eg. RPi, Cortex-M)
+
+### Performance
+#### Distributed TensorFlow:
+* Multi-GPU support optimized for a variety of GPU topologies
+* Improved mechanisms for distributing computations on several machines
+
+#### Optimizations:
+* Mixed precision training support with initial example model and guide
+* Native TensorRT support
+* Int8 support for SkyLake via MKL
+* Dynamic loading of SIMD-optimized kernels
+
+### Documentation and Usability:
+* Updated documentation, tutorials and Getting Started guides
+* Process to enable external contributions to tutorials, documentation, and blogs showcasing best practice use-cases of TensorFlow and high-impact applications
+
+### Community and Partner Engagement
+#### Special Interest Groups:
+* Mobilizing the community to work together in focused domains
+* [tf-distribute](https://groups.google.com/a/tensorflow.org/forum/#!forum/tf-distribute)
+: build and packaging of TensorFlow
+* More to be identified and launched
-### Community
-* More educational resources
-* Better integration of TensorFlow into the opensource big data ecosystem (e.g.
-[#2655](https://github.com/tensorflow/tensorflow/issues/2655))
+#### Community:
+* Incorporate public feedback on significant design decisions via a Request-for-Comment (RFC) process
+* Formalize process for external contributions to land in TensorFlow and associated projects
+* Grow global TensorFlow communities and user groups
+* Collaborate with partners to co-develop and publish research papers
diff --git a/tensorflow/docs_src/about/uses.md b/tensorflow/docs_src/about/uses.md
index 8818177a28..d646880bd3 100644
--- a/tensorflow/docs_src/about/uses.md
+++ b/tensorflow/docs_src/about/uses.md
@@ -22,6 +22,14 @@ This section describes some of the current uses of the TensorFlow system.
> TensorFlow, or even better, send us a pull request to add an entry to this
> file.
+* **Deep Speech**
+<ul>
+ <li>**Organization**: Mozilla</li>
+ <li> **Domain**: Speech Recognition</li>
+ <li> **Description**: A TensorFlow implementation motivated by Baidu's Deep Speech architecture.</li>
+ <li> **More info**: [GitHub Repo](https://github.com/mozilla/deepspeech)</li>
+</ul>
+
* **RankBrain**
<ul>
<li>**Organization**: Google</li>
diff --git a/tensorflow/docs_src/deploy/index.md b/tensorflow/docs_src/deploy/index.md
index 5831960b4f..07b1bc9257 100644
--- a/tensorflow/docs_src/deploy/index.md
+++ b/tensorflow/docs_src/deploy/index.md
@@ -7,6 +7,8 @@ the following documents:
a cluster of TensorFlow servers.
* @{$hadoop$How to run TensorFlow on Hadoop}, which has a highly
self-explanatory title.
+ * @{$s3$How to run TensorFlow with the S3 filesystem}, which explains how
+ to run TensorFlow with the S3 file system.
* The entire document set for [TensorFlow serving](/serving), an open-source,
flexible, high-performance serving system for machine-learned models
designed for production environments. TensorFlow Serving provides
diff --git a/tensorflow/docs_src/deploy/leftnav_files b/tensorflow/docs_src/deploy/leftnav_files
index f8f8d578e6..c682e7add1 100644
--- a/tensorflow/docs_src/deploy/leftnav_files
+++ b/tensorflow/docs_src/deploy/leftnav_files
@@ -1,3 +1,4 @@
index.md
distributed.md
hadoop.md
+s3.md
diff --git a/tensorflow/docs_src/deploy/s3.md b/tensorflow/docs_src/deploy/s3.md
new file mode 100644
index 0000000000..38f8428634
--- /dev/null
+++ b/tensorflow/docs_src/deploy/s3.md
@@ -0,0 +1,40 @@
+# How to run TensorFlow on S3
+
+This document describes how to run TensorFlow on S3 file system.
+
+## S3
+
+We assume that you are familiar with @{$reading_data$reading data}.
+
+To use S3 with TensorFlow, change the file paths you use to read and write
+data to an S3 path. For example:
+
+```python
+filenames = ["s3://bucketname/path/to/file1.tfrecord",
+ "s3://bucketname/path/to/file2.tfrecord"]
+dataset = tf.data.TFRecordDataset(filenames)
+```
+
+When reading or writing data on S3 with your TensorFlow program, the behavior
+could be controlled by various environmental variables:
+
+* **AWS_REGION**: By default, regional endpoint is used for S3, with region
+ controlled by `AWS_REGION`. If `AWS_REGION` is not specified, then
+ `us-east-1` is used.
+* **S3_ENDPOINT**: The endpoint could be overridden explicitly with
+ `S3_ENDPOINT` specified.
+* **S3_USE_HTTPS**: HTTPS is used to access S3 by default, unless
+ `S3_USE_HTTPS=0`.
+* **S3_VERIFY_SSL**: If HTTPS is used, SSL verification could be disabled
+ with `S3_VERIFY_SSL=0`.
+
+To read or write objects in a bucket that is no publicly accessible,
+AWS credentials must be provided through one of the following methods:
+
+* Set credentials in the AWS credentials profile file on the local system,
+ located at: `~/.aws/credentials` on Linux, macOS, or Unix, or
+ `C:\Users\USERNAME\.aws\credentials` on Windows.
+* Set the `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment
+ variables.
+* If TensorFlow is deployed on an EC2 instance, specify an IAM role and then
+ give the EC2 instance access to that role.
diff --git a/tensorflow/docs_src/extend/add_filesys.md b/tensorflow/docs_src/extend/add_filesys.md
index f0591b7b7d..06f11de4eb 100644
--- a/tensorflow/docs_src/extend/add_filesys.md
+++ b/tensorflow/docs_src/extend/add_filesys.md
@@ -81,6 +81,8 @@ filesystem implementations call their existing libraries. Examples include:
plugin](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/hadoop/hadoop_file_system.h)
* [GCS
plugin](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/cloud/gcs_file_system.h)
+* [S3
+ plugin](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/platform/s3/s3_file_system.h)
#### The File interfaces
diff --git a/tensorflow/docs_src/install/install_c.md b/tensorflow/docs_src/install/install_c.md
index 9563eb5017..818798555a 100644
--- a/tensorflow/docs_src/install/install_c.md
+++ b/tensorflow/docs_src/install/install_c.md
@@ -38,7 +38,7 @@ enable TensorFlow for C:
OS="linux" # Change to "darwin" for macOS
TARGET_DIRECTORY="/usr/local"
curl -L \
- "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.6.0-rc0.tar.gz" |
+ "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.6.0-rc1.tar.gz" |
sudo tar -C $TARGET_DIRECTORY -xz
The `tar` command extracts the TensorFlow C library into the `lib`
diff --git a/tensorflow/docs_src/install/install_go.md b/tensorflow/docs_src/install/install_go.md
index f4207debe0..4c6dfa8daf 100644
--- a/tensorflow/docs_src/install/install_go.md
+++ b/tensorflow/docs_src/install/install_go.md
@@ -38,7 +38,7 @@ steps to install this library and enable TensorFlow for Go:
TF_TYPE="cpu" # Change to "gpu" for GPU support
TARGET_DIRECTORY='/usr/local'
curl -L \
- "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.6.0-rc0.tar.gz" |
+ "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.6.0-rc1.tar.gz" |
sudo tar -C $TARGET_DIRECTORY -xz
The `tar` command extracts the TensorFlow C library into the `lib`
diff --git a/tensorflow/docs_src/install/install_java.md b/tensorflow/docs_src/install/install_java.md
index 9a80c18aa5..527884863e 100644
--- a/tensorflow/docs_src/install/install_java.md
+++ b/tensorflow/docs_src/install/install_java.md
@@ -36,7 +36,7 @@ following to the project's `pom.xml` to use the TensorFlow Java APIs:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
- <version>1.6.0-rc0</version>
+ <version>1.6.0-rc1</version>
</dependency>
```
@@ -65,7 +65,7 @@ As an example, these steps will create a Maven project that uses TensorFlow:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
- <version>1.6.0-rc0</version>
+ <version>1.6.0-rc1</version>
</dependency>
</dependencies>
</project>
@@ -123,12 +123,12 @@ instead:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>libtensorflow</artifactId>
- <version>1.6.0-rc0</version>
+ <version>1.6.0-rc1</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>libtensorflow_jni_gpu</artifactId>
- <version>1.6.0-rc0</version>
+ <version>1.6.0-rc1</version>
</dependency>
```
@@ -147,7 +147,7 @@ refer to the simpler instructions above instead.
Take the following steps to install TensorFlow for Java on Linux or macOS:
1. Download
- [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.6.0-rc0.jar),
+ [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.6.0-rc1.jar),
which is the TensorFlow Java Archive (JAR).
2. Decide whether you will run TensorFlow for Java on CPU(s) only or with
@@ -166,7 +166,7 @@ Take the following steps to install TensorFlow for Java on Linux or macOS:
OS=$(uname -s | tr '[:upper:]' '[:lower:]')
mkdir -p ./jni
curl -L \
- "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.6.0-rc0.tar.gz" |
+ "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.6.0-rc1.tar.gz" |
tar -xz -C ./jni
### Install on Windows
@@ -174,10 +174,10 @@ Take the following steps to install TensorFlow for Java on Linux or macOS:
Take the following steps to install TensorFlow for Java on Windows:
1. Download
- [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.6.0-rc0.jar),
+ [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.6.0-rc1.jar),
which is the TensorFlow Java Archive (JAR).
2. Download the following Java Native Interface (JNI) file appropriate for
- [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.6.0-rc0.zip).
+ [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.6.0-rc1.zip).
3. Extract this .zip file.
@@ -225,7 +225,7 @@ must be part of your `classpath`. For example, you can include the
downloaded `.jar` in your `classpath` by using the `-cp` compilation flag
as follows:
-<pre><b>javac -cp libtensorflow-1.6.0-rc0.jar HelloTF.java</b></pre>
+<pre><b>javac -cp libtensorflow-1.6.0-rc1.jar HelloTF.java</b></pre>
### Running
@@ -239,11 +239,11 @@ two files are available to the JVM:
For example, the following command line executes the `HelloTF` program on Linux
and macOS X:
-<pre><b>java -cp libtensorflow-1.6.0-rc0.jar:. -Djava.library.path=./jni HelloTF</b></pre>
+<pre><b>java -cp libtensorflow-1.6.0-rc1.jar:. -Djava.library.path=./jni HelloTF</b></pre>
And the following command line executes the `HelloTF` program on Windows:
-<pre><b>java -cp libtensorflow-1.6.0-rc0.jar;. -Djava.library.path=jni HelloTF</b></pre>
+<pre><b>java -cp libtensorflow-1.6.0-rc1.jar;. -Djava.library.path=jni HelloTF</b></pre>d
If the program prints <tt>Hello from <i>version</i></tt>, you've successfully
installed TensorFlow for Java and are ready to use the API. If the program
diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md
index 105b225177..e3e115d9f6 100644
--- a/tensorflow/docs_src/install/install_linux.md
+++ b/tensorflow/docs_src/install/install_linux.md
@@ -188,7 +188,7 @@ Take the following steps to install TensorFlow with Virtualenv:
Virtualenv environment:
<pre>(tensorflow)$ <b>pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp34-cp34m-linux_x86_64.whl</b></pre>
+ https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc1-cp34-cp34m-linux_x86_64.whl</b></pre>
If you encounter installation problems, see
[Common Installation Problems](#common_installation_problems).
@@ -293,7 +293,7 @@ take the following steps:
<pre>
$ <b>sudo pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp34-cp34m-linux_x86_64.whl</b>
+ https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc1-cp34-cp34m-linux_x86_64.whl</b>
</pre>
If this step fails, see
@@ -480,8 +480,7 @@ Take the following steps to install TensorFlow in an Anaconda environment:
<pre>
(tensorflow)$ <b>pip install --ignore-installed --upgrade \
- https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp34-cp34m-linux_x86_64.whl</b></pre>
-
+ https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc1-cp34-cp34m-linux_x86_64.whl</b></pre>
<a name="ValidateYourInstallation"></a>
## Validate your installation
@@ -648,14 +647,14 @@ This section documents the relevant values for Linux installations.
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc1-cp27-none-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc0-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc1-cp27-none-linux_x86_64.whl
</pre>
Note that GPU support requires the NVIDIA hardware and software described in
@@ -667,14 +666,14 @@ Note that GPU support requires the NVIDIA hardware and software described in
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc1-cp34-cp34m-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc0-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc1-cp34-cp34m-linux_x86_64.whl
</pre>
Note that GPU support requires the NVIDIA hardware and software described in
@@ -686,14 +685,14 @@ Note that GPU support requires the NVIDIA hardware and software described in
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc1-cp35-cp35m-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc0-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc1-cp35-cp35m-linux_x86_64.whl
</pre>
@@ -705,14 +704,14 @@ Note that GPU support requires the NVIDIA hardware and software described in
CPU only:
<pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc1-cp36-cp36m-linux_x86_64.whl
</pre>
GPU support:
<pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc0-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc1-cp36-cp36m-linux_x86_64.whl
</pre>
diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md
index d6df27f8c8..5be38ae1ef 100644
--- a/tensorflow/docs_src/install/install_mac.md
+++ b/tensorflow/docs_src/install/install_mac.md
@@ -119,7 +119,7 @@ Take the following steps to install TensorFlow with Virtualenv:
TensorFlow in the active Virtualenv is as follows:
<pre> $ <b>pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc0-py3-none-any.whl</b></pre>
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc1-py3-none-any.whl</b></pre>
If you encounter installation problems, see
[Common Installation Problems](#common-installation-problems).
@@ -242,7 +242,7 @@ take the following steps:
issue the following command:
<pre> $ <b>sudo pip3 install --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc0-py3-none-any.whl</b> </pre>
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc1-py3-none-any.whl</b> </pre>
If the preceding command fails, see
[installation problems](#common-installation-problems).
@@ -351,7 +351,7 @@ Take the following steps to install TensorFlow in an Anaconda environment:
TensorFlow for Python 2.7:
<pre> (<i>targetDirectory</i>)$ <b>pip install --ignore-installed --upgrade \
- https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc0-py2-none-any.whl</b></pre>
+ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc1-py2-none-any.whl</b></pre>
<a name="ValidateYourInstallation"></a>
@@ -524,7 +524,7 @@ This section documents the relevant values for Mac OS installations.
<pre>
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc0-py2-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc1-py2-none-any.whl
</pre>
@@ -532,5 +532,5 @@ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc0-py2-none-a
<pre>
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc0-py3-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc1-py3-none-any.whl
</pre>
diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md
index 90031b4b5e..8d83e9f119 100644
--- a/tensorflow/docs_src/install/install_sources.md
+++ b/tensorflow/docs_src/install/install_sources.md
@@ -359,10 +359,10 @@ Invoke `pip install` to install that pip package.
The filename of the `.whl` file depends on your platform.
For example, the following command will install the pip package
-for TensorFlow 1.6.0rc0 on Linux:
+for TensorFlow 1.6.0rc1 on Linux:
<pre>
-$ <b>sudo pip install /tmp/tensorflow_pkg/tensorflow-1.6.0rc0-py2-none-any.whl</b>
+$ <b>sudo pip install /tmp/tensorflow_pkg/tensorflow-1.6.0rc1-py2-none-any.whl</b>
</pre>
## Validate your installation
@@ -393,7 +393,7 @@ TensorFlow programs:
<pre>Hello, TensorFlow!</pre>
-If you are new to TensorFlow, see @{$get_started$Getting Started with
+If you are new to TensorFlow, see @{$get_started/premade_estimators$Getting Started with
TensorFlow}.
If the system outputs an error message instead of a greeting, see [Common
@@ -460,8 +460,8 @@ Stack Overflow and specify the `tensorflow` tag.
**Linux**
<table>
<tr><th>Version:</th><th>CPU/GPU:</th><th>Python Version:</th><th>Compiler:</th><th>Build Tools:</th><th>cuDNN:</th><th>CUDA:</th></tr>
-<tr><td>tensorflow-1.6.0rc0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.9.0</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.6.0rc0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.9.0</td><td>7</td><td>9</td></tr>
+<tr><td>tensorflow-1.6.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.9.0</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>tensorflow_gpu-1.6.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.9.0</td><td>7</td><td>9</td></tr>
<tr><td>tensorflow-1.5.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.8.0</td><td>N/A</td><td>N/A</td></tr>
<tr><td>tensorflow_gpu-1.5.0</td><td>GPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.8.0</td><td>7</td><td>9</td></tr>
<tr><td>tensorflow-1.4.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>GCC 4.8</td><td>Bazel 0.5.4</td><td>N/A</td><td>N/A</td></tr>
@@ -479,7 +479,7 @@ Stack Overflow and specify the `tensorflow` tag.
**Mac**
<table>
<tr><th>Version:</th><th>CPU/GPU:</th><th>Python Version:</th><th>Compiler:</th><th>Build Tools:</th><th>cuDNN:</th><th>CUDA:</th></tr>
-<tr><td>tensorflow-1.6.0rc0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.8.1</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>tensorflow-1.6.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.8.1</td><td>N/A</td><td>N/A</td></tr>
<tr><td>tensorflow-1.5.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.8.1</td><td>N/A</td><td>N/A</td></tr>
<tr><td>tensorflow-1.4.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.5.4</td><td>N/A</td><td>N/A</td></tr>
<tr><td>tensorflow-1.3.0</td><td>CPU</td><td>2.7, 3.3-3.6</td><td>Clang from xcode</td><td>Bazel 0.4.5</td><td>N/A</td><td>N/A</td></tr>
@@ -493,8 +493,8 @@ Stack Overflow and specify the `tensorflow` tag.
**Windows**
<table>
<tr><th>Version:</th><th>CPU/GPU:</th><th>Python Version:</th><th>Compiler:</th><th>Build Tools:</th><th>cuDNN:</th><th>CUDA:</th></tr>
-<tr><td>tensorflow-1.6.0rc0</td><td>CPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
-<tr><td>tensorflow_gpu-1.6.0rc0</td><td>GPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>7</td><td>9</td></tr>
+<tr><td>tensorflow-1.6.0</td><td>CPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
+<tr><td>tensorflow_gpu-1.6.0</td><td>GPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>7</td><td>9</td></tr>
<tr><td>tensorflow-1.5.0</td><td>CPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
<tr><td>tensorflow_gpu-1.5.0</td><td>GPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>7</td><td>9</td></tr>
<tr><td>tensorflow-1.4.0</td><td>CPU</td><td>3.5-3.6</td><td>MSVC 2015 update 3</td><td>Cmake v3.6.3</td><td>N/A</td><td>N/A</td></tr>
diff --git a/tensorflow/docs_src/install/install_windows.md b/tensorflow/docs_src/install/install_windows.md
index e020451c04..dedf485f93 100644
--- a/tensorflow/docs_src/install/install_windows.md
+++ b/tensorflow/docs_src/install/install_windows.md
@@ -47,7 +47,7 @@ installed on your system:
If you have a different version of one of the preceding packages, please
change to the specified versions. In particular, the cuDNN version
-must match exactly: TensorFlow will not load if it cannot find `cudnn64_7.dll`.
+must match exactly: TensorFlow will not load if it cannot find `cuDNN64_7.dll`.
To use a different version of cuDNN, you must build from source.
## Determine how to install TensorFlow
@@ -153,7 +153,7 @@ TensorFlow programs:
<pre>Hello, TensorFlow!</pre>
-If you are new to TensorFlow, see @{$get_started$Getting Started with
+If you are new to TensorFlow, see @{$get_started/premade_estimators$Getting Started with
TensorFlow}.
If the system outputs an error message instead of a greeting, see [Common
diff --git a/tensorflow/docs_src/mobile/mobile_intro.md b/tensorflow/docs_src/mobile/mobile_intro.md
index 17dbf1c3e6..69b63ae7d2 100644
--- a/tensorflow/docs_src/mobile/mobile_intro.md
+++ b/tensorflow/docs_src/mobile/mobile_intro.md
@@ -235,7 +235,7 @@ TensorFlow [on Github](https://github.com/tensorflow/models) that you can look
through. Lean towards the simplest model you can find, and try to get started as
soon as you have even a small amount of labelled data, since you’ll get the best
results when you’re able to iterate quickly. The shorter the time it takes to
-try training a model and running it in s real application, the better overall
+try training a model and running it in its real application, the better overall
results you’ll see. It’s common for an algorithm to get great training accuracy
numbers but then fail to be useful within a real application because there’s a
mismatch between the dataset and real usage. Prototype end-to-end usage as soon
diff --git a/tensorflow/docs_src/programmers_guide/low_level_intro.md b/tensorflow/docs_src/programmers_guide/low_level_intro.md
index a8cc0feae3..05709ad10a 100644
--- a/tensorflow/docs_src/programmers_guide/low_level_intro.md
+++ b/tensorflow/docs_src/programmers_guide/low_level_intro.md
@@ -312,7 +312,7 @@ the same input. @{tf.layers$Layers} are the preferred way to add trainable
parameters to a graph.
Layers package together both the variables and the operations that act
-on them, . For example a
+on them. For example a
[densely-connected layer](https://developers.google.com/machine-learning/glossary/#fully_connected_layer)
performs a weighted sum across all inputs
for each output and applies an optional
@@ -495,7 +495,7 @@ good. Here's what we got; your own output will almost certainly differ:
[ 0.10527515]]
```
-### loss
+### Loss
To optimize a model, you first need to define the loss. We'll use the mean
square error, a standard loss for regression problems.
@@ -521,7 +521,7 @@ TensorFlow provides
[**optimizers**](https://developers.google.com/machine-learning/glossary/#optimizer)
implementing standard optimization algorithms. These are implemented as
sub-classes of @{tf.train.Optimizer}. They incrementally change each
-variable in order to minimizethe loss. The simplest optimization algorithm is
+variable in order to minimize the loss. The simplest optimization algorithm is
[**gradient descent**](https://developers.google.com/machine-learning/glossary/#gradient_descent),
implemented by @{tf.train.GradientDescentOptimizer}. It modifies each
variable according to the magnitude of the derivative of loss with respect to
diff --git a/tensorflow/docs_src/tutorials/layers.md b/tensorflow/docs_src/tutorials/layers.md
index b898cbe29c..5111b16247 100644
--- a/tensorflow/docs_src/tutorials/layers.md
+++ b/tensorflow/docs_src/tutorials/layers.md
@@ -635,7 +635,7 @@ should be logged after every 50 steps of training.
### Train the Model
Now we're ready to train our model, which we can do by creating `train_input_fn`
-ans calling `train()` on `mnist_classifier`. Add the following to `main()`:
+and calling `train()` on `mnist_classifier`. Add the following to `main()`:
```python
# Train the model
diff --git a/tensorflow/examples/android/res/animator/color_animation.xml b/tensorflow/examples/android/res/animator/color_animation.xml
new file mode 100644
index 0000000000..891d8cc1d4
--- /dev/null
+++ b/tensorflow/examples/android/res/animator/color_animation.xml
@@ -0,0 +1,30 @@
+<?xml version="1.0" encoding="utf-8"?><!--
+ Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<set xmlns:android="http://schemas.android.com/apk/res/android"
+ android:ordering="sequentially">
+ <objectAnimator
+ android:propertyName="backgroundColor"
+ android:duration="375"
+ android:valueFrom="0x00b3ccff"
+ android:valueTo="0xffb3ccff"
+ android:valueType="colorType"/>
+ <objectAnimator
+ android:propertyName="backgroundColor"
+ android:duration="375"
+ android:valueFrom="0xffb3ccff"
+ android:valueTo="0x00b3ccff"
+ android:valueType="colorType"/>
+</set>
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java
index 184df1bdb4..1cddf3dc55 100644
--- a/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java
@@ -31,7 +31,8 @@ the RecognizeCommands helper class.
package org.tensorflow.demo;
-import android.animation.ValueAnimator;
+import android.animation.AnimatorInflater;
+import android.animation.AnimatorSet;
import android.app.Activity;
import android.content.pm.PackageManager;
import android.media.AudioFormat;
@@ -329,17 +330,13 @@ public class SpeechActivity extends Activity {
labelIndex = i;
}
}
- final View labelView = (View) labelsListView.getChildAt(labelIndex - 2);
- ValueAnimator colorAnimation =
- ValueAnimator.ofArgb(0x00b3ccff, 0xffb3ccff, 0x00b3ccff);
- colorAnimation.setDuration(750);
- colorAnimation.addUpdateListener(
- new ValueAnimator.AnimatorUpdateListener() {
- @Override
- public void onAnimationUpdate(ValueAnimator animator) {
- labelView.setBackgroundColor((int) animator.getAnimatedValue());
- }
- });
+ final View labelView = labelsListView.getChildAt(labelIndex - 2);
+
+ AnimatorSet colorAnimation =
+ (AnimatorSet)
+ AnimatorInflater.loadAnimator(
+ SpeechActivity.this, R.animator.color_animation);
+ colorAnimation.setTarget(labelView);
colorAnimation.start();
}
}
diff --git a/tensorflow/examples/get_started/regression/imports85.py b/tensorflow/examples/get_started/regression/imports85.py
index 6bee556eb8..4fdaceea9a 100644
--- a/tensorflow/examples/get_started/regression/imports85.py
+++ b/tensorflow/examples/get_started/regression/imports85.py
@@ -131,11 +131,12 @@ def dataset(y_name="price", train_fraction=0.7):
# booleans but we are dealing with symbolic tensors.
return ~in_training_set(line)
- base_dataset = (tf.contrib.data
- # Get the lines from the file.
- .TextLineDataset(path)
- # drop lines with question marks.
- .filter(has_no_question_marks))
+ base_dataset = (
+ tf.data
+ # Get the lines from the file.
+ .TextLineDataset(path)
+ # drop lines with question marks.
+ .filter(has_no_question_marks))
train = (base_dataset
# Take only the training-set lines.
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py
index 868310cbc0..25e09fecbf 100644
--- a/tensorflow/examples/image_retraining/retrain.py
+++ b/tensorflow/examples/image_retraining/retrain.py
@@ -41,7 +41,6 @@ The subfolder names are important, since they define what label is applied to
each image, but the filenames themselves don't matter. Once your images are
prepared, you can run the training with a command like this:
-
```bash
bazel build tensorflow/examples/image_retraining:retrain && \
bazel-bin/tensorflow/examples/image_retraining/retrain \
@@ -70,12 +69,14 @@ on resource-limited platforms, you can try the `--architecture` flag with a
Mobilenet model. For example:
Run floating-point version of mobilenet:
+
```bash
python tensorflow/examples/image_retraining/retrain.py \
--image_dir ~/flower_photos --architecture mobilenet_1.0_224
```
Run quantized version of mobilenet:
+
```bash
python tensorflow/examples/image_retraining/retrain.py \
--image_dir ~/flower_photos/ --architecture mobilenet_1.0_224_quantized
@@ -96,6 +97,12 @@ Visualize the summaries with this command:
tensorboard --logdir /tmp/retrain_logs
+To use with Tensorflow Serving:
+
+```bash
+tensorflow_model_server --port=9000 --model_name=inception \
+ --model_base_path=/tmp/saved_models/
+```
"""
from __future__ import absolute_import
from __future__ import division
@@ -1004,6 +1011,45 @@ def add_jpeg_decoding(input_width, input_height, input_depth, input_mean,
return jpeg_data, mul_image
+def export_model(sess, architecture, saved_model_dir):
+ """Exports model for serving.
+
+ Args:
+ sess: Current active TensorFlow Session.
+ architecture: Model architecture.
+ saved_model_dir: Directory in which to save exported model and variables.
+ """
+ if architecture == 'inception_v3':
+ input_tensor = 'DecodeJpeg/contents:0'
+ elif architecture.startswith('mobilenet_'):
+ input_tensor = 'input:0'
+ else:
+ raise ValueError('Unknown architecture', architecture)
+ in_image = sess.graph.get_tensor_by_name(input_tensor)
+ inputs = {'image': tf.saved_model.utils.build_tensor_info(in_image)}
+
+ out_classes = sess.graph.get_tensor_by_name('final_result:0')
+ outputs = {'prediction': tf.saved_model.utils.build_tensor_info(out_classes)}
+
+ signature = tf.saved_model.signature_def_utils.build_signature_def(
+ inputs=inputs,
+ outputs=outputs,
+ method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
+
+ legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
+
+ # Save out the SavedModel.
+ builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
+ builder.add_meta_graph_and_variables(
+ sess, [tf.saved_model.tag_constants.SERVING],
+ signature_def_map={
+ tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
+ signature
+ },
+ legacy_init_op=legacy_init_op)
+ builder.save()
+
+
def main(_):
# Needed to make sure the logging output is visible.
# See https://github.com/tensorflow/tensorflow/issues/3047
@@ -1179,6 +1225,8 @@ def main(_):
with gfile.FastGFile(FLAGS.output_labels, 'w') as f:
f.write('\n'.join(image_lists.keys()) + '\n')
+ export_model(sess, FLAGS.architecture, FLAGS.saved_model_dir)
+
if __name__ == '__main__':
parser = argparse.ArgumentParser()
@@ -1362,5 +1410,10 @@ if __name__ == '__main__':
takes 128x128 images. See https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html
for more information on Mobilenet.\
""")
+ parser.add_argument(
+ '--saved_model_dir',
+ type=str,
+ default='/tmp/saved_models/1/',
+ help='Where to save the exported graph.')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/examples/speech_commands/label_wav_dir.py b/tensorflow/examples/speech_commands/label_wav_dir.py
new file mode 100644
index 0000000000..a34db512dd
--- /dev/null
+++ b/tensorflow/examples/speech_commands/label_wav_dir.py
@@ -0,0 +1,136 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Runs a trained audio graph against WAVE files and reports the results.
+
+The model, labels and .wav files specified in the arguments will be loaded, and
+then the predictions from running the model against the audio data will be
+printed to the console. This is a useful script for sanity checking trained
+models, and as an example of how to use an audio model from Python.
+
+Here's an example of running it:
+
+python tensorflow/examples/speech_commands/label_wav_dir.py \
+--graph=/tmp/my_frozen_graph.pb \
+--labels=/tmp/speech_commands_train/conv_labels.txt \
+--wav_dir=/tmp/speech_dataset/left
+
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import glob
+import sys
+
+import tensorflow as tf
+
+# pylint: disable=unused-import
+from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
+# pylint: enable=unused-import
+
+FLAGS = None
+
+
+def load_graph(filename):
+ """Unpersists graph from file as default graph."""
+ with tf.gfile.FastGFile(filename, 'rb') as f:
+ graph_def = tf.GraphDef()
+ graph_def.ParseFromString(f.read())
+ tf.import_graph_def(graph_def, name='')
+
+
+def load_labels(filename):
+ """Read in labels, one label per line."""
+ return [line.rstrip() for line in tf.gfile.GFile(filename)]
+
+
+def run_graph(wav_dir, labels, input_layer_name, output_layer_name,
+ num_top_predictions):
+ """Runs the audio data through the graph and prints predictions."""
+ with tf.Session() as sess:
+ # Feed the audio data as input to the graph.
+ # predictions will contain a two-dimensional array, where one
+ # dimension represents the input image count, and the other has
+ # predictions per class
+ for wav_path in glob.glob(wav_dir + '/*.wav'):
+ if not wav_path or not tf.gfile.Exists(wav_path):
+ tf.logging.fatal('Audio file does not exist %s', wav_path)
+
+ with open(wav_path, 'rb') as wav_file:
+ wav_data = wav_file.read()
+
+ softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name)
+ predictions, = sess.run(softmax_tensor, {input_layer_name: wav_data})
+
+ # Sort to show labels in order of confidence
+ print('\n%s' % (wav_path.split('/')[-1]))
+ top_k = predictions.argsort()[-num_top_predictions:][::-1]
+ for node_id in top_k:
+ human_string = labels[node_id]
+ score = predictions[node_id]
+ print('%s (score = %.5f)' % (human_string, score))
+
+ return 0
+
+
+def label_wav(wav_dir, labels, graph, input_name, output_name, how_many_labels):
+ """Loads the model and labels, and runs the inference to print predictions."""
+ if not labels or not tf.gfile.Exists(labels):
+ tf.logging.fatal('Labels file does not exist %s', labels)
+
+ if not graph or not tf.gfile.Exists(graph):
+ tf.logging.fatal('Graph file does not exist %s', graph)
+
+ labels_list = load_labels(labels)
+
+ # load graph, which is stored in the default session
+ load_graph(graph)
+
+ run_graph(wav_dir, labels_list, input_name, output_name, how_many_labels)
+
+
+def main(_):
+ """Entry point for script, converts flags to arguments."""
+ label_wav(FLAGS.wav_dir, FLAGS.labels, FLAGS.graph, FLAGS.input_name,
+ FLAGS.output_name, FLAGS.how_many_labels)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--wav_dir', type=str, default='', help='Audio file to be identified.')
+ parser.add_argument(
+ '--graph', type=str, default='', help='Model to use for identification.')
+ parser.add_argument(
+ '--labels', type=str, default='', help='Path to file containing labels.')
+ parser.add_argument(
+ '--input_name',
+ type=str,
+ default='wav_data:0',
+ help='Name of WAVE data input node in model.')
+ parser.add_argument(
+ '--output_name',
+ type=str,
+ default='labels_softmax:0',
+ help='Name of node outputting a prediction in the model.')
+ parser.add_argument(
+ '--how_many_labels',
+ type=int,
+ default=3,
+ help='Number of results to show.')
+
+ FLAGS, unparsed = parser.parse_known_args()
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/examples/speech_commands/train.py b/tensorflow/examples/speech_commands/train.py
index a4e80041f8..07c1919347 100644
--- a/tensorflow/examples/speech_commands/train.py
+++ b/tensorflow/examples/speech_commands/train.py
@@ -357,12 +357,14 @@ if __name__ == '__main__':
'--window_size_ms',
type=float,
default=30.0,
- help='How long each spectrogram timeslice is',)
+ help='How long each spectrogram timeslice is.',
+ )
parser.add_argument(
'--window_stride_ms',
type=float,
default=10.0,
- help='How long each spectrogram timeslice is',)
+ help='How far to move in time between spectogram timeslices.',
+ )
parser.add_argument(
'--dct_coefficient_count',
type=int,
diff --git a/tensorflow/examples/udacity/5_word2vec.ipynb b/tensorflow/examples/udacity/5_word2vec.ipynb
index 18c456cad7..3b43d1fb55 100644
--- a/tensorflow/examples/udacity/5_word2vec.ipynb
+++ b/tensorflow/examples/udacity/5_word2vec.ipynb
@@ -455,7 +455,7 @@
" \n",
" # Compute the similarity between minibatch examples and all embeddings.\n",
" # We use the cosine distance:\n",
- " norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True))\n",
+ " norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keepdims=True))\n",
" normalized_embeddings = embeddings / norm\n",
" valid_embeddings = tf.nn.embedding_lookup(\n",
" normalized_embeddings, valid_dataset)\n",
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index e269b71f2e..1167b3834e 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -1114,7 +1114,7 @@ def _write_dict_to_summary(output_dir,
isinstance(dictionary[key], np.int32) or
isinstance(dictionary[key], int)):
summary_proto.value.add(tag=key, simple_value=int(dictionary[key]))
- elif isinstance(dictionary[key], six.string_types):
+ elif isinstance(dictionary[key], six.binary_type):
try:
summ = summary_pb2.Summary.FromString(dictionary[key])
for i, _ in enumerate(summ.value):
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 7c7d913c32..7a0745b1d0 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -80,18 +80,18 @@ def dummy_model_fn(features, labels, params):
_, _, _ = features, labels, params
-def check_eventfile_for_keyword(keyword, est):
+def check_eventfile_for_keyword(keyword, dir_):
"""Checks event files for the keyword."""
writer_cache.FileWriterCache.clear()
# Get last Event written.
- event_paths = glob.glob(os.path.join(est.model_dir, 'events*'))
+ event_paths = glob.glob(os.path.join(dir_, 'events*'))
last_event = None
for last_event in summary_iterator.summary_iterator(event_paths[-1]):
if last_event.summary is not None:
- if last_event.summary.value:
- if keyword in last_event.summary.value[0].tag:
+ for value in last_event.summary.value:
+ if keyword in value.tag:
return True
return False
@@ -610,7 +610,7 @@ class EstimatorTrainTest(test.TestCase):
# Make sure nothing is stuck in limbo.
writer_cache.FileWriterCache.clear()
- if check_eventfile_for_keyword('loss', est):
+ if check_eventfile_for_keyword('loss', est.model_dir):
return
self.fail('{} should be part of reported summaries.'.format('loss'))
@@ -1290,8 +1290,9 @@ class EstimatorEvaluateTest(test.TestCase):
# Make sure nothing is stuck in limbo.
writer_cache.FileWriterCache.clear()
- # Get last Event written.
- if check_eventfile_for_keyword('image', est):
+ # Get last evaluation Event written.
+ if check_eventfile_for_keyword('image', os.path.join(est.model_dir,
+ 'eval')):
return
self.fail('{} should be part of reported summaries.'.format('image'))
diff --git a/tensorflow/python/framework/common_shapes.py b/tensorflow/python/framework/common_shapes.py
index 3b1092f923..3c5aebbce8 100644
--- a/tensorflow/python/framework/common_shapes.py
+++ b/tensorflow/python/framework/common_shapes.py
@@ -34,7 +34,7 @@ def scalar_shape(unused_op):
def unchanged_shape(op):
- """Shape function for ops that output an tensor like their first input."""
+ """Shape function for ops that output a tensor like their first input."""
return [op.inputs[0].get_shape()]
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index b35cee0111..301a7f682d 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -1458,7 +1458,7 @@ class FunctionInlineControlTest(test.TestCase):
def Cell(v):
# If v is a vector [n, 1], x is a big square matrix.
x = math_ops.tanh(v + array_ops.transpose(v, [1, 0]))
- return math_ops.reduce_sum(x, 1, keep_dims=True)
+ return math_ops.reduce_sum(x, 1, keepdims=True)
@function.Defun(dtype)
def Forward(x):
diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py
index 4231a79b2d..d306d1b8d6 100644
--- a/tensorflow/python/kernel_tests/reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/reduction_ops_test.py
@@ -110,10 +110,10 @@ class ReductionUnknownShape(test.TestCase):
class BaseReductionTest(test.TestCase):
- def _tf_reduce(self, x, reduction_axes, keep_dims):
+ def _tf_reduce(self, x, reduction_axes, keepdims):
raise NotImplementedError()
- def _np_reduce(self, x, reduction_axes, keep_dims):
+ def _np_reduce(self, x, reduction_axes, keepdims):
raise NotImplementedError()
def _makeIncremental(self, shape, dtype):
@@ -128,10 +128,10 @@ class BaseReductionTest(test.TestCase):
data -= 2j * data
return data
- def _compare(self, x, reduction_axes, keep_dims, feed_dict=None):
- np_ans = self._np_reduce(x, reduction_axes, keep_dims)
+ def _compare(self, x, reduction_axes, keepdims, feed_dict=None):
+ np_ans = self._np_reduce(x, reduction_axes, keepdims)
with self.test_session(use_gpu=True) as sess:
- tf_ans = self._tf_reduce(x, reduction_axes, keep_dims)
+ tf_ans = self._tf_reduce(x, reduction_axes, keepdims)
out = sess.run(tf_ans, feed_dict)
self.assertAllClose(np_ans, out)
self.assertShapeEqual(np_ans, tf_ans)
@@ -140,8 +140,8 @@ class BaseReductionTest(test.TestCase):
if reduction_axes is not None and np.shape(reduction_axes) == (1,):
# Test scalar reduction_axes argument
self._compareAll(x, reduction_axes[0])
- self._compare(x, reduction_axes, keep_dims=False, feed_dict=feed_dict)
- self._compare(x, reduction_axes, keep_dims=True, feed_dict=feed_dict)
+ self._compare(x, reduction_axes, keepdims=False, feed_dict=feed_dict)
+ self._compare(x, reduction_axes, keepdims=True, feed_dict=feed_dict)
def _compareAllAxes(self, x, feed_dict=None):
self._compareAll(x, None)
@@ -171,14 +171,14 @@ class BaseReductionTest(test.TestCase):
class SumReductionTest(BaseReductionTest):
- def _tf_reduce(self, x, reduction_axes, keep_dims):
- return math_ops.reduce_sum(x, reduction_axes, keep_dims)
+ def _tf_reduce(self, x, reduction_axes, keepdims):
+ return math_ops.reduce_sum(x, reduction_axes, keepdims)
- def _np_reduce(self, x, reduction_axes, keep_dims):
+ def _np_reduce(self, x, reduction_axes, keepdims):
if isinstance(reduction_axes, list) or isinstance(reduction_axes,
np.ndarray):
reduction_axes = tuple(reduction_axes)
- return np.sum(x, axis=reduction_axes, keepdims=keep_dims)
+ return np.sum(x, axis=reduction_axes, keepdims=keepdims)
def testAxesType(self):
for dtype in [dtypes.int64, dtypes.int32]:
@@ -298,7 +298,7 @@ class SumReductionTest(BaseReductionTest):
c_known_rank = array_ops.placeholder(dtypes.float32)
c_known_rank.set_shape(tensor_shape.unknown_shape(ndims=3))
s_known_rank = math_ops.reduce_sum(
- c_known_rank, reduction_axes, keep_dims=True)
+ c_known_rank, reduction_axes, keepdims=True)
self.assertEqual(3, s_known_rank.get_shape().ndims)
np_input = np.random.randn(3, 3, 3)
@@ -308,11 +308,11 @@ class SumReductionTest(BaseReductionTest):
unknown_indices = array_ops.placeholder(dtypes.int32)
c_unknown_indices = constant_op.constant([[10.0], [20.0]])
s_unknown_indices = math_ops.reduce_sum(
- c_unknown_indices, unknown_indices, keep_dims=False)
+ c_unknown_indices, unknown_indices, keepdims=False)
self.assertEqual(tensor_shape.unknown_shape(),
s_unknown_indices.get_shape())
s_unknown_indices_keep = math_ops.reduce_sum(
- c_unknown_indices, unknown_indices, keep_dims=True)
+ c_unknown_indices, unknown_indices, keepdims=True)
self.assertEqual(2, s_unknown_indices_keep.get_shape().ndims)
def testWrongShapeForReductionIndices(self):
@@ -372,10 +372,10 @@ class SumReductionTest(BaseReductionTest):
class MeanReductionTest(BaseReductionTest):
- def _tf_reduce(self, x, reduction_axes, keep_dims):
- return math_ops.reduce_mean(x, reduction_axes, keep_dims)
+ def _tf_reduce(self, x, reduction_axes, keepdims):
+ return math_ops.reduce_mean(x, reduction_axes, keepdims)
- def _np_reduce(self, x, reduction_axes, keep_dims):
+ def _np_reduce(self, x, reduction_axes, keepdims):
if isinstance(reduction_axes, list) or isinstance(reduction_axes,
np.ndarray):
reduction_axes = tuple(reduction_axes)
@@ -389,7 +389,7 @@ class MeanReductionTest(BaseReductionTest):
# np.mean automatically converts integer inputs to float, while TensorFlow's
# reduce_mean does not. For integer inputs, we emulate TensorFlow's behavior
# using np.sum and truncating division.
- np_sum = np.sum(x, axis=reduction_axes, keepdims=keep_dims)
+ np_sum = np.sum(x, axis=reduction_axes, keepdims=keepdims)
if np.issubdtype(x.dtype, np.integer):
return np_sum // count
return np_sum / count
@@ -458,14 +458,14 @@ class MeanReductionTest(BaseReductionTest):
class ProdReductionTest(BaseReductionTest):
- def _tf_reduce(self, x, reduction_axes, keep_dims):
- return math_ops.reduce_prod(x, reduction_axes, keep_dims)
+ def _tf_reduce(self, x, reduction_axes, keepdims):
+ return math_ops.reduce_prod(x, reduction_axes, keepdims)
- def _np_reduce(self, x, reduction_axes, keep_dims):
+ def _np_reduce(self, x, reduction_axes, keepdims):
if isinstance(reduction_axes, list) or isinstance(reduction_axes,
np.ndarray):
reduction_axes = tuple(reduction_axes)
- return np.prod(x, axis=reduction_axes, keepdims=keep_dims)
+ return np.prod(x, axis=reduction_axes, keepdims=keepdims)
def testAxesType(self):
for dtype in [dtypes.int64, dtypes.int32]:
@@ -549,17 +549,17 @@ class ProdReductionTest(BaseReductionTest):
class MinReductionTest(test.TestCase):
- def _compare(self, x, reduction_axes, keep_dims, use_gpu=False):
+ def _compare(self, x, reduction_axes, keepdims, use_gpu=False):
np_ans = x
if reduction_axes is None:
- np_ans = np.amin(np_ans, keepdims=keep_dims)
+ np_ans = np.amin(np_ans, keepdims=keepdims)
else:
for ra in reduction_axes[::-1]:
- np_ans = np.amin(np_ans, axis=ra, keepdims=keep_dims)
+ np_ans = np.amin(np_ans, axis=ra, keepdims=keepdims)
with self.test_session(use_gpu=use_gpu):
if reduction_axes is not None:
reduction_axes = np.array(reduction_axes).astype(np.int32)
- tf_ans = math_ops.reduce_min(x, reduction_axes, keep_dims)
+ tf_ans = math_ops.reduce_min(x, reduction_axes, keepdims)
out = tf_ans.eval()
self.assertAllClose(np_ans, out)
self.assertShapeEqual(np_ans, tf_ans)
@@ -662,17 +662,17 @@ class MinReductionTest(test.TestCase):
class MaxReductionTest(test.TestCase):
- def _compare(self, x, reduction_axes, keep_dims, use_gpu=False):
+ def _compare(self, x, reduction_axes, keepdims, use_gpu=False):
np_ans = x
if reduction_axes is None:
- np_ans = np.amax(np_ans, keepdims=keep_dims)
+ np_ans = np.amax(np_ans, keepdims=keepdims)
else:
for ra in reduction_axes[::-1]:
- np_ans = np.amax(np_ans, axis=ra, keepdims=keep_dims)
+ np_ans = np.amax(np_ans, axis=ra, keepdims=keepdims)
with self.test_session(use_gpu=use_gpu):
if reduction_axes is not None:
reduction_axes = np.array(reduction_axes).astype(np.int32)
- tf_ans = math_ops.reduce_max(x, reduction_axes, keep_dims)
+ tf_ans = math_ops.reduce_max(x, reduction_axes, keepdims)
out = tf_ans.eval()
self.assertAllClose(np_ans, out)
self.assertShapeEqual(np_ans, tf_ans)
@@ -789,17 +789,17 @@ class MaxReductionTest(test.TestCase):
class AllReductionTest(test.TestCase):
- def _compare(self, x, reduction_axes, keep_dims, use_gpu=False):
+ def _compare(self, x, reduction_axes, keepdims, use_gpu=False):
np_ans = x
if reduction_axes is None:
- np_ans = np.all(np_ans, keepdims=keep_dims)
+ np_ans = np.all(np_ans, keepdims=keepdims)
else:
for ra in reduction_axes[::-1]:
- np_ans = np.all(np_ans, axis=ra, keepdims=keep_dims)
+ np_ans = np.all(np_ans, axis=ra, keepdims=keepdims)
with self.test_session(use_gpu=use_gpu):
if reduction_axes is not None:
reduction_axes = np.array(reduction_axes).astype(np.int32)
- tf_ans = math_ops.reduce_all(x, reduction_axes, keep_dims)
+ tf_ans = math_ops.reduce_all(x, reduction_axes, keepdims)
out = tf_ans.eval()
self.assertAllEqual(np_ans, out)
self.assertShapeEqual(np_ans, tf_ans)
@@ -838,17 +838,17 @@ class AllReductionTest(test.TestCase):
class AnyReductionTest(test.TestCase):
- def _compare(self, x, reduction_axes, keep_dims, use_gpu=False):
+ def _compare(self, x, reduction_axes, keepdims, use_gpu=False):
np_ans = x
if reduction_axes is None:
- np_ans = np.any(np_ans, keepdims=keep_dims)
+ np_ans = np.any(np_ans, keepdims=keepdims)
else:
for ra in reduction_axes[::-1]:
- np_ans = np.any(np_ans, axis=ra, keepdims=keep_dims)
+ np_ans = np.any(np_ans, axis=ra, keepdims=keepdims)
with self.test_session(use_gpu=use_gpu):
if reduction_axes is not None:
reduction_axes = np.array(reduction_axes).astype(np.int32)
- tf_ans = math_ops.reduce_any(x, reduction_axes, keep_dims)
+ tf_ans = math_ops.reduce_any(x, reduction_axes, keepdims)
out = tf_ans.eval()
self.assertAllEqual(np_ans, out)
self.assertShapeEqual(np_ans, tf_ans)
@@ -887,21 +887,17 @@ class AnyReductionTest(test.TestCase):
class CountNonzeroReductionTest(test.TestCase):
- def _compare(self,
- x,
- reduction_axes,
- keep_dims,
- use_gpu=False,
+ def _compare(self, x, reduction_axes, keepdims, use_gpu=False,
feed_dict=None):
np_ans = (x != 0).astype(np.int32)
if reduction_axes is None:
- np_ans = np.sum(np_ans, keepdims=keep_dims)
+ np_ans = np.sum(np_ans, keepdims=keepdims)
else:
reduction_axes = np.array(reduction_axes).astype(np.int32)
for ra in reduction_axes.ravel()[::-1]:
- np_ans = np.sum(np_ans, axis=ra, keepdims=keep_dims)
+ np_ans = np.sum(np_ans, axis=ra, keepdims=keepdims)
with self.test_session(use_gpu=use_gpu) as sess:
- tf_ans = math_ops.count_nonzero(x, reduction_axes, keep_dims)
+ tf_ans = math_ops.count_nonzero(x, reduction_axes, keepdims)
out = sess.run(tf_ans, feed_dict)
self.assertAllClose(np_ans, out)
self.assertShapeEqual(np_ans, tf_ans)
diff --git a/tensorflow/python/kernel_tests/reduction_ops_test_big.py b/tensorflow/python/kernel_tests/reduction_ops_test_big.py
index 0959adb026..d70360775a 100644
--- a/tensorflow/python/kernel_tests/reduction_ops_test_big.py
+++ b/tensorflow/python/kernel_tests/reduction_ops_test_big.py
@@ -27,24 +27,24 @@ from tensorflow.python.platform import test
class BaseReductionTest(test.TestCase):
- def _tf_reduce(self, x, reduction_axes, keep_dims):
+ def _tf_reduce(self, x, reduction_axes, keepdims):
raise NotImplementedError()
class BigReductionTest(BaseReductionTest):
"""Test reductions for sum and boolean all over a wide range of shapes."""
- def _tf_reduce_max(self, x, reduction_axes, keep_dims):
- return math_ops.reduce_max(x, reduction_axes, keep_dims)
+ def _tf_reduce_max(self, x, reduction_axes, keepdims):
+ return math_ops.reduce_max(x, reduction_axes, keepdims)
- def _tf_reduce_all(self, x, reduction_axes, keep_dims):
- return math_ops.reduce_all(x, reduction_axes, keep_dims)
+ def _tf_reduce_all(self, x, reduction_axes, keepdims):
+ return math_ops.reduce_all(x, reduction_axes, keepdims)
- def _tf_reduce_mean(self, x, reduction_axes, keep_dims):
- return math_ops.reduce_mean(x, reduction_axes, keep_dims)
+ def _tf_reduce_mean(self, x, reduction_axes, keepdims):
+ return math_ops.reduce_mean(x, reduction_axes, keepdims)
- def _tf_reduce_sum(self, x, reduction_axes, keep_dims):
- return math_ops.reduce_sum(x, reduction_axes, keep_dims)
+ def _tf_reduce_sum(self, x, reduction_axes, keepdims):
+ return math_ops.reduce_sum(x, reduction_axes, keepdims)
def testFloat32Sum(self):
# make sure we test all possible kernel invocations
diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py
index ec4fca78f0..6970bf9234 100644
--- a/tensorflow/python/layers/core.py
+++ b/tensorflow/python/layers/core.py
@@ -36,6 +36,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import standard_ops
from tensorflow.python.util.tf_export import tf_export
@@ -291,13 +292,7 @@ class Dropout(base.Layer):
# shapes with dynamically sized inputs.
if self.noise_shape is None:
return self.noise_shape
-
- symbolic_shape = array_ops.shape(inputs)
- noise_shape = [
- symbolic_shape[axis] if shape is None else shape
- for axis, shape in enumerate(self.noise_shape)
- ]
- return noise_shape
+ return nn_ops._get_noise_shape(inputs, self.noise_shape)
def call(self, inputs, training=False):
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index 323a9f8ee3..d83292b809 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -94,8 +94,8 @@ class BatchNormalization(base.Layer):
and should be neither too small (which would add noise) nor too large
(which would give stale estimates). Note that `momentum` is still applied
to get the means and variances for inference.
- fused: if `True`, use a faster, fused implementation if possible.
- If `None`, use the system recommended implementation.
+ fused: if `None` or `True`, use a faster, fused implementation if possible.
+ If `False`, use the system recommended implementation.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
virtual_batch_size: An `int`. By default, `virtual_batch_size` is `None`,
@@ -729,8 +729,8 @@ def batch_normalization(inputs,
and should be neither too small (which would add noise) nor too large
(which would give stale estimates). Note that `momentum` is still applied
to get the means and variances for inference.
- fused: if `True`, use a faster, fused implementation if possible.
- If `None`, use the system recommended implementation.
+ fused: if `None` or `True`, use a faster, fused implementation if possible.
+ If `False`, use the system recommended implementation.
virtual_batch_size: An `int`. By default, `virtual_batch_size` is `None`,
which means batch normalization is performed across the whole batch. When
`virtual_batch_size` is not `None`, instead perform "Ghost Batch
diff --git a/tensorflow/python/layers/utils.py b/tensorflow/python/layers/utils.py
index 1195284024..484c6fc466 100644
--- a/tensorflow/python/layers/utils.py
+++ b/tensorflow/python/layers/utils.py
@@ -179,73 +179,56 @@ def deconv_output_length(input_length, filter_size, padding, stride):
return input_length
-def smart_cond(pred, fn1, fn2, name=None):
- """Return either `fn1()` or `fn2()` based on the boolean predicate `pred`.
+def smart_cond(pred, true_fn=None, false_fn=None, name=None):
+ """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
- If `pred` is a bool or has a constant value, we return either `fn1()`
- or `fn2()`, otherwise we use `tf.cond` to dynamically route to both.
+ If `pred` is a bool or has a constant value, we return either `true_fn()`
+ or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
Arguments:
- pred: A scalar determining whether to return the result of `fn1` or `fn2`.
- fn1: The callable to be performed if pred is true.
- fn2: The callable to be performed if pred is false.
+ pred: A scalar determining whether to return the result of `true_fn` or
+ `false_fn`.
+ true_fn: The callable to be performed if pred is true.
+ false_fn: The callable to be performed if pred is false.
name: Optional name prefix when using `tf.cond`.
Returns:
- Tensors returned by the call to either `fn1` or `fn2`.
+ Tensors returned by the call to either `true_fn` or `false_fn`.
Raises:
- TypeError: If `fn1` or `fn2` is not callable.
+ TypeError: If `true_fn` or `false_fn` is not callable.
"""
- if not callable(fn1):
- raise TypeError('`fn1` must be callable.')
- if not callable(fn2):
- raise TypeError('`fn2` must be callable.')
-
- if context.in_eager_mode():
- if pred:
- return fn1()
- else:
- return fn2()
-
- pred_value = constant_value(pred)
- if pred_value is not None:
- if pred_value:
- return fn1()
- else:
- return fn2()
- else:
- return control_flow_ops.cond(pred, true_fn=fn1, false_fn=fn2, name=name)
+ if isinstance(pred, variables.Variable):
+ return control_flow_ops.cond(
+ pred, true_fn=true_fn, false_fn=false_fn, name=name)
+ return control_flow_ops.smart_cond(
+ pred, true_fn=true_fn, false_fn=false_fn, name=name)
def constant_value(pred):
"""Return the bool value for `pred`, or None if `pred` had a dynamic value.
- Arguments:
- pred: A scalar, either a Python bool or a TensorFlow boolean variable
- or tensor, or the Python integer 1 or 0.
+ Arguments:
+ pred: A scalar, either a Python bool or a TensorFlow boolean variable
+ or tensor, or the Python integer 1 or 0.
- Returns:
- True or False if `pred` has a constant boolean value, None otherwise.
+ Returns:
+ True or False if `pred` has a constant boolean value, None otherwise.
- Raises:
- TypeError: If `pred` is not a Variable, Tensor or bool.
- """
+ Raises:
+ TypeError: If `pred` is not a Variable, Tensor or bool, or Python
+ interger 1 or 0.
+ """
# Allow integer booleans.
- if pred == 0:
- pred = False
- elif pred == 1:
- pred = True
-
- if isinstance(pred, bool):
- pred_value = pred
- elif isinstance(pred, variables.Variable):
- pred_value = None
- elif isinstance(pred, ops.Tensor):
- pred_value = tensor_util.constant_value(pred)
- else:
- raise TypeError('`pred` must be a Tensor, a Variable, or a Python bool.')
- return pred_value
+ if isinstance(pred, int):
+ if pred == 1:
+ pred = True
+ elif pred == 0:
+ pred = False
+
+ if isinstance(pred, variables.Variable):
+ return None
+ return control_flow_ops.smart_constant_value(pred)
def object_list_uid(object_list):
diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py
index dd8c33247c..49f8c66531 100644
--- a/tensorflow/python/ops/clip_ops.py
+++ b/tensorflow/python/ops/clip_ops.py
@@ -110,7 +110,7 @@ def clip_by_norm(t, clip_norm, axes=None, name=None):
t = ops.convert_to_tensor(t, name="t")
# Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm
- l2norm = math_ops.sqrt(math_ops.reduce_sum(t * t, axes, keep_dims=True))
+ l2norm = math_ops.sqrt(math_ops.reduce_sum(t * t, axes, keepdims=True))
intermediate = t * clip_norm
# Assert that the shape is compatible with the initial shape,
# to prevent unintentional broadcasting.
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index a2d605532a..b4bfc0fe47 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -23,6 +23,7 @@ See the @{$python/control_flow_ops} guide.
@@no_op
@@count_up_to
@@cond
+@@smart_cond
@@case
@@while_loop
@@logical_and
@@ -2129,6 +2130,61 @@ def cond(pred,
# pylint: enable=redefined-outer-name
+def smart_cond(pred, true_fn=None, false_fn=None, name=None):
+ """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
+
+ If `pred` is a bool or has a constant value, we return either `true_fn()`
+ or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
+
+ Arguments:
+ pred: A scalar determining whether to return the result of `true_fn` or
+ `false_fn`.
+ true_fn: The callable to be performed if pred is true.
+ false_fn: The callable to be performed if pred is false.
+ name: Optional name prefix when using `tf.cond`.
+
+ Returns:
+ Tensors returned by the call to either `true_fn` or `false_fn`.
+
+ Raises:
+ TypeError: If `true_fn` or `false_fn` is not callable.
+ """
+ if not callable(true_fn):
+ raise TypeError("`true_fn` must be callable.")
+ if not callable(false_fn):
+ raise TypeError("`false_fn` must be callable.")
+
+ pred_value = smart_constant_value(pred)
+ if pred_value is not None:
+ if pred_value:
+ return true_fn()
+ else:
+ return false_fn()
+ else:
+ return cond(pred, true_fn=true_fn, false_fn=false_fn, name=name)
+
+
+def smart_constant_value(pred):
+ """Return the bool value for `pred`, or None if `pred` had a dynamic value.
+
+ Arguments:
+ pred: A scalar, either a Python bool or tensor.
+
+ Returns:
+ True or False if `pred` has a constant boolean value, None otherwise.
+
+ Raises:
+ TypeError: If `pred` is not a Tensor or bool.
+ """
+ if isinstance(pred, bool):
+ pred_value = pred
+ elif isinstance(pred, ops.Tensor):
+ pred_value = tensor_util.constant_value(pred)
+ else:
+ raise TypeError("`pred` must be a Tensor or a Python bool.")
+ return pred_value
+
+
def _resource_safe_shape(t):
"""Returns the shape of t or the variable it points to."""
if t.dtype == dtypes.resource:
@@ -3126,6 +3182,43 @@ def while_loop(cond,
shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
```
+ Example which demonstrates non-strict semantics: In the following
+ example, the final value of the counter `i` does not depend on `x`. So
+ the `while_loop` can increment the counter parallel to updates of `x`.
+ However, because the loop counter at one loop iteration depends
+ on the value at the previous iteration, the loop counter itself cannot
+ be incremented in parallel. Hence if we just want the final value of the
+ counter (which we print on the line `print(sess.run(i))`), then
+ `x` will never be incremented, but the counter will be updated on a
+ single thread. Conversely, if we want the value of the output (which we
+ print on the line `print(sess.run(out).shape)`), then the counter may be
+ incremented on its own thread, while `x` can be incremented in
+ parallel on a separate thread. In the extreme case, it is conceivable
+ that the thread incrementing the counter runs until completion before
+ `x` is incremented even a single time. The only thing that can never
+ happen is that the thread updating `x` can never get ahead of the
+ counter thread because the thread incrementing `x` depends on the value
+ of the counter.
+ ```python
+ import tensorflow as tf
+
+ n = 10000
+ x = tf.constant(list(range(n)))
+ c = lambda i, x: i < n
+ b = lambda i, x: (tf.Print(i + 1, [i]), tf.Print(x + 1, [i], "x:"))
+ i, out = tf.while_loop(c, b, (0, x))
+ with tf.Session() as sess:
+ print(sess.run(i)) # prints [0] ... [9999]
+
+ # The following line may increment the counter and x in parallel.
+ # The counter thread may get ahead of the other thread, but not the
+ # other way around. So you may see things like
+ # [9996] x:[9987]
+ # meaning that the counter thread is on iteration 9996,
+ # while the other thread is on iteration 9987
+ print(sess.run(out).shape)
+ ```
+
"""
with ops.name_scope(name, "while", loop_vars):
if not loop_vars:
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py
index f22f3059d1..adc8c51e11 100644
--- a/tensorflow/python/ops/control_flow_ops_test.py
+++ b/tensorflow/python/ops/control_flow_ops_test.py
@@ -350,6 +350,42 @@ class SwitchTestCase(test_util.TensorFlowTestCase):
@test_util.with_c_api
+class SmartCondTest(test_util.TensorFlowTestCase):
+
+ def testSmartCondTrue(self):
+ with ops.Graph().as_default():
+ with session.Session():
+ x = constant_op.constant(2)
+ y = constant_op.constant(5)
+ z = control_flow_ops.smart_cond(True, lambda: math_ops.multiply(x, 16),
+ lambda: math_ops.multiply(y, 5))
+ self.assertEqual(z.eval(), 32)
+
+ def testSmartCondFalse(self):
+ with ops.Graph().as_default():
+ with session.Session():
+ x = constant_op.constant(4)
+ y = constant_op.constant(3)
+ z = control_flow_ops.smart_cond(False, lambda: math_ops.multiply(x, 16),
+ lambda: math_ops.multiply(y, 3))
+ self.assertEqual(z.eval(), 9)
+
+ def testSmartCondMissingArg1(self):
+ with ops.Graph().as_default():
+ with session.Session():
+ x = constant_op.constant(1)
+ with self.assertRaises(TypeError):
+ control_flow_ops.smart_cond(True, false_fn=lambda: x)
+
+ def testSmartCondMissingArg2(self):
+ with ops.Graph().as_default():
+ with session.Session():
+ x = constant_op.constant(1)
+ with self.assertRaises(TypeError):
+ control_flow_ops.smart_cond(True, lambda: x)
+
+
+@test_util.with_c_api
class CondTest(test_util.TensorFlowTestCase):
def testCondTrue(self):
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 95e45bff06..03ed537cfc 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -474,7 +474,7 @@ class QueueBase(object):
name: A name for the operation (optional).
Returns:
- The tuple of concatenated tensors that was dequeued.
+ The list of concatenated tensors that was dequeued.
"""
if name is None:
name = "%s_DequeueMany" % self._name
diff --git a/tensorflow/python/ops/distributions/multinomial.py b/tensorflow/python/ops/distributions/multinomial.py
index 26b5c5aef9..4ae67a009b 100644
--- a/tensorflow/python/ops/distributions/multinomial.py
+++ b/tensorflow/python/ops/distributions/multinomial.py
@@ -238,7 +238,7 @@ class Multinomial(distribution.Distribution):
n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32)
k = self.event_shape_tensor()[0]
- # boardcast the total_count and logits to same shape
+ # broadcast the total_count and logits to same shape
n_draws = array_ops.ones_like(
self.logits[..., 0], dtype=n_draws.dtype) * n_draws
logits = array_ops.ones_like(
diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py
index 5bc25128a8..0fe6aa30f9 100644
--- a/tensorflow/python/ops/distributions/util.py
+++ b/tensorflow/python/ops/distributions/util.py
@@ -1041,14 +1041,14 @@ def reduce_weighted_logsumexp(
with ops.name_scope(name, "reduce_weighted_logsumexp", [logx, w]):
logx = ops.convert_to_tensor(logx, name="logx")
if w is None:
- lswe = math_ops.reduce_logsumexp(logx, axis=axis, keep_dims=keep_dims)
+ lswe = math_ops.reduce_logsumexp(logx, axis=axis, keepdims=keep_dims)
if return_sign:
sgn = array_ops.ones_like(lswe)
return lswe, sgn
return lswe
w = ops.convert_to_tensor(w, dtype=logx.dtype, name="w")
log_absw_x = logx + math_ops.log(math_ops.abs(w))
- max_log_absw_x = math_ops.reduce_max(log_absw_x, axis=axis, keep_dims=True)
+ max_log_absw_x = math_ops.reduce_max(log_absw_x, axis=axis, keepdims=True)
# If the largest element is `-inf` or `inf` then we don't bother subtracting
# off the max. We do this because otherwise we'd get `inf - inf = NaN`. That
# this is ok follows from the fact that we're actually free to subtract any
@@ -1060,9 +1060,7 @@ def reduce_weighted_logsumexp(
wx_over_max_absw_x = (
math_ops.sign(w) * math_ops.exp(log_absw_x - max_log_absw_x))
sum_wx_over_max_absw_x = math_ops.reduce_sum(
- wx_over_max_absw_x,
- axis=axis,
- keep_dims=keep_dims)
+ wx_over_max_absw_x, axis=axis, keepdims=keep_dims)
if not keep_dims:
max_log_absw_x = array_ops.squeeze(max_log_absw_x, axis)
sgn = math_ops.sign(sum_wx_over_max_absw_x)
@@ -1180,8 +1178,7 @@ def process_quadrature_grid_and_probs(
grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype)
probs = ops.convert_to_tensor(probs, name="unnormalized_probs",
dtype=dtype)
- probs /= linalg_ops.norm(probs, ord=1, axis=-1, keep_dims=True,
- name="probs")
+ probs /= linalg_ops.norm(probs, ord=1, axis=-1, keepdims=True, name="probs")
def _static_event_size(x):
"""Returns the static size of a specific dimension or `None`."""
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index bcd9e5683a..53bd108c44 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -167,6 +167,28 @@ def _Assert3DImage(image):
_Check3DImage(image, require_static=False), image)
+def _AssertAtLeast3DImage(image):
+ """Assert that we are working with a properly shaped image.
+
+ Performs the check statically if possible (i.e. if the shape
+ is statically known). Otherwise adds a control dependency
+ to an assert op that checks the dynamic shape.
+
+ Args:
+ image: >= 3-D Tensor of size [*, height, width, depth]
+
+ Raises:
+ ValueError: if image.shape is not a [>= 3] vector.
+
+ Returns:
+ If the shape of `image` could be verified statically, `image` is
+ returned unchanged, otherwise there will be a control dependency
+ added that asserts the correct dynamic shape.
+ """
+ return control_flow_ops.with_dependencies(
+ _CheckAtLeast3DImage(image, require_static=False), image)
+
+
def _CheckAtLeast3DImage(image, require_static=True):
"""Assert that we are working with properly shaped image.
@@ -292,108 +314,187 @@ def random_flip_left_right(image, seed=None):
def flip_left_right(image):
"""Flip an image horizontally (left to right).
- Outputs the contents of `image` flipped along the second dimension, which is
- `width`.
+ Outputs the contents of `image` flipped along the width dimension.
See also `reverse()`.
Args:
- image: A 3-D tensor of shape `[height, width, channels].`
+ image: 4-D Tensor of shape `[batch, height, width, channels]` or
+ 3-D Tensor of shape `[height, width, channels]`.
Returns:
- A 3-D tensor of the same type and shape as `image`.
+ A tensor of the same type and shape as `image`.
Raises:
ValueError: if the shape of `image` not supported.
"""
- with ops.name_scope(None, 'flip_left_right', [image]) as scope:
+ with ops.name_scope(None, 'flip_left_right', [image]):
image = ops.convert_to_tensor(image, name='image')
- image = _Assert3DImage(image)
- return fix_image_flip_shape(image, array_ops.reverse(
- image, [1], name=scope))
+ image = _AssertAtLeast3DImage(image)
+ shape = image.get_shape()
+ if shape.ndims == 3 or shape.ndims is None:
+ return fix_image_flip_shape(image, array_ops.reverse(image, [1]))
+ elif shape.ndims == 4:
+ return array_ops.reverse(image, [2])
+ else:
+ raise ValueError('\'image\' must have either 3 or 4 dimensions.')
@tf_export('image.flip_up_down')
def flip_up_down(image):
"""Flip an image vertically (upside down).
- Outputs the contents of `image` flipped along the first dimension, which is
- `height`.
+ Outputs the contents of `image` flipped along the height dimension.
See also `reverse()`.
Args:
- image: A 3-D tensor of shape `[height, width, channels].`
+ image: 4-D Tensor of shape `[batch, height, width, channels]` or
+ 3-D Tensor of shape `[height, width, channels]`.
Returns:
- A 3-D tensor of the same type and shape as `image`.
+ A tensor of the same type and shape as `image`.
Raises:
ValueError: if the shape of `image` not supported.
"""
- with ops.name_scope(None, 'flip_up_down', [image]) as scope:
+ with ops.name_scope(None, 'flip_up_down', [image]):
image = ops.convert_to_tensor(image, name='image')
- image = _Assert3DImage(image)
- return fix_image_flip_shape(image, array_ops.reverse(
- image, [0], name=scope))
+ image = _AssertAtLeast3DImage(image)
+ shape = image.get_shape()
+ if shape.ndims == 3 or shape.ndims is None:
+ return fix_image_flip_shape(image, array_ops.reverse(image, [0]))
+ elif shape.ndims == 4:
+ return array_ops.reverse(image, [1])
+ else:
+ raise ValueError('\'image\' must have either 3 or 4 dimensions.')
@tf_export('image.rot90')
def rot90(image, k=1, name=None):
- """Rotate an image counter-clockwise by 90 degrees.
+ """Rotate image(s) counter-clockwise by 90 degrees.
Args:
- image: A 3-D tensor of shape `[height, width, channels]`.
+ image: 4-D Tensor of shape `[batch, height, width, channels]` or
+ 3-D Tensor of shape `[height, width, channels]`.
k: A scalar integer. The number of times the image is rotated by 90 degrees.
name: A name for this operation (optional).
Returns:
- A rotated 3-D tensor of the same type and shape as `image`.
+ A rotated tensor of the same type and shape as `image`.
+
+ Raises:
+ ValueError: if the shape of `image` not supported.
"""
with ops.name_scope(name, 'rot90', [image, k]) as scope:
image = ops.convert_to_tensor(image, name='image')
- image = _Assert3DImage(image)
+ image = _AssertAtLeast3DImage(image)
k = ops.convert_to_tensor(k, dtype=dtypes.int32, name='k')
k.get_shape().assert_has_rank(0)
k = math_ops.mod(k, 4)
- def _rot90():
- return array_ops.transpose(array_ops.reverse_v2(image, [1]), [1, 0, 2])
+ shape = image.get_shape()
+ if shape.ndims == 3 or shape.ndims is None:
+ return _rot90_3D(image, k, scope)
+ elif shape.ndims == 4:
+ return _rot90_4D(image, k, scope)
+ else:
+ raise ValueError('\'image\' must have either 3 or 4 dimensions.')
+
+
+def _rot90_3D(image, k, name_scope):
+ """Rotate image counter-clockwise by 90 degrees `k` times.
+
+ Args:
+ image: 3-D Tensor of shape `[height, width, channels]`.
+ k: A scalar integer. The number of times the image is rotated by 90 degrees.
+ name_scope: A valid TensorFlow name scope.
+
+ Returns:
+ A 3-D tensor of the same type and shape as `image`.
+
+ """
+
+ def _rot90():
+ return array_ops.transpose(array_ops.reverse_v2(image, [1]), [1, 0, 2])
+
+ def _rot180():
+ return array_ops.reverse_v2(image, [0, 1])
+
+ def _rot270():
+ return array_ops.reverse_v2(array_ops.transpose(image, [1, 0, 2]), [1])
+
+ cases = [(math_ops.equal(k, 1), _rot90), (math_ops.equal(k, 2), _rot180),
+ (math_ops.equal(k, 3), _rot270)]
+
+ result = control_flow_ops.case(
+ cases, default=lambda: image, exclusive=True, name=name_scope)
+ result.set_shape([None, None, image.get_shape()[2]])
+ return result
+
+
+def _rot90_4D(images, k, name_scope):
+ """Rotate batch of images counter-clockwise by 90 degrees `k` times.
+
+ Args:
+ images: 4-D Tensor of shape `[height, width, channels]`.
+ k: A scalar integer. The number of times the images are rotated by 90
+ degrees.
+ name_scope: A valid TensorFlow name scope.
+
+ Returns:
+ A 4-D tensor of the same type and shape as `images`.
+
+ """
- def _rot180():
- return array_ops.reverse_v2(image, [0, 1])
+ def _rot90():
+ return array_ops.transpose(array_ops.reverse_v2(images, [2]), [0, 2, 1, 3])
- def _rot270():
- return array_ops.reverse_v2(array_ops.transpose(image, [1, 0, 2]), [1])
+ def _rot180():
+ return array_ops.reverse_v2(images, [1, 2])
- cases = [(math_ops.equal(k, 1), _rot90), (math_ops.equal(k, 2), _rot180),
- (math_ops.equal(k, 3), _rot270)]
+ def _rot270():
+ return array_ops.reverse_v2(array_ops.transpose(images, [0, 2, 1, 3]), [2])
- ret = control_flow_ops.case(
- cases, default=lambda: image, exclusive=True, name=scope)
- ret.set_shape([None, None, image.get_shape()[2]])
- return ret
+ cases = [(math_ops.equal(k, 1), _rot90), (math_ops.equal(k, 2), _rot180),
+ (math_ops.equal(k, 3), _rot270)]
+
+ result = control_flow_ops.case(
+ cases, default=lambda: images, exclusive=True, name=name_scope)
+ shape = result.get_shape()
+ result.set_shape([shape[0], None, None, shape[3]])
+ return result
@tf_export('image.transpose_image')
def transpose_image(image):
- """Transpose an image by swapping the first and second dimension.
+ """Transpose image(s) by swapping the height and width dimension.
See also `transpose()`.
Args:
- image: 3-D tensor of shape `[height, width, channels]`
+ image: 4-D Tensor of shape `[batch, height, width, channels]` or
+ 3-D Tensor of shape `[height, width, channels]`.
Returns:
- A 3-D tensor of shape `[width, height, channels]`
+ If `image` was 4-D, a 4-D float Tensor of shape
+ `[batch, width, height, channels]`
+ If `image` was 3-D, a 3-D float Tensor of shape
+ `[width, height, channels]`
Raises:
ValueError: if the shape of `image` not supported.
"""
- with ops.name_scope(None, 'transpose_image', [image]) as scope:
+ with ops.name_scope(None, 'transpose_image', [image]):
image = ops.convert_to_tensor(image, name='image')
- image = _Assert3DImage(image)
- return array_ops.transpose(image, [1, 0, 2], name=scope)
+ image = _AssertAtLeast3DImage(image)
+ shape = image.get_shape()
+ if shape.ndims == 3 or shape.ndims is None:
+ return array_ops.transpose(image, [1, 0, 2], name='transpose_image')
+ elif shape.ndims == 4:
+ return array_ops.transpose(image, [0, 2, 1, 3], name='transpose_image')
+ else:
+ raise ValueError('\'image\' must have either 3 or 4 dimensions.')
@tf_export('image.central_crop')
@@ -1026,9 +1127,9 @@ def adjust_contrast(images, contrast_factor):
def adjust_gamma(image, gamma=1, gain=1):
"""Performs Gamma Correction on the input image.
- Also known as Power Law Transform. This function transforms the
- input image pixelwise according to the equation Out = In**gamma
- after scaling each pixel to the range 0 to 1.
+ Also known as Power Law Transform. This function transforms the
+ input image pixelwise according to the equation `Out = In**gamma`
+ after scaling each pixel to the range 0 to 1.
Args:
image : A Tensor.
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 18625293e0..b67e7cc558 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -934,7 +934,7 @@ class AdjustSaturationTest(test_util.TensorFlowTestCase):
class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
- def testIdempotentLeftRight(self):
+ def testInvolutionLeftRight(self):
x_np = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
with self.test_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
@@ -942,6 +942,16 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
y_tf = y.eval()
self.assertAllEqual(y_tf, x_np)
+ def testInvolutionLeftRightWithBatch(self):
+ x_np = np.array(
+ [[[1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3]]],
+ dtype=np.uint8).reshape([2, 2, 3, 1])
+ with self.test_session(use_gpu=True):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.flip_left_right(image_ops.flip_left_right(x_tf))
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, x_np)
+
def testLeftRight(self):
x_np = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
y_np = np.array([[3, 2, 1], [3, 2, 1]], dtype=np.uint8).reshape([2, 3, 1])
@@ -953,9 +963,24 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
y_tf = y.eval()
self.assertAllEqual(y_tf, y_np)
+ def testLeftRightWithBatch(self):
+ x_np = np.array(
+ [[[1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3]]],
+ dtype=np.uint8).reshape([2, 2, 3, 1])
+ y_np = np.array(
+ [[[3, 2, 1], [3, 2, 1]], [[3, 2, 1], [3, 2, 1]]],
+ dtype=np.uint8).reshape([2, 2, 3, 1])
+
+ with self.test_session(use_gpu=True):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.flip_left_right(x_tf)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, y_np)
+
def testRandomFlipLeftRight(self):
x_np = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
y_np = np.array([[3, 2, 1], [3, 2, 1]], dtype=np.uint8).reshape([2, 3, 1])
+ seed = 42
with self.test_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
@@ -964,7 +989,7 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
count_flipped = 0
count_unflipped = 0
- for _ in range(50):
+ for _ in range(100):
y_tf = y.eval()
if y_tf[0][0] == 1:
self.assertAllEqual(y_tf, x_np)
@@ -972,10 +997,15 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
else:
self.assertAllEqual(y_tf, y_np)
count_flipped += 1
- self.assertGreaterEqual(count_flipped, 1)
- self.assertGreaterEqual(count_unflipped, 1)
- def testIdempotentUpDown(self):
+ # 100 trials
+ # Mean: 50
+ # Std Dev: ~5
+ # Six Sigma: 50 - (5 * 6) = 20
+ self.assertGreaterEqual(count_flipped, 20)
+ self.assertGreaterEqual(count_unflipped, 20)
+
+ def testInvolutionUpDown(self):
x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
with self.test_session(use_gpu=True):
@@ -984,6 +1014,17 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
y_tf = y.eval()
self.assertAllEqual(y_tf, x_np)
+ def testInvolutionUpDownWithBatch(self):
+ x_np = np.array(
+ [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
+ dtype=np.uint8).reshape([2, 2, 3, 1])
+
+ with self.test_session(use_gpu=True):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.flip_up_down(image_ops.flip_up_down(x_tf))
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, x_np)
+
def testUpDown(self):
x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
y_np = np.array([[4, 5, 6], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
@@ -995,17 +1036,31 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
y_tf = y.eval()
self.assertAllEqual(y_tf, y_np)
+ def testUpDownWithBatch(self):
+ x_np = np.array(
+ [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
+ dtype=np.uint8).reshape([2, 2, 3, 1])
+ y_np = np.array(
+ [[[4, 5, 6], [1, 2, 3]], [[10, 11, 12], [7, 8, 9]]],
+ dtype=np.uint8).reshape([2, 2, 3, 1])
+
+ with self.test_session(use_gpu=True):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.flip_up_down(x_tf)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, y_np)
+
def testRandomFlipUpDown(self):
x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
y_np = np.array([[4, 5, 6], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
with self.test_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
- y = image_ops.random_flip_up_down(x_tf)
+ y = image_ops.random_flip_up_down(x_tf, seed=42)
self.assertTrue(y.op.name.startswith("random_flip_up_down"))
count_flipped = 0
count_unflipped = 0
- for _ in range(50):
+ for _ in range(100):
y_tf = y.eval()
if y_tf[0][0] == 1:
self.assertAllEqual(y_tf, x_np)
@@ -1013,10 +1068,15 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
else:
self.assertAllEqual(y_tf, y_np)
count_flipped += 1
- self.assertGreaterEqual(count_flipped, 1)
- self.assertGreaterEqual(count_unflipped, 1)
- def testIdempotentTranspose(self):
+ # 100 trials
+ # Mean: 50
+ # Std Dev: ~5
+ # Six Sigma: 50 - (5 * 6) = 20
+ self.assertGreaterEqual(count_flipped, 20)
+ self.assertGreaterEqual(count_unflipped, 20)
+
+ def testInvolutionTranspose(self):
x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
with self.test_session(use_gpu=True):
@@ -1025,6 +1085,17 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
y_tf = y.eval()
self.assertAllEqual(y_tf, x_np)
+ def testInvolutionTransposeWithBatch(self):
+ x_np = np.array(
+ [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
+ dtype=np.uint8).reshape([2, 2, 3, 1])
+
+ with self.test_session(use_gpu=True):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.transpose_image(image_ops.transpose_image(x_tf))
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, x_np)
+
def testTranspose(self):
x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
y_np = np.array([[1, 4], [2, 5], [3, 6]], dtype=np.uint8).reshape([3, 2, 1])
@@ -1036,15 +1107,34 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
y_tf = y.eval()
self.assertAllEqual(y_tf, y_np)
+ def testTransposeWithBatch(self):
+ x_np = np.array(
+ [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
+ dtype=np.uint8).reshape([2, 2, 3, 1])
+
+ y_np = np.array(
+ [[[1, 4], [2, 5], [3, 6]], [[7, 10], [8, 11], [9, 12]]],
+ dtype=np.uint8).reshape([2, 3, 2, 1])
+
+ with self.test_session(use_gpu=True):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.transpose_image(x_tf)
+ y_tf = y.eval()
+ self.assertAllEqual(y_tf, y_np)
+
def testPartialShapes(self):
p_unknown_rank = array_ops.placeholder(dtypes.uint8)
- p_unknown_dims = array_ops.placeholder(
+ p_unknown_dims_3 = array_ops.placeholder(
dtypes.uint8, shape=[None, None, None])
+ p_unknown_dims_4 = array_ops.placeholder(
+ dtypes.uint8, shape=[None, None, None, None])
p_unknown_width = array_ops.placeholder(dtypes.uint8, shape=[64, None, 3])
-
+ p_unknown_batch = array_ops.placeholder(
+ dtypes.uint8, shape=[None, 64, 64, 3])
p_wrong_rank = array_ops.placeholder(dtypes.uint8, shape=[None, None])
p_zero_dim = array_ops.placeholder(dtypes.uint8, shape=[64, 0, 3])
+ #Ops that support 3D input
for op in [
image_ops.flip_left_right, image_ops.flip_up_down,
image_ops.random_flip_left_right, image_ops.random_flip_up_down,
@@ -1052,16 +1142,34 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
]:
transformed_unknown_rank = op(p_unknown_rank)
self.assertEqual(3, transformed_unknown_rank.get_shape().ndims)
- transformed_unknown_dims = op(p_unknown_dims)
- self.assertEqual(3, transformed_unknown_dims.get_shape().ndims)
+ transformed_unknown_dims_3 = op(p_unknown_dims_3)
+ self.assertEqual(3, transformed_unknown_dims_3.get_shape().ndims)
transformed_unknown_width = op(p_unknown_width)
self.assertEqual(3, transformed_unknown_width.get_shape().ndims)
- with self.assertRaisesRegexp(ValueError, "must be three-dimensional"):
- op(p_wrong_rank)
with self.assertRaisesRegexp(ValueError, "must be > 0"):
op(p_zero_dim)
+ #Ops that support 4D input
+ for op in [
+ image_ops.flip_left_right, image_ops.flip_up_down,
+ image_ops.transpose_image, image_ops.rot90
+ ]:
+ transformed_unknown_dims_4 = op(p_unknown_dims_4)
+ self.assertEqual(4, transformed_unknown_dims_4.get_shape().ndims)
+ transformed_unknown_batch = op(p_unknown_batch)
+ self.assertEqual(4, transformed_unknown_batch.get_shape().ndims)
+ with self.assertRaisesRegexp(ValueError,
+ "must be at least three-dimensional"):
+ op(p_wrong_rank)
+
+ for op in [
+ image_ops.random_flip_left_right,
+ image_ops.random_flip_up_down,
+ ]:
+ with self.assertRaisesRegexp(ValueError, "must be three-dimensional"):
+ op(p_wrong_rank)
+
def testRot90GroupOrder(self):
image = np.arange(24, dtype=np.uint8).reshape([2, 4, 3])
with self.test_session(use_gpu=True):
@@ -1070,6 +1178,14 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
rotated = image_ops.rot90(rotated)
self.assertAllEqual(image, rotated.eval())
+ def testRot90GroupOrderWithBatch(self):
+ image = np.arange(48, dtype=np.uint8).reshape([2, 2, 4, 3])
+ with self.test_session(use_gpu=True):
+ rotated = image
+ for _ in xrange(4):
+ rotated = image_ops.rot90(rotated)
+ self.assertAllEqual(image, rotated.eval())
+
def testRot90NumpyEquivalence(self):
image = np.arange(24, dtype=np.uint8).reshape([2, 4, 3])
with self.test_session(use_gpu=True):
@@ -1079,6 +1195,15 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
y_np = np.rot90(image, k=k)
self.assertAllEqual(y_np, y_tf.eval({k_placeholder: k}))
+ def testRot90NumpyEquivalenceWithBatch(self):
+ image = np.arange(48, dtype=np.uint8).reshape([2, 2, 4, 3])
+ with self.test_session(use_gpu=True):
+ k_placeholder = array_ops.placeholder(dtypes.int32, shape=[])
+ y_tf = image_ops.rot90(image, k_placeholder)
+ for k in xrange(4):
+ y_np = np.rot90(image, k=k, axes=(1, 2))
+ self.assertAllEqual(y_np, y_tf.eval({k_placeholder: k}))
+
class RandomFlipTest(test_util.TensorFlowTestCase):
@@ -3173,6 +3298,14 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
scores = constant_op.constant([0.9])
image_ops.non_max_suppression(boxes, scores, 3, 0.5)
+ # The boxes is of shape [num_boxes, 4], and the scores is
+ # of shape [num_boxes]. So an error will thrown.
+ with self.assertRaisesRegexp(ValueError,
+ "Dimensions must be equal, but are 1 and 2"):
+ boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]])
+ scores = constant_op.constant([0.9, 0.75])
+ selected_indices = image_ops.non_max_suppression(boxes, scores, 3, 0.5)
+
# The scores should be 1D of shape [num_boxes].
with self.assertRaisesRegexp(ValueError,
"Shape must be rank 1 but is rank 2"):
diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py
index c86cc92321..a39417139e 100644
--- a/tensorflow/python/ops/losses/losses_impl.py
+++ b/tensorflow/python/ops/losses/losses_impl.py
@@ -156,8 +156,10 @@ def _num_present(losses, weights, per_batch=False):
present = weights_broadcast_ops.broadcast_weights(present, losses)
if per_batch:
return math_ops.reduce_sum(
- present, axis=math_ops.range(1, array_ops.rank(present)),
- keep_dims=True, name=scope)
+ present,
+ axis=math_ops.range(1, array_ops.rank(present)),
+ keepdims=True,
+ name=scope)
return math_ops.reduce_sum(present, name=scope)
@@ -324,7 +326,7 @@ def cosine_distance(
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
radial_diffs = math_ops.multiply(predictions, labels)
- losses = 1 - math_ops.reduce_sum(radial_diffs, axis=(axis,), keep_dims=True)
+ losses = 1 - math_ops.reduce_sum(radial_diffs, axis=(axis,), keepdims=True)
return compute_weighted_loss(
losses, weights, scope, loss_collection, reduction=reduction)
@@ -390,7 +392,7 @@ def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None,
`weights` acts as a coefficient for the loss. If a scalar is provided, then
the loss is simply scaled by the given value. If `weights` is a tensor of size
- [batch_size], then the total loss for each sample of the batch is rescaled
+ `[batch_size]`, then the total loss for each sample of the batch is rescaled
by the corresponding element in the `weights` vector. If the shape of
`weights` matches the shape of `predictions`, then the loss of each
measurable element of `predictions` is scaled by the corresponding value of
@@ -452,7 +454,7 @@ def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None,
`weights` acts as a coefficient for the loss. If a scalar is provided, then
the loss is simply scaled by the given value. If `weights` is a tensor of size
- [batch_size], then the total loss for each sample of the batch is rescaled
+ `[batch_size]`, then the total loss for each sample of the batch is rescaled
by the corresponding element in the `weights` vector. If the shape of
`weights` matches the shape of `predictions`, then the loss of each
measurable element of `predictions` is scaled by the corresponding value of
@@ -519,7 +521,7 @@ def mean_pairwise_squared_error(
`weights` acts as a coefficient for the loss. If a scalar is provided, then
the loss is simply scaled by the given value. If `weights` is a tensor of size
- [batch_size], then the total loss for each sample of the batch is rescaled
+ `[batch_size]`, then the total loss for each sample of the batch is rescaled
by the corresponding element in the `weights` vector.
Args:
@@ -559,15 +561,16 @@ def mean_pairwise_squared_error(
reduction_indices = math_ops.range(1, array_ops.rank(diffs))
sum_squares_diff_per_batch = math_ops.reduce_sum(
- math_ops.square(diffs), reduction_indices=reduction_indices,
- keep_dims=True)
+ math_ops.square(diffs),
+ reduction_indices=reduction_indices,
+ keepdims=True)
num_present_per_batch = _num_present(diffs, weights, per_batch=True)
term1 = 2.0 * _safe_div(sum_squares_diff_per_batch,
num_present_per_batch - 1)
sum_diff = math_ops.reduce_sum(
- diffs, reduction_indices=reduction_indices, keep_dims=True)
+ diffs, reduction_indices=reduction_indices, keepdims=True)
term2 = 2.0 * _safe_div(
math_ops.square(sum_diff),
math_ops.multiply(num_present_per_batch, num_present_per_batch - 1))
@@ -593,7 +596,7 @@ def mean_squared_error(
`weights` acts as a coefficient for the loss. If a scalar is provided, then
the loss is simply scaled by the given value. If `weights` is a tensor of size
- [batch_size], then the total loss for each sample of the batch is rescaled
+ `[batch_size]`, then the total loss for each sample of the batch is rescaled
by the corresponding element in the `weights` vector. If the shape of
`weights` matches the shape of `predictions`, then the loss of each
measurable element of `predictions` is scaled by the corresponding value of
@@ -812,7 +815,7 @@ def sparse_softmax_cross_entropy(
`weights` acts as a coefficient for the loss. If a scalar is provided,
then the loss is simply scaled by the given value. If `weights` is a
- tensor of shape [`batch_size`], then the loss weights apply to each
+ tensor of shape `[batch_size]`, then the loss weights apply to each
corresponding sample.
Args:
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index bd26ff6696..d314124ccd 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -105,7 +105,7 @@ class LogSumExpTest(test_util.TensorFlowTestCase):
for dtype in [np.float16, np.float32, np.double]:
x_np = np.random.rand(5, 5).astype(dtype)
with self.test_session(use_gpu=True):
- y_tf_np = math_ops.reduce_logsumexp(x_np, keep_dims=True).eval()
+ y_tf_np = math_ops.reduce_logsumexp(x_np, keepdims=True).eval()
self.assertEqual(y_tf_np.ndim, x_np.ndim)
y_np = log(np.sum(exp(x_np), keepdims=True))
self.assertAllClose(y_tf_np, y_np)
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index 2a883eb0d5..dc24b821a5 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -863,27 +863,27 @@ def _BatchNormGrad(grad_y,
grad_y = math_ops.cast(grad_y, dtypes.float32)
if is_training:
if data_format == b"NHWC":
- keep_dims = False
+ keepdims = False
reduce_axis = [0, 1, 2]
else:
- keep_dims = True
+ keepdims = True
reduce_axis = [0, 2, 3]
shape = [1, array_ops.size(scale), 1, 1]
scale = array_ops.reshape(scale, shape)
- mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keep_dims=keep_dims)
- mean_x = math_ops.reduce_mean(x, reduce_axis, keep_dims=keep_dims)
+ mean_grad_y = math_ops.reduce_mean(grad_y, reduce_axis, keepdims=keepdims)
+ mean_x = math_ops.reduce_mean(x, reduce_axis, keepdims=keepdims)
var_x = math_ops.reduce_mean(
math_ops.squared_difference(x, array_ops.stop_gradient(mean_x)),
reduce_axis,
- keep_dims=keep_dims)
+ keepdims=keepdims)
grad_y_offset = grad_y - mean_grad_y
x_offset = x - mean_x
mean = math_ops.reduce_mean(
- grad_y * x_offset, axis=reduce_axis, keep_dims=keep_dims)
+ grad_y * x_offset, axis=reduce_axis, keepdims=keepdims)
grad_x = scale * math_ops.rsqrt(var_x + epsilon) * (
grad_y_offset - math_ops.reciprocal(var_x + epsilon) * mean * x_offset)
grad_scale = math_ops.rsqrt(var_x + epsilon) * math_ops.reduce_sum(
- grad_y * x_offset, axis=reduce_axis, keep_dims=keep_dims)
+ grad_y * x_offset, axis=reduce_axis, keepdims=keepdims)
if data_format == b"NCHW":
grad_scale = array_ops.squeeze(grad_scale)
grad_offset = math_ops.reduce_sum(grad_y, axis=reduce_axis)
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 47f48a7e16..8fbe698914 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -2215,6 +2215,31 @@ def xw_plus_b_v1(x, weights, biases, name=None): # pylint: disable=invalid-name
return bias_add_v1(mm, biases, name=name)
+def _get_noise_shape(x, noise_shape):
+ # If noise_shape is none return immediately.
+ if noise_shape is None:
+ return array_ops.shape(x)
+
+ try:
+ # Best effort to figure out the intended shape.
+ # If not possible, let the op to handle it.
+ # In eager mode exception will show up.
+ noise_shape_ = tensor_shape.as_shape(noise_shape)
+ except (TypeError, ValueError):
+ return noise_shape
+
+ if x.shape.dims is not None and len(x.shape.dims) == len(noise_shape_.dims):
+ new_dims = []
+ for i, dim in enumerate(x.shape.dims):
+ if noise_shape_.dims[i].value is None and dim.value is not None:
+ new_dims.append(dim.value)
+ else:
+ new_dims.append(noise_shape_.dims[i].value)
+ return tensor_shape.TensorShape(new_dims)
+
+ return noise_shape
+
+
@tf_export("nn.dropout")
def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: disable=invalid-name
"""Computes dropout.
@@ -2265,7 +2290,8 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: di
if tensor_util.constant_value(keep_prob) == 1:
return x
- noise_shape = noise_shape if noise_shape is not None else array_ops.shape(x)
+ noise_shape = _get_noise_shape(x, noise_shape)
+
# uniform [keep_prob, 1.0 + keep_prob)
random_tensor = keep_prob
random_tensor += random_ops.random_uniform(
@@ -2380,7 +2406,7 @@ def conv1d(value,
Args:
value: A 3D `Tensor`. Must be of type `float16` or `float32`.
- filters: A 3D `Tensor`. Must have the same type as `input`.
+ filters: A 3D `Tensor`. Must have the same type as `value`.
stride: An `integer`. The number of entries by which
the filter is moved right at each step.
padding: 'SAME' or 'VALID'
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index 5a45bdc1e5..21eea3db25 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -383,6 +383,31 @@ class DropoutTest(test_lib.TestCase):
x, keep_prob, noise_shape=array_ops.placeholder(dtypes.int32))
self.assertEqual(x.get_shape(), dropout_x.get_shape())
+ def testPartialShapedDropout(self):
+ x_dim = 40 * 30
+ y_dim = 3
+ num_iter = 10
+ for keep_prob in [0.1, 0.5, 0.8]:
+ with self.test_session():
+ t = constant_op.constant(
+ 1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
+ # Set noise_shape=[None, 1] which means [x_dim, 1].
+ dropout = nn_ops.dropout(t, keep_prob, noise_shape=[None, 1])
+ self.assertEqual([x_dim, y_dim], dropout.get_shape())
+ final_count = 0
+ for _ in xrange(0, num_iter):
+ value = dropout.eval()
+ final_count += np.count_nonzero(value)
+ # Verifies that there are only two values: 0 and 1/keep_prob.
+ sorted_value = np.unique(np.sort(value))
+ self.assertEqual(0, sorted_value[0])
+ self.assertAllClose(1 / keep_prob, sorted_value[1])
+ # Check that we are in the 15% error range
+ expected_count = x_dim * y_dim * keep_prob * num_iter
+ rel_error = math.fabs(final_count - expected_count) / expected_count
+ print(rel_error)
+ self.assertTrue(rel_error < 0.15)
+
def testInvalidKeepProb(self):
x_dim = 40
y_dim = 30
diff --git a/tensorflow/python/profiler/option_builder.py b/tensorflow/python/profiler/option_builder.py
index 957ebe6ddd..2ad7adf769 100644
--- a/tensorflow/python/profiler/option_builder.py
+++ b/tensorflow/python/profiler/option_builder.py
@@ -300,7 +300,7 @@ class ProfileOptionBuilder(object):
# pylint: disable=line-too-long
"""Only show profiler nodes consuming no less than 'min_float_ops'.
- Please see https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/profilerg3doc/profile_model_architecture.md
+ Please see https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/profiler/g3doc/profile_model_architecture.md
on the caveats of calculating float operations.
Args:
diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py
index 074b8e7132..a52f325ddb 100644
--- a/tensorflow/python/tools/freeze_graph.py
+++ b/tensorflow/python/tools/freeze_graph.py
@@ -109,7 +109,7 @@ def freeze_graph_with_def_protos(input_graph_def,
input_meta_graph_def, clear_devices=True)
restorer.restore(sess, input_checkpoint)
if initializer_nodes:
- sess.run(initializer_nodes.split(","))
+ sess.run(initializer_nodes.replace(" ", "").split(","))
elif input_saved_model_dir:
if saved_model_tags is None:
saved_model_tags = []
@@ -130,25 +130,27 @@ def freeze_graph_with_def_protos(input_graph_def,
var_list=var_list, write_version=checkpoint_version)
saver.restore(sess, input_checkpoint)
if initializer_nodes:
- sess.run(initializer_nodes.split(","))
+ sess.run(initializer_nodes.replace(" ", "").split(","))
- variable_names_whitelist = (variable_names_whitelist.split(",")
- if variable_names_whitelist else None)
- variable_names_blacklist = (variable_names_blacklist.split(",")
- if variable_names_blacklist else None)
+ variable_names_whitelist = (
+ variable_names_whitelist.replace(" ", "").split(",")
+ if variable_names_whitelist else None)
+ variable_names_blacklist = (
+ variable_names_blacklist.replace(" ", "").split(",")
+ if variable_names_blacklist else None)
if input_meta_graph_def:
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_meta_graph_def.graph_def,
- output_node_names.split(","),
+ output_node_names.replace(" ", "").split(","),
variable_names_whitelist=variable_names_whitelist,
variable_names_blacklist=variable_names_blacklist)
else:
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
- output_node_names.split(","),
+ output_node_names.replace(" ", "").split(","),
variable_names_whitelist=variable_names_whitelist,
variable_names_blacklist=variable_names_blacklist)
@@ -250,7 +252,7 @@ def freeze_graph(input_graph,
variable_names_blacklist,
input_meta_graph_def,
input_saved_model_dir,
- saved_model_tags.split(","),
+ saved_model_tags.replace(" ", "").split(","),
checkpoint_version=checkpoint_version)
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 0c1c8e664b..3888e9bba4 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -1597,9 +1597,9 @@ class Saver(object):
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
Returns:
- A string: path prefix used for the checkpoint files. If checkpoint
- format is V1 and the saver is sharded, this string ends with:
- '-?????-of-nnnnn' where 'nnnnn' is the number of shards created.
+ A string: path prefix used for the checkpoint files. If the saver is
+ sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn'
+ is the number of shards created.
If the saver is empty, returns None.
Raises:
@@ -1749,12 +1749,6 @@ class Saver(object):
return
if save_path is None:
raise ValueError("Can't load save_path when it is None.")
- if (os.path.isfile(save_path) and
- self._write_version not in (
- saver_pb2.SaverDef.V1, saver_pb2.SaverDef.LEGACY)):
- raise ValueError("The specified path: %s is a file."
- " Please specify only the path prefix"
- " to the checkpoint files." % save_path)
logging.info("Restoring parameters from %s", save_path)
if context.in_graph_mode():
sess.run(self.saver_def.restore_op_name,
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 82142fa21d..818d67f7b5 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -618,7 +618,7 @@ def tf_cc_test(name,
srcs=srcs + tf_binary_additional_srcs(),
copts=tf_copts() + extra_copts,
linkopts=select({
- "//tensorflow:android": [
+ clean_dep("//tensorflow:android"): [
"-pie",
],
clean_dep("//tensorflow:windows"): [],
@@ -1312,6 +1312,46 @@ def tf_extension_linkopts():
def tf_extension_copts():
return [] # No extension c opts
+# In tf_py_wrap_cc generated libraries
+# module init functions are not exported unless
+# they contain one of the keywords in the version file
+# this prevents custom python modules.
+# This function attempts to append init_module_name to list of
+# exported functions in version script
+def _append_init_to_versionscript_impl(ctx):
+ mod_name = ctx.attr.module_name
+ if ctx.attr.is_version_script:
+ ctx.actions.expand_template(
+ template=ctx.file.template_file,
+ output=ctx.outputs.versionscript,
+ substitutions={
+ "global:":"global:\n init_%s;\n PyInit_*;"%(mod_name),
+ },
+ is_executable=False,
+ )
+ else:
+ ctx.actions.expand_template(
+ template=ctx.file.template_file,
+ output=ctx.outputs.versionscript,
+ substitutions={
+ "*tensorflow*":"*tensorflow*\ninit_%s\nPyInit_*\n"%(mod_name),
+ },
+ is_executable=False,
+ )
+
+
+_append_init_to_versionscript= rule(
+ implementation=_append_init_to_versionscript_impl,
+ attrs={
+ "module_name":attr.string(mandatory=True),
+ "template_file":attr.label(allow_files=True,single_file=True,mandatory=True),
+ "is_version_script":attr.bool(default=True,
+ doc='whether target is a ld version script or exported symbol list',
+ mandatory=False),
+ },
+ outputs={"versionscript":"%{name}.lds"},
+)
+
def tf_py_wrap_cc(name,
srcs,
swig_includes=[],
@@ -1333,26 +1373,39 @@ def tf_py_wrap_cc(name,
toolchain_deps=["//tools/defaults:crosstool"],
module_name=module_name,
py_module_name=name)
+ vscriptname=name+"_versionscript"
+ _append_init_to_versionscript(
+ name=vscriptname,
+ module_name=module_name,
+ is_version_script=select({
+ "@local_config_cuda//cuda:darwin":False,
+ "//conditions:default":True,
+ }),
+ template_file=select({
+ "@local_config_cuda//cuda:darwin":clean_dep("//tensorflow:tf_exported_symbols.lds"),
+ "//conditions:default":clean_dep("//tensorflow:tf_version_script.lds")
+ })
+ )
extra_linkopts = select({
"@local_config_cuda//cuda:darwin": [
"-Wl,-exported_symbols_list",
- clean_dep("//tensorflow:tf_exported_symbols.lds")
+ "%s.lds"%vscriptname,
],
clean_dep("//tensorflow:windows"): [],
clean_dep("//tensorflow:windows_msvc"): [],
"//conditions:default": [
"-Wl,--version-script",
- clean_dep("//tensorflow:tf_version_script.lds")
+ "%s.lds"%vscriptname,
]
})
extra_deps += select({
"@local_config_cuda//cuda:darwin": [
- clean_dep("//tensorflow:tf_exported_symbols.lds")
+ "%s.lds"%vscriptname,
],
clean_dep("//tensorflow:windows"): [],
clean_dep("//tensorflow:windows_msvc"): [],
"//conditions:default": [
- clean_dep("//tensorflow:tf_version_script.lds")
+ "%s.lds"%vscriptname,
]
})
diff --git a/tensorflow/tools/ci_build/install/install_bazel.sh b/tensorflow/tools/ci_build/install/install_bazel.sh
index cf8737c2d8..1df6a84d7c 100755
--- a/tensorflow/tools/ci_build/install/install_bazel.sh
+++ b/tensorflow/tools/ci_build/install/install_bazel.sh
@@ -15,7 +15,7 @@
# ==============================================================================
# Select bazel version.
-BAZEL_VERSION="0.8.0"
+BAZEL_VERSION="0.10.0"
set +e
local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}')
diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD
index 8601b3d0f1..ad3668fa02 100644
--- a/tensorflow/tools/graph_transforms/BUILD
+++ b/tensorflow/tools/graph_transforms/BUILD
@@ -103,6 +103,7 @@ cc_library(
"quantize_nodes.cc",
"quantize_weights.cc",
"remove_attribute.cc",
+ "remove_control_dependencies.cc",
"remove_device.cc",
"remove_ema.cc",
"remove_nodes.cc",
diff --git a/tensorflow/tools/graph_transforms/README.md b/tensorflow/tools/graph_transforms/README.md
index 345d9eadb8..67badb4869 100644
--- a/tensorflow/tools/graph_transforms/README.md
+++ b/tensorflow/tools/graph_transforms/README.md
@@ -639,6 +639,13 @@ specified devices may not be available. In order to work with graphs like these,
you can run this transform to wipe the slate clean and delete the device
specifier from all ops.
+### remove_control_dependencies
+
+Args: None \
+Prerequisites: None
+
+Removes all control dependencies from the graph.
+
### remove_nodes
Args:
diff --git a/tensorflow/tools/graph_transforms/remove_control_dependencies.cc b/tensorflow/tools/graph_transforms/remove_control_dependencies.cc
new file mode 100644
index 0000000000..a900ee65b0
--- /dev/null
+++ b/tensorflow/tools/graph_transforms/remove_control_dependencies.cc
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace graph_transforms {
+
+// Remove control depdencies in preparation for inference.
+// In the tensorflow graph, control dependencies are represented as extra
+// inputs which are referenced with "^tensor_name".
+// See node_def.proto for more details.
+Status RemoveControlDependencies(const GraphDef& input_graph_def,
+ const TransformFuncContext& context,
+ GraphDef* output_graph_def) {
+ output_graph_def->Clear();
+ for (const NodeDef& node : input_graph_def.node()) {
+ NodeDef* new_node = output_graph_def->mutable_node()->Add();
+ *new_node = node;
+ new_node->clear_input();
+ for (const auto& input : node.input()) {
+ if (input[0] != '^') {
+ new_node->add_input(input);
+ }
+ }
+ }
+ return Status::OK();
+}
+
+REGISTER_GRAPH_TRANSFORM("remove_control_dependencies",
+ RemoveControlDependencies);
+
+} // namespace graph_transforms
+} // namespace tensorflow
diff --git a/tensorflow/tools/graph_transforms/remove_nodes.cc b/tensorflow/tools/graph_transforms/remove_nodes.cc
index 119b44d6a4..05f036a86a 100644
--- a/tensorflow/tools/graph_transforms/remove_nodes.cc
+++ b/tensorflow/tools/graph_transforms/remove_nodes.cc
@@ -81,7 +81,17 @@ Status RemoveNodes(const GraphDef& input_graph_def,
return Status::OK();
}
const NodeDef& input_node = match.inputs[0].node;
- inputs_to_rename[replace_node.name()] = input_node.name();
+ string target_name = input_node.name();
+ for (const string& input : replace_node.input()) {
+ if (!input.compare(0, target_name.size(), target_name)) {
+ if (input.size() == target_name.size() ||
+ input[target_name.size()] == ':') {
+ target_name = input;
+ break;
+ }
+ }
+ }
+ inputs_to_rename[replace_node.name()] = target_name;
inputs_to_rename["^" + replace_node.name()] =
"^" + input_node.name();
new_nodes->push_back(input_node);
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 791016e8b7..fb6eaa4faa 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -11,6 +11,7 @@ load(
)
load("//third_party/mkl:build_defs.bzl", "if_mkl")
load("//tensorflow:tensorflow.bzl", "if_cuda")
+load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt")
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps")
# This returns a list of headers of all public header libraries (e.g.,
@@ -191,7 +192,9 @@ sh_binary(
"//tensorflow/python:test_ops",
"//tensorflow/tools/dist_test/server:grpc_tensorflow_server",
],
- }) + if_mkl(["//third_party/mkl:intel_binary_blob"]),
+ }) + if_mkl(["//third_party/mkl:intel_binary_blob"]) + if_tensorrt([
+ "//tensorflow/contrib/tensorrt:init_py",
+ ]),
)
# A genrule for generating a marker file for the pip package on Windows
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index 0e6b32bb49..4b6f123daa 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -29,17 +29,17 @@ from setuptools.dist import Distribution
# This version string is semver compatible, but incompatible with pip.
# For pip, we will remove all '-' characters from this string, and use the
# result for pip.
-_VERSION = '1.6.0-rc0'
+_VERSION = '1.6.0-rc1'
REQUIRED_PACKAGES = [
'absl-py >= 0.1.6',
'astor >= 0.6.0',
'gast >= 0.2.0',
'grpcio >= 1.8.6',
- 'numpy >= 1.12.1',
+ 'numpy >= 1.13.3',
'six >= 1.10.0',
'protobuf >= 3.4.0',
- 'tensorflow-tensorboard >= 1.5.0, < 1.6.0',
+ 'tensorboard >= 1.6.0, < 1.7.0',
'termcolor >= 1.1.0',
]
@@ -62,7 +62,7 @@ else:
if 'tf_nightly' in project_name:
for i, pkg in enumerate(REQUIRED_PACKAGES):
if 'tensorboard' in pkg:
- REQUIRED_PACKAGES[i] = 'tb-nightly >= 1.5.0a0, < 1.6.0a0'
+ REQUIRED_PACKAGES[i] = 'tb-nightly >= 1.7.0a0, < 1.8.0a0'
break
# weakref.finalize and enum were introduced in Python 3.4
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index 255ae01190..b7c47a19dd 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -367,11 +367,20 @@ def find_cuda_define(repository_ctx, header_dir, header_file, define):
if result.stdout.find(define) == -1:
auto_configure_fail("Cannot find line containing '%s' in %s" %
(define, h_path))
- version = result.stdout
- # Remove the new line and '\' character if any.
- version = version.replace("\\", " ")
- version = version.replace("\n", " ")
- version = version.replace(define, "").lstrip()
+ # Split results to lines
+ lines = result.stdout.split('\n')
+ num_lines = len(lines)
+ for l in range(num_lines):
+ line = lines[l]
+ if define in line: # Find the line with define
+ version = line
+ if l != num_lines-1 and line[-1] == '\\': # Add next line, if multiline
+ version = version[:-1] + lines[l+1]
+ break
+ # Remove any comments
+ version = version.split("//")[0]
+ # Remove define name
+ version = version.replace(define, "").strip()
# Remove the code after the version number.
version_end = version.find(" ")
if version_end != -1:
diff --git a/third_party/tensorrt/BUILD.tpl b/third_party/tensorrt/BUILD.tpl
index feaeb0bea6..57682e8735 100644
--- a/third_party/tensorrt/BUILD.tpl
+++ b/third_party/tensorrt/BUILD.tpl
@@ -3,6 +3,8 @@
licenses(["notice"])
+exports_files(["LICENSE"])
+
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts")
package(default_visibility = ["//visibility:public"])
@@ -32,36 +34,6 @@ cc_library(
visibility = ["//visibility:public"],
)
-cc_library(
- name = "nv_infer_plugin",
- srcs = [%{nv_infer_plugin}],
- data = [%{nv_infer_plugin}],
- includes = [
- "include",
- ],
- copts= cuda_default_copts(),
- deps = [
- "@local_config_cuda//cuda:cuda",
- ":nv_infer",
- ":tensorrt_headers",
- ],
- linkstatic = 1,
- visibility = ["//visibility:public"],
-)
-
-cc_library(
- name = "nv_parsers",
- srcs = [%{nv_parsers}],
- data = [%{nv_parsers}],
- includes = [
- "include",
- ],
- copts= cuda_default_copts(),
- deps = [
- ":tensorrt_headers",
- ],
- linkstatic = 1,
- visibility = ["//visibility:public"],
-)
%{tensorrt_genrules}
+
diff --git a/third_party/tensorrt/LICENSE b/third_party/tensorrt/LICENSE
new file mode 100644
index 0000000000..146d9b765c
--- /dev/null
+++ b/third_party/tensorrt/LICENSE
@@ -0,0 +1,203 @@
+Copyright 2018 The TensorFlow Authors. All rights reserved.
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2018, The TensorFlow Authors.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/third_party/tensorrt/tensorrt_configure.bzl b/third_party/tensorrt/tensorrt_configure.bzl
index 8aa0f28f39..8e76e5d02a 100644
--- a/third_party/tensorrt/tensorrt_configure.bzl
+++ b/third_party/tensorrt/tensorrt_configure.bzl
@@ -19,11 +19,8 @@ load(
_TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH"
_TF_TENSORRT_VERSION = "TF_TENSORRT_VERSION"
-_TF_TENSORRT_LIBS = ["nvinfer", "nvinfer_plugin", "nvparsers"]
-_TF_TENSORRT_HEADERS = [
- "NvInfer.h", "NvInferPlugin.h", "NvCaffeParser.h", "NvUffParser.h",
- "NvUtils.h"
-]
+_TF_TENSORRT_LIBS = ["nvinfer"]
+_TF_TENSORRT_HEADERS = ["NvInfer.h", "NvUtils.h"]
_DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR"
_DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR"