aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Qianli Scott Zhu <scottzhu@google.com>2018-04-09 21:25:15 -0700
committerGravatar GitHub <noreply@github.com>2018-04-09 21:25:15 -0700
commit3e0fd55ccec1f8ac5ca7d11f9999a16871a9198c (patch)
tree501b3932d36997bbe0f4c5b6e4d8bbb6047cf852
parent2770cf7805b7cffda830e6e09e1230cf94173937 (diff)
parent7f2d9ad2dae2bd653820a8b4e3191d24ef5f1f12 (diff)
Merge pull request #18366 from qlzh727/branch_192210794
Branch 192210794
-rw-r--r--README.md2
-rw-r--r--tensorflow/BUILD30
-rw-r--r--tensorflow/__init__.py7
-rw-r--r--tensorflow/compiler/tests/BUILD5
-rw-r--r--tensorflow/compiler/tests/build_defs.bzl2
-rw-r--r--tensorflow/compiler/tests/spacetobatch_op_test.py3
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py3
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc8
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_test.cc5
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.cc11
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.cc3
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc90
-rw-r--r--tensorflow/compiler/xla/literal_util.cc120
-rw-r--r--tensorflow/compiler/xla/literal_util.h52
-rw-r--r--tensorflow/compiler/xla/literal_util_test.cc62
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc1
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h1
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h3
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc30
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc51
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc19
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc27
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h1
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc7
-rw-r--r--tensorflow/compiler/xla/shape_util.h103
-rw-r--r--tensorflow/compiler/xla/shape_util_test.cc18
-rw-r--r--tensorflow/compiler/xla/tests/BUILD8
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc19
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h9
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.cc8
-rw-r--r--tensorflow/compiler/xla/tests/pad_test.cc54
-rw-r--r--tensorflow/compiler/xla/tests/reshape_test.cc162
-rw-r--r--tensorflow/compiler/xla/tests/xla_internal_test_main.cc8
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser.cc9
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc12
-rw-r--r--tensorflow/contrib/BUILD1
-rw-r--r--tensorflow/contrib/__init__.py1
-rw-r--r--tensorflow/contrib/autograph/operators/control_flow.py19
-rw-r--r--tensorflow/contrib/autograph/pyct/inspect_utils.py12
-rw-r--r--tensorflow/contrib/autograph/pyct/inspect_utils_test.py24
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt12
-rw-r--r--tensorflow/contrib/cmake/tf_core_ops.cmake3
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake94
-rw-r--r--tensorflow/contrib/eager/python/BUILD3
-rw-r--r--tensorflow/contrib/eager/python/checkpointable_utils_test.py12
-rw-r--r--tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py57
-rw-r--r--tensorflow/contrib/estimator/BUILD55
-rw-r--r--tensorflow/contrib/estimator/__init__.py1
-rw-r--r--tensorflow/contrib/estimator/python/estimator/multi_head.py2
-rw-r--r--tensorflow/contrib/estimator/python/estimator/multi_head_test.py16
-rw-r--r--tensorflow/contrib/estimator/python/estimator/rnn.py481
-rw-r--r--tensorflow/contrib/estimator/python/estimator/rnn_test.py1131
-rw-r--r--tensorflow/contrib/graph_editor/select.py26
-rw-r--r--tensorflow/contrib/graph_editor/tests/select_test.py155
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib.py5
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib_test.py14
-rw-r--r--tensorflow/contrib/lite/build_def.bzl3
-rw-r--r--tensorflow/contrib/lite/interpreter.h2
-rw-r--r--tensorflow/contrib/lite/java/BUILD1
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h293
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py106
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc4
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc14
-rw-r--r--tensorflow/contrib/lite/toco/model.h12
-rw-r--r--tensorflow/contrib/lite/toco/python/toco_from_protos_test.py1
-rw-r--r--tensorflow/contrib/lite/toco/python/toco_python_api.cc3
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc1
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt3
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_manager.cc2
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_manager_test.cc214
-rw-r--r--tensorflow/contrib/nccl/ops/nccl_ops.cc14
-rw-r--r--tensorflow/contrib/nccl/python/ops/nccl_ops_test.py2
-rw-r--r--tensorflow/contrib/proto/BUILD16
-rw-r--r--tensorflow/contrib/proto/__init__.py28
-rw-r--r--tensorflow/contrib/proto/python/ops/BUILD44
-rw-r--r--tensorflow/contrib/proto/python/ops/decode_proto_op.py25
-rw-r--r--tensorflow/contrib/proto/python/ops/encode_proto_op.py25
-rw-r--r--tensorflow/contrib/quantize/README.md21
-rw-r--r--tensorflow/contrib/quantize/python/quant_ops.py10
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py70
-rw-r--r--tensorflow/contrib/quantize/python/quantize_graph.py26
-rw-r--r--tensorflow/contrib/quantize/python/quantize_graph_test.py110
-rw-r--r--tensorflow/contrib/quantize/python/quantize_test.py30
-rw-r--r--tensorflow/contrib/recurrent/BUILD106
-rw-r--r--tensorflow/contrib/recurrent/README.md13
-rw-r--r--tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py163
-rw-r--r--tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py192
-rw-r--r--tensorflow/contrib/recurrent/python/ops/functional_rnn.py396
-rw-r--r--tensorflow/contrib/recurrent/python/ops/recurrent.py720
-rw-r--r--tensorflow/contrib/recurrent/python/recurrent_api.py (renamed from tensorflow/experimental_api.py)23
-rw-r--r--tensorflow/contrib/rpc/BUILD13
-rw-r--r--tensorflow/contrib/rpc/__init__.py28
-rw-r--r--tensorflow/contrib/rpc/python/ops/BUILD24
-rw-r--r--tensorflow/contrib/rpc/python/ops/rpc_op.py26
-rw-r--r--tensorflow/contrib/tpu/BUILD1
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py211
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py34
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_test.py2
-rw-r--r--tensorflow/contrib/training/python/training/hparam.py3
-rw-r--r--tensorflow/core/BUILD158
-rw-r--r--tensorflow/core/api_def/base_api/api_def_CollectiveBcastRecv.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_CollectiveBcastSend.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_CollectiveReduce.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt116
-rw-r--r--tensorflow/core/api_def/base_api/api_def_DeepCopy.pbtxt15
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Empty.pbtxt23
-rw-r--r--tensorflow/core/api_def/base_api/api_def_EncodeProto.pbtxt81
-rw-r--r--tensorflow/core/api_def/base_api/api_def_InplaceAdd.pbtxt28
-rw-r--r--tensorflow/core/api_def/base_api/api_def_InplaceSub.pbtxt28
-rw-r--r--tensorflow/core/api_def/base_api/api_def_InplaceUpdate.pbtxt28
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Rpc.pbtxt108
-rw-r--r--tensorflow/core/api_def/base_api/api_def_TryRpc.pbtxt123
-rw-r--r--tensorflow/core/api_def/python_api/api_def_DeepCopy.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Empty.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_InplaceAdd.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_InplaceSub.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_InplaceUpdate.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SlideDataset.pbtxt4
-rw-r--r--tensorflow/core/common_runtime/bfc_allocator.h2
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc210
-rw-r--r--tensorflow/core/common_runtime/eigen_thread_pool.h2
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h2
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc33
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h2
-rw-r--r--tensorflow/core/common_runtime/gpu/pool_allocator.h2
-rw-r--r--tensorflow/core/common_runtime/graph_execution_state.cc105
-rw-r--r--tensorflow/core/common_runtime/graph_runner.cc21
-rw-r--r--tensorflow/core/common_runtime/mkl_cpu_allocator.h2
-rw-r--r--tensorflow/core/common_runtime/shape_refiner.cc26
-rw-r--r--tensorflow/core/common_runtime/single_threaded_cpu_device.h82
-rw-r--r--tensorflow/core/common_runtime/visitable_allocator.h (renamed from tensorflow/core/framework/visitable_allocator.h)6
-rw-r--r--tensorflow/core/distributed_runtime/local_master.cc41
-rw-r--r--tensorflow/core/distributed_runtime/local_master.h10
-rw-r--r--tensorflow/core/distributed_runtime/master.cc51
-rw-r--r--tensorflow/core/distributed_runtime/master.h7
-rw-r--r--tensorflow/core/distributed_runtime/master_interface.h10
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc499
-rw-r--r--tensorflow/core/distributed_runtime/master_session.h28
-rw-r--r--tensorflow/core/distributed_runtime/message_wrappers.cc26
-rw-r--r--tensorflow/core/distributed_runtime/message_wrappers.h9
-rw-r--r--tensorflow/core/distributed_runtime/rpc/BUILD30
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc46
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc35
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h45
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc22
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc213
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h59
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory_registration.cc34
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session.cc78
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session.h27
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc43
-rw-r--r--tensorflow/core/framework/allocator.cc63
-rw-r--r--tensorflow/core/framework/collective.h3
-rw-r--r--tensorflow/core/framework/op_kernel.cc2
-rw-r--r--tensorflow/core/framework/shape_inference.cc78
-rw-r--r--tensorflow/core/framework/shape_inference.h11
-rw-r--r--tensorflow/core/framework/shape_inference_test.cc13
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc165
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h26
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc64
-rw-r--r--tensorflow/core/grappler/op_types.cc2
-rw-r--r--tensorflow/core/grappler/op_types.h1
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc563
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h5
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc161
-rw-r--r--tensorflow/core/grappler/optimizers/debug_stripper.cc17
-rw-r--r--tensorflow/core/grappler/optimizers/debug_stripper_test.cc36
-rw-r--r--tensorflow/core/grappler/optimizers/graph_optimizer_stage.h12
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc47
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer_test.cc24
-rw-r--r--tensorflow/core/grappler/optimizers/loop_optimizer.cc266
-rw-r--r--tensorflow/core/grappler/optimizers/loop_optimizer.h47
-rw-r--r--tensorflow/core/grappler/optimizers/loop_optimizer_test.cc43
-rw-r--r--tensorflow/core/grappler/utils.cc26
-rw-r--r--tensorflow/core/kernels/BUILD59
-rw-r--r--tensorflow/core/kernels/collective_ops.cc266
-rw-r--r--tensorflow/core/kernels/concat_lib_gpu.cc1
-rw-r--r--tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc4
-rw-r--r--tensorflow/core/kernels/concat_op.cc1
-rw-r--r--tensorflow/core/kernels/concat_op_test.cc4
-rw-r--r--tensorflow/core/kernels/constant_op.cc4
-rw-r--r--tensorflow/core/kernels/cudnn_rnn_ops.cc106
-rw-r--r--tensorflow/core/kernels/decode_proto_op.cc1011
-rw-r--r--tensorflow/core/kernels/dense_update_functor.cc56
-rw-r--r--tensorflow/core/kernels/dense_update_functor.h14
-rw-r--r--tensorflow/core/kernels/encode_proto_op.cc591
-rw-r--r--tensorflow/core/kernels/fill_functor.cu.cc2
-rw-r--r--tensorflow/core/kernels/gather_functor.h13
-rw-r--r--tensorflow/core/kernels/inplace_ops.cc296
-rw-r--r--tensorflow/core/kernels/inplace_ops_functor.h17
-rw-r--r--tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc97
-rw-r--r--tensorflow/core/kernels/lookup_util.cc5
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc118
-rw-r--r--tensorflow/core/kernels/rpc_op.cc129
-rw-r--r--tensorflow/core/kernels/scatter_functor.h118
-rw-r--r--tensorflow/core/kernels/training_op_helpers.h30
-rw-r--r--tensorflow/core/lib/gtl/flatmap_test.cc2
-rw-r--r--tensorflow/core/ops/array_ops.cc76
-rw-r--r--tensorflow/core/ops/collective_ops.cc55
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt2572
-rw-r--r--tensorflow/core/ops/decode_proto_ops.cc67
-rw-r--r--tensorflow/core/ops/encode_proto_ops.cc49
-rw-r--r--tensorflow/core/ops/list_ops.cc4
-rw-r--r--tensorflow/core/ops/math_ops.cc52
-rw-r--r--tensorflow/core/ops/ops.pbtxt558
-rw-r--r--tensorflow/core/ops/rpc_ops.cc81
-rw-r--r--tensorflow/core/platform/default/build_config/BUILD13
-rw-r--r--tensorflow/core/platform/s3/s3_file_system.cc2
-rw-r--r--tensorflow/core/protobuf/config.proto18
-rw-r--r--tensorflow/core/protobuf/master.proto68
-rw-r--r--tensorflow/core/protobuf/master_service.proto9
-rw-r--r--tensorflow/core/util/proto/BUILD62
-rw-r--r--tensorflow/core/util/proto/decode.h592
-rw-r--r--tensorflow/core/util/proto/descriptor_pool_registry.cc45
-rw-r--r--tensorflow/core/util/proto/descriptor_pool_registry.h76
-rw-r--r--tensorflow/core/util/proto/descriptor_pool_registry_test.cc43
-rw-r--r--tensorflow/core/util/proto/descriptors.cc85
-rw-r--r--tensorflow/core/util/proto/descriptors.h42
-rw-r--r--tensorflow/core/util/proto/local_descriptor_pool_registration.cc39
-rw-r--r--tensorflow/core/util/reporter.cc12
-rw-r--r--tensorflow/core/util/reporter.h10
-rw-r--r--tensorflow/core/util/reporter_test.cc23
-rw-r--r--tensorflow/core/util/rpc/BUILD48
-rw-r--r--tensorflow/core/util/rpc/call_container.h90
-rw-r--r--tensorflow/core/util/rpc/rpc_factory.cc53
-rw-r--r--tensorflow/core/util/rpc/rpc_factory.h70
-rw-r--r--tensorflow/core/util/rpc/rpc_factory_registry.cc44
-rw-r--r--tensorflow/core/util/rpc/rpc_factory_registry.h72
-rw-r--r--tensorflow/core/util/rpc/rpc_factory_registry_test.cc41
-rw-r--r--tensorflow/docs_src/extend/new_data_formats.md395
-rw-r--r--tensorflow/docs_src/programmers_guide/eager.md42
-rw-r--r--tensorflow/go/op/wrappers.go2534
-rw-r--r--tensorflow/python/BUILD6
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py2
-rw-r--r--tensorflow/python/eager/function.py10
-rw-r--r--tensorflow/python/estimator/canned/head.py8
-rw-r--r--tensorflow/python/framework/dtypes.py2
-rw-r--r--tensorflow/python/framework/ops.py50
-rw-r--r--tensorflow/python/framework/ops_test.py7
-rw-r--r--tensorflow/python/framework/tensor_util.py19
-rw-r--r--tensorflow/python/framework/versions.py12
-rw-r--r--tensorflow/python/grappler/cluster.i3
-rwxr-xr-xtensorflow/python/keras/BUILD1
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/resnet50.py3
-rw-r--r--tensorflow/python/keras/_impl/keras/estimator.py13
-rw-r--r--tensorflow/python/keras/_impl/keras/estimator_test.py20
-rw-r--r--tensorflow/python/kernel_tests/BUILD16
-rw-r--r--tensorflow/python/kernel_tests/constant_op_test.py8
-rw-r--r--tensorflow/python/kernel_tests/inplace_ops_test.py198
-rw-r--r--tensorflow/python/kernel_tests/list_ops_test.py113
-rw-r--r--tensorflow/python/ops/batch_norm_benchmark.py3
-rw-r--r--tensorflow/python/ops/control_flow_ops.py18
-rw-r--r--tensorflow/python/ops/control_flow_ops_test.py23
-rw-r--r--tensorflow/python/ops/gradients_impl.py48
-rw-r--r--tensorflow/python/ops/inplace_ops.py227
-rw-r--r--tensorflow/python/ops/list_ops.py11
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py11
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc8
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.h1
-rw-r--r--tensorflow/stream_executor/dnn.h7
-rw-r--r--tensorflow/tools/api/generator/BUILD2
-rw-r--r--tensorflow/tools/api/generator/create_python_api.py124
-rw-r--r--tensorflow/tools/api/generator/create_python_api_test.py6
-rw-r--r--tensorflow/tools/api/tests/BUILD1
-rw-r--r--tensorflow/tools/api/tests/api_compatibility_test.py56
-rw-r--r--tensorflow/tools/ci_build/windows/cpu/cmake/run_py.bat6
-rw-r--r--tensorflow/tools/pip_package/setup.py7
276 files changed, 19248 insertions, 3159 deletions
diff --git a/README.md b/README.md
index c66f7e3f3f..99f4a253d9 100644
--- a/README.md
+++ b/README.md
@@ -7,7 +7,7 @@
| **`Documentation`** | **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** |
|-----------------|---------------------|------------------|-------------------|---------------|---------------|
-| [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-cpu)](https://ci.tensorflow.org/job/tensorflow-master-cpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-linux-gpu)](https://ci.tensorflow.org/job/tensorflow-master-linux-gpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-mac)](https://ci.tensorflow.org/job/tensorflow-master-mac) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) [ ![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg) ](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
+| [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) | ![Build Status](https://storage.cloud.google.com/tensorflow-kokoro-build-badges/ubuntu-cc.png) | ![Build Status](https://storage.cloud.google.com/tensorflow-kokoro-build-badges/ubuntu-gpu-cc.png) | ![Build Status](https://storage.cloud.google.com/tensorflow-kokoro-build-badges/macos-py2-cc.png) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) [ ![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg) ](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
**TensorFlow** is an open source software library for numerical computation using
data flow graphs. The graph nodes represent mathematical operations, while
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 823393ebdf..cfafffdd13 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -401,25 +401,6 @@ package_group(
],
)
-py_library(
- name = "tensorflow_py",
- srcs = ["__init__.py"],
- srcs_version = "PY2AND3",
- visibility = ["//visibility:public"],
- deps = ["//tensorflow/python"],
-)
-
-py_library(
- name = "experimental_tensorflow_py",
- srcs = ["experimental_api.py"],
- srcs_version = "PY2AND3",
- visibility = ["//tensorflow/tools/api/tests:__subpackages__"],
- deps = [
- "//tensorflow/python",
- "//tensorflow/tools/api/generator:python_api",
- ],
-)
-
load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
@@ -553,3 +534,14 @@ exports_files(
"tf_exported_symbols.lds",
],
)
+
+py_library(
+ name = "tensorflow_py",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/python",
+ "//tensorflow/tools/api/generator:python_api",
+ ],
+)
diff --git a/tensorflow/__init__.py b/tensorflow/__init__.py
index 78ad6aec19..c8683e3976 100644
--- a/tensorflow/__init__.py
+++ b/tensorflow/__init__.py
@@ -20,14 +20,19 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+# pylint: disable=g-bad-import-order
+from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
# pylint: disable=wildcard-import
-from tensorflow.python import * # pylint: disable=redefined-builtin
+from tensorflow.tools.api.generator.api import * # pylint: disable=redefined-builtin
# pylint: enable=wildcard-import
from tensorflow.python.util.lazy_loader import LazyLoader
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
del LazyLoader
+from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
+app.flags = flags # pylint: disable=undefined-variable
+
del absolute_import
del division
del print_function
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index e345c1266a..a7a8d2d1ff 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -124,6 +124,7 @@ tf_xla_py_test(
name = "categorical_op_test",
size = "small",
srcs = ["categorical_op_test.py"],
+ tags = ["optonly"],
deps = [
":xla_test",
"//tensorflow/python:framework_for_generated_wrappers",
@@ -199,6 +200,10 @@ tf_xla_py_test(
"cpu",
"cpu_ondemand",
],
+ tags = [
+ # Allocates very large amounts of memory and does not work under TSAN.
+ "notsan",
+ ],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl
index a9db1c173d..45b6a6eb86 100644
--- a/tensorflow/compiler/tests/build_defs.bzl
+++ b/tensorflow/compiler/tests/build_defs.bzl
@@ -51,7 +51,7 @@ def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None,
if backend == "cpu":
backend_args += [
"--test_device=XLA_CPU",
- "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64"
+ "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64"
]
elif backend == "gpu":
backend_args += [
diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py
index 6083981493..ef47187477 100644
--- a/tensorflow/compiler/tests/spacetobatch_op_test.py
+++ b/tensorflow/compiler/tests/spacetobatch_op_test.py
@@ -163,6 +163,9 @@ class SpaceToBatchNDTest(XLATestCase):
# error.
if dtype == dtypes.bfloat16.as_numpy_dtype:
continue
+ # TODO(b/77694432): Half test failed on CPU, last ran on 04-06-2018.
+ if dtype == np.float16 and self.device == "XLA_CPU":
+ continue
placeholder = array_ops.placeholder(dtype)
# outputs = space_to_batch(inputs)
x_tf = array_ops.space_to_batch_nd(placeholder, block_shape, paddings)
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 17149aa1c8..ba79f393a8 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -154,6 +154,9 @@ class UnaryOpsTest(XLATestCase):
def testFloatOps(self):
for dtype in self.float_types:
+ # TODO(b/77694432): Half test failed on CPU, last ran on 04-06-2018.
+ if dtype == np.float16 and self.device == "XLA_CPU":
+ continue
x = np.arange(-0.90, 0.90, 0.25)
self._assertOpOutputMatchesExpected(
math_ops.acos,
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 8b7beef83e..16b9142cbf 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -901,6 +901,14 @@ FunctionalizeCond::DeterminePredicateSwitchOrder() {
int src_depth = switch_depth[src_id];
if (!e->IsControlEdge() || new_switch_depth == src_depth) {
if (src_depth != new_switch_depth) {
+ // TODO(b/77601805) remove this when outside_compilation supports
+ // control flow.
+ if (str_util::StrContains(src->name(), "outside_compilation") ||
+ str_util::StrContains(n->name(), "outside_compilation")) {
+ return errors::InvalidArgument(
+ "outside_compilation is not yet supported within TensorFlow "
+ "control flow constructs b/77601805");
+ }
return errors::InvalidArgument(
"Unable to functionalize control flow in graph: Operand ('",
src->name(), "') and operator ('", n->name(),
diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc
index a9978e697b..b813668a9e 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_test.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc
@@ -90,6 +90,11 @@ TEST(ConvertGraphDefToXla, Sum) {
TF_EXPECT_OK(result_or.status());
std::unique_ptr<xla::Literal> result = std::move(result_or.ValueOrDie());
EXPECT_EQ("(s32[]) (\n42\n)", result->ToString());
+
+ config.mutable_feed(0)->mutable_id()->set_output_index(
+ 123); /* invalid output_index */
+ EXPECT_TRUE(errors::IsInvalidArgument(
+ ConvertGraphDefToXla(graph_def, config, client, &computation)));
}
} // namespace
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc
index f428a19432..2fc77cc4bc 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc
@@ -151,8 +151,15 @@ Status AddPlaceholdersForFeeds(
Status status;
Node* feed_node = g.AddNode(gd.node(0), &status);
TF_RETURN_IF_ERROR(status);
- info.data_type =
- BaseType(feed_node->output_type(info.feed->id().output_index()));
+
+ if (info.feed->id().output_index() < feed_node->num_outputs()) {
+ info.data_type =
+ BaseType(feed_node->output_type(info.feed->id().output_index()));
+ } else {
+ return errors::InvalidArgument(
+ "Invalid output_index ", info.feed->id().output_index(),
+ " for feed node ", info.feed->id().node_name());
+ }
}
}
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc
index 3b0b2f06eb..62a5114837 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.cc
+++ b/tensorflow/compiler/tf2xla/xla_helpers.cc
@@ -122,6 +122,9 @@ xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b,
xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b,
DataType data_type) {
switch (data_type) {
+ case DT_HALF:
+ return b->ConstantR0<Eigen::half>(
+ static_cast<Eigen::half>(Eigen::NumTraits<Eigen::half>::epsilon()));
case DT_BFLOAT16:
return b->ConstantR0<bfloat16>(bfloat16::epsilon());
case DT_FLOAT:
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index 2d587cc3b9..ed9f994d39 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -548,7 +548,22 @@ XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice<XlaOp> operands,
XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value,
const PaddingConfig& padding_config) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+ TF_ASSIGN_OR_RETURN(const Shape& padding_value_shape,
+ GetShape(padding_value));
+ TF_ASSIGN_OR_RETURN(
+ *instr.mutable_shape(),
+ ShapeInference::InferPadShape(operand_shape, padding_value_shape,
+ padding_config));
+
+ *instr.mutable_padding_config() = padding_config;
+
+ return AddInstruction(std::move(instr), HloOpcode::kPad,
+ {operand, padding_value});
+ });
}
XlaOp XlaBuilder::Reshape(const XlaOp& operand,
@@ -578,7 +593,45 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand,
XlaOp XlaBuilder::Collapse(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> dimensions) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ if (dimensions.size() <= 1) {
+ // Not collapsing anything, trivially we can return the operand versus
+ // enqueueing a trivial reshape.
+ return operand;
+ }
+
+ // Out-of-order collapse is not supported.
+ // Checks that the collapsed dimensions are in order and consecutive.
+ for (tensorflow::gtl::ArraySlice<int64>::size_type i = 1;
+ i < dimensions.size(); ++i) {
+ if (dimensions[i] - 1 != dimensions[i - 1]) {
+ return InvalidArgument(
+ "Collapsed dimensions are not in consecutive order.");
+ }
+ }
+
+ // Create a new sizes vector from the old shape, replacing the collapsed
+ // dimensions by the product of their sizes.
+ TF_ASSIGN_OR_RETURN(const Shape& original_shape, GetShape(operand));
+
+ VLOG(3) << "original shape: " << ShapeUtil::HumanString(original_shape);
+ VLOG(3) << "dims to collapse: "
+ << tensorflow::str_util::Join(dimensions, ",");
+
+ std::vector<int64> new_sizes;
+ for (int i = 0; i < ShapeUtil::Rank(original_shape); ++i) {
+ if (i <= dimensions.front() || i > dimensions.back()) {
+ new_sizes.push_back(original_shape.dimensions(i));
+ } else {
+ new_sizes.back() *= original_shape.dimensions(i);
+ }
+ }
+
+ VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",")
+ << "]";
+
+ return Reshape(operand, new_sizes);
+ });
}
void XlaBuilder::Trace(const string& tag, const XlaOp& operand) {
@@ -728,12 +781,41 @@ XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type,
}
XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ if (!LayoutUtil::HasLayout(shape)) {
+ return InvalidArgument("Given shape to Infeed must have a layout");
+ }
+ *instr.mutable_shape() = shape;
+ instr.set_infeed_config(config);
+ return AddInstruction(std::move(instr), HloOpcode::kInfeed);
+ });
}
void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
const string& outfeed_config) {
- UnimplementedOp();
+ NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+
+ *instr.mutable_shape() = ShapeUtil::MakeNil();
+
+ // Check and set outfeed shape.
+ if (!LayoutUtil::HasLayout(shape_with_layout)) {
+ return InvalidArgument("Given shape to Outfeed must have a layout");
+ }
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+ if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) {
+ return InvalidArgument(
+ "Outfeed shape %s must be compatible with operand shape %s",
+ ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(),
+ ShapeUtil::HumanStringWithLayout(operand_shape).c_str());
+ }
+ *instr.mutable_outfeed_shape() = shape_with_layout;
+
+ instr.set_outfeed_config(outfeed_config);
+
+ return AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand});
+ });
}
XlaOp XlaBuilder::CustomCall(const string& call_target_name,
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 13675b7d00..c2950c1faa 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -1409,6 +1409,28 @@ std::unique_ptr<Literal> ConvertBetweenNativeTypes(const Literal& src_literal) {
src_literal, converter);
}
+template <typename NativeSrcT, typename NativeDestT>
+typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)),
+ std::unique_ptr<Literal>>::type
+BitcastBetweenNativeTypes(const Literal& src_literal) {
+ auto converter = [](NativeSrcT src) {
+ return tensorflow::bit_cast<NativeDestT>(src);
+ };
+ return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
+ src_literal, converter);
+}
+
+// This template specialization is here to make the compiler happy. bit_cast has
+// a static check that the types are the same size. This specialization should
+// never be used because the source and destination types are checked for
+// identical sizes higher up.
+template <typename NativeSrcT, typename NativeDestT>
+typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
+ std::unique_ptr<Literal>>::type
+BitcastBetweenNativeTypes(const Literal& src_literal) {
+ LOG(FATAL) << "Invalid bitcast between types of different sizes.";
+}
+
template <PrimitiveType primitive_src_type>
std::unique_ptr<Literal> ConvertToC64(const Literal& src_literal) {
CHECK(ShapeUtil::IsArray(src_literal.shape()));
@@ -1428,21 +1450,33 @@ std::unique_ptr<Literal> ConvertToC64(const Literal& src_literal) {
}
template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
-std::unique_ptr<Literal> ConvertIfTypesMatch(const Literal& src_literal) {
+std::unique_ptr<Literal> ConvertIfTypesMatch(const Literal& src_literal,
+ bool bitcast) {
CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
- return ConvertBetweenNativeTypes<
- typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type,
- typename primitive_util::PrimitiveTypeToNative<
- primitive_dest_type>::type>(src_literal);
+ if (bitcast) {
+ return BitcastBetweenNativeTypes<
+ typename primitive_util::PrimitiveTypeToNative<
+ primitive_src_type>::type,
+ typename primitive_util::PrimitiveTypeToNative<
+ primitive_dest_type>::type>(src_literal);
+ } else {
+ return ConvertBetweenNativeTypes<
+ typename primitive_util::PrimitiveTypeToNative<
+ primitive_src_type>::type,
+ typename primitive_util::PrimitiveTypeToNative<
+ primitive_dest_type>::type>(src_literal);
+ }
}
template <PrimitiveType primitive_src_type>
StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
- const Literal& src_literal, PrimitiveType primitive_dest_type) {
+ const Literal& src_literal, PrimitiveType primitive_dest_type,
+ bool bitcast) {
switch (primitive_dest_type) {
-#define CONVERT_IF_TYPES_MATCH(type) \
- case (type): \
- return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal);
+#define CONVERT_IF_TYPES_MATCH(type) \
+ case (type): \
+ return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal, \
+ bitcast);
CONVERT_IF_TYPES_MATCH(PRED)
CONVERT_IF_TYPES_MATCH(S8)
CONVERT_IF_TYPES_MATCH(S32)
@@ -1456,28 +1490,31 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
CONVERT_IF_TYPES_MATCH(BF16)
#undef CONVERT_IF_TYPES_MATCH
case C64:
- return ConvertToC64<primitive_src_type>(src_literal);
+ if (!bitcast) {
+ return ConvertToC64<primitive_src_type>(src_literal);
+ }
+ break;
// Other types are not yet supported.
default:
- return Unimplemented(
- "Converting from type %s to type %s is not implemented.",
- PrimitiveType_Name(src_literal.shape().element_type()).c_str(),
- PrimitiveType_Name(primitive_dest_type).c_str());
- }
-}
-
-} // namespace
-
-StatusOr<std::unique_ptr<Literal>> Literal::Convert(
- PrimitiveType primitive_dest_type) const {
- TF_RET_CHECK(ShapeUtil::IsArray(shape()));
- if (shape().element_type() == primitive_dest_type) {
- return CloneToUnique();
+ break;
}
- switch (shape().element_type()) {
-#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
- case (type): \
- return ConvertIfDestTypeMatches<(type)>(*this, primitive_dest_type);
+ return Unimplemented(
+ "Converting from type %s to type %s is not implemented.",
+ PrimitiveType_Name(src_literal.shape().element_type()).c_str(),
+ PrimitiveType_Name(primitive_dest_type).c_str());
+}
+
+StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
+ const Literal& literal, PrimitiveType primitive_dest_type, bool bitcast) {
+ TF_RET_CHECK(ShapeUtil::IsArray(literal.shape()));
+ if (literal.shape().element_type() == primitive_dest_type) {
+ return literal.CloneToUnique();
+ }
+ switch (literal.shape().element_type()) {
+#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
+ case (type): \
+ return ConvertIfDestTypeMatches<(type)>(literal, primitive_dest_type, \
+ bitcast);
CONVERT_IF_DEST_TYPE_MATCHES(PRED)
CONVERT_IF_DEST_TYPE_MATCHES(S8)
CONVERT_IF_DEST_TYPE_MATCHES(S32)
@@ -1493,12 +1530,35 @@ StatusOr<std::unique_ptr<Literal>> Literal::Convert(
// Other types are not yet supported.
default:
return Unimplemented(
- "Converting from type %s to type %s is not implemented.",
- PrimitiveType_Name(shape().element_type()).c_str(),
+ "%s from type %s to type %s is not implemented.",
+ (bitcast ? "Bitcast converting" : "Converting"),
+ PrimitiveType_Name(literal.shape().element_type()).c_str(),
PrimitiveType_Name(primitive_dest_type).c_str());
}
}
+} // namespace
+
+StatusOr<std::unique_ptr<Literal>> Literal::Convert(
+ PrimitiveType primitive_dest_type) const {
+ return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
+}
+
+StatusOr<std::unique_ptr<Literal>> Literal::BitcastConvert(
+ PrimitiveType primitive_dest_type) const {
+ if (primitive_util::BitWidth(shape().element_type()) !=
+ primitive_util::BitWidth(primitive_dest_type)) {
+ return InvalidArgument(
+ "Cannot bitcast convert from %s to %s, bit widths are different: %d != "
+ "%d",
+ PrimitiveType_Name(shape().element_type()).c_str(),
+ PrimitiveType_Name(primitive_dest_type).c_str(),
+ primitive_util::BitWidth(shape().element_type()),
+ primitive_util::BitWidth(primitive_dest_type));
+ }
+ return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
+}
+
StatusOr<std::unique_ptr<Literal>> Literal::ConvertToShape(
const Shape& dest_shape, bool round_f32_to_bf16) const {
if (!ShapeUtil::IsTuple(dest_shape)) {
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index a96a76fbb4..66ff39ecbb 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -333,11 +333,19 @@ class Literal {
template <typename NativeT>
std::unique_ptr<Literal> Replicate(int64 times) const;
- // Converts this literal to another primitive type. Returns an error if the
- // conversion is not possible. This literal must be array-shaped.
+ // Converts this literal to another primitive type using
+ // static_cast<>. Returns an error if the conversion is not possible. This
+ // literal must be array-shaped.
StatusOr<std::unique_ptr<Literal>> Convert(
PrimitiveType primitive_dest_type) const;
+ // Converts this literal to another primitive type using a bitcast
+ // conversion. The to and from primitive types must have the same bit
+ // width. Returns an error if the conversion is not possible. This literal
+ // must be array-shaped.
+ StatusOr<std::unique_ptr<Literal>> BitcastConvert(
+ PrimitiveType primitive_dest_type) const;
+
// Converts this literal to the given shape. Returns an error is the
// conversion is not possible.
//
@@ -587,6 +595,12 @@ class Literal {
template <typename NativeT, typename FnType>
Status Populate(const FnType& generator);
+ // A parallel version of Populate(). This can be used if the generator is
+ // thread-safe and the values for the shape's different elements are
+ // independent.
+ template <typename NativeT, typename FnType>
+ Status PopulateParallel(const FnType& generator);
+
// Fills this literal with the given value.
template <typename NativeT>
void PopulateWithValue(NativeT value);
@@ -785,6 +799,10 @@ class Literal {
// buffer).
void DeallocateBuffers();
+ // Implementation details shared between Populate() and PopulateParallel()
+ template <typename NativeT, typename FnType>
+ Status PopulateInternal(const FnType& generator, bool parallel);
+
Shape shape_;
ShapeTree<Piece> pieces_;
@@ -1276,7 +1294,7 @@ void Literal::PopulateSparse(SparseIndexArray indices,
}
template <typename NativeT, typename FnType>
-Status Literal::Populate(const FnType& generator) {
+Status Literal::PopulateInternal(const FnType& generator, bool parallel) {
const Shape& this_shape = shape();
const int64 rank = ShapeUtil::Rank(this_shape);
TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
@@ -1286,11 +1304,11 @@ Status Literal::Populate(const FnType& generator) {
if (rank > 0) {
StrideConfig stride_config(this_shape, this_shape,
AsInt64Slice(this_shape.dimensions()));
- DimensionVector minor_scan_indexes(rank, 0);
int64 minor_dimension_size =
ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
auto init_function = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
+ DimensionVector minor_scan_indexes(rank, 0);
const int64 index =
IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
@@ -1298,17 +1316,35 @@ Status Literal::Populate(const FnType& generator) {
minor_scan_indexes[stride_config.minor_dimension] = i;
literal_data.at(index + i) = generator(minor_scan_indexes);
}
- return true;
};
- ShapeUtil::ForEachIndex(this_shape, stride_config.base,
- stride_config.dimensions, stride_config.step,
- init_function);
+ if (parallel) {
+ ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base,
+ stride_config.dimensions,
+ stride_config.step, init_function);
+ } else {
+ ShapeUtil::ForEachIndex(
+ this_shape, stride_config.base, stride_config.dimensions,
+ stride_config.step,
+ [&init_function](tensorflow::gtl::ArraySlice<int64> indexes) {
+ init_function(indexes);
+ return true;
+ });
+ }
} else {
// For scalars.
literal_data.at(0) = generator({});
}
return Status::OK();
}
+template <typename NativeT, typename FnType>
+Status Literal::Populate(const FnType& generator) {
+ return PopulateInternal<NativeT>(generator, /*parallel=*/false);
+}
+
+template <typename NativeT, typename FnType>
+Status Literal::PopulateParallel(const FnType& generator) {
+ return PopulateInternal<NativeT>(generator, /*parallel=*/true);
+}
template <typename NativeT>
void Literal::PopulateWithValue(NativeT value) {
diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc
index 7627762074..be4f2bc5ce 100644
--- a/tensorflow/compiler/xla/literal_util_test.cc
+++ b/tensorflow/compiler/xla/literal_util_test.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -1090,6 +1091,48 @@ TEST_F(LiteralUtilTest, Populate) {
}
}
+TEST_F(LiteralUtilTest, PopulateParallel) {
+ struct PopulateData {
+ std::vector<int64> dimensions;
+ std::vector<int64> layout;
+ } populate_data[] = {
+ {{}, {}},
+ {{0}, {0}},
+ {{16}, {0}},
+ {{2, 0}, {1, 0}},
+ {{4, 16}, {1, 0}},
+ {{21, 12}, {0, 1}},
+ {{6, 11, 17}, {2, 0, 1}},
+ {{6, 11, 5, 17}, {3, 2, 0, 1}},
+ };
+ for (const auto& data : populate_data) {
+ Shape shape = ShapeUtil::MakeShapeWithLayout(
+ primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
+ data.layout);
+ auto literal = Literal::CreateFromShape(shape);
+ auto generator = [&](ArraySlice<int64> indexes) -> uint32 {
+ // Offsets from linear index just to avoid R0 literals to be initialized
+ // with zero.
+ return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
+ indexes) +
+ 17;
+ };
+ TF_EXPECT_OK(literal->PopulateParallel<uint32>(generator));
+
+ std::vector<int64> zero_base(data.dimensions.size(), 0);
+ std::vector<int64> step(data.dimensions.size(), 1);
+ bool matched = true;
+ auto check_function = [&](ArraySlice<int64> indexes) {
+ auto value = literal->Get<uint32>(indexes);
+ matched = matched && (value == generator(indexes));
+ return matched;
+ };
+ ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step,
+ check_function);
+ EXPECT_TRUE(matched);
+ }
+}
+
TEST_F(LiteralUtilTest, ConvertR4) {
// clang-format off
auto original = Literal::CreateR4WithLayout<int8>({{
@@ -1243,6 +1286,25 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
tensorflow::error::UNIMPLEMENTED);
}
+TEST_F(LiteralUtilTest, BitcastConvert) {
+ auto original =
+ Literal::CreateR1<uint32>({tensorflow::bit_cast<uint32>(2.5f),
+ tensorflow::bit_cast<uint32>(-42.25f),
+ tensorflow::bit_cast<uint32>(100.f), 0xbeef});
+ auto expected = Literal::CreateR1<float>(
+ {2.5f, -42.25f, 100.0f, tensorflow::bit_cast<float>(0xbeef)});
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> converted,
+ original->BitcastConvert(F32));
+}
+
+TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) {
+ auto literal = Literal::CreateR0<uint32>(1234);
+ Status status = literal->BitcastConvert(F64).status();
+ EXPECT_NE(Status::OK(), status);
+ EXPECT_TRUE(tensorflow::str_util::StrContains(status.error_message(),
+ "bit widths are different"));
+}
+
TEST_F(LiteralUtilTest, CopyFromProto_Bool) {
LiteralProto p;
p.mutable_shape()->set_element_type(PRED);
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 0e4624fd69..6cb1bd5669 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -1424,6 +1424,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
return Status::OK();
}
+// TODO(b/74536353): do this simplification for BroadcastDimOne as well.
StatusOr<bool> AlgebraicSimplifierVisitor::
TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(
HloInstruction* reshape_or_broadcast) {
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 56723e7650..3f7089d6ca 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -199,6 +199,7 @@ class DfsHloVisitorBase {
virtual Status HandleReduce(HloInstructionPtr hlo) = 0;
virtual Status HandleBitcast(HloInstructionPtr hlo) = 0;
virtual Status HandleBroadcast(HloInstructionPtr hlo) = 0;
+ virtual Status HandleBroadcastDimOne(HloInstructionPtr hlo) = 0;
virtual Status HandleReshape(HloInstructionPtr hlo) = 0;
virtual Status HandleTranspose(HloInstructionPtr hlo) = 0;
virtual Status HandleParameter(HloInstructionPtr hlo) = 0;
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index 240faebe62..e6680ee9b8 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -158,6 +158,9 @@ class DfsHloVisitorWithDefaultBase
Status HandleBroadcast(HloInstructionPtr broadcast) override {
return DefaultAction(broadcast);
}
+ Status HandleBroadcastDimOne(HloInstructionPtr broadcastDimOne) override {
+ return DefaultAction(broadcastDimOne);
+ }
Status HandlePad(HloInstructionPtr pad) override {
return DefaultAction(pad);
}
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index 1792893ae4..d6b457a91b 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -94,11 +94,17 @@ se::port::StatusOr<se::DeviceMemory<uint8>> ScratchAllocator::AllocateBytes(
// Determines whether we can safely perform a winograd non-fused convolution for
// the given input and output shapes. This works around b/68264959, an integer
// overflow in cuDNNv5 and cuDNNv6.
-//
-// TODO(jlebar): We shouldn't need this check for cuDNNv7.
-bool ShouldIncludeWinogradNonfusedAlgo(
- const Shape& input_shape, const Shape& output_shape,
- const ConvolutionDimensionNumbers& dnums) {
+bool ShouldIncludeWinogradNonfusedAlgo(const Shape& input_shape,
+ const Shape& output_shape,
+ const ConvolutionDimensionNumbers& dnums,
+ se::StreamExecutor* stream_exec) {
+ // Skip this check for cudnn7 and newer.
+ se::port::StatusOr<std::tuple<int, int, int>> version =
+ stream_exec->AsDnn()->GetVersion();
+ if (version.ok() && std::get<0>(version.ValueOrDie()) >= 7) {
+ return true;
+ }
+
int64 batch = input_shape.dimensions(dnums.input_batch_dimension());
int64 in_depths = input_shape.dimensions(dnums.input_feature_dimension());
int64 in_rows = input_shape.dimensions(dnums.input_spatial_dimensions(0));
@@ -118,20 +124,20 @@ bool ShouldIncludeWinogradNonfusedAlgo(
std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind,
bool with_winograd_nonfused,
- se::StreamExecutor* stream_exec_) {
+ se::StreamExecutor* stream_exec) {
std::vector<AlgorithmDesc> algorithms;
switch (kind) {
case CudnnConvKind::kBackwardFilter:
- CHECK(stream_exec_->GetConvolveBackwardFilterAlgorithms(
+ CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms(
with_winograd_nonfused, &algorithms));
break;
case CudnnConvKind::kBackwardInput:
- CHECK(stream_exec_->GetConvolveBackwardDataAlgorithms(
+ CHECK(stream_exec->GetConvolveBackwardDataAlgorithms(
with_winograd_nonfused, &algorithms));
break;
case CudnnConvKind::kForward:
- CHECK(stream_exec_->GetConvolveAlgorithms(with_winograd_nonfused,
- &algorithms));
+ CHECK(stream_exec->GetConvolveAlgorithms(with_winograd_nonfused,
+ &algorithms));
break;
}
@@ -209,8 +215,8 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
return nullopt;
}
- const bool use_winograd_nonfused =
- ShouldIncludeWinogradNonfusedAlgo(input_shape, output_shape, dnums);
+ const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo(
+ input_shape, output_shape, dnums, stream_exec_);
se::dnn::ProfileResult best_result;
int64 best_result_bytes_used = 0;
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
index 35ecd4428d..7aa38c6b79 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
@@ -69,7 +69,8 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
// Broadcasts dramatically increase the size of constants, which is often
// detrimental to performance and memory capacity, so do not fold
// broadcasts.
- if (instruction->opcode() == HloOpcode::kBroadcast) {
+ if (instruction->opcode() == HloOpcode::kBroadcast ||
+ instruction->opcode() == HloOpcode::kBroadcastDimOne) {
continue;
}
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 44e4f75f75..ea4dd62fdb 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -336,6 +336,11 @@ Status HloCostAnalysis::HandleBroadcast(const HloInstruction*) {
return Status::OK();
}
+Status HloCostAnalysis::HandleBroadcastDimOne(
+ const HloInstruction* broadcastDimOne) {
+ return Status::OK();
+}
+
Status HloCostAnalysis::HandlePad(const HloInstruction*) {
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index d17678d20f..a9f6845747 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -95,6 +95,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleSelectAndScatter(const HloInstruction* instruction) override;
Status HandleBitcast(const HloInstruction* bitcast) override;
Status HandleBroadcast(const HloInstruction* broadcast) override;
+ Status HandleBroadcastDimOne(const HloInstruction* broadcastDimOne) override;
Status HandlePad(const HloInstruction* pad) override;
Status HandleReshape(const HloInstruction* reshape) override;
Status HandleTranspose(const HloInstruction* transpose) override;
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 53ad8909c5..b4f9a9db9c 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -399,6 +399,22 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
+ Status HandleBitcastConvert(HloInstruction* convert) override {
+ const HloInstruction* operand = convert->operand(0);
+ TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape()));
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result,
+ parent_->GetEvaluatedLiteralFor(operand).BitcastConvert(
+ convert->shape().element_type()));
+
+ if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) {
+ parent_->evaluated_[convert] = std::move(result);
+ } else {
+ parent_->evaluated_[convert] =
+ result->Relayout(convert->shape().layout());
+ }
+ return Status::OK();
+ }
+
Status HandleExp(HloInstruction* exp) override {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp],
ElementWiseUnaryOp(exp, [](ElementwiseT elem_operand) {
@@ -998,18 +1014,6 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
- // Dimension number applicable for input (lhs).
- const int64 input_batch_dim = dnums.input_batch_dimension();
- const int64 input_z_dim = dnums.input_feature_dimension();
- // Dimension number applicable for kernel (rhs).
- const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension();
- const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension();
- // Dimension number applicable for output.
- const int64 output_batch_dim = dnums.output_batch_dimension();
- const int64 output_z_dim = dnums.output_feature_dimension();
-
- const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim);
-
std::vector<int64> window_dimension_sizes;
for (auto i : dnums.kernel_spatial_dimensions()) {
window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i));
@@ -1021,14 +1025,27 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
DimensionVector lhs_dim_multipliers = MakeDimMultipliers(lhs_shape);
DimensionVector rhs_dim_multipliers = MakeDimMultipliers(rhs_shape);
- DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size());
-
auto lhs_literal_data = lhs_literal.data<ReturnT>();
auto rhs_literal_data = rhs_literal.data<ReturnT>();
- auto func = [&](ArraySlice<int64> out_index) {
+ auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window,
+ &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data,
+ rhs_literal_data](ArraySlice<int64> out_index) {
+ // Dimension number applicable for input (lhs).
+ const int64 input_batch_dim = dnums.input_batch_dimension();
+ const int64 input_z_dim = dnums.input_feature_dimension();
+ // Dimension number applicable for kernel (rhs).
+ const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension();
+ const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension();
+ // Dimension number applicable for output.
+ const int64 output_batch_dim = dnums.output_batch_dimension();
+ const int64 output_z_dim = dnums.output_feature_dimension();
+
+ const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim);
+
ElementwiseT result_val = static_cast<ElementwiseT>(0);
- std::fill(rhs_spatial_index.begin(), rhs_spatial_index.end(), 0);
+ DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(),
+ 0);
// Convolve input feature with kernel.
do {
@@ -1100,7 +1117,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
};
auto result = Literal::CreateFromShape(result_shape);
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
+ TF_RETURN_IF_ERROR(result->PopulateParallel<ReturnT>(func));
parent_->evaluated_[conv] = std::move(result);
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 25702dc65e..c35783c456 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -956,6 +956,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kTuple:
return kWhite;
case HloOpcode::kBroadcast:
+ case HloOpcode::kBroadcastDimOne:
// De-emphasize nodes which broadcast a scalar within a fusion node --
// these are essentially free.
if (instr->IsFused() &&
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index fcf9ebf5f7..8149e47cb5 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -684,6 +684,15 @@ HloInstruction::CreateSelectAndScatter(
}
/* static */ std::unique_ptr<HloInstruction>
+HloInstruction::CreateBroadcastDimOne(const Shape& shape,
+ HloInstruction* operand) {
+ auto instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kBroadcastDimOne, shape));
+ instruction->AppendOperand(operand);
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateBroadcastSequence(
const Shape& output_shape, HloInstruction* operand,
const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
@@ -1275,6 +1284,10 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
CHECK_EQ(new_operands.size(), 1);
clone = CreateBroadcast(shape, new_operands[0], dimensions_);
break;
+ case HloOpcode::kBroadcastDimOne:
+ CHECK_EQ(new_operands.size(), 1);
+ clone = CreateBroadcastDimOne(shape, new_operands[0]);
+ break;
case HloOpcode::kCall:
clone = CreateCall(shape, new_operands, to_apply());
break;
@@ -1826,6 +1839,8 @@ bool HloInstruction::IdenticalSlowPath(
// Remaining instructions with special values.
case HloOpcode::kBitcast:
+ case HloOpcode::kBroadcastDimOne:
+ case HloOpcode::kDynamicUpdateSlice:
return eq_shapes(shape(), other.shape());
case HloOpcode::kBroadcast:
return eq_shapes(shape(), other.shape()) &&
@@ -1844,8 +1859,6 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kDynamicSlice:
return eq_shapes(shape(), other.shape()) &&
dynamic_slice_sizes_ == other.dynamic_slice_sizes_;
- case HloOpcode::kDynamicUpdateSlice:
- return eq_shapes(shape(), other.shape());
case HloOpcode::kCall:
case HloOpcode::kMap:
return eq_computations(to_apply(), other.to_apply());
@@ -2646,6 +2659,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleBitcast(this);
case HloOpcode::kBroadcast:
return visitor->HandleBroadcast(this);
+ case HloOpcode::kBroadcastDimOne:
+ return visitor->HandleBroadcastDimOne(this);
case HloOpcode::kPad:
return visitor->HandlePad(this);
case HloOpcode::kReshape:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 80f8408244..a6cb19f331 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -401,6 +401,10 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ // Creates a broadcast-size-one-dimensions instruction.
+ static std::unique_ptr<HloInstruction> CreateBroadcastDimOne(
+ const Shape& shape, HloInstruction* operand);
+
// Creates a sequence of instructions that performs an explicit broadcast of
// the operand to the target shape.
//
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index af24604c39..dddc72480f 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -54,6 +54,7 @@ namespace xla {
V(kBitcast, "bitcast") \
V(kBitcastConvert, "bitcast-convert") \
V(kBroadcast, "broadcast") \
+ V(kBroadcastDimOne, "broadcast-dim-one") \
V(kCall, "call", kHloOpcodeIsVariadic) \
V(kCeil, "ceil") \
V(kClamp, "clamp") \
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 8c875698eb..63ec5964eb 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -174,17 +174,34 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
TF_RETURN_IF_ERROR(CheckShape(broadcast, broadcast->shape()));
TF_RET_CHECK(ShapeUtil::Rank(operand_shape) ==
broadcast->dimensions().size());
- for (int64 operand_dimension = 0;
- operand_dimension < ShapeUtil::Rank(operand_shape);
- ++operand_dimension) {
- int64 output_dimension = broadcast->dimensions()[operand_dimension];
+ for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) {
+ int64 output_dimension = broadcast->dimensions()[i];
TF_RET_CHECK(broadcast->shape().dimensions(output_dimension) ==
- operand_shape.dimensions(operand_dimension))
+ operand_shape.dimensions(i))
<< broadcast->ToString() << " operand shape " << operand_shape;
}
return tensorflow::Status::OK();
}
+Status ShapeVerifier::HandleBroadcastDimOne(HloInstruction* broadcastDimOne) {
+ const Shape& operand_shape = broadcastDimOne->operand(0)->shape();
+ int64 operand_rank = ShapeUtil::Rank(operand_shape);
+ const Shape& output_shape = broadcastDimOne->shape();
+ // Check for mixed precision.
+ TF_RETURN_IF_ERROR(CheckShape(broadcastDimOne, output_shape));
+ TF_RET_CHECK(operand_rank == ShapeUtil::Rank(output_shape));
+ for (int64 i = 0; i < operand_rank; ++i) {
+ int64 operand_dimension = operand_shape.dimensions(i);
+ int64 output_dimension = output_shape.dimensions(i);
+ TF_RET_CHECK(operand_dimension == 1 ||
+ operand_dimension == output_dimension)
+ << "Dimension " << i << " of broadcastDimOne "
+ << broadcastDimOne->ToString() << " is " << operand_dimension
+ << ", expected 1 or " << output_dimension;
+ }
+ return tensorflow::Status::OK();
+}
+
Status ShapeVerifier::HandleReshape(HloInstruction* reshape) {
// Check for mixed precision.
TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape()));
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 1dd7ec3c51..a4dff977ba 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -54,6 +54,7 @@ class ShapeVerifier : public DfsHloVisitor {
Status HandleReduce(HloInstruction* reduce) override;
Status HandleBitcast(HloInstruction* bitcast) override;
Status HandleBroadcast(HloInstruction* broadcast) override;
+ Status HandleBroadcastDimOne(HloInstruction* broadcastDimOne) override;
Status HandleReshape(HloInstruction* reshape) override;
Status HandleTranspose(HloInstruction* transpose) override;
Status HandleParameter(HloInstruction*) override;
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index d69ad80bdb..3f4dbf897d 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -37,6 +37,7 @@ namespace xla {
case HloOpcode::kBitcast:
case HloOpcode::kBitcastConvert:
case HloOpcode::kBroadcast:
+ case HloOpcode::kBroadcastDimOne:
case HloOpcode::kCeil:
case HloOpcode::kClamp:
case HloOpcode::kComplex:
@@ -142,7 +143,8 @@ bool InstructionFusion::EffectivelyUnary(HloInstruction* hlo) {
});
return std::count_if(hlo->operands().begin(), hlo->operands().end(),
[output_rank](HloInstruction* operand) {
- if (operand->opcode() == HloOpcode::kBroadcast) {
+ if (operand->opcode() == HloOpcode::kBroadcast ||
+ operand->opcode() == HloOpcode::kBroadcastDimOne) {
return false;
}
if (operand->opcode() == HloOpcode::kConstant &&
@@ -247,7 +249,8 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
auto reachability = computation->ComputeReachability();
auto cheap_to_duplicate = [this](HloInstruction* producer) {
- if (producer->opcode() == HloOpcode::kBroadcast) {
+ if (producer->opcode() == HloOpcode::kBroadcast ||
+ producer->opcode() == HloOpcode::kBroadcastDimOne) {
return true;
}
if (producer->opcode() == HloOpcode::kConstant &&
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 3e130a02e2..1375f981a8 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -28,8 +28,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/optional.h"
+#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -583,34 +585,7 @@ class ShapeUtil {
tensorflow::gtl::ArraySlice<int64> count,
tensorflow::gtl::ArraySlice<int64> incr,
const FnType& visitor_function) {
- if (ShapeUtil::HasZeroElements(shape)) {
- return Status::OK();
- }
- CHECK_EQ(Rank(shape), base.size());
- CHECK_EQ(incr.size(), base.size());
- CHECK_EQ(count.size(), base.size());
- const int64 rank = LayoutUtil::MinorToMajor(shape).size();
- // Allows handling R0 arrays, such that the visitor function will be called
- // once with the proper empty indexes.
- int64 n = -1;
- std::vector<int64> indexes(base.begin(), base.end());
- while (n < rank) {
- TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes));
- if (!should_continue) {
- break;
- }
- // Increments dimensions in minor to major order.
- for (n = 0; n < rank; ++n) {
- int64 dim = LayoutUtil::Minor(shape.layout(), n);
- indexes[dim] += incr[dim];
- if (indexes[dim] < base[dim] + count[dim]) {
- break;
- }
- indexes[dim] = base[dim];
- }
- }
-
- return Status::OK();
+ return ForEachIndexInternal(shape, base, count, incr, visitor_function);
}
// Simple ergonomic wrapper around ShapeUtil::ForEachIndexWithStatus.
@@ -642,11 +617,83 @@ class ShapeUtil {
.IgnoreError();
}
+ // A parallel version of ForEachIndex(WithStatus). This can only be used if
+ // the visitor_function is thread-safe and the order of iteration does not
+ // matter.
+ //
+ // visitor_function must be a callable of type
+ // void(ArraySlice<int64>) or compatible.
+ template <typename FnType>
+ static void ForEachIndexParallel(const Shape& shape,
+ tensorflow::gtl::ArraySlice<int64> base,
+ tensorflow::gtl::ArraySlice<int64> count,
+ tensorflow::gtl::ArraySlice<int64> incr,
+ const FnType& visitor_function) {
+ // The parallel version of ForEachIndexInternal can never fail.
+ CHECK(ForEachIndexInternal(
+ shape, base, count, incr,
+ [&visitor_function](tensorflow::gtl::ArraySlice<int64> indexes)
+ -> StatusOr<bool> {
+ visitor_function(indexes);
+ return true;
+ },
+ /*parallel=*/true)
+ .ok());
+ }
+
private:
// Validates all of the non-layout properties of the shape -- this is a helper
// used by both the layout-optional and layout-required public method.
static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape);
+ template <typename FnType>
+ static Status ForEachIndexInternal(const Shape& shape,
+ tensorflow::gtl::ArraySlice<int64> base,
+ tensorflow::gtl::ArraySlice<int64> count,
+ tensorflow::gtl::ArraySlice<int64> incr,
+ const FnType& visitor_function,
+ bool parallel = false) {
+ if (ShapeUtil::HasZeroElements(shape)) {
+ return Status::OK();
+ }
+ CHECK_EQ(Rank(shape), base.size());
+ CHECK_EQ(incr.size(), base.size());
+ CHECK_EQ(count.size(), base.size());
+ const int64 rank = LayoutUtil::MinorToMajor(shape).size();
+ // Allows handling R0 arrays, such that the visitor function will be called
+ // once with the proper empty indexes.
+ int64 n = -1;
+ std::vector<int64> indexes(base.begin(), base.end());
+ const int kNumThreads = tensorflow::port::NumSchedulableCPUs();
+ tensorflow::gtl::optional<tensorflow::thread::ThreadPool> pool;
+ if (parallel) {
+ pool.emplace(tensorflow::Env::Default(), "foreach", kNumThreads);
+ }
+
+ while (n < rank) {
+ if (pool != tensorflow::gtl::nullopt) {
+ pool->Schedule(
+ [indexes, &visitor_function] { visitor_function(indexes); });
+ } else {
+ TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes));
+ if (!should_continue) {
+ break;
+ }
+ }
+ // Increments dimensions in minor to major order.
+ for (n = 0; n < rank; ++n) {
+ int64 dim = LayoutUtil::Minor(shape.layout(), n);
+ indexes[dim] += incr[dim];
+ if (indexes[dim] < base[dim] + count[dim]) {
+ break;
+ }
+ indexes[dim] = base[dim];
+ }
+ }
+
+ return Status::OK();
+ }
+
TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil);
};
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
index 424cfe37ea..13582a2a26 100644
--- a/tensorflow/compiler/xla/shape_util_test.cc
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -624,6 +624,24 @@ TEST(ShapeUtilTest, ForEachIndexWithStatus) {
EXPECT_EQ(invocations, 5);
}
+TEST(ShapeUtilTest, ForEachIndexParallel) {
+ Shape shape = ShapeUtil::MakeShape(F32, {10, 10});
+ int64 output[10][10];
+ int init = 5;
+ auto set_func = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
+ output[indexes[0]][indexes[1]] = init + indexes[0] + indexes[1];
+ };
+
+ ShapeUtil::ForEachIndexParallel(shape, /*base=*/{0, 0}, /*count=*/{10, 10},
+ /*incr=*/{1, 1}, set_func);
+
+ for (int i = 0; i < 10; ++i) {
+ for (int j = 0; j < 10; ++j) {
+ EXPECT_EQ(output[i][j], init + i + j);
+ }
+ }
+}
+
TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) {
// All output dimensions should be unmodified. One of the input dimensions is
// modified because the input rank is larger by one.
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 6f58c20f34..218345772f 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -1266,9 +1266,9 @@ xla_test(
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1372,11 +1372,10 @@ xla_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1477,6 +1476,9 @@ xla_test(
xla_test(
name = "bitcast_convert_test",
srcs = ["bitcast_convert_test.cc"],
+ tags = [
+ "enable_for_xla_interpreter",
+ ],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index 17c6a83c1a..c2e3cd2350 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -74,9 +74,9 @@ string ClientLibraryTestBase::TestName() const {
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
}
+template <typename BuilderT>
StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
- ComputationBuilder* builder,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ BuilderT* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
// Build the computation, as a convenience.
TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
return client_->Execute(computation, arguments, &execution_options_);
@@ -595,6 +595,14 @@ ComputationDataHandle ClientLibraryTestBase::AddParam(
return data_handle;
}
+XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
+ XlaBuilder* builder) {
+ XlaOp data_handle;
+ arguments_.push_back(CreateParameterAndTransferLiteral(
+ arguments_.size(), argument, "", builder, &data_handle));
+ return data_handle;
+}
+
ComputationDataHandle ClientLibraryTestBase::CreateConstantFromLiteral(
const Literal& literal, ComputationBuilder* builder) {
return builder->ConstantLiteral(
@@ -643,4 +651,11 @@ template void ClientLibraryTestBase::ComputeAndCompareTuple(
XlaBuilder* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error);
+template StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
+ ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+
+template StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
+ XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index 52f31b0669..0572acff88 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -92,9 +92,9 @@ class ClientLibraryTestBase : public ::testing::Test {
// Convenience methods for building and running a computation with the member
// execution options. Modify execution_options_ in your test if you want to
// customize the options.
+ template <typename BuilderT>
StatusOr<std::unique_ptr<GlobalData>> Execute(
- ComputationBuilder* builder,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ BuilderT* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments);
// TODO(b/74197823): Remove the template type 'BuilderT' in all methods once
// the migration to XlaBuilder is complete.
@@ -300,12 +300,17 @@ class ClientLibraryTestBase : public ::testing::Test {
// set exactly once. The first added parameter gets index 0, then 1 and so on.
ComputationDataHandle AddParam(const Literal& argument,
ComputationBuilder* builder);
+ XlaOp AddParam(const Literal& argument, XlaBuilder* builder);
template <class T>
ComputationDataHandle AddParam(const Array<T>& argument,
ComputationBuilder* builder) {
return AddParam(*Literal::CreateFromArray(argument), builder);
}
+ template <class T>
+ XlaOp AddParam(const Array<T>& argument, XlaBuilder* builder) {
+ return AddParam(*Literal::CreateFromArray(argument), builder);
+ }
// Creates a constant instruction with the given literal. When the
// use_bfloat16 flag is set but the literal has F32 elements, the elements
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
index e574644dea..21f71fc91b 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -91,7 +91,7 @@ HloTestBase::HloTestBase()
HloTestBase::HloTestBase(se::Platform* test_platform,
se::Platform* reference_platform)
: test_runner_(test_platform), reference_runner_(reference_platform) {
- hlo_verifier_ = MakeUnique<HloVerifier>();
+ hlo_verifier_ = MakeUnique<HloVerifier>(/*allow_mixed_precision=*/true);
}
/* static */
@@ -142,8 +142,7 @@ StatusOr<std::unique_ptr<HloModule>> HloTestBase::MakeReferenceModule(
"reference preprocessor must not modify the program shape");
}
}
- TF_RETURN_IF_ERROR(VerifyHloModule(*reference_runner_.backend().platform(),
- reference_module.get()));
+ TF_RETURN_IF_ERROR(hlo_verifier_->Run(reference_module.get()).status());
return std::move(reference_module);
}
@@ -151,8 +150,7 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
std::unique_ptr<HloModule> module, const ArraySlice<Literal*> arguments,
const optional<ErrorSpec>& error, bool run_hlo_passes,
const std::function<void(HloModule*)>& reference_preprocessor) {
- TF_RETURN_IF_ERROR(
- VerifyHloModule(*test_runner_.backend().platform(), module.get()));
+ TF_RETURN_IF_ERROR(hlo_verifier_->Run(module.get()).status());
TF_ASSIGN_OR_RETURN(auto reference_module,
MakeReferenceModule(*module, reference_preprocessor));
diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc
index 8cef8dd34d..ce295b832d 100644
--- a/tensorflow/compiler/xla/tests/pad_test.cc
+++ b/tensorflow/compiler/xla/tests/pad_test.cc
@@ -18,9 +18,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@@ -85,7 +85,7 @@ class PadTestFloat : public PadTest,
// Tests a Pad() with a zero-element input and output.
XLA_TEST_P(PadTestFloat, Pad1DS0ToS0Array) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
// Set up the padding configuration {low: 0, high: 0, interior: 0}.
PaddingConfig padding_config;
auto dimension = padding_config.add_dimensions();
@@ -100,7 +100,7 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS0Array) {
// Tests a Pad() with a zero-element input but a non-zero-element output.
XLA_TEST_P(PadTestFloat, Pad1DS0ToS5Array) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
// Set up the padding configuration {low: 3, high: 0, interior: 1}.
PaddingConfig padding_config;
auto dimension = padding_config.add_dimensions();
@@ -115,7 +115,7 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS5Array) {
}
XLA_TEST_P(PadTestFloat, Pad1DS3Array) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
// Set up the padding configuration {low: 3, high: 0, interior: 1}.
PaddingConfig padding_config;
auto dimension = padding_config.add_dimensions();
@@ -130,7 +130,7 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) {
}
XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
b.Pad(AddParam(Array4D<float>(2, 0, 3, 2), &b),
AddParam(*Literal::CreateR0<float>(1.5), &b), r4_padding_on_dim0_dim1_);
ComputeAndCompareR4<float>(&b, Array4D<float>(5, 2, 3, 2, 1.5f), {},
@@ -138,7 +138,7 @@ XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) {
}
TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
auto input = MakeUnique<Array4D<float>>(1, 1, 3, 2);
Array2D<float> input_xy({
{1.0f, 2.0f}, // row 0
@@ -162,7 +162,7 @@ TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) {
}
TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
const float pad_value = 1.5f;
Array4D<float> input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6});
@@ -181,7 +181,7 @@ TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) {
}
TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstSmall) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
PaddingConfig padding_config;
auto dimension0 = padding_config.add_dimensions();
@@ -223,7 +223,7 @@ TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstSmall) {
}
XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
PaddingConfig padding_config;
auto dimension0 = padding_config.add_dimensions();
@@ -266,7 +266,7 @@ XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) {
}
XLA_TEST_F(PadTest, Pad4DU8Array) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
auto input = MakeUnique<Array4D<uint8>>(1, 1, 3, 2);
Array2D<uint8> input_xy({
{1, 2}, // row 0
@@ -290,7 +290,7 @@ XLA_TEST_F(PadTest, Pad4DU8Array) {
}
XLA_TEST_F(PadTest, Pad4DPredArray) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
// Since bool is currently not well supported, use Broadcast operation to
// create the operand for Pad.
@@ -317,7 +317,7 @@ XLA_TEST_F(PadTest, Pad4DPredArray) {
}
XLA_TEST_P(PadTestFloat, Large2DPad) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
auto ones = MakeUnique<Array2D<float>>(4, 4);
ones->Fill(1.0f);
@@ -329,15 +329,14 @@ XLA_TEST_P(PadTestFloat, Large2DPad) {
padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 +
100 * dim);
}
- auto padded = b.Pad(input, AddParam(*Literal::CreateR0<float>(0.0f), &b),
- padding_config);
+ b.Pad(input, AddParam(*Literal::CreateR0<float>(0.0f), &b), padding_config);
auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f);
ComputeAndCompareR2<float>(&b, *expected, {}, DefaultErrorSpec());
}
XLA_TEST_P(PadTestFloat, AllTypes2DPad) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
constexpr int64 in_rows = 35;
constexpr int64 in_cols = 35;
@@ -352,15 +351,14 @@ XLA_TEST_P(PadTestFloat, AllTypes2DPad) {
padding_config.mutable_dimensions(1)->set_edge_padding_low(6);
padding_config.mutable_dimensions(1)->set_edge_padding_high(4);
padding_config.mutable_dimensions(1)->set_interior_padding(2);
- auto padded = b.Pad(input, AddParam(*Literal::CreateR0<float>(3.14f), &b),
- padding_config);
+ b.Pad(input, AddParam(*Literal::CreateR0<float>(3.14f), &b), padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f);
ComputeAndCompareR2<float>(&b, *expected, {}, DefaultErrorSpec());
}
XLA_TEST_P(PadTestFloat, High2DPad) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
constexpr int64 in_rows = 129;
constexpr int64 in_cols = 129;
@@ -378,8 +376,7 @@ XLA_TEST_P(PadTestFloat, High2DPad) {
padding_config.mutable_dimensions(dim)->set_interior_padding(
interior_padding);
}
- auto padded = b.Pad(input, AddParam(*Literal::CreateR0<float>(2.718f), &b),
- padding_config);
+ b.Pad(input, AddParam(*Literal::CreateR0<float>(2.718f), &b), padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
@@ -387,7 +384,7 @@ XLA_TEST_P(PadTestFloat, High2DPad) {
}
XLA_TEST_P(PadTestFloat, NegativePadding2D) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
constexpr int64 in_rows = 129;
constexpr int64 in_cols = 129;
@@ -406,8 +403,7 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) {
padding_config.mutable_dimensions(dim)->set_interior_padding(
interior_padding);
}
- auto padded = b.Pad(input, AddParam(*Literal::CreateR0<float>(2.718f), &b),
- padding_config);
+ b.Pad(input, AddParam(*Literal::CreateR0<float>(2.718f), &b), padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
@@ -415,7 +411,7 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) {
}
XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
constexpr int64 in_rows = 8;
constexpr int64 in_cols = 11;
@@ -434,8 +430,7 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) {
padding_config.mutable_dimensions(dim)->set_interior_padding(
interior_padding[dim]);
}
- auto padded = b.Pad(input, AddParam(*Literal::CreateR0<float>(2.718f), &b),
- padding_config);
+ b.Pad(input, AddParam(*Literal::CreateR0<float>(2.718f), &b), padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
@@ -444,20 +439,19 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) {
// Regression test for b/31827337.
XLA_TEST_P(PadTestFloat, ReducePad) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
auto ones = MakeUnique<Array4D<float>>(2, 2, 2, 2);
ones->Fill(1.0);
auto input = AddParam(*ones, &b);
- Computation add = CreateScalarAddComputation(FloatType(), &b);
+ XlaComputation add = CreateScalarAddComputation(FloatType(), &b);
auto reduce =
b.Reduce(input, AddParam(*Literal::CreateR0<float>(0.0), &b), add, {0});
PaddingConfig padding_config = MakeNoPaddingConfig(3);
padding_config.mutable_dimensions(0)->set_edge_padding_low(1);
padding_config.mutable_dimensions(0)->set_edge_padding_high(1);
- auto padded = b.Pad(reduce, AddParam(*Literal::CreateR0<float>(0.0f), &b),
- padding_config);
+ b.Pad(reduce, AddParam(*Literal::CreateR0<float>(0.0f), &b), padding_config);
Array3D<float> expected({{{0.0, 0.0}, {0.0, 0.0}},
{{2.0, 2.0}, {2.0, 2.0}},
diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc
index 02272d6017..d7462d581b 100644
--- a/tensorflow/compiler/xla/tests/reshape_test.cc
+++ b/tensorflow/compiler/xla/tests/reshape_test.cc
@@ -20,11 +20,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
@@ -53,11 +52,11 @@ class ReshapeTest : public ::testing::WithParamInterface<bool>,
// Collapses 2-dimensional pseudo-scalar (single-element array) to 1 dimension.
XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array2D<float> input_array(1, 1);
input_array.Fill(1.0f);
auto input_literal = Literal::CreateR2FromArray2D(input_array);
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
&builder, &parameter);
builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
@@ -68,9 +67,9 @@ XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) {
}
XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto input_literal = Literal::CreateR1<float>({1.0f});
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
&builder, &parameter);
builder.Collapse(/*operand=*/parameter, /*dimensions=*/{});
@@ -81,9 +80,9 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) {
}
XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto input_literal = Literal::CreateR1<float>({1.0f});
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
&builder, &parameter);
builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0});
@@ -95,11 +94,11 @@ XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) {
// Collapses 2-dimensional pseudo-scalar (single-element array) to scalar.
XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array2D<float> input_array(1, 1);
input_array.Fill(1.0f);
auto input_literal = Literal::CreateR2FromArray2D(input_array);
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
&builder, &parameter);
auto reshape = builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
@@ -112,15 +111,14 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) {
}
XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(1.0f);
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0",
&builder, &parameter);
auto a = builder.Neg(parameter);
- auto reshape =
- builder.Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1});
+ builder.Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1});
auto expected_literal = Literal::CreateR1<float>({-1.0f});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
@@ -131,10 +129,10 @@ XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) {
// does not handle zero-sized shapes correctly. Failed last on 2017-11-30
// with an incorrect result rank.
XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3)) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array2D<float> input_array(0, 3);
auto input_literal = Literal::CreateR2FromArray2D(input_array);
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
@@ -147,11 +145,11 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3)) {
// does not handle zero-sized shapes correctly. Failed last on 2017-05-15
// with an incorrect result rank.
XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
Literal::CreateR2FromArray2D<float>(Array2D<float>(0, 3));
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0",
&builder, &parameter);
builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
@@ -164,10 +162,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) {
// does not handle zero-sized shapes correctly. Failed last on 2017-11-30
// with an incorrect result rank.
XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial3x0)) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array2D<float> input_array(3, 0);
auto input_literal = Literal::CreateR2FromArray2D(input_array);
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
@@ -178,9 +176,9 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial3x0)) {
// Collapses a 2-dimensional row vector to 1 dimension.
XLA_TEST_P(ReshapeTest, Trivial1x3) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto input_literal = Literal::CreateR2<float>({{1.0f, 2.0f, 3.0f}});
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
@@ -191,9 +189,9 @@ XLA_TEST_P(ReshapeTest, Trivial1x3) {
// Collapses a 2-dimensional column vector to 1 dimension.
XLA_TEST_P(ReshapeTest, Trivial3x1) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto input_literal = Literal::CreateR2<float>({{1.0f}, {2.0f}, {3.0f}});
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
@@ -344,9 +342,9 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitNoShuffleZeroElements)) {
// does not handle zero-sized shapes correctly. Failed last on 2017-11-30
// with an incorrect result rank.
XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeR4ToR2ZeroElements)) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto input_literal = Literal::CreateFromArray(Array4D<float>(2, 3, 4, 0));
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3},
@@ -359,10 +357,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeR4ToR2ZeroElements)) {
// Reshapes a 2-dimensional array with dimensions that are not just a
// rearrangement of the originals (split), but no reordering (no shuffle).
XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
auto input_literal = Literal::CreateFromArray(*a4x3);
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
@@ -379,9 +377,9 @@ XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) {
// with an incorrect result rank.
//
XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitAndShuffleZeroElements)) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto input_literal = Literal::CreateFromArray(Array2D<float>(0, 6));
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
@@ -394,10 +392,10 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitAndShuffleZeroElements)) {
// Reshapes a 2-dimensional array with dimensions that are not just a
// rearrangement of the originals (split), and reorder the input (shuffle).
XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
auto input_literal = Literal::CreateFromArray(*a4x3);
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
@@ -421,9 +419,9 @@ static Array3D<float> ArrayForDocR3Tests() {
}
XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests());
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2},
@@ -436,9 +434,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) {
}
XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests());
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2},
@@ -456,9 +454,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) {
}
XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests());
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
@@ -471,9 +469,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) {
}
XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests());
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
@@ -491,9 +489,9 @@ XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) {
}
XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests());
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
@@ -521,12 +519,12 @@ XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) {
//
// 1 2 3 4 5 6 1 2 3 4 5 6
XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array4D<float> t2x2x2x3(2, 2, 2, 3);
auto filler2x3 = MakeLinspaceArray2D(1.0f, 6.0f, 2, 3);
t2x2x2x3.FillWithYX(*filler2x3);
auto input_literal = Literal::CreateFromArray(t2x2x2x3);
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
builder.Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3});
@@ -540,7 +538,7 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) {
// As above, but uses reshape directly.
XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array4D<float> t(2, 1, 2, 2);
t(0, 0, 0, 0) = 0;
t(0, 0, 0, 1) = 1;
@@ -551,7 +549,7 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) {
t(1, 0, 1, 0) = 6;
t(1, 0, 1, 1) = 7;
auto input_literal = Literal::CreateFromArray(t);
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3},
@@ -566,7 +564,7 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) {
// Reshape various ranks to a scalar.
XLA_TEST_P(ReshapeTest, ToScalar) {
for (int rank = 0; rank < 8; ++rank) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
std::vector<int64> ones(rank, 1); // this is {1, ..., 1}.
std::vector<int64> dimensions(rank);
std::iota(dimensions.begin(), dimensions.end(), 0);
@@ -574,7 +572,7 @@ XLA_TEST_P(ReshapeTest, ToScalar) {
std::vector<int64> zeros(rank, 0); // this is {0, ..., 0}.
input_literal.Set<float>(zeros, 83.0f);
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&b, &parameter);
b.Reshape(parameter, dimensions, {});
@@ -586,9 +584,9 @@ XLA_TEST_P(ReshapeTest, ToScalar) {
}
XLA_TEST_P(ReshapeTest, BadDimensions) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
auto input_literal = Literal::CreateR1<float>({1.0f});
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b,
&parameter);
b.Reshape(parameter, {}, {});
@@ -598,9 +596,9 @@ XLA_TEST_P(ReshapeTest, BadDimensions) {
}
XLA_TEST_P(ReshapeTest, BadNewSizes) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
auto input_literal = Literal::CreateR1<float>({1.0f, 2.0f});
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b,
&parameter);
b.Reshape(parameter, {1}, {});
@@ -609,7 +607,7 @@ XLA_TEST_P(ReshapeTest, BadNewSizes) {
}
XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
// clang-format off
auto input_literal = Literal::CreateR4FromArray4DWithLayout(Array4D<float>{
{
@@ -635,7 +633,7 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
},
LayoutUtil::MakeLayout({0, 1, 2, 3}));
// clang-format on
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
@@ -646,7 +644,7 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
{222, 333, 444, 555, 666, 777, 888, 999},
});
- Computation computation = builder.Build().ConsumeValueOrDie();
+ XlaComputation computation = builder.Build().ConsumeValueOrDie();
ExecutionOptions execution_options = execution_options_;
*execution_options.mutable_shape_with_output_layout() =
ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {2, 8},
@@ -664,13 +662,13 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
}
XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::unique_ptr<Literal> input_literal = Literal::CreateR2<float>({
{0, 1, 2, 3, 4, 5, 6, 7},
{100, 101, 102, 103, 104, 105, 106, 107},
{200, 201, 202, 203, 204, 205, 206, 207},
});
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
builder.Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4});
@@ -691,13 +689,13 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
// Tests R2->R4 reshape with the reshape dimensions {1, 0}.
XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::unique_ptr<Literal> input_literal = Literal::CreateR2<float>({
{0, 1, 2, 3, 4, 5, 6, 7},
{100, 101, 102, 103, 104, 105, 106, 107},
{200, 201, 202, 203, 204, 205, 206, 207},
});
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
builder.Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4});
@@ -717,7 +715,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) {
}
XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::mt19937 rng;
std::uniform_real_distribution<float> distribution;
Array4D<float> input(2, 1, 1, 1);
@@ -727,7 +725,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) {
std::unique_ptr<Literal> input_literal =
Literal::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input_data = CreateParameterAndTransferLiteral(
0, *input_literal, "input", &builder, &parameter);
builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1});
@@ -739,7 +737,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) {
}
XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::mt19937 rng;
std::uniform_real_distribution<float> distribution;
Array4D<float> input(2, 1, 4, 1);
@@ -749,7 +747,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) {
std::unique_ptr<Literal> input_literal =
Literal::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input_data = CreateParameterAndTransferLiteral(
0, *input_literal, "input", &builder, &parameter);
builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2});
@@ -762,7 +760,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) {
// Tests R4->R2 reshape with the reshape dimensions {0, 2, 1, 3}.
XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::mt19937 rng;
std::uniform_real_distribution<float> distribution;
Array4D<float> input(5, 10, 2, 3);
@@ -772,7 +770,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) {
std::unique_ptr<Literal> input_literal =
Literal::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input_data = CreateParameterAndTransferLiteral(
0, *input_literal, "input", &builder, &parameter);
builder.Reshape(parameter, /*dimensions=*/{0, 2, 1, 3},
@@ -789,7 +787,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) {
}
XLA_TEST_P(ReshapeTest, NoopReshape) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::mt19937 rng;
std::uniform_real_distribution<float> distribution;
Array4D<float> input_array(2, 3, 5, 7);
@@ -799,12 +797,12 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
std::unique_ptr<Literal> input_literal =
Literal::CreateR4FromArray4DWithLayout(
input_array, LayoutUtil::MakeLayout({1, 2, 3, 0}));
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input_data = CreateParameterAndTransferLiteral(
0, *input_literal, "input", &builder, &parameter);
builder.Reshape(parameter, /*dimensions=*/{3, 0, 1, 2},
/*new_sizes=*/{7, 2, 3, 5});
- Computation computation = builder.Build().ConsumeValueOrDie();
+ XlaComputation computation = builder.Build().ConsumeValueOrDie();
ExecutionOptions execution_options = execution_options_;
*execution_options.mutable_shape_with_output_layout() =
@@ -827,12 +825,12 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
}
XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto literal_1x2x3x4 = Literal::CreateR4<float>(
{{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
- ComputationDataHandle parameter;
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input",
&builder, &parameter);
builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3},
@@ -846,8 +844,8 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) {
{{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
- ComputationBuilder builder(client_, TestName());
- ComputationDataHandle parameter;
+ XlaBuilder builder(TestName());
+ XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input",
&builder, &parameter);
builder.Reshape(parameter, /*dimensions=*/{1, 3, 2, 0},
@@ -880,8 +878,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) {
std::unique_ptr<Literal> input_literal =
Literal::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
- ComputationBuilder builder(client_, TestName());
- ComputationDataHandle parameter;
+ XlaBuilder builder(TestName());
+ XlaOp parameter;
auto input_data = CreateParameterAndTransferLiteral(
0, *input_literal, "input", &builder, &parameter);
builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
@@ -909,8 +907,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) {
std::unique_ptr<Literal> input_literal =
Literal::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
- ComputationBuilder builder(client_, TestName());
- ComputationDataHandle parameter;
+ XlaBuilder builder(TestName());
+ XlaOp parameter;
auto input_data = CreateParameterAndTransferLiteral(
0, *input_literal, "input", &builder, &parameter);
builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
@@ -938,8 +936,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) {
std::unique_ptr<Literal> input_literal =
Literal::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
- ComputationBuilder builder(client_, TestName());
- ComputationDataHandle parameter;
+ XlaBuilder builder(TestName());
+ XlaOp parameter;
auto input_data = CreateParameterAndTransferLiteral(
0, *input_literal, "input", &builder, &parameter);
builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
@@ -968,8 +966,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) {
std::unique_ptr<Literal> input_literal =
Literal::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
- ComputationBuilder builder(client_, TestName());
- ComputationDataHandle parameter;
+ XlaBuilder builder(TestName());
+ XlaOp parameter;
auto input_data = CreateParameterAndTransferLiteral(
0, *input_literal, "input", &builder, &parameter);
builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
@@ -997,8 +995,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) {
std::unique_ptr<Literal> input_literal =
Literal::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({0, 1, 2, 3}));
- ComputationBuilder builder(client_, TestName());
- ComputationDataHandle parameter;
+ XlaBuilder builder(TestName());
+ XlaOp parameter;
auto input_data = CreateParameterAndTransferLiteral(
0, *input_literal, "input", &builder, &parameter);
builder.Reshape(parameter, /*dimensions=*/{1, 0, 2, 3},
diff --git a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc
index 0af40bc15a..a9f2915b45 100644
--- a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc
+++ b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -32,14 +33,15 @@ GTEST_API_ int main(int argc, char** argv) {
// tests.
for (int i = 1; i < argc; i++) {
tensorflow::StringPiece arg(argv[i]);
- if (arg == "--benchmarks" || arg.starts_with("--benchmarks=")) {
+ if (arg == "--benchmarks" ||
+ tensorflow::str_util::StartsWith(arg, "--benchmarks=")) {
const char* pattern = nullptr;
- if (arg.starts_with("--benchmarks=")) {
+ if (tensorflow::str_util::StartsWith(arg, "--benchmarks=")) {
pattern = argv[i] + strlen("--benchmarks=");
} else {
// Handle flag of the form '--benchmarks foo' (no '=').
if (i + 1 >= argc ||
- tensorflow::StringPiece(argv[i + 1]).starts_with("--")) {
+ tensorflow::str_util::StartsWith(argv[i + 1], "--")) {
LOG(ERROR) << "--benchmarks flag requires an argument.";
return 2;
}
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index e60a5a4919..b2f122982a 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -724,6 +724,15 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
shape, operands[0], *broadcast_dimensions));
break;
}
+ case HloOpcode::kBroadcastDimOne: {
+ if (!ParseOperands(&operands, /*expected_size=*/1) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateBroadcastDimOne(shape, operands[0]));
+ break;
+ }
case HloOpcode::kConcatenate: {
optional<std::vector<int64>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
index adc8b1d620..57684b5834 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
@@ -59,6 +59,18 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
)"
},
+// broadcast size-one dimensions
+{
+"BroadcastDimOne",
+R"(HloModule broadcast_dim_one_module
+
+ENTRY %broadcast-dim-one () -> f32[2,2] {
+ %constant = f32[1,2]{1,0} constant(f32[1,2] { { 1.1, 2.2 } })
+ ROOT %broadcast-dim-one = f32[2,2]{1,0} broadcast-dim-one(f32[1,2]{1,0} %constant)
+}
+
+)"
+},
// pred constant
{
"ConstantPred",
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index bf69144ad8..9bef0d8b61 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -81,6 +81,7 @@ py_library(
"//tensorflow/contrib/quantize:quantize_graph",
"//tensorflow/contrib/autograph",
"//tensorflow/contrib/receptive_field:receptive_field_py",
+ "//tensorflow/contrib/recurrent:recurrent_py",
"//tensorflow/contrib/reduce_slice_ops:reduce_slice_ops_py",
"//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py",
"//tensorflow/contrib/resampler:resampler_py",
diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py
index 1c5b00f92e..aaddb06fa0 100644
--- a/tensorflow/contrib/__init__.py
+++ b/tensorflow/contrib/__init__.py
@@ -66,6 +66,7 @@ from tensorflow.contrib import periodic_resample
from tensorflow.contrib import predictor
from tensorflow.contrib import quantization
from tensorflow.contrib import quantize
+from tensorflow.contrib import recurrent
from tensorflow.contrib import reduce_slice_ops
from tensorflow.contrib import resampler
from tensorflow.contrib import rnn
diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py
index 5b8cb2d63c..81ae64f110 100644
--- a/tensorflow/contrib/autograph/operators/control_flow.py
+++ b/tensorflow/contrib/autograph/operators/control_flow.py
@@ -83,7 +83,8 @@ def _known_len_for_loop(iterated, extra_cond, loop_body, init_state):
while_cond,
while_body,
init_state=(0,) + init_state,
- extra_deps=(iterated,))
+ extra_deps=(iterated,),
+ opts=dict(maximum_iterations=n))
# Dropping the iteration index because it's not syntactically visible.
results = results[1:]
@@ -136,7 +137,7 @@ def _dataset_for_loop(ds, extra_cond, loop_body, init_state):
return results
-def while_loop(loop_cond, loop_body, init_state, extra_deps):
+def while_loop(loop_cond, loop_body, init_state, extra_deps, opts=None):
"""Functional form of a while statement.
The loop operates on a so-called state, which includes all symbols that are
@@ -153,6 +154,7 @@ def while_loop(loop_cond, loop_body, init_state, extra_deps):
extra_deps: Tuple containing additional entities on which the loop may
depend, such as loop invariants referenced by loop_cond. Used
exclusively for dispatch control.
+ opts: Optional dict of extra loop parameters.
Returns:
Tuple containing the final state.
@@ -161,18 +163,21 @@ def while_loop(loop_cond, loop_body, init_state, extra_deps):
# That could be somethins as simple as a collection of dispatch rules, with
# some prioritization.
if any(tensor_util.is_tensor(v) for v in init_state + extra_deps):
- return _tf_while_loop(loop_cond, loop_body, init_state)
+ return _tf_while_loop(loop_cond, loop_body, init_state, opts)
else:
- return _py_while_loop(loop_cond, loop_body, init_state)
+ return _py_while_loop(loop_cond, loop_body, init_state, opts)
-def _tf_while_loop(loop_cond, loop_body, init_state):
+def _tf_while_loop(loop_cond, loop_body, init_state, opts):
"""Overload of while_loop that stages a TF while_loop."""
- return control_flow_ops.while_loop(loop_cond, loop_body, init_state)
+ if opts is None:
+ opts = {}
+ return control_flow_ops.while_loop(loop_cond, loop_body, init_state, **opts)
-def _py_while_loop(loop_cond, loop_body, init_state):
+def _py_while_loop(loop_cond, loop_body, init_state, opts):
"""Overload of while_loop that executes a Python while loop."""
+ del opts
state = init_state
while loop_cond(*state):
state = loop_body(*state)
diff --git a/tensorflow/contrib/autograph/pyct/inspect_utils.py b/tensorflow/contrib/autograph/pyct/inspect_utils.py
index 30a5961821..386a6d21ec 100644
--- a/tensorflow/contrib/autograph/pyct/inspect_utils.py
+++ b/tensorflow/contrib/autograph/pyct/inspect_utils.py
@@ -50,6 +50,18 @@ def getnamespace(f):
return namespace
+def getdefiningclass(m, owner_class):
+ """Resolves the class (e.g. one of the superclasses) that defined a method."""
+ m = six.get_unbound_function(m)
+ last_defining = owner_class
+ for superclass in tf_inspect.getmro(owner_class):
+ if hasattr(superclass, m.__name__):
+ superclass_m = getattr(superclass, m.__name__)
+ if six.get_unbound_function(superclass_m) == m:
+ last_defining = superclass
+ return last_defining
+
+
def getmethodclass(m):
"""Resolves a function's owner, e.g. a method's class.
diff --git a/tensorflow/contrib/autograph/pyct/inspect_utils_test.py b/tensorflow/contrib/autograph/pyct/inspect_utils_test.py
index eda3fc13fd..58f827b79a 100644
--- a/tensorflow/contrib/autograph/pyct/inspect_utils_test.py
+++ b/tensorflow/contrib/autograph/pyct/inspect_utils_test.py
@@ -234,6 +234,30 @@ class InspectUtilsTest(test.TestCase):
c = TestCallable()
self.assertEqual(inspect_utils.getmethodclass(c), TestCallable)
+ def test_getdefiningclass(self):
+ class Superclass(object):
+
+ def foo(self):
+ pass
+
+ def bar(self):
+ pass
+
+ class Subclass(Superclass):
+
+ def foo(self):
+ pass
+
+ def baz(self):
+ pass
+
+ self.assertTrue(
+ inspect_utils.getdefiningclass(Subclass.foo, Subclass) is Subclass)
+ self.assertTrue(
+ inspect_utils.getdefiningclass(Subclass.bar, Subclass) is Superclass)
+ self.assertTrue(
+ inspect_utils.getdefiningclass(Subclass.baz, Subclass) is Subclass)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index f273c7e550..de84af866b 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -104,6 +104,8 @@ tensorflow/python/user_ops
tensorflow/python/util
tensorflow/python/util/protobuf
tensorflow/tools
+tensorflow/tools/api
+tensorflow/tools/api/generator
tensorflow/tools/graph_transforms
tensorflow/contrib
tensorflow/contrib/all_reduce
@@ -355,6 +357,9 @@ tensorflow/contrib/periodic_resample
tensorflow/contrib/periodic_resample/python
tensorflow/contrib/periodic_resample/python/ops
tensorflow/contrib/predictor
+tensorflow/contrib/proto
+tensorflow/contrib/proto/python
+tensorflow/contrib/proto/python/ops
tensorflow/contrib/quantization
tensorflow/contrib/quantization/python
tensorflow/contrib/quantize
@@ -363,6 +368,10 @@ tensorflow/contrib/receptive_field
tensorflow/contrib/receptive_field/python
tensorflow/contrib/receptive_field/python/util
tensorflow/contrib/receptive_field/python/util/examples
+tensorflow/contrib/recurrent
+tensorflow/contrib/recurrent/python
+tensorflow/contrib/recurrent/python/ops
+tensorflow/contrib/recurrent/python/kernel_tests
tensorflow/contrib/reduce_slice_ops
tensorflow/contrib/reduce_slice_ops/kernels
tensorflow/contrib/reduce_slice_ops/ops
@@ -383,6 +392,9 @@ tensorflow/contrib/rnn/ops
tensorflow/contrib/rnn/python
tensorflow/contrib/rnn/python/kernel_tests
tensorflow/contrib/rnn/python/ops
+tensorflow/contrib/rpc
+tensorflow/contrib/rpc/python
+tensorflow/contrib/rpc/python/ops
tensorflow/contrib/saved_model
tensorflow/contrib/saved_model/python
tensorflow/contrib/saved_model/python/saved_model
diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake
index 092a48bc6b..e558691de4 100644
--- a/tensorflow/contrib/cmake/tf_core_ops.cmake
+++ b/tensorflow/contrib/cmake/tf_core_ops.cmake
@@ -25,6 +25,8 @@ set(tf_op_lib_names
"cudnn_rnn_ops"
"data_flow_ops"
"dataset_ops"
+ "decode_proto_ops"
+ "encode_proto_ops"
"functional_ops"
"image_ops"
"io_ops"
@@ -40,6 +42,7 @@ set(tf_op_lib_names
"random_ops"
"remote_fused_graph_ops"
"resource_variable_ops"
+ "rpc_ops"
"script_ops"
"sdca_ops"
"set_ops"
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index fae45ead5c..ded15b4b66 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -330,6 +330,8 @@ GENERATE_PYTHON_OP_LIB("ctc_ops")
GENERATE_PYTHON_OP_LIB("cudnn_rnn_ops")
GENERATE_PYTHON_OP_LIB("data_flow_ops")
GENERATE_PYTHON_OP_LIB("dataset_ops")
+GENERATE_PYTHON_OP_LIB("decode_proto_ops")
+GENERATE_PYTHON_OP_LIB("encode_proto_ops")
GENERATE_PYTHON_OP_LIB("image_ops")
GENERATE_PYTHON_OP_LIB("io_ops")
GENERATE_PYTHON_OP_LIB("linalg_ops")
@@ -343,6 +345,7 @@ GENERATE_PYTHON_OP_LIB("random_ops")
GENERATE_PYTHON_OP_LIB("remote_fused_graph_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/remote_fused_graph/pylib/python/ops/gen_remote_fused_graph_ops.py)
GENERATE_PYTHON_OP_LIB("resource_variable_ops")
+GENERATE_PYTHON_OP_LIB("rpc_ops")
GENERATE_PYTHON_OP_LIB("script_ops")
GENERATE_PYTHON_OP_LIB("sdca_ops")
GENERATE_PYTHON_OP_LIB("set_ops")
@@ -686,6 +689,77 @@ AddUserOps(TARGET _beam_search_ops
DEPENDS pywrap_tensorflow_internal tf_python_ops
DISTCOPY ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/seq2seq/python/ops/)
+if(WIN32)
+ if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*")
+ add_custom_command(TARGET pywrap_tensorflow_internal POST_BUILD
+ COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/$(Configuration)/pywrap_tensorflow_internal.dll
+ ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.pyd
+ COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/$(Configuration)/pywrap_tensorflow_internal.lib
+ ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/)
+ else()
+ add_custom_command(TARGET pywrap_tensorflow_internal POST_BUILD
+ COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.dll
+ ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.pyd
+ COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.lib
+ ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/)
+ endif()
+else()
+ add_custom_command(TARGET pywrap_tensorflow_internal POST_BUILD
+ COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal.so
+ ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.so)
+endif()
+
+
+########################################################
+# Generate API __init__.py files.
+########################################################
+
+# Parse tensorflow/tools/api/generator/BUILD to get list of generated files.
+FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/BUILD api_generator_BUILD_text)
+STRING(REGEX MATCH "# BEGIN GENERATED FILES.*# END GENERATED FILES" api_init_files_text ${api_generator_BUILD_text})
+string(REPLACE "# BEGIN GENERATED FILES" "" api_init_files_text ${api_init_files_text})
+string(REPLACE "# END GENERATED FILES" "" api_init_files_text ${api_init_files_text})
+string(REPLACE "," ";" api_init_files_list ${api_init_files_text})
+
+set(api_init_files "")
+foreach(api_init_file ${api_init_files_list})
+ string(STRIP "${api_init_file}" api_init_file)
+ if(api_init_file)
+ string(REPLACE "\"" "" api_init_file "${api_init_file}") # Remove quotes
+ list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/${api_init_file}")
+ endif()
+endforeach(api_init_file)
+set(api_init_list_file "${tensorflow_source_dir}/api_init_files_list.txt")
+file(WRITE "${api_init_list_file}" "${api_init_files}")
+
+# Run create_python_api.py to generate __init__.py files.
+add_custom_command(
+ OUTPUT ${api_init_files}
+ DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops
+
+ # tensorflow/__init__.py depends on files generated in this step. So, remove it while
+ # this step is running since the files aren't there yet.
+ COMMAND ${CMAKE_COMMAND} -E rename ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py
+ ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/final.__init__.py
+ COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py
+
+ # Run create_python_api.py to generate API init files.
+ COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE}
+ "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" "${api_init_list_file}"
+
+ # Re-add tensorflow/__init__.py back.
+ COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py
+ COMMAND ${CMAKE_COMMAND} -E rename ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/final.__init__.py
+ ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py
+
+ COMMENT "Generating __init__.py files for Python API."
+ WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python"
+)
+
+add_custom_target(tf_python_api SOURCES ${api_init_files})
+add_dependencies(tf_python_api tf_python_ops)
+
+
############################################################
# Build a PIP package containing the TensorFlow runtime.
############################################################
@@ -695,6 +769,7 @@ add_dependencies(tf_python_build_pip_package
tf_python_copy_scripts_to_destination
tf_python_touchup_modules
tf_python_ops
+ tf_python_api
tf_extension_ops)
# Fix-up Python files that were not included by the add_python_module() macros.
@@ -707,25 +782,6 @@ add_custom_command(TARGET tf_python_copy_scripts_to_destination PRE_BUILD
COMMAND ${CMAKE_COMMAND} -E copy ${tensorflow_source_dir}/tensorflow/contrib/testing/python/framework/util_test.py
${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/testing/python/framework/)
-if(WIN32)
- if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*")
- add_custom_command(TARGET tf_python_build_pip_package POST_BUILD
- COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/$(Configuration)/pywrap_tensorflow_internal.dll
- ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.pyd
- COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/$(Configuration)/pywrap_tensorflow_internal.lib
- ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/)
- else()
- add_custom_command(TARGET tf_python_build_pip_package POST_BUILD
- COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.dll
- ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.pyd
- COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.lib
- ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/)
- endif()
-else()
- add_custom_command(TARGET tf_python_build_pip_package POST_BUILD
- COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal.so
- ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.so)
-endif()
add_custom_command(TARGET tf_python_build_pip_package POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy ${tensorflow_source_dir}/tensorflow/tools/pip_package/README
${CMAKE_CURRENT_BINARY_DIR}/tf_python/)
diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD
index edb9130266..4e088503bf 100644
--- a/tensorflow/contrib/eager/python/BUILD
+++ b/tensorflow/contrib/eager/python/BUILD
@@ -272,8 +272,7 @@ cuda_py_test(
"//tensorflow/python/keras",
],
tags = [
- "no_oss", # b/74395663
"no_windows", # TODO: needs investigation on Windows
- "notsan",
+ "notsan", # b/74395663
],
)
diff --git a/tensorflow/contrib/eager/python/checkpointable_utils_test.py b/tensorflow/contrib/eager/python/checkpointable_utils_test.py
index 891c093a0f..e6498ddb06 100644
--- a/tensorflow/contrib/eager/python/checkpointable_utils_test.py
+++ b/tensorflow/contrib/eager/python/checkpointable_utils_test.py
@@ -714,7 +714,7 @@ class CheckpointingTests(test.TestCase):
status.run_restore_ops()
self.assertEqual(-14., self.evaluate(loaded_dep_after_var.dep.var))
- @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
+ @test_util.run_in_graph_and_eager_modes()
def testDeferredSlotRestoration(self):
checkpoint_directory = self.get_temp_dir()
@@ -779,7 +779,7 @@ class CheckpointingTests(test.TestCase):
self.evaluate(train_op)
slot_status.assert_consumed()
- @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
+ @test_util.run_in_graph_and_eager_modes()
def testOverlappingRestores(self):
checkpoint_directory = self.get_temp_dir()
save_root = checkpointable.Checkpointable()
@@ -830,7 +830,7 @@ class CheckpointingTests(test.TestCase):
second_status.run_restore_ops()
self.assertEqual(12., self.evaluate(load_dep.var))
- @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
+ @test_util.run_in_graph_and_eager_modes()
def testAmbiguousLoad(self):
# Not OK to split one checkpoint object into two
checkpoint_directory = self.get_temp_dir()
@@ -853,7 +853,7 @@ class CheckpointingTests(test.TestCase):
"resolved to different objects"):
load_root.dep_two.dep_three = checkpointable.Checkpointable()
- @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
+ @test_util.run_in_graph_and_eager_modes()
def testObjectsCombined(self):
# Currently fine to load two checkpoint objects into one Python object
checkpoint_directory = self.get_temp_dir()
@@ -1154,7 +1154,7 @@ class CheckpointingTests(test.TestCase):
class TemplateTests(test.TestCase):
- @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
+ @test_util.run_in_graph_and_eager_modes()
def test_checkpointable_save_restore(self):
def _templated():
@@ -1185,7 +1185,7 @@ class TemplateTests(test.TestCase):
self.assertAllEqual([13.], self.evaluate(var_plus_one))
self.assertAllEqual([14.], self.evaluate(var2))
- @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
+ @test_util.run_in_graph_and_eager_modes()
def test_checkpointable_save_restore_nested(self):
def _inner_template():
diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
index a90048d813..be5d60449d 100644
--- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
+++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
@@ -315,32 +315,37 @@ def main(_):
have_gpu = tfe.num_gpus() > 0
use_cudnn_rnn = not FLAGS.no_use_cudnn_rnn and have_gpu
- with tfe.restore_variables_on_create(
- tf.train.latest_checkpoint(FLAGS.logdir)):
- with tf.device("/device:GPU:0" if have_gpu else None):
- # Make learning_rate a Variable so it can be included in the checkpoint
- # and we can resume training with the last saved learning_rate.
- learning_rate = tfe.Variable(20.0, name="learning_rate")
- sys.stderr.write("learning_rate=%f\n" % learning_rate.numpy())
- model = PTBModel(corpus.vocab_size(), FLAGS.embedding_dim,
- FLAGS.hidden_dim, FLAGS.num_layers, FLAGS.dropout,
- use_cudnn_rnn)
- optimizer = tf.train.GradientDescentOptimizer(learning_rate)
-
- best_loss = None
- for _ in range(FLAGS.epoch):
- train(model, optimizer, train_data, FLAGS.seq_len, FLAGS.clip)
- eval_loss = evaluate(model, eval_data)
- if not best_loss or eval_loss < best_loss:
- if FLAGS.logdir:
- tfe.Saver(model.trainable_weights + [learning_rate]).save(
- os.path.join(FLAGS.logdir, "ckpt"))
- best_loss = eval_loss
- else:
- learning_rate.assign(learning_rate / 4.0)
- sys.stderr.write("eval_loss did not reduce in this epoch, "
- "changing learning rate to %f for the next epoch\n" %
- learning_rate.numpy())
+ with tf.device("/device:GPU:0" if have_gpu else None):
+ # Make learning_rate a Variable so it can be included in the checkpoint
+ # and we can resume training with the last saved learning_rate.
+ learning_rate = tfe.Variable(20.0, name="learning_rate")
+ model = PTBModel(corpus.vocab_size(), FLAGS.embedding_dim,
+ FLAGS.hidden_dim, FLAGS.num_layers, FLAGS.dropout,
+ use_cudnn_rnn)
+ optimizer = tf.train.GradientDescentOptimizer(learning_rate)
+ checkpoint = tfe.Checkpoint(
+ learning_rate=learning_rate, model=model,
+ # GradientDescentOptimizer has no state to checkpoint, but noting it
+ # here lets us swap in an optimizer that does.
+ optimizer=optimizer)
+ # Restore existing variables now (learning_rate), and restore new variables
+ # on creation if a checkpoint exists.
+ checkpoint.restore(tf.train.latest_checkpoint(FLAGS.logdir))
+ sys.stderr.write("learning_rate=%f\n" % learning_rate.numpy())
+
+ best_loss = None
+ for _ in range(FLAGS.epoch):
+ train(model, optimizer, train_data, FLAGS.seq_len, FLAGS.clip)
+ eval_loss = evaluate(model, eval_data)
+ if not best_loss or eval_loss < best_loss:
+ if FLAGS.logdir:
+ checkpoint.save(os.path.join(FLAGS.logdir, "ckpt"))
+ best_loss = eval_loss
+ else:
+ learning_rate.assign(learning_rate / 4.0)
+ sys.stderr.write("eval_loss did not reduce in this epoch, "
+ "changing learning rate to %f for the next epoch\n" %
+ learning_rate.numpy())
if __name__ == "__main__":
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index bec0329ebb..9f4cd44afb 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -23,6 +23,7 @@ py_library(
":logit_fns",
":multi_head",
":replicate_model_fn",
+ ":rnn",
"//tensorflow/python:util",
],
)
@@ -412,3 +413,57 @@ cuda_py_test(
"notap",
],
)
+
+py_library(
+ name = "rnn",
+ srcs = ["python/estimator/rnn.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":extenders",
+ "//tensorflow/contrib/feature_column:feature_column_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:partitioned_variables",
+ "//tensorflow/python:rnn",
+ "//tensorflow/python:rnn_cell",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:head",
+ "//tensorflow/python/estimator:optimizers",
+ "//tensorflow/python/feature_column",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "rnn_test",
+ size = "medium",
+ srcs = ["python/estimator/rnn_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ "notsan",
+ ],
+ deps = [
+ ":rnn",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/estimator:numpy_io",
+ "//tensorflow/python/feature_column",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py
index d2fc2c4bfa..9a87fa915d 100644
--- a/tensorflow/contrib/estimator/__init__.py
+++ b/tensorflow/contrib/estimator/__init__.py
@@ -52,6 +52,7 @@ _allowed_symbols = [
'linear_logit_fn_builder',
'replicate_model_fn',
'TowerOptimizer',
+ 'RNNClassifier',
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py
index bbbc19cc4d..ce75899214 100644
--- a/tensorflow/contrib/estimator/python/estimator/multi_head.py
+++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py
@@ -345,7 +345,7 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access
if k == _DEFAULT_SERVING_KEY:
key = head_name
else:
- key = '%s/%s' % (k, head_name)
+ key = '%s/%s' % (head_name, k)
export_outputs[key] = v
if (k == head_lib._PREDICT_SERVING_KEY and # pylint:disable=protected-access
isinstance(v, export_output_lib.PredictOutput)):
diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
index d9e5aca295..3d6fccb118 100644
--- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
@@ -127,8 +127,8 @@ class MultiHeadTest(test.TestCase):
logits=logits)
self.assertItemsEqual(
- (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'classification/head1',
- 'predict/head1', 'head2', 'classification/head2', 'predict/head2'),
+ (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/classification',
+ 'head1/predict', 'head2', 'head2/classification', 'head2/predict'),
spec.export_outputs.keys())
# Assert predictions and export_outputs.
@@ -169,11 +169,11 @@ class MultiHeadTest(test.TestCase):
self.assertAllClose(
expected_probabilities['head1'],
sess.run(
- spec.export_outputs['predict/head1'].outputs['probabilities']))
+ spec.export_outputs['head1/predict'].outputs['probabilities']))
self.assertAllClose(
expected_probabilities['head2'],
sess.run(
- spec.export_outputs['predict/head2'].outputs['probabilities']))
+ spec.export_outputs['head2/predict'].outputs['probabilities']))
def test_predict_two_heads_logits_tensor(self):
"""Tests predict with logits as Tensor."""
@@ -197,8 +197,8 @@ class MultiHeadTest(test.TestCase):
logits=logits)
self.assertItemsEqual(
- (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'classification/head1',
- 'predict/head1', 'head2', 'classification/head2', 'predict/head2'),
+ (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/classification',
+ 'head1/predict', 'head2', 'head2/classification', 'head2/predict'),
spec.export_outputs.keys())
# Assert predictions and export_outputs.
@@ -254,8 +254,8 @@ class MultiHeadTest(test.TestCase):
logits=logits)
self.assertItemsEqual(
- (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'regression/head1',
- 'predict/head1', 'head2', 'regression/head2', 'predict/head2'),
+ (_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/regression',
+ 'head1/predict', 'head2', 'head2/regression', 'head2/predict'),
spec.export_outputs.keys())
# Assert predictions and export_outputs.
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py
new file mode 100644
index 0000000000..b475c12f5a
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/rnn.py
@@ -0,0 +1,481 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Recurrent Neural Network estimators."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+
+from tensorflow.contrib.estimator.python.estimator import extenders
+from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as seq_fc
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator.canned import head as head_lib
+from tensorflow.python.estimator.canned import optimizers
+from tensorflow.python.feature_column import feature_column as feature_column_lib
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.layers import core as core_layers
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import rnn
+from tensorflow.python.ops import rnn_cell
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.summary import summary
+from tensorflow.python.training import optimizer as optimizer_lib
+from tensorflow.python.training import training_util
+
+
+# The defaults are historical artifacts of the initial implementation, but seem
+# reasonable choices.
+_DEFAULT_LEARNING_RATE = 0.05
+_DEFAULT_CLIP_NORM = 5.0
+
+_CELL_TYPES = {'basic_rnn': rnn_cell.BasicRNNCell,
+ 'lstm': rnn_cell.BasicLSTMCell,
+ 'gru': rnn_cell.GRUCell}
+
+# Indicates no value was provided by the user to a kwarg.
+USE_DEFAULT = object()
+
+
+def _single_rnn_cell(num_units, cell_type):
+ cell_type = _CELL_TYPES.get(cell_type, cell_type)
+ if not cell_type or not issubclass(cell_type, rnn_cell.RNNCell):
+ raise ValueError('Supported cell types are {}; got {}'.format(
+ list(_CELL_TYPES.keys()), cell_type))
+ return cell_type(num_units=num_units)
+
+
+def _make_rnn_cell_fn(num_units, cell_type='basic_rnn'):
+ """Convenience function to create `rnn_cell_fn` for canned RNN Estimators.
+
+ Args:
+ num_units: Iterable of integer number of hidden units per RNN layer.
+ cell_type: A subclass of `tf.nn.rnn_cell.RNNCell` or a string specifying
+ the cell type. Supported strings are: `'basic_rnn'`, `'lstm'`, and
+ `'gru'`.
+
+ Returns:
+ A function that takes a single argument, an instance of
+ `tf.estimator.ModeKeys`, and returns an instance derived from
+ `tf.nn.rnn_cell.RNNCell`.
+
+ Raises:
+ ValueError: If cell_type is not supported.
+ """
+ def rnn_cell_fn(mode):
+ # Unused. Part of the rnn_cell_fn interface since user specified functions
+ # may need different behavior across modes (e.g. dropout).
+ del mode
+ cells = [_single_rnn_cell(n, cell_type) for n in num_units]
+ if len(cells) == 1:
+ return cells[0]
+ return rnn_cell.MultiRNNCell(cells)
+ return rnn_cell_fn
+
+
+def _concatenate_context_input(sequence_input, context_input):
+ """Replicates `context_input` across all timesteps of `sequence_input`.
+
+ Expands dimension 1 of `context_input` then tiles it `sequence_length` times.
+ This value is appended to `sequence_input` on dimension 2 and the result is
+ returned.
+
+ Args:
+ sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size,
+ padded_length, d0]`.
+ context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`.
+
+ Returns:
+ A `Tensor` of dtype `float32` and shape `[batch_size, padded_length,
+ d0 + d1]`.
+
+ Raises:
+ ValueError: If `sequence_input` does not have rank 3 or `context_input` does
+ not have rank 2.
+ """
+ seq_rank_check = check_ops.assert_rank(
+ sequence_input,
+ 3,
+ message='sequence_input must have rank 3',
+ data=[array_ops.shape(sequence_input)])
+ seq_type_check = check_ops.assert_type(
+ sequence_input,
+ dtypes.float32,
+ message='sequence_input must have dtype float32; got {}.'.format(
+ sequence_input.dtype))
+ ctx_rank_check = check_ops.assert_rank(
+ context_input,
+ 2,
+ message='context_input must have rank 2',
+ data=[array_ops.shape(context_input)])
+ ctx_type_check = check_ops.assert_type(
+ context_input,
+ dtypes.float32,
+ message='context_input must have dtype float32; got {}.'.format(
+ context_input.dtype))
+ with ops.control_dependencies(
+ [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]):
+ padded_length = array_ops.shape(sequence_input)[1]
+ tiled_context_input = array_ops.tile(
+ array_ops.expand_dims(context_input, 1),
+ array_ops.concat([[1], [padded_length], [1]], 0))
+ return array_ops.concat([sequence_input, tiled_context_input], 2)
+
+
+def _select_last_activations(activations, sequence_lengths):
+ """Selects the nth set of activations for each n in `sequence_length`.
+
+ Returns a `Tensor` of shape `[batch_size, k]`. If `sequence_length` is not
+ `None`, then `output[i, :] = activations[i, sequence_length[i] - 1, :]`. If
+ `sequence_length` is `None`, then `output[i, :] = activations[i, -1, :]`.
+
+ Args:
+ activations: A `Tensor` with shape `[batch_size, padded_length, k]`.
+ sequence_lengths: A `Tensor` with shape `[batch_size]` or `None`.
+ Returns:
+ A `Tensor` of shape `[batch_size, k]`.
+ """
+ with ops.name_scope(
+ 'select_last_activations', values=[activations, sequence_lengths]):
+ activations_shape = array_ops.shape(activations)
+ batch_size = activations_shape[0]
+ padded_length = activations_shape[1]
+ output_units = activations_shape[2]
+ if sequence_lengths is None:
+ sequence_lengths = padded_length
+ start_indices = math_ops.to_int64(
+ math_ops.range(batch_size) * padded_length)
+ last_indices = start_indices + sequence_lengths - 1
+ reshaped_activations = array_ops.reshape(
+ activations, [batch_size * padded_length, output_units])
+
+ last_activations = array_ops.gather(reshaped_activations, last_indices)
+ last_activations.set_shape([activations.shape[0], activations.shape[2]])
+ return last_activations
+
+
+def _rnn_logit_fn_builder(output_units, rnn_cell_fn, sequence_feature_columns,
+ context_feature_columns, input_layer_partitioner):
+ """Function builder for a rnn logit_fn.
+
+ Args:
+ output_units: An int indicating the dimension of the logit layer.
+ rnn_cell_fn: A function with one argument, a `tf.estimator.ModeKeys`, and
+ returns an object of type `tf.nn.rnn_cell.RNNCell`.
+ sequence_feature_columns: An iterable containing the `FeatureColumn`s
+ that represent sequential input.
+ context_feature_columns: An iterable containing the `FeatureColumn`s
+ that represent contextual input.
+ input_layer_partitioner: Partitioner for input layer.
+
+ Returns:
+ A logit_fn (see below).
+
+ Raises:
+ ValueError: If output_units is not an int.
+ """
+ if not isinstance(output_units, int):
+ raise ValueError('output_units must be an int. Given type: {}'.format(
+ type(output_units)))
+
+ def rnn_logit_fn(features, mode):
+ """Recurrent Neural Network logit_fn.
+
+ Args:
+ features: This is the first item returned from the `input_fn`
+ passed to `train`, `evaluate`, and `predict`. This should be a
+ single `Tensor` or `dict` of same.
+ mode: Optional. Specifies if this training, evaluation or prediction. See
+ `ModeKeys`.
+
+ Returns:
+ A `Tensor` representing the logits.
+ """
+ with variable_scope.variable_scope(
+ 'sequence_input_layer',
+ values=tuple(six.itervalues(features)),
+ partitioner=input_layer_partitioner):
+ sequence_input, sequence_length = seq_fc.sequence_input_layer(
+ features=features, feature_columns=sequence_feature_columns)
+ summary.histogram('sequence_length', sequence_length)
+
+ if context_feature_columns:
+ context_input = feature_column_lib.input_layer(
+ features=features,
+ feature_columns=context_feature_columns)
+ sequence_input = _concatenate_context_input(sequence_input,
+ context_input)
+
+ cell = rnn_cell_fn(mode)
+ # Ignore output state.
+ rnn_outputs, _ = rnn.dynamic_rnn(
+ cell=cell,
+ inputs=sequence_input,
+ dtype=dtypes.float32,
+ time_major=False)
+ last_activations = _select_last_activations(rnn_outputs, sequence_length)
+
+ with variable_scope.variable_scope('logits', values=(rnn_outputs,)):
+ logits = core_layers.dense(
+ last_activations,
+ units=output_units,
+ activation=None,
+ kernel_initializer=init_ops.glorot_uniform_initializer())
+ return logits
+
+ return rnn_logit_fn
+
+
+def _rnn_model_fn(features,
+ labels,
+ mode,
+ head,
+ rnn_cell_fn,
+ sequence_feature_columns,
+ context_feature_columns,
+ optimizer='Adagrad',
+ input_layer_partitioner=None,
+ config=None):
+ """Recurrent Neural Net model_fn.
+
+ Args:
+ features: dict of `Tensor` and `SparseTensor` objects returned from
+ `input_fn`.
+ labels: `Tensor` of shape [batch_size, 1] or [batch_size] with labels.
+ mode: Defines whether this is training, evaluation or prediction.
+ See `ModeKeys`.
+ head: A `head_lib._Head` instance.
+ rnn_cell_fn: A function with one argument, a `tf.estimator.ModeKeys`, and
+ returns an object of type `tf.nn.rnn_cell.RNNCell`.
+ sequence_feature_columns: Iterable containing `FeatureColumn`s that
+ represent sequential model inputs.
+ context_feature_columns: Iterable containing `FeatureColumn`s that
+ represent model inputs not associated with a specific timestep.
+ optimizer: String, `tf.Optimizer` object, or callable that creates the
+ optimizer to use for training. If not specified, will use the Adagrad
+ optimizer with a default learning rate of 0.05 and gradient clip norm of
+ 5.0.
+ input_layer_partitioner: Partitioner for input layer. Defaults
+ to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
+ config: `RunConfig` object to configure the runtime settings.
+
+ Returns:
+ An `EstimatorSpec` instance.
+
+ Raises:
+ ValueError: If mode or optimizer is invalid, or features has the wrong type.
+ """
+ if not isinstance(features, dict):
+ raise ValueError('features should be a dictionary of `Tensor`s. '
+ 'Given type: {}'.format(type(features)))
+
+ # If user does not provide an optimizer instance, use the optimizer specified
+ # by the string with default learning rate and gradient clipping.
+ if not isinstance(optimizer, optimizer_lib.Optimizer):
+ optimizer = optimizers.get_optimizer_instance(
+ optimizer, learning_rate=_DEFAULT_LEARNING_RATE)
+ optimizer = extenders.clip_gradients_by_norm(optimizer, _DEFAULT_CLIP_NORM)
+
+ num_ps_replicas = config.num_ps_replicas if config else 0
+ partitioner = partitioned_variables.min_max_variable_partitioner(
+ max_partitions=num_ps_replicas)
+ with variable_scope.variable_scope(
+ 'rnn',
+ values=tuple(six.itervalues(features)),
+ partitioner=partitioner):
+ input_layer_partitioner = input_layer_partitioner or (
+ partitioned_variables.min_max_variable_partitioner(
+ max_partitions=num_ps_replicas,
+ min_slice_size=64 << 20))
+
+ logit_fn = _rnn_logit_fn_builder(
+ output_units=head.logits_dimension,
+ rnn_cell_fn=rnn_cell_fn,
+ sequence_feature_columns=sequence_feature_columns,
+ context_feature_columns=context_feature_columns,
+ input_layer_partitioner=input_layer_partitioner)
+ logits = logit_fn(features=features, mode=mode)
+
+ def _train_op_fn(loss):
+ """Returns the op to optimize the loss."""
+ return optimizer.minimize(
+ loss,
+ global_step=training_util.get_global_step())
+
+ return head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ logits=logits)
+
+
+class RNNClassifier(estimator.Estimator):
+ """A classifier for TensorFlow RNN models.
+
+ Trains a recurrent neural network model to classify instances into one of
+ multiple classes.
+
+ Example:
+
+ ```python
+ token_sequence = sequence_categorical_column_with_hash_bucket(...)
+ token_emb = embedding_column(categorical_column=token_sequence, ...)
+
+ estimator = RNNClassifier(
+ num_units=[32, 16], cell_type='lstm',
+ sequence_feature_columns=[token_emb])
+
+ # Input builders
+ def input_fn_train: # returns x, y
+ pass
+ estimator.train(input_fn=input_fn_train, steps=100)
+
+ def input_fn_eval: # returns x, y
+ pass
+ metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)
+ def input_fn_predict: # returns x, None
+ pass
+ predictions = estimator.predict(input_fn=input_fn_predict)
+ ```
+
+ Input of `train` and `evaluate` should have following features,
+ otherwise there will be a `KeyError`:
+
+ * if `weight_column` is not `None`, a feature with
+ `key=weight_column` whose value is a `Tensor`.
+ * for each `column` in `sequence_feature_columns`:
+ - a feature with `key=column.name` whose `value` is a `SparseTensor`.
+ * for each `column` in `context_feature_columns`:
+ - if `column` is a `_CategoricalColumn`, a feature with `key=column.name`
+ whose `value` is a `SparseTensor`.
+ - if `column` is a `_WeightedCategoricalColumn`, two features: the first
+ with `key` the id column name, the second with `key` the weight column
+ name. Both features' `value` must be a `SparseTensor`.
+ - if `column` is a `_DenseColumn`, a feature with `key=column.name`
+ whose `value` is a `Tensor`.
+
+ Loss is calculated by using softmax cross entropy.
+
+ @compatibility(eager)
+ Estimators are not compatible with eager execution.
+ @end_compatibility
+ """
+
+ def __init__(self,
+ sequence_feature_columns,
+ context_feature_columns=None,
+ num_units=None,
+ cell_type=USE_DEFAULT,
+ rnn_cell_fn=None,
+ model_dir=None,
+ n_classes=2,
+ weight_column=None,
+ label_vocabulary=None,
+ optimizer='Adagrad',
+ input_layer_partitioner=None,
+ config=None):
+ """Initializes a `RNNClassifier` instance.
+
+ Args:
+ sequence_feature_columns: An iterable containing the `FeatureColumn`s
+ that represent sequential input. All items in the set should either be
+ sequence columns (e.g. `sequence_numeric_column`) or constructed from
+ one (e.g. `embedding_column` with `sequence_categorical_column_*` as
+ input).
+ context_feature_columns: An iterable containing the `FeatureColumn`s
+ for contextual input. The data represented by these columns will be
+ replicated and given to the RNN at each timestep. These columns must be
+ instances of classes derived from `_DenseColumn` such as
+ `numeric_column`, not the sequential variants.
+ num_units: Iterable of integer number of hidden units per RNN layer. If
+ set, `cell_type` must also be specified and `rnn_cell_fn` must be
+ `None`.
+ cell_type: A subclass of `tf.nn.rnn_cell.RNNCell` or a string specifying
+ the cell type. Supported strings are: `'basic_rnn'`, `'lstm'`, and
+ `'gru'`. If set, `num_units` must also be specified and `rnn_cell_fn`
+ must be `None`.
+ rnn_cell_fn: A function with one argument, a `tf.estimator.ModeKeys`, and
+ returns an object of type `tf.nn.rnn_cell.RNNCell` that will be used to
+ construct the RNN. If set, `num_units` and `cell_type` cannot be set.
+ This is for advanced users who need additional customization beyond
+ `num_units` and `cell_type`. Note that `tf.nn.rnn_cell.MultiRNNCell` is
+ needed for stacked RNNs.
+ model_dir: Directory to save model parameters, graph and etc. This can
+ also be used to load checkpoints from the directory into a estimator to
+ continue training a previously saved model.
+ n_classes: Number of label classes. Defaults to 2, namely binary
+ classification. Must be > 1.
+ weight_column: A string or a `_NumericColumn` created by
+ `tf.feature_column.numeric_column` defining feature column representing
+ weights. It is used to down weight or boost examples during training. It
+ will be multiplied by the loss of the example. If it is a string, it is
+ used as a key to fetch weight tensor from the `features`. If it is a
+ `_NumericColumn`, raw tensor is fetched by key `weight_column.key`,
+ then weight_column.normalizer_fn is applied on it to get weight tensor.
+ label_vocabulary: A list of strings represents possible label values. If
+ given, labels must be string type and have any value in
+ `label_vocabulary`. If it is not given, that means labels are
+ already encoded as integer or float within [0, 1] for `n_classes=2` and
+ encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 .
+ Also there will be errors if vocabulary is not provided and labels are
+ string.
+ optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
+ to Adagrad optimizer.
+ input_layer_partitioner: Optional. Partitioner for input layer. Defaults
+ to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
+ config: `RunConfig` object to configure the runtime settings.
+
+ Raises:
+ ValueError: If `num_units`, `cell_type`, and `rnn_cell_fn` are not
+ compatible.
+ """
+ if rnn_cell_fn and (num_units or cell_type != USE_DEFAULT):
+ raise ValueError(
+ 'num_units and cell_type must not be specified when using rnn_cell_fn'
+ )
+ if not rnn_cell_fn:
+ if cell_type == USE_DEFAULT:
+ cell_type = 'basic_rnn'
+ rnn_cell_fn = _make_rnn_cell_fn(num_units, cell_type)
+
+ if n_classes == 2:
+ head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access
+ weight_column=weight_column,
+ label_vocabulary=label_vocabulary)
+ else:
+ head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access
+ n_classes, weight_column=weight_column,
+ label_vocabulary=label_vocabulary)
+ def _model_fn(features, labels, mode, config):
+ return _rnn_model_fn(
+ features=features,
+ labels=labels,
+ mode=mode,
+ head=head,
+ rnn_cell_fn=rnn_cell_fn,
+ sequence_feature_columns=tuple(sequence_feature_columns or []),
+ context_feature_columns=tuple(context_feature_columns or []),
+ optimizer=optimizer,
+ input_layer_partitioner=input_layer_partitioner,
+ config=config)
+ super(RNNClassifier, self).__init__(
+ model_fn=_model_fn, model_dir=model_dir, config=config)
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn_test.py b/tensorflow/contrib/estimator/python/estimator/rnn_test.py
new file mode 100644
index 0000000000..393f94f5c7
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/rnn_test.py
@@ -0,0 +1,1131 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for rnn.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import random
+import shutil
+import tempfile
+
+import numpy as np
+import six
+
+from tensorflow.contrib.estimator.python.estimator import rnn
+from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as seq_fc
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from tensorflow.python.estimator import model_fn
+from tensorflow.python.estimator.canned import metric_keys
+from tensorflow.python.estimator.canned import prediction_keys
+from tensorflow.python.estimator.export import export
+from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.feature_column import feature_column as fc
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import rnn_cell
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as variables_lib
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import checkpoint_utils
+from tensorflow.python.training import input as input_lib
+from tensorflow.python.training import monitored_session
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import training_util
+
+
+# Names of variables created by BasicRNNCell model.
+TOKEN_EMBEDDING_NAME = 'rnn/sequence_input_layer/input_layer/tokens_sequential_embedding/embedding_weights'
+CELL_WEIGHTS_NAME = 'rnn/rnn/basic_rnn_cell/kernel'
+CELL_BIAS_NAME = 'rnn/rnn/basic_rnn_cell/bias'
+MULTI_CELL_WEIGHTS_NAME_PATTERN = 'rnn/rnn/multi_rnn_cell/cell_%d/basic_rnn_cell/kernel'
+MULTI_CELL_BIAS_NAME_PATTERN = 'rnn/rnn/multi_rnn_cell/cell_%d/basic_rnn_cell/bias'
+LOGITS_WEIGHTS_NAME = 'rnn/logits/dense/kernel'
+LOGITS_BIAS_NAME = 'rnn/logits/dense/bias'
+
+
+def _assert_close(expected, actual, rtol=1e-04, name='assert_close'):
+ with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope:
+ expected = ops.convert_to_tensor(expected, name='expected')
+ actual = ops.convert_to_tensor(actual, name='actual')
+ rdiff = math_ops.abs(expected - actual, 'diff') / math_ops.abs(expected)
+ rtol = ops.convert_to_tensor(rtol, name='rtol')
+ return check_ops.assert_less(
+ rdiff,
+ rtol,
+ data=('Condition expected =~ actual did not hold element-wise:'
+ 'expected = ', expected, 'actual = ', actual, 'rdiff = ', rdiff,
+ 'rtol = ', rtol,),
+ name=scope)
+
+
+def create_checkpoint(rnn_weights, rnn_biases, logits_weights, logits_biases,
+ global_step, model_dir):
+ """Create checkpoint file with provided model weights.
+
+ Args:
+ rnn_weights: Iterable of values of weights for the RNN cell.
+ rnn_biases: Iterable of values of biases for the RNN cell.
+ logits_weights: Iterable of values for matrix connecting RNN output to
+ logits.
+ logits_biases: Iterable of values for logits bias term.
+ global_step: Initial global step to save in checkpoint.
+ model_dir: Directory into which checkpoint is saved.
+ """
+ model_weights = {}
+ model_weights[CELL_WEIGHTS_NAME] = rnn_weights
+ model_weights[CELL_BIAS_NAME] = rnn_biases
+ model_weights[LOGITS_WEIGHTS_NAME] = logits_weights
+ model_weights[LOGITS_BIAS_NAME] = logits_biases
+
+ with ops.Graph().as_default():
+ # Create model variables.
+ for k, v in six.iteritems(model_weights):
+ variables_lib.Variable(v, name=k, dtype=dtypes.float32)
+
+ # Create non-model variables.
+ global_step_var = training_util.create_global_step()
+ assign_op = global_step_var.assign(global_step)
+
+ # Initialize vars and save checkpoint.
+ with monitored_session.MonitoredTrainingSession(
+ checkpoint_dir=model_dir) as sess:
+ sess.run(assign_op)
+
+
+class RNNLogitFnTest(test.TestCase):
+ """Tests correctness of logits calculated from _rnn_logit_fn_builder."""
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+ def _test_logits(self, mode, rnn_units, logits_dimension, features_fn,
+ sequence_feature_columns, context_feature_columns,
+ expected_logits):
+ """Tests that the expected logits are calculated."""
+ with ops.Graph().as_default():
+ # Global step needed for MonitoredSession, which is in turn used to
+ # explicitly set variable weights through a checkpoint.
+ training_util.create_global_step()
+ # Use a variable scope here with 'rnn', emulating the rnn model_fn, so
+ # the checkpoint naming is shared.
+ with variable_scope.variable_scope('rnn'):
+ input_layer_partitioner = (
+ partitioned_variables.min_max_variable_partitioner(
+ max_partitions=0, min_slice_size=64 << 20))
+ logit_fn = rnn._rnn_logit_fn_builder(
+ output_units=logits_dimension,
+ rnn_cell_fn=rnn._make_rnn_cell_fn(rnn_units),
+ sequence_feature_columns=sequence_feature_columns,
+ context_feature_columns=context_feature_columns,
+ input_layer_partitioner=input_layer_partitioner)
+ # Features are constructed within this function, otherwise the Tensors
+ # containing the features would be defined outside this graph.
+ logits = logit_fn(features=features_fn(), mode=mode)
+ with monitored_session.MonitoredTrainingSession(
+ checkpoint_dir=self._model_dir) as sess:
+ self.assertAllClose(expected_logits, sess.run(logits), atol=1e-4)
+
+ def testOneDimLogits(self):
+ """Tests one-dimensional logits.
+
+ Intermediate values are rounded for ease in reading.
+ input_layer = [[[10]], [[5]]]
+ initial_state = [0, 0]
+ rnn_output_timestep_1 = [[tanh(.1*10 + .2*0 + .3*0 +.2),
+ tanh(-.2*10 - .3*0 - .4*0 +.5)]]
+ = [[0.83, -0.91]]
+ rnn_output_timestep_2 = [[tanh(.1*5 + .2*.83 - .3*.91 +.2),
+ tanh(-.2*5 - .3*.83 + .4*.91 +.5)]]
+ = [[0.53, -0.37]]
+ logits = [[-1*0.53 - 1*0.37 + 0.3]] = [[-0.6033]]
+ """
+ base_global_step = 100
+ create_checkpoint(
+ rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]],
+ rnn_biases=[.2, .5],
+ logits_weights=[[-1.], [1.]],
+ logits_biases=[0.3],
+ global_step=base_global_step,
+ model_dir=self._model_dir)
+
+ def features_fn():
+ return {
+ 'price':
+ sparse_tensor.SparseTensor(
+ values=[10., 5.],
+ indices=[[0, 0], [0, 1]],
+ dense_shape=[1, 2]),
+ }
+
+ sequence_feature_columns = [
+ seq_fc.sequence_numeric_column('price', shape=(1,))]
+ context_feature_columns = []
+ for mode in [
+ model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
+ model_fn.ModeKeys.PREDICT
+ ]:
+ self._test_logits(
+ mode,
+ rnn_units=[2],
+ logits_dimension=1,
+ features_fn=features_fn,
+ sequence_feature_columns=sequence_feature_columns,
+ context_feature_columns=context_feature_columns,
+ expected_logits=[[-0.6033]])
+
+ def testMultiDimLogits(self):
+ """Tests multi-dimensional logits.
+
+ Intermediate values are rounded for ease in reading.
+ input_layer = [[[10]], [[5]]]
+ initial_state = [0, 0]
+ rnn_output_timestep_1 = [[tanh(.1*10 + .2*0 + .3*0 +.2),
+ tanh(-.2*10 - .3*0 - .4*0 +.5)]]
+ = [[0.83, -0.91]]
+ rnn_output_timestep_2 = [[tanh(.1*5 + .2*.83 - .3*.91 +.2),
+ tanh(-.2*5 - .3*.83 + .4*.91 +.5)]]
+ = [[0.53, -0.37]]
+ logits = [[-1*0.53 - 1*0.37 + 0.3],
+ [0.5*0.53 + 0.3*0.37 + 0.4],
+ [0.2*0.53 - 0.1*0.37 + 0.5]
+ = [[-0.6033, 0.7777, 0.5698]]
+ """
+ base_global_step = 100
+ create_checkpoint(
+ rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]],
+ rnn_biases=[.2, .5],
+ logits_weights=[[-1., 0.5, 0.2], [1., -0.3, 0.1]],
+ logits_biases=[0.3, 0.4, 0.5],
+ global_step=base_global_step,
+ model_dir=self._model_dir)
+
+ def features_fn():
+ return {
+ 'price':
+ sparse_tensor.SparseTensor(
+ values=[10., 5.],
+ indices=[[0, 0], [0, 1]],
+ dense_shape=[1, 2]),
+ }
+
+ sequence_feature_columns = [
+ seq_fc.sequence_numeric_column('price', shape=(1,))]
+ context_feature_columns = []
+
+ for mode in [
+ model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
+ model_fn.ModeKeys.PREDICT
+ ]:
+ self._test_logits(
+ mode,
+ rnn_units=[2],
+ logits_dimension=3,
+ features_fn=features_fn,
+ sequence_feature_columns=sequence_feature_columns,
+ context_feature_columns=context_feature_columns,
+ expected_logits=[[-0.6033, 0.7777, 0.5698]])
+
+ def testMultiExampleMultiDim(self):
+ """Tests multiple examples and multi-dimensional logits.
+
+ Intermediate values are rounded for ease in reading.
+ input_layer = [[[10], [5]], [[2], [7]]]
+ initial_state = [[0, 0], [0, 0]]
+ rnn_output_timestep_1 = [[tanh(.1*10 + .2*0 + .3*0 +.2),
+ tanh(-.2*10 - .3*0 - .4*0 +.5)],
+ [tanh(.1*2 + .2*0 + .3*0 +.2),
+ tanh(-.2*2 - .3*0 - .4*0 +.5)]]
+ = [[0.83, -0.91], [0.38, 0.10]]
+ rnn_output_timestep_2 = [[tanh(.1*5 + .2*.83 - .3*.91 +.2),
+ tanh(-.2*5 - .3*.83 + .4*.91 +.5)],
+ [tanh(.1*7 + .2*.38 + .3*.10 +.2),
+ tanh(-.2*7 - .3*.38 - .4*.10 +.5)]]
+ = [[0.53, -0.37], [0.76, -0.78]
+ logits = [[-1*0.53 - 1*0.37 + 0.3,
+ 0.5*0.53 + 0.3*0.37 + 0.4,
+ 0.2*0.53 - 0.1*0.37 + 0.5],
+ [-1*0.76 - 1*0.78 + 0.3,
+ 0.5*0.76 +0.3*0.78 + 0.4,
+ 0.2*0.76 -0.1*0.78 + 0.5]]
+ = [[-0.6033, 0.7777, 0.5698], [-1.2473, 1.0170, 0.5745]]
+ """
+ base_global_step = 100
+ create_checkpoint(
+ rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]],
+ rnn_biases=[.2, .5],
+ logits_weights=[[-1., 0.5, 0.2], [1., -0.3, 0.1]],
+ logits_biases=[0.3, 0.4, 0.5],
+ global_step=base_global_step,
+ model_dir=self._model_dir)
+
+ def features_fn():
+ return {
+ 'price':
+ sparse_tensor.SparseTensor(
+ values=[10., 5., 2., 7.],
+ indices=[[0, 0], [0, 1], [1, 0], [1, 1]],
+ dense_shape=[2, 2]),
+ }
+
+ sequence_feature_columns = [
+ seq_fc.sequence_numeric_column('price', shape=(1,))
+ ]
+ context_feature_columns = []
+
+ for mode in [
+ model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
+ model_fn.ModeKeys.PREDICT
+ ]:
+ self._test_logits(
+ mode,
+ rnn_units=[2],
+ logits_dimension=3,
+ features_fn=features_fn,
+ sequence_feature_columns=sequence_feature_columns,
+ context_feature_columns=context_feature_columns,
+ expected_logits=[[-0.6033, 0.7777, 0.5698],
+ [-1.2473, 1.0170, 0.5745]])
+
+ def testMultiExamplesDifferentLength(self):
+ """Tests multiple examples with different lengths.
+
+ Intermediate values are rounded for ease in reading.
+ input_layer = [[[10], [5]], [[2], [0]]]
+ initial_state = [[0, 0], [0, 0]]
+ rnn_output_timestep_1 = [[tanh(.1*10 + .2*0 + .3*0 +.2),
+ tanh(-.2*10 - .3*0 - .4*0 +.5)],
+ [tanh(.1*2 + .2*0 + .3*0 +.2),
+ tanh(-.2*2 - .3*0 - .4*0 +.5)]]
+ = [[0.83, -0.91], [0.38, 0.10]]
+ rnn_output_timestep_2 = [[tanh(.1*5 + .2*.83 - .3*.91 +.2),
+ tanh(-.2*5 - .3*.83 + .4*.91 +.5)],
+ [<ignored-padding>]]
+ = [[0.53, -0.37], [<ignored-padding>]]
+ logits = [[-1*0.53 - 1*0.37 + 0.3],
+ [-1*0.38 + 1*0.10 + 0.3]]
+ = [[-0.6033], [0.0197]]
+ """
+ base_global_step = 100
+ create_checkpoint(
+ rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]],
+ rnn_biases=[.2, .5],
+ logits_weights=[[-1.], [1.]],
+ logits_biases=[0.3],
+ global_step=base_global_step,
+ model_dir=self._model_dir)
+
+ def features_fn():
+ return {
+ 'price':
+ sparse_tensor.SparseTensor(
+ values=[10., 5., 2.],
+ indices=[[0, 0], [0, 1], [1, 0]],
+ dense_shape=[2, 2]),
+ }
+
+ sequence_feature_columns = [
+ seq_fc.sequence_numeric_column('price', shape=(1,))]
+ context_feature_columns = []
+
+ for mode in [
+ model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
+ model_fn.ModeKeys.PREDICT
+ ]:
+ self._test_logits(
+ mode,
+ rnn_units=[2],
+ logits_dimension=1,
+ features_fn=features_fn,
+ sequence_feature_columns=sequence_feature_columns,
+ context_feature_columns=context_feature_columns,
+ expected_logits=[[-0.6033], [0.0197]])
+
+ def testMultiExamplesWithContext(self):
+ """Tests multiple examples with context features.
+
+ Intermediate values are rounded for ease in reading.
+ input_layer = [[[10, -0.5], [5, -0.5]], [[2, 0.8], [0, 0]]]
+ initial_state = [[0, 0], [0, 0]]
+ rnn_output_timestep_1 = [[tanh(.1*10 - 1*.5 + .2*0 + .3*0 +.2),
+ tanh(-.2*10 - 0.9*.5 - .3*0 - .4*0 +.5)],
+ [tanh(.1*2 + 1*.8 + .2*0 + .3*0 +.2),
+ tanh(-.2*2 + .9*.8 - .3*0 - .4*0 +.5)]]
+ = [[0.60, -0.96], [0.83, 0.68]]
+ rnn_output_timestep_2 = [[tanh(.1*5 - 1*.5 + .2*.60 - .3*.96 +.2),
+ tanh(-.2*5 - .9*.5 - .3*.60 + .4*.96 +.5)],
+ [<ignored-padding>]]
+ = [[0.03, -0.63], [<ignored-padding>]]
+ logits = [[-1*0.03 - 1*0.63 + 0.3],
+ [-1*0.83 + 1*0.68 + 0.3]]
+ = [[-0.3662], [0.1414]]
+ """
+ base_global_step = 100
+ create_checkpoint(
+ # Context features weights are inserted between input and state weights.
+ rnn_weights=[[.1, -.2], [1., 0.9], [.2, -.3], [.3, -.4]],
+ rnn_biases=[.2, .5],
+ logits_weights=[[-1.], [1.]],
+ logits_biases=[0.3],
+ global_step=base_global_step,
+ model_dir=self._model_dir)
+
+ def features_fn():
+ return {
+ 'price':
+ sparse_tensor.SparseTensor(
+ values=[10., 5., 2.],
+ indices=[[0, 0], [0, 1], [1, 0]],
+ dense_shape=[2, 2]),
+ 'context': [[-0.5], [0.8]],
+ }
+
+ sequence_feature_columns = [
+ seq_fc.sequence_numeric_column('price', shape=(1,))]
+ context_feature_columns = [fc.numeric_column('context', shape=(1,))]
+
+ for mode in [
+ model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
+ model_fn.ModeKeys.PREDICT
+ ]:
+ self._test_logits(
+ mode,
+ rnn_units=[2],
+ logits_dimension=1,
+ features_fn=features_fn,
+ sequence_feature_columns=sequence_feature_columns,
+ context_feature_columns=context_feature_columns,
+ expected_logits=[[-0.3662], [0.1414]])
+
+ def testMultiExamplesMultiFeatures(self):
+ """Tests examples with multiple sequential feature columns.
+
+ Intermediate values are rounded for ease in reading.
+ input_layer = [[[1, 0, 10], [0, 1, 5]], [[1, 0, 2], [0, 0, 0]]]
+ initial_state = [[0, 0], [0, 0]]
+ rnn_output_timestep_1 = [[tanh(.5*1 + 1*0 + .1*10 + .2*0 + .3*0 +.2),
+ tanh(-.5*1 - 1*0 - .2*10 - .3*0 - .4*0 +.5)],
+ [tanh(.5*1 + 1*0 + .1*2 + .2*0 + .3*0 +.2),
+ tanh(-.5*1 - 1*0 - .2*2 - .3*0 - .4*0 +.5)]]
+ = [[0.94, -0.96], [0.72, -0.38]]
+ rnn_output_timestep_2 = [[tanh(.5*0 + 1*1 + .1*5 + .2*.94 - .3*.96 +.2),
+ tanh(-.5*0 - 1*1 - .2*5 - .3*.94 + .4*.96 +.5)],
+ [<ignored-padding>]]
+ = [[0.92, -0.88], [<ignored-padding>]]
+ logits = [[-1*0.92 - 1*0.88 + 0.3],
+ [-1*0.72 - 1*0.38 + 0.3]]
+ = [[-1.5056], [-0.7962]]
+ """
+ base_global_step = 100
+ create_checkpoint(
+ # FeatureColumns are sorted alphabetically, so on_sale weights are
+ # inserted before price.
+ rnn_weights=[[.5, -.5], [1., -1.], [.1, -.2], [.2, -.3], [.3, -.4]],
+ rnn_biases=[.2, .5],
+ logits_weights=[[-1.], [1.]],
+ logits_biases=[0.3],
+ global_step=base_global_step,
+ model_dir=self._model_dir)
+
+ def features_fn():
+ return {
+ 'price':
+ sparse_tensor.SparseTensor(
+ values=[10., 5., 2.],
+ indices=[[0, 0], [0, 1], [1, 0]],
+ dense_shape=[2, 2]),
+ 'on_sale':
+ sparse_tensor.SparseTensor(
+ values=[0, 1, 0],
+ indices=[[0, 0], [0, 1], [1, 0]],
+ dense_shape=[2, 2]),
+ }
+
+ price_column = seq_fc.sequence_numeric_column('price', shape=(1,))
+ on_sale_column = fc.indicator_column(
+ seq_fc.sequence_categorical_column_with_identity(
+ 'on_sale', num_buckets=2))
+ sequence_feature_columns = [price_column, on_sale_column]
+ context_feature_columns = []
+
+ for mode in [
+ model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
+ model_fn.ModeKeys.PREDICT
+ ]:
+ self._test_logits(
+ mode,
+ rnn_units=[2],
+ logits_dimension=1,
+ features_fn=features_fn,
+ sequence_feature_columns=sequence_feature_columns,
+ context_feature_columns=context_feature_columns,
+ expected_logits=[[-1.5056], [-0.7962]])
+
+
+class RNNClassifierTrainingTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+ def _assert_checkpoint(
+ self, n_classes, input_units, cell_units, expected_global_step):
+
+ shapes = {
+ name: shape for (name, shape) in
+ checkpoint_utils.list_variables(self._model_dir)
+ }
+
+ self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP])
+ self.assertEqual(
+ expected_global_step,
+ checkpoint_utils.load_variable(
+ self._model_dir, ops.GraphKeys.GLOBAL_STEP))
+
+ # RNN Cell variables.
+ if len(cell_units) > 1:
+ for i, cell_unit in enumerate(cell_units):
+ self.assertEqual([input_units + cell_unit, cell_unit],
+ shapes[MULTI_CELL_WEIGHTS_NAME_PATTERN % i])
+ self.assertEqual([cell_unit],
+ shapes[MULTI_CELL_BIAS_NAME_PATTERN % i])
+ input_units = cell_unit
+ elif len(cell_units) == 1:
+ self.assertEqual([input_units + cell_unit, cell_unit],
+ shapes[CELL_WEIGHTS_NAME])
+ self.assertEqual([cell_unit], shapes[CELL_BIAS_NAME])
+
+ # Logits variables.
+ logits_dimension = n_classes if n_classes > 2 else 1
+ self.assertEqual([cell_units[-1], logits_dimension],
+ shapes[LOGITS_WEIGHTS_NAME])
+ self.assertEqual([logits_dimension], shapes[LOGITS_BIAS_NAME])
+
+ def _mock_optimizer(self, expected_loss=None):
+ expected_var_names = [
+ '%s/part_0:0' % CELL_BIAS_NAME,
+ '%s/part_0:0' % CELL_WEIGHTS_NAME,
+ '%s/part_0:0' % LOGITS_BIAS_NAME,
+ '%s/part_0:0' % LOGITS_WEIGHTS_NAME,
+ ]
+
+ def _minimize(loss, global_step):
+ trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertItemsEqual(
+ expected_var_names,
+ [var.name for var in trainable_vars])
+
+ # Verify loss. We can't check the value directly, so we add an assert op.
+ self.assertEquals(0, loss.shape.ndims)
+ if expected_loss is None:
+ return state_ops.assign_add(global_step, 1).op
+ assert_loss = _assert_close(
+ math_ops.to_float(expected_loss, name='expected'),
+ loss,
+ name='assert_loss')
+ with ops.control_dependencies((assert_loss,)):
+ return state_ops.assign_add(global_step, 1).op
+
+ mock_optimizer = test.mock.NonCallableMock(
+ spec=optimizer.Optimizer,
+ wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer'))
+ mock_optimizer.minimize = test.mock.MagicMock(wraps=_minimize)
+
+ # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks.
+ # So, return mock_optimizer itself for deepcopy.
+ mock_optimizer.__deepcopy__ = lambda _: mock_optimizer
+ return mock_optimizer
+
+ def testConflictingRNNCellFn(self):
+ col = seq_fc.sequence_categorical_column_with_hash_bucket(
+ 'tokens', hash_bucket_size=10)
+ embed = fc.embedding_column(col, dimension=2)
+ cell_units = [4, 2]
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'num_units and cell_type must not be specified when using rnn_cell_fn'):
+ rnn.RNNClassifier(
+ sequence_feature_columns=[embed],
+ rnn_cell_fn=lambda x: x,
+ num_units=cell_units)
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'num_units and cell_type must not be specified when using rnn_cell_fn'):
+ rnn.RNNClassifier(
+ sequence_feature_columns=[embed],
+ rnn_cell_fn=lambda x: x,
+ cell_type='lstm')
+
+ def _testFromScratchWithDefaultOptimizer(self, n_classes):
+ def train_input_fn():
+ return {
+ 'tokens':
+ sparse_tensor.SparseTensor(
+ values=['the', 'cat', 'sat'],
+ indices=[[0, 0], [0, 1], [0, 2]],
+ dense_shape=[1, 3]),
+ }, [[1]]
+
+ col = seq_fc.sequence_categorical_column_with_hash_bucket(
+ 'tokens', hash_bucket_size=10)
+ embed = fc.embedding_column(col, dimension=2)
+ input_units = 2
+
+ cell_units = [4, 2]
+ est = rnn.RNNClassifier(
+ sequence_feature_columns=[embed],
+ num_units=cell_units,
+ n_classes=n_classes,
+ model_dir=self._model_dir)
+
+ # Train for a few steps, and validate final checkpoint.
+ num_steps = 10
+ est.train(input_fn=train_input_fn, steps=num_steps)
+ self._assert_checkpoint(n_classes, input_units, cell_units, num_steps)
+
+ def testBinaryClassFromScratchWithDefaultOptimizer(self):
+ self._testFromScratchWithDefaultOptimizer(n_classes=2)
+
+ def testMultiClassFromScratchWithDefaultOptimizer(self):
+ self._testFromScratchWithDefaultOptimizer(n_classes=4)
+
+ def testFromScratchWithCustomRNNCellFn(self):
+ def train_input_fn():
+ return {
+ 'tokens':
+ sparse_tensor.SparseTensor(
+ values=['the', 'cat', 'sat'],
+ indices=[[0, 0], [0, 1], [0, 2]],
+ dense_shape=[1, 3]),
+ }, [[1]]
+
+ col = seq_fc.sequence_categorical_column_with_hash_bucket(
+ 'tokens', hash_bucket_size=10)
+ embed = fc.embedding_column(col, dimension=2)
+ input_units = 2
+ cell_units = [4, 2]
+ n_classes = 2
+
+ def rnn_cell_fn(mode):
+ del mode # unused
+ cells = [rnn_cell.BasicRNNCell(num_units=n) for n in cell_units]
+ return rnn_cell.MultiRNNCell(cells)
+
+ est = rnn.RNNClassifier(
+ sequence_feature_columns=[embed],
+ rnn_cell_fn=rnn_cell_fn,
+ n_classes=n_classes,
+ model_dir=self._model_dir)
+
+ # Train for a few steps, and validate final checkpoint.
+ num_steps = 10
+ est.train(input_fn=train_input_fn, steps=num_steps)
+ self._assert_checkpoint(n_classes, input_units, cell_units, num_steps)
+
+ def _testExampleWeight(self, n_classes):
+ def train_input_fn():
+ return {
+ 'tokens':
+ sparse_tensor.SparseTensor(
+ values=['the', 'cat', 'sat', 'dog', 'barked'],
+ indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]],
+ dense_shape=[2, 3]),
+ 'w': [[1], [2]],
+ }, [[1], [0]]
+
+ col = seq_fc.sequence_categorical_column_with_hash_bucket(
+ 'tokens', hash_bucket_size=10)
+ embed = fc.embedding_column(col, dimension=2)
+ input_units = 2
+
+ cell_units = [4, 2]
+ est = rnn.RNNClassifier(
+ num_units=cell_units,
+ sequence_feature_columns=[embed],
+ n_classes=n_classes,
+ weight_column='w',
+ model_dir=self._model_dir)
+
+ # Train for a few steps, and validate final checkpoint.
+ num_steps = 10
+ est.train(input_fn=train_input_fn, steps=num_steps)
+ self._assert_checkpoint(n_classes, input_units, cell_units, num_steps)
+
+ def testBinaryClassWithExampleWeight(self):
+ self._testExampleWeight(n_classes=2)
+
+ def testMultiClassWithExampleWeight(self):
+ self._testExampleWeight(n_classes=4)
+
+ def testBinaryClassFromCheckpoint(self):
+ initial_global_step = 100
+ create_checkpoint(
+ rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]],
+ rnn_biases=[.2, .5],
+ logits_weights=[[-1.], [1.]],
+ logits_biases=[0.3],
+ global_step=initial_global_step,
+ model_dir=self._model_dir)
+
+ def train_input_fn():
+ return {
+ 'price':
+ sparse_tensor.SparseTensor(
+ values=[10., 5., 2.],
+ indices=[[0, 0], [0, 1], [1, 0]],
+ dense_shape=[2, 2]),
+ }, [[0], [1]]
+
+ # Uses same checkpoint and examples as testBinaryClassEvaluationMetrics.
+ # See that test for loss calculation.
+ mock_optimizer = self._mock_optimizer(expected_loss=1.119661)
+
+ sequence_feature_columns = [
+ seq_fc.sequence_numeric_column('price', shape=(1,))]
+ est = rnn.RNNClassifier(
+ num_units=[2],
+ sequence_feature_columns=sequence_feature_columns,
+ n_classes=2,
+ optimizer=mock_optimizer,
+ model_dir=self._model_dir)
+ self.assertEqual(0, mock_optimizer.minimize.call_count)
+ est.train(input_fn=train_input_fn, steps=10)
+ self.assertEqual(1, mock_optimizer.minimize.call_count)
+
+ def testMultiClassFromCheckpoint(self):
+ initial_global_step = 100
+ create_checkpoint(
+ rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]],
+ rnn_biases=[.2, .5],
+ logits_weights=[[-1., 0.5, 0.2], [1., -0.3, 0.1]],
+ logits_biases=[0.3, 0.4, 0.5],
+ global_step=initial_global_step,
+ model_dir=self._model_dir)
+
+ def train_input_fn():
+ return {
+ 'price':
+ sparse_tensor.SparseTensor(
+ values=[10., 5., 2., 7.],
+ indices=[[0, 0], [0, 1], [1, 0], [1, 1]],
+ dense_shape=[2, 2]),
+ }, [[0], [1]]
+
+ # Uses same checkpoint and examples as testMultiClassEvaluationMetrics.
+ # See that test for loss calculation.
+ mock_optimizer = self._mock_optimizer(expected_loss=2.662932)
+
+ sequence_feature_columns = [
+ seq_fc.sequence_numeric_column('price', shape=(1,))]
+ est = rnn.RNNClassifier(
+ num_units=[2],
+ sequence_feature_columns=sequence_feature_columns,
+ n_classes=3,
+ optimizer=mock_optimizer,
+ model_dir=self._model_dir)
+ self.assertEqual(0, mock_optimizer.minimize.call_count)
+ est.train(input_fn=train_input_fn, steps=10)
+ self.assertEqual(1, mock_optimizer.minimize.call_count)
+
+
+def sorted_key_dict(unsorted_dict):
+ return {k: unsorted_dict[k] for k in sorted(unsorted_dict)}
+
+
+class RNNClassifierEvaluationTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+ def testBinaryClassEvaluationMetrics(self):
+ global_step = 100
+ create_checkpoint(
+ rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]],
+ rnn_biases=[.2, .5],
+ logits_weights=[[-1.], [1.]],
+ logits_biases=[0.3],
+ global_step=global_step,
+ model_dir=self._model_dir)
+
+ def eval_input_fn():
+ return {
+ 'price':
+ sparse_tensor.SparseTensor(
+ values=[10., 5., 2.],
+ indices=[[0, 0], [0, 1], [1, 0]],
+ dense_shape=[2, 2]),
+ }, [[0], [1]]
+
+ sequence_feature_columns = [
+ seq_fc.sequence_numeric_column('price', shape=(1,))]
+
+ est = rnn.RNNClassifier(
+ num_units=[2],
+ sequence_feature_columns=sequence_feature_columns,
+ n_classes=2,
+ model_dir=self._model_dir)
+ eval_metrics = est.evaluate(eval_input_fn, steps=1)
+
+ # Uses identical numbers to testMultiExamplesWithDifferentLength.
+ # See that test for logits calculation.
+ # logits = [[-0.603282], [0.019719]]
+ # probability = exp(logits) / (1 + exp(logits)) = [[0.353593], [0.504930]]
+ # loss = -label * ln(p) - (1 - label) * ln(1 - p)
+ # = [[0.436326], [0.683335]]
+ expected_metrics = {
+ ops.GraphKeys.GLOBAL_STEP: global_step,
+ metric_keys.MetricKeys.LOSS: 1.119661,
+ metric_keys.MetricKeys.LOSS_MEAN: 0.559831,
+ metric_keys.MetricKeys.ACCURACY: 1.0,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 0.429262,
+ metric_keys.MetricKeys.LABEL_MEAN: 0.5,
+ metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5,
+ # With default threshold of 0.5, the model is a perfect classifier.
+ metric_keys.MetricKeys.RECALL: 1.0,
+ metric_keys.MetricKeys.PRECISION: 1.0,
+ # Positive example is scored above negative, so AUC = 1.0.
+ metric_keys.MetricKeys.AUC: 1.0,
+ metric_keys.MetricKeys.AUC_PR: 1.0,
+ }
+ self.assertAllClose(
+ sorted_key_dict(expected_metrics), sorted_key_dict(eval_metrics))
+
+ def testMultiClassEvaluationMetrics(self):
+ global_step = 100
+ create_checkpoint(
+ rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]],
+ rnn_biases=[.2, .5],
+ logits_weights=[[-1., 0.5, 0.2], [1., -0.3, 0.1]],
+ logits_biases=[0.3, 0.4, 0.5],
+ global_step=global_step,
+ model_dir=self._model_dir)
+
+ def eval_input_fn():
+ return {
+ 'price':
+ sparse_tensor.SparseTensor(
+ values=[10., 5., 2., 7.],
+ indices=[[0, 0], [0, 1], [1, 0], [1, 1]],
+ dense_shape=[2, 2]),
+ }, [[0], [1]]
+
+ sequence_feature_columns = [
+ seq_fc.sequence_numeric_column('price', shape=(1,))]
+
+ est = rnn.RNNClassifier(
+ num_units=[2],
+ sequence_feature_columns=sequence_feature_columns,
+ n_classes=3,
+ model_dir=self._model_dir)
+ eval_metrics = est.evaluate(eval_input_fn, steps=1)
+
+ # Uses identical numbers to testMultiExampleMultiDim.
+ # See that test for logits calculation.
+ # logits = [[-0.603282, 0.777708, 0.569756],
+ # [-1.247356, 1.017018, 0.574481]]
+ # logits_exp = exp(logits) / (1 + exp(logits))
+ # = [[0.547013, 2.176468, 1.767836],
+ # [0.287263, 2.764937, 1.776208]]
+ # softmax_probabilities = logits_exp / logits_exp.sum()
+ # = [[0.121793, 0.484596, 0.393611],
+ # [0.059494, 0.572639, 0.367866]]
+ # loss = -1. * log(softmax[label])
+ # = [[2.105432], [0.557500]]
+ expected_metrics = {
+ ops.GraphKeys.GLOBAL_STEP: global_step,
+ metric_keys.MetricKeys.LOSS: 2.662932,
+ metric_keys.MetricKeys.LOSS_MEAN: 1.331466,
+ metric_keys.MetricKeys.ACCURACY: 0.5,
+ }
+
+ self.assertAllClose(
+ sorted_key_dict(expected_metrics), sorted_key_dict(eval_metrics))
+
+
+class RNNClassifierPredictionTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+ def testBinaryClassPredictions(self):
+ create_checkpoint(
+ rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]],
+ rnn_biases=[.2, .5],
+ logits_weights=[[-1.], [1.]],
+ logits_biases=[0.3],
+ global_step=0,
+ model_dir=self._model_dir)
+
+ def predict_input_fn():
+ return {
+ 'price':
+ sparse_tensor.SparseTensor(
+ values=[10., 5.],
+ indices=[[0, 0], [0, 1]],
+ dense_shape=[1, 2]),
+ }
+
+ sequence_feature_columns = [
+ seq_fc.sequence_numeric_column('price', shape=(1,))]
+ label_vocabulary = ['class_0', 'class_1']
+
+ est = rnn.RNNClassifier(
+ num_units=[2],
+ sequence_feature_columns=sequence_feature_columns,
+ n_classes=2,
+ label_vocabulary=label_vocabulary,
+ model_dir=self._model_dir)
+ # Uses identical numbers to testOneDimLogits.
+ # See that test for logits calculation.
+ # logits = [-0.603282]
+ # logistic = exp(-0.6033) / (1 + exp(-0.6033)) = [0.353593]
+ # probabilities = [0.646407, 0.353593]
+ # class_ids = argmax(probabilities) = [0]
+ predictions = next(est.predict(predict_input_fn))
+ self.assertAllClose([-0.603282],
+ predictions[prediction_keys.PredictionKeys.LOGITS])
+ self.assertAllClose([0.353593],
+ predictions[prediction_keys.PredictionKeys.LOGISTIC])
+ self.assertAllClose(
+ [0.646407, 0.353593],
+ predictions[prediction_keys.PredictionKeys.PROBABILITIES])
+ self.assertAllClose([0],
+ predictions[prediction_keys.PredictionKeys.CLASS_IDS])
+ self.assertEqual([b'class_0'],
+ predictions[prediction_keys.PredictionKeys.CLASSES])
+
+ def testMultiClassPredictions(self):
+ create_checkpoint(
+ rnn_weights=[[.1, -.2], [.2, -.3], [.3, -.4]],
+ rnn_biases=[.2, .5],
+ logits_weights=[[-1., 0.5, 0.2], [1., -0.3, 0.1]],
+ logits_biases=[0.3, 0.4, 0.5],
+ global_step=0,
+ model_dir=self._model_dir)
+
+ def predict_input_fn():
+ return {
+ 'price':
+ sparse_tensor.SparseTensor(
+ values=[10., 5.],
+ indices=[[0, 0], [0, 1]],
+ dense_shape=[1, 2]),
+ }
+
+ sequence_feature_columns = [
+ seq_fc.sequence_numeric_column('price', shape=(1,))]
+ label_vocabulary = ['class_0', 'class_1', 'class_2']
+
+ est = rnn.RNNClassifier(
+ num_units=[2],
+ sequence_feature_columns=sequence_feature_columns,
+ n_classes=3,
+ label_vocabulary=label_vocabulary,
+ model_dir=self._model_dir)
+ # Uses identical numbers to testMultiDimLogits.
+ # See that test for logits calculation.
+ # logits = [-0.603282, 0.777708, 0.569756]
+ # logits_exp = exp(logits) = [0.547013, 2.176468, 1.767836]
+ # softmax_probabilities = logits_exp / logits_exp.sum()
+ # = [0.121793, 0.484596, 0.393611]
+ # class_ids = argmax(probabilities) = [1]
+ predictions = next(est.predict(predict_input_fn))
+ self.assertAllClose([-0.603282, 0.777708, 0.569756],
+ predictions[prediction_keys.PredictionKeys.LOGITS])
+ self.assertAllClose(
+ [0.121793, 0.484596, 0.393611],
+ predictions[prediction_keys.PredictionKeys.PROBABILITIES])
+ self.assertAllClose([1],
+ predictions[prediction_keys.PredictionKeys.CLASS_IDS])
+ self.assertEqual([b'class_1'],
+ predictions[prediction_keys.PredictionKeys.CLASSES])
+
+
+class RNNClassifierIntegrationTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+ def _test_complete_flow(
+ self, train_input_fn, eval_input_fn, predict_input_fn, n_classes,
+ batch_size):
+ col = seq_fc.sequence_categorical_column_with_hash_bucket(
+ 'tokens', hash_bucket_size=10)
+ embed = fc.embedding_column(col, dimension=2)
+ feature_columns = [embed]
+
+ cell_units = [4, 2]
+ est = rnn.RNNClassifier(
+ num_units=cell_units,
+ sequence_feature_columns=feature_columns,
+ n_classes=n_classes,
+ model_dir=self._model_dir)
+
+ # TRAIN
+ num_steps = 10
+ est.train(train_input_fn, steps=num_steps)
+
+ # EVALUATE
+ scores = est.evaluate(eval_input_fn)
+ self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
+ self.assertIn('loss', six.iterkeys(scores))
+
+ # PREDICT
+ predicted_proba = np.array([
+ x[prediction_keys.PredictionKeys.PROBABILITIES]
+ for x in est.predict(predict_input_fn)
+ ])
+ self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)
+
+ # EXPORT
+ feature_spec = {
+ 'tokens': parsing_ops.VarLenFeature(dtypes.string),
+ 'label': parsing_ops.FixedLenFeature([1], dtypes.int64),
+ }
+ serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+ feature_spec)
+ export_dir = est.export_savedmodel(tempfile.mkdtemp(),
+ serving_input_receiver_fn)
+ self.assertTrue(gfile.Exists(export_dir))
+
+ def testNumpyInputFn(self):
+ """Tests complete flow with numpy_input_fn."""
+ n_classes = 3
+ batch_size = 10
+ words = ['dog', 'cat', 'bird', 'the', 'a', 'sat', 'flew', 'slept']
+ # Numpy only supports dense input, so all examples will have same length.
+ # TODO(b/73160931): Update test when support for prepadded data exists.
+ sequence_length = 3
+
+ features = []
+ for _ in range(batch_size):
+ sentence = random.sample(words, sequence_length)
+ features.append(sentence)
+
+ x_data = np.array(features)
+ y_data = np.random.randint(n_classes, size=batch_size)
+
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'tokens': x_data},
+ y=y_data,
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ eval_input_fn = numpy_io.numpy_input_fn(
+ x={'tokens': x_data},
+ y=y_data,
+ batch_size=batch_size,
+ shuffle=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'tokens': x_data},
+ batch_size=batch_size,
+ shuffle=False)
+
+ self._test_complete_flow(
+ train_input_fn=train_input_fn,
+ eval_input_fn=eval_input_fn,
+ predict_input_fn=predict_input_fn,
+ n_classes=n_classes,
+ batch_size=batch_size)
+
+ def testParseExampleInputFn(self):
+ """Tests complete flow with input_fn constructed from parse_example."""
+ n_classes = 3
+ batch_size = 10
+ words = [b'dog', b'cat', b'bird', b'the', b'a', b'sat', b'flew', b'slept']
+
+ serialized_examples = []
+ for _ in range(batch_size):
+ sequence_length = random.randint(1, len(words))
+ sentence = random.sample(words, sequence_length)
+ label = random.randint(0, n_classes - 1)
+ example = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'tokens':
+ feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+ value=sentence)),
+ 'label':
+ feature_pb2.Feature(int64_list=feature_pb2.Int64List(
+ value=[label])),
+ }))
+ serialized_examples.append(example.SerializeToString())
+
+ feature_spec = {
+ 'tokens': parsing_ops.VarLenFeature(dtypes.string),
+ 'label': parsing_ops.FixedLenFeature([1], dtypes.int64),
+ }
+ def _train_input_fn():
+ features = parsing_ops.parse_example(serialized_examples, feature_spec)
+ labels = features.pop('label')
+ return features, labels
+ def _eval_input_fn():
+ features = parsing_ops.parse_example(
+ input_lib.limit_epochs(serialized_examples, num_epochs=1),
+ feature_spec)
+ labels = features.pop('label')
+ return features, labels
+ def _predict_input_fn():
+ features = parsing_ops.parse_example(
+ input_lib.limit_epochs(serialized_examples, num_epochs=1),
+ feature_spec)
+ features.pop('label')
+ return features, None
+
+ self._test_complete_flow(
+ train_input_fn=_train_input_fn,
+ eval_input_fn=_eval_input_fn,
+ predict_input_fn=_predict_input_fn,
+ n_classes=n_classes,
+ batch_size=batch_size)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/graph_editor/select.py b/tensorflow/contrib/graph_editor/select.py
index 3ea6ff4d61..d700e6e1a7 100644
--- a/tensorflow/contrib/graph_editor/select.py
+++ b/tensorflow/contrib/graph_editor/select.py
@@ -383,6 +383,7 @@ def get_within_boundary_ops(ops,
def get_forward_walk_ops(seed_ops,
inclusive=True,
within_ops=None,
+ within_ops_fn=None,
stop_at_ts=(),
control_outputs=None):
"""Do a forward graph walk and return all the visited ops.
@@ -395,6 +396,9 @@ def get_forward_walk_ops(seed_ops,
within_ops: an iterable of `tf.Operation` within which the search is
restricted. If `within_ops` is `None`, the search is performed within
the whole graph.
+ within_ops_fn: if provided, a function on ops that should return True iff
+ the op is within the graph traversal. This can be used along within_ops,
+ in which case an op is within if it is also in within_ops.
stop_at_ts: an iterable of tensors at which the graph walk stops.
control_outputs: a `util.ControlOutputs` instance or None.
If not `None`, it will be used while walking the graph forward.
@@ -423,7 +427,8 @@ def get_forward_walk_ops(seed_ops,
seed_ops &= within_ops
def is_within(op):
- return within_ops is None or op in within_ops
+ return (within_ops is None or op in within_ops) and (
+ within_ops_fn is None or within_ops_fn(op))
result = list(seed_ops)
wave = set(seed_ops)
@@ -450,6 +455,7 @@ def get_forward_walk_ops(seed_ops,
def get_backward_walk_ops(seed_ops,
inclusive=True,
within_ops=None,
+ within_ops_fn=None,
stop_at_ts=(),
control_inputs=False):
"""Do a backward graph walk and return all the visited ops.
@@ -462,6 +468,9 @@ def get_backward_walk_ops(seed_ops,
within_ops: an iterable of `tf.Operation` within which the search is
restricted. If `within_ops` is `None`, the search is performed within
the whole graph.
+ within_ops_fn: if provided, a function on ops that should return True iff
+ the op is within the graph traversal. This can be used along within_ops,
+ in which case an op is within if it is also in within_ops.
stop_at_ts: an iterable of tensors at which the graph walk stops.
control_inputs: if True, control inputs will be used while moving backward.
Returns:
@@ -488,7 +497,8 @@ def get_backward_walk_ops(seed_ops,
seed_ops &= within_ops
def is_within(op):
- return within_ops is None or op in within_ops
+ return (within_ops is None or op in within_ops) and (
+ within_ops_fn is None or within_ops_fn(op))
result = list(seed_ops)
wave = set(seed_ops)
@@ -516,6 +526,7 @@ def get_walks_intersection_ops(forward_seed_ops,
forward_inclusive=True,
backward_inclusive=True,
within_ops=None,
+ within_ops_fn=None,
control_inputs=False,
control_outputs=None,
control_ios=None):
@@ -535,6 +546,9 @@ def get_walks_intersection_ops(forward_seed_ops,
within_ops: an iterable of tf.Operation within which the search is
restricted. If within_ops is None, the search is performed within
the whole graph.
+ within_ops_fn: if provided, a function on ops that should return True iff
+ the op is within the graph traversal. This can be used along within_ops,
+ in which case an op is within if it is also in within_ops.
control_inputs: A boolean indicating whether control inputs are enabled.
control_outputs: An instance of util.ControlOutputs or None. If not None,
control outputs are enabled.
@@ -555,11 +569,13 @@ def get_walks_intersection_ops(forward_seed_ops,
forward_seed_ops,
inclusive=forward_inclusive,
within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
control_outputs=control_outputs)
backward_ops = get_backward_walk_ops(
backward_seed_ops,
inclusive=backward_inclusive,
within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
control_inputs=control_inputs)
return [op for op in forward_ops if op in backward_ops]
@@ -569,6 +585,7 @@ def get_walks_union_ops(forward_seed_ops,
forward_inclusive=True,
backward_inclusive=True,
within_ops=None,
+ within_ops_fn=None,
control_inputs=False,
control_outputs=None,
control_ios=None):
@@ -587,6 +604,9 @@ def get_walks_union_ops(forward_seed_ops,
resulting set.
within_ops: restrict the search within those operations. If within_ops is
None, the search is done within the whole graph.
+ within_ops_fn: if provided, a function on ops that should return True iff
+ the op is within the graph traversal. This can be used along within_ops,
+ in which case an op is within if it is also in within_ops.
control_inputs: A boolean indicating whether control inputs are enabled.
control_outputs: An instance of util.ControlOutputs or None. If not None,
control outputs are enabled.
@@ -607,11 +627,13 @@ def get_walks_union_ops(forward_seed_ops,
forward_seed_ops,
inclusive=forward_inclusive,
within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
control_outputs=control_outputs)
backward_ops = get_backward_walk_ops(
backward_seed_ops,
inclusive=backward_inclusive,
within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
control_inputs=control_inputs)
return util.concatenate_unique(forward_ops, backward_ops)
diff --git a/tensorflow/contrib/graph_editor/tests/select_test.py b/tensorflow/contrib/graph_editor/tests/select_test.py
index 82f999637d..d12c6d3cbd 100644
--- a/tensorflow/contrib/graph_editor/tests/select_test.py
+++ b/tensorflow/contrib/graph_editor/tests/select_test.py
@@ -77,12 +77,10 @@ class SelectTest(test.TestCase):
"""Test for ge.get_ops_ios."""
control_outputs = ge.util.ControlOutputs(self.graph)
self.assertEqual(
- len(ge.get_ops_ios(
- self.h.op, control_ios=control_outputs)), 3)
+ len(ge.get_ops_ios(self.h.op, control_ios=control_outputs)), 3)
self.assertEqual(len(ge.get_ops_ios(self.h.op)), 2)
self.assertEqual(
- len(ge.get_ops_ios(
- self.c.op, control_ios=control_outputs)), 6)
+ len(ge.get_ops_ios(self.c.op, control_ios=control_outputs)), 6)
self.assertEqual(len(ge.get_ops_ios(self.c.op)), 5)
def test_compute_boundary_ts_0(self):
@@ -135,16 +133,49 @@ class SelectTest(test.TestCase):
ops = ge.get_walks_intersection_ops([self.c.op], [self.g.op])
self.assertEqual(len(ops), 2)
+ ops = ge.get_walks_intersection_ops([self.a.op], [self.f.op])
+ self.assertEqual(len(ops), 3)
+ self.assertTrue(self.a.op in ops)
+ self.assertTrue(self.c.op in ops)
+ self.assertTrue(self.f.op in ops)
+
+ within_ops = [self.a.op, self.f.op]
+ ops = ge.get_walks_intersection_ops(
+ [self.a.op], [self.f.op], within_ops=within_ops)
+ self.assertEqual(len(ops), 0)
+
+ within_ops_fn = lambda op: op in [self.a.op, self.f.op]
+ ops = ge.get_walks_intersection_ops(
+ [self.a.op], [self.f.op], within_ops_fn=within_ops_fn)
+ self.assertEqual(len(ops), 0)
+
def test_get_walks_union(self):
"""Test for ge.get_walks_union_ops."""
ops = ge.get_walks_union_ops([self.f.op], [self.g.op])
self.assertEqual(len(ops), 6)
+ ops = ge.get_walks_union_ops([self.a.op], [self.f.op])
+ self.assertEqual(len(ops), 8)
+
+ within_ops = [self.a.op, self.c.op, self.d.op, self.f.op]
+ ops = ge.get_walks_union_ops([self.a.op], [self.f.op],
+ within_ops=within_ops)
+ self.assertEqual(len(ops), 4)
+ self.assertTrue(self.b.op not in ops)
+
+ within_ops_fn = lambda op: op in [self.a.op, self.c.op, self.f.op]
+ ops = ge.get_walks_union_ops([self.a.op], [self.f.op],
+ within_ops_fn=within_ops_fn)
+ self.assertEqual(len(ops), 3)
+ self.assertTrue(self.b.op not in ops)
+ self.assertTrue(self.d.op not in ops)
+
def test_select_ops(self):
parameters = (
(("^foo/",), 7),
(("^foo/bar/",), 4),
- (("^foo/bar/", "a"), 5),)
+ (("^foo/bar/", "a"), 5),
+ )
for param, length in parameters:
ops = ge.select_ops(*param, graph=self.graph)
self.assertEqual(len(ops), length)
@@ -152,7 +183,8 @@ class SelectTest(test.TestCase):
def test_select_ts(self):
parameters = (
(".*:0", 8),
- (r".*/bar/\w+:0", 4),)
+ (r".*/bar/\w+:0", 4),
+ )
for regex, length in parameters:
ts = ge.select_ts(regex, graph=self.graph)
self.assertEqual(len(ts), length)
@@ -160,12 +192,121 @@ class SelectTest(test.TestCase):
def test_select_ops_and_ts(self):
parameters = (
(("^foo/.*",), 7, 0),
- (("^foo/.*", "(?#ts)^foo/bar/.*"), 7, 4),)
+ (("^foo/.*", "(?#ts)^foo/bar/.*"), 7, 4),
+ )
for param, l0, l1 in parameters:
ops, ts = ge.select_ops_and_ts(*param, graph=self.graph)
self.assertEqual(len(ops), l0)
self.assertEqual(len(ts), l1)
+ def test_forward_walk_ops(self):
+ seed_ops = [self.a.op, self.d.op]
+ # Include all ops except for self.g.op
+ within_ops = [
+ x.op for x in [self.a, self.b, self.c, self.d, self.e, self.f, self.h]
+ ]
+ # For the fn, exclude self.e.op.
+ within_ops_fn = lambda op: op not in (self.e.op,)
+ stop_at_ts = (self.f,)
+
+ with self.graph.as_default():
+ # No b.op since it's an independent source node.
+ # No g.op from within_ops.
+ # No e.op from within_ops_fn.
+ # No h.op from stop_at_ts and within_ops.
+ ops = ge.select.get_forward_walk_ops(
+ seed_ops,
+ inclusive=True,
+ within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
+ stop_at_ts=stop_at_ts)
+ self.assertEqual(
+ set(ops), set([self.a.op, self.c.op, self.d.op, self.f.op]))
+
+ # Also no a.op and d.op when inclusive=False
+ ops = ge.select.get_forward_walk_ops(
+ seed_ops,
+ inclusive=False,
+ within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
+ stop_at_ts=stop_at_ts)
+ self.assertEqual(set(ops), set([self.c.op, self.f.op]))
+
+ # Not using within_ops_fn adds e.op.
+ ops = ge.select.get_forward_walk_ops(
+ seed_ops,
+ inclusive=False,
+ within_ops=within_ops,
+ stop_at_ts=stop_at_ts)
+ self.assertEqual(set(ops), set([self.c.op, self.e.op, self.f.op]))
+
+ # Not using stop_at_ts adds back h.op.
+ ops = ge.select.get_forward_walk_ops(
+ seed_ops, inclusive=False, within_ops=within_ops)
+ self.assertEqual(
+ set(ops), set([self.c.op, self.e.op, self.f.op, self.h.op]))
+
+ # Starting just form a (the tensor, not op) omits a, b, d.
+ ops = ge.select.get_forward_walk_ops([self.a], inclusive=True)
+ self.assertEqual(
+ set(ops), set([self.c.op, self.e.op, self.f.op, self.g.op,
+ self.h.op]))
+
+ def test_backward_walk_ops(self):
+ seed_ops = [self.h.op]
+ # Include all ops except for self.g.op
+ within_ops = [
+ x.op for x in [self.a, self.b, self.c, self.d, self.e, self.f, self.h]
+ ]
+ # For the fn, exclude self.c.op.
+ within_ops_fn = lambda op: op not in (self.c.op,)
+ stop_at_ts = (self.f,)
+
+ with self.graph.as_default():
+ # Backward walk only includes h since we stop at f and g is not within.
+ ops = ge.select.get_backward_walk_ops(
+ seed_ops,
+ inclusive=True,
+ within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
+ stop_at_ts=stop_at_ts)
+ self.assertEqual(set(ops), set([self.h.op]))
+
+ # If we do inclusive=False, the result is empty.
+ ops = ge.select.get_backward_walk_ops(
+ seed_ops,
+ inclusive=False,
+ within_ops=within_ops,
+ within_ops_fn=within_ops_fn,
+ stop_at_ts=stop_at_ts)
+ self.assertEqual(set(ops), set())
+
+ # Removing stop_at_fs adds f.op, d.op.
+ ops = ge.select.get_backward_walk_ops(
+ seed_ops,
+ inclusive=True,
+ within_ops=within_ops,
+ within_ops_fn=within_ops_fn)
+ self.assertEqual(set(ops), set([self.d.op, self.f.op, self.h.op]))
+
+ # Not using within_ops_fn adds back ops for a, b, c.
+ ops = ge.select.get_backward_walk_ops(
+ seed_ops, inclusive=True, within_ops=within_ops)
+ self.assertEqual(
+ set(ops),
+ set([
+ self.a.op, self.b.op, self.c.op, self.d.op, self.f.op, self.h.op
+ ]))
+
+ # Vanially backward search via self.h.op includes everything excpet e.op.
+ ops = ge.select.get_backward_walk_ops(seed_ops, inclusive=True)
+ self.assertEqual(
+ set(ops),
+ set([
+ self.a.op, self.b.op, self.c.op, self.d.op, self.f.op, self.g.op,
+ self.h.op
+ ]))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
index e49589ddf6..02d294c68f 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
@@ -247,9 +247,7 @@ class RevBlock(base.Layer):
f_vars_idxs = [[] for _ in range(self.num_layers)]
g_vars_idxs = [[] for _ in range(self.num_layers)]
- for i, t in enumerate(variables):
- ref = _underlying_variable_ref(t)
-
+ for i, ref in enumerate(variables):
# Use the name to identify the layer number and function (f or g)
regex = LAYER_RE.match(ref.name)
layer_no = int(regex.group(1))
@@ -604,6 +602,7 @@ def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False):
"""Custom grad fn applying grad_fn for identity Defun."""
fn_inputs, fn_vars, fn_outputs = nest.pack_sequence_as(
defun_inputs, list(op.inputs))
+ fn_vars = [_underlying_variable_ref(v) for v in fn_vars]
dys = list(dys)
assert len(fn_outputs) == len(outputs)
assert len(fn_outputs) == len(dys)
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
index d1ad4e8c98..392a490be1 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
@@ -304,6 +304,20 @@ class RecomputeTest(test.TestCase):
self.assertAllClose(current, g)
current = g
+ def testResourceVariable(self):
+ @rev_block_lib.recompute_grad(tupleize_grads=True)
+ def layer_with_recompute(inputs):
+ var = variable_scope.get_variable("var", ())
+ return var * inputs
+
+ inputs = array_ops.ones((), dtypes.float32)
+ with variable_scope.variable_scope("layer", use_resource=True):
+ outputs = layer_with_recompute(inputs)
+ loss = math_ops.square(outputs)
+ grads = gradients_impl.gradients(loss, variables.trainable_variables())
+ self.assertEqual(1, len(grads))
+ self.assertTrue(grads[0] is not None)
+
class FnWithCustomGradTest(test.TestCase):
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index 2813d1c347..b8f6b7fd59 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -200,8 +200,7 @@ def gen_zipped_test_files(name, files):
native.genrule(
name = name + "_" + f + ".files",
cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco
- + " --zip_to_output " + f +
- " $(@D) zipped"),
+ + " --zip_to_output " + f + " $(@D)"),
outs = [out_file],
tools = [
":generate_examples",
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index 77db178783..a6d582a813 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -208,7 +208,7 @@ class Interpreter {
// TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this
// read/write access to structure
const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration(
- int node_index) {
+ int node_index) const {
if (node_index >= nodes_and_registration_.size() || node_index < 0)
return nullptr;
return &nodes_and_registration_[node_index];
diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD
index b14230acd7..1dda55b8ed 100644
--- a/tensorflow/contrib/lite/java/BUILD
+++ b/tensorflow/contrib/lite/java/BUILD
@@ -117,6 +117,7 @@ java_test(
"src/testdata/int64.bin",
"src/testdata/invalid_model.bin",
"src/testdata/uint8.bin",
+ "src/testdata/with_custom_op.lite",
],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.lite.NativeInterpreterWrapperTest",
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 9a274612ad..5acf1eaede 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -554,88 +554,261 @@ inline void GEMVForLstmCellWithSymmetricRange(
// exercises it). We just guard our assumptions about size evenness with
// the following assertions.
TFLITE_DCHECK(!(output_size % 4));
- TFLITE_DCHECK(!(input_size % 8));
+ TFLITE_DCHECK(!(input_size % 64));
const int32* bias_ptr = bias_data;
int16* output_ptr = output_data;
const uint8x16_t signbit = vdupq_n_u8(0x80);
for (int in = 0; in < input_size; in += 32) {
optimized_ops_preload_l1_keep(input_data + in);
}
+ const int left_shift = accum_shift > 0 ? accum_shift : 0;
+ const int right_shift = accum_shift > 0 ? 0 : -accum_shift;
for (int out = 0; out < output_size; out += 4) {
- const uint8* weights_ptr_0 = weights_data + out * input_size;
- const uint8* weights_ptr_1 = weights_ptr_0 + 1 * input_size;
- const uint8* weights_ptr_2 = weights_ptr_0 + 2 * input_size;
- const uint8* weights_ptr_3 = weights_ptr_0 + 3 * input_size;
+ // Load the bias values
+ int32x4_t bias_vec = vld1q_s32(bias_ptr);
+ bias_ptr += 4;
- int32x4_t acc_0 = vdupq_n_s32(0);
- int32x4_t acc_1 = vdupq_n_s32(0);
- int32x4_t acc_2 = vdupq_n_s32(0);
- int32x4_t acc_3 = vdupq_n_s32(0);
- int in = 0;
- const int kReadAhead = 256;
- // Handle 16 levels of depth at a time.
- for (; in < input_size; in += 16) {
- int8x16_t weights_val_0 =
- vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(weights_ptr_0)));
- int8x16_t weights_val_1 =
- vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(weights_ptr_1)));
- int8x16_t weights_val_2 =
- vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(weights_ptr_2)));
- int8x16_t weights_val_3 =
- vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(weights_ptr_3)));
- int8x16_t input_val =
- vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(input_data + in)));
- int16x8_t acc16_0 =
- vmull_s8(vget_low_s8(weights_val_0), vget_low_s8(input_val));
- int16x8_t acc16_1 =
- vmull_s8(vget_low_s8(weights_val_1), vget_low_s8(input_val));
- int16x8_t acc16_2 =
- vmull_s8(vget_low_s8(weights_val_2), vget_low_s8(input_val));
- int16x8_t acc16_3 =
- vmull_s8(vget_low_s8(weights_val_3), vget_low_s8(input_val));
- acc16_0 = vmlal_s8(acc16_0, vget_high_s8(weights_val_0),
- vget_high_s8(input_val));
- acc16_1 = vmlal_s8(acc16_1, vget_high_s8(weights_val_1),
- vget_high_s8(input_val));
- acc16_2 = vmlal_s8(acc16_2, vget_high_s8(weights_val_2),
- vget_high_s8(input_val));
- acc16_3 = vmlal_s8(acc16_3, vget_high_s8(weights_val_3),
- vget_high_s8(input_val));
- acc_0 = vpadalq_s16(acc_0, acc16_0);
- acc_1 = vpadalq_s16(acc_1, acc16_1);
- acc_2 = vpadalq_s16(acc_2, acc16_2);
- acc_3 = vpadalq_s16(acc_3, acc16_3);
- weights_ptr_0 += 16;
- weights_ptr_1 += 16;
- weights_ptr_2 += 16;
- weights_ptr_3 += 16;
- optimized_ops_preload_l1_stream(weights_ptr_0 + kReadAhead);
- optimized_ops_preload_l1_stream(weights_ptr_1 + kReadAhead);
- optimized_ops_preload_l1_stream(weights_ptr_2 + kReadAhead);
- optimized_ops_preload_l1_stream(weights_ptr_3 + kReadAhead);
+ // Clear accumulators. We use 2 accumulator registers per row,
+ // for 4 rows. row_accumRN is the N-th accumulator for row R.
+ int32x4_t row_accum00 = vdupq_n_s32(0);
+ int32x4_t row_accum01 = vdupq_n_s32(0);
+ int32x4_t row_accum10 = vdupq_n_s32(0);
+ int32x4_t row_accum11 = vdupq_n_s32(0);
+ int32x4_t row_accum20 = vdupq_n_s32(0);
+ int32x4_t row_accum21 = vdupq_n_s32(0);
+ int32x4_t row_accum30 = vdupq_n_s32(0);
+ int32x4_t row_accum31 = vdupq_n_s32(0);
+
+ // kReadAhead parametrizes how far ahead we prefetch weights into L1 cache.
+ const int kReadAhead = 512;
+ // Prefetch the first weights values.
+ for (int k = 0; k < kReadAhead; k += 64) {
+ optimized_ops_preload_l1_stream(weights_data + (out + 0) * input_size +
+ k);
+ optimized_ops_preload_l1_stream(weights_data + (out + 1) * input_size +
+ k);
+ optimized_ops_preload_l1_stream(weights_data + (out + 2) * input_size +
+ k);
+ optimized_ops_preload_l1_stream(weights_data + (out + 3) * input_size +
+ k);
+ }
+ // Loop along the rows, handling 64 bytes per iteration because that's
+ // cache line size on most current ARM-architecture CPUs.
+ for (int in = 0; in < input_size; in += 64) {
+ // Prefetch some future weights values.
+ optimized_ops_preload_l1_stream(weights_data + (out + 0) * input_size +
+ in + kReadAhead);
+ optimized_ops_preload_l1_stream(weights_data + (out + 1) * input_size +
+ in + kReadAhead);
+ optimized_ops_preload_l1_stream(weights_data + (out + 2) * input_size +
+ in + kReadAhead);
+ optimized_ops_preload_l1_stream(weights_data + (out + 3) * input_size +
+ in + kReadAhead);
+
+ // We will use 2 local 16-bit accumulators per row, for 2 rows.
+ // See below (*) for the rationale of processing only 2 rows at a time.
+ // local_accumRN is the N-th local accumulator for row R.
+ int16x8_t local_accum00;
+ int16x8_t local_accum01;
+ int16x8_t local_accum10;
+ int16x8_t local_accum11;
+
+ // Load 64 bytes of input activations values. Convert to signed int8
+ // by flipping the sign bit (i.e. subtracting 128, the required
+ // zero_point value).
+ int8x16_t input0 = vreinterpretq_s8_u8(
+ veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 0)));
+ int8x16_t input1 = vreinterpretq_s8_u8(
+ veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 1)));
+ int8x16_t input2 = vreinterpretq_s8_u8(
+ veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 2)));
+ int8x16_t input3 = vreinterpretq_s8_u8(
+ veorq_u8(signbit, vld1q_u8(input_data + in + 16 * 3)));
+
+ // Beginning of the core accumulation. Notice how while we have 4
+ // rows to process, this code is taking care of only 2 rows at a time,
+ // thus being divided into two parts looking similar ("Rows 0 and 1" and
+ // "Rows 2 and 3").
+ //
+ // (*) The rationale for handling only 2 rows at a time is to avoid
+ // cache aliasing issues on 4-way set-associative L1-cache CPUs, such
+ // as Cortex-A53. With sufficiently large, power-of-two matrix dimensions,
+ // we may find ourselves in a situation where rows alias each other in
+ // the L1 cache, and moreover may also mutually alias with the input
+ // activations. If we try to load 4 rows at a time, together with the
+ // input activations, that may be 5 mutually-aliasing vectors, resulting
+ // in constant mutual eviction from L1 cache. Handling 2 rows at a time
+ // here largely mitigates these issues, and seems at least to be very
+ // effective on Cortex-A53:
+ // Before After
+ // big (Cortex-A73) 2.85 ms 2.85 ms
+ // little (Cortex-A53) 11.0 ms 5.16 ms
+
+ // Rows 0 and 1:
+ // Load 64 bytes of weights values from each row. Convert to signed int8
+ // by flipping the sign bit (i.e. subtracting 128, the required
+ // zero_point value).
+ int8x16_t weights00 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 0)));
+ int8x16_t weights01 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 1)));
+ int8x16_t weights02 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 2)));
+ int8x16_t weights03 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 0) * input_size + in + 16 * 3)));
+ int8x16_t weights10 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 0)));
+ int8x16_t weights11 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 1)));
+ int8x16_t weights12 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 2)));
+ int8x16_t weights13 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 1) * input_size + in + 16 * 3)));
+ // Multiply-accumulate into local 16-bit accumulators.
+ // We can accumulate two products without overflow because weights are
+ // required to never be -128, so each product is at most 127^2 in absolute
+ // value.
+ local_accum00 = vmull_s8(vget_low_s8(weights00), vget_low_s8(input0));
+ local_accum01 = vmull_s8(vget_low_s8(weights01), vget_low_s8(input1));
+ local_accum10 = vmull_s8(vget_low_s8(weights10), vget_low_s8(input0));
+ local_accum11 = vmull_s8(vget_low_s8(weights11), vget_low_s8(input1));
+ local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights00),
+ vget_high_s8(input0));
+ local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights01),
+ vget_high_s8(input1));
+ local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights10),
+ vget_high_s8(input0));
+ local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights11),
+ vget_high_s8(input1));
+ // Pairwise add and accumulate into 32-bit accumulators
+ row_accum00 = vpadalq_s16(row_accum00, local_accum00);
+ row_accum01 = vpadalq_s16(row_accum01, local_accum01);
+ row_accum10 = vpadalq_s16(row_accum10, local_accum10);
+ row_accum11 = vpadalq_s16(row_accum11, local_accum11);
+ // Multiply-accumulate into local 16-bit accumulators.
+ // We can accumulate two products without overflow because weights are
+ // required to never be -128, so each product is at most 127^2 in absolute
+ // value.
+ local_accum00 = vmull_s8(vget_low_s8(weights02), vget_low_s8(input2));
+ local_accum01 = vmull_s8(vget_low_s8(weights03), vget_low_s8(input3));
+ local_accum10 = vmull_s8(vget_low_s8(weights12), vget_low_s8(input2));
+ local_accum11 = vmull_s8(vget_low_s8(weights13), vget_low_s8(input3));
+ local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights02),
+ vget_high_s8(input2));
+ local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights03),
+ vget_high_s8(input3));
+ local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights12),
+ vget_high_s8(input2));
+ local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights13),
+ vget_high_s8(input3));
+ // Pairwise add and accumulate into 32-bit accumulators
+ row_accum00 = vpadalq_s16(row_accum00, local_accum00);
+ row_accum01 = vpadalq_s16(row_accum01, local_accum01);
+ row_accum10 = vpadalq_s16(row_accum10, local_accum10);
+ row_accum11 = vpadalq_s16(row_accum11, local_accum11);
+
+ // Rows 2 and 3:
+ // Load 64 bytes of weights values from each row. Convert to signed int8
+ // by flipping the sign bit (i.e. subtracting 128, the required
+ // zero_point value).
+ weights00 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 0)));
+ weights01 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 1)));
+ weights02 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 2)));
+ weights03 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 2) * input_size + in + 16 * 3)));
+ weights10 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 0)));
+ weights11 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 1)));
+ weights12 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 2)));
+ weights13 = vreinterpretq_s8_u8(veorq_u8(
+ signbit,
+ vld1q_u8(weights_data + (out + 3) * input_size + in + 16 * 3)));
+ // Multiply-accumulate into local 16-bit accumulators.
+ // We can accumulate two products without overflow because weights are
+ // required to never be -128, so each product is at most 127^2 in absolute
+ // value.
+ local_accum00 = vmull_s8(vget_low_s8(weights00), vget_low_s8(input0));
+ local_accum01 = vmull_s8(vget_low_s8(weights01), vget_low_s8(input1));
+ local_accum10 = vmull_s8(vget_low_s8(weights10), vget_low_s8(input0));
+ local_accum11 = vmull_s8(vget_low_s8(weights11), vget_low_s8(input1));
+ local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights00),
+ vget_high_s8(input0));
+ local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights01),
+ vget_high_s8(input1));
+ local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights10),
+ vget_high_s8(input0));
+ local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights11),
+ vget_high_s8(input1));
+ // Pairwise add and accumulate into 32-bit accumulators
+ row_accum20 = vpadalq_s16(row_accum20, local_accum00);
+ row_accum21 = vpadalq_s16(row_accum21, local_accum01);
+ row_accum30 = vpadalq_s16(row_accum30, local_accum10);
+ row_accum31 = vpadalq_s16(row_accum31, local_accum11);
+ // Multiply-accumulate into local 16-bit accumulators.
+ // We can accumulate two products without overflow because weights are
+ // required to never be -128, so each product is at most 127^2 in absolute
+ // value.
+ local_accum00 = vmull_s8(vget_low_s8(weights02), vget_low_s8(input2));
+ local_accum01 = vmull_s8(vget_low_s8(weights03), vget_low_s8(input3));
+ local_accum10 = vmull_s8(vget_low_s8(weights12), vget_low_s8(input2));
+ local_accum11 = vmull_s8(vget_low_s8(weights13), vget_low_s8(input3));
+ local_accum00 = vmlal_s8(local_accum00, vget_high_s8(weights02),
+ vget_high_s8(input2));
+ local_accum01 = vmlal_s8(local_accum01, vget_high_s8(weights03),
+ vget_high_s8(input3));
+ local_accum10 = vmlal_s8(local_accum10, vget_high_s8(weights12),
+ vget_high_s8(input2));
+ local_accum11 = vmlal_s8(local_accum11, vget_high_s8(weights13),
+ vget_high_s8(input3));
+ // Pairwise add and accumulate into 32-bit accumulators
+ row_accum20 = vpadalq_s16(row_accum20, local_accum00);
+ row_accum21 = vpadalq_s16(row_accum21, local_accum01);
+ row_accum30 = vpadalq_s16(row_accum30, local_accum10);
+ row_accum31 = vpadalq_s16(row_accum31, local_accum11);
}
+
+ row_accum00 = vaddq_s32(row_accum00, row_accum01);
+ row_accum10 = vaddq_s32(row_accum10, row_accum11);
+ row_accum20 = vaddq_s32(row_accum20, row_accum21);
+ row_accum30 = vaddq_s32(row_accum30, row_accum31);
// Horizontally reduce accumulators
int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
pairwise_reduced_acc_2, pairwise_reduced_acc_3;
pairwise_reduced_acc_0 =
- vpadd_s32(vget_low_s32(acc_0), vget_high_s32(acc_0));
+ vpadd_s32(vget_low_s32(row_accum00), vget_high_s32(row_accum00));
pairwise_reduced_acc_1 =
- vpadd_s32(vget_low_s32(acc_1), vget_high_s32(acc_1));
+ vpadd_s32(vget_low_s32(row_accum10), vget_high_s32(row_accum10));
pairwise_reduced_acc_2 =
- vpadd_s32(vget_low_s32(acc_2), vget_high_s32(acc_2));
+ vpadd_s32(vget_low_s32(row_accum20), vget_high_s32(row_accum20));
pairwise_reduced_acc_3 =
- vpadd_s32(vget_low_s32(acc_3), vget_high_s32(acc_3));
+ vpadd_s32(vget_low_s32(row_accum30), vget_high_s32(row_accum30));
const int32x2_t reduced_lo =
vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
const int32x2_t reduced_hi =
vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
// Add bias values.
- int32x4_t bias_vec = vld1q_s32(bias_ptr);
- bias_ptr += 4;
reduced = vaddq_s32(reduced, bias_vec);
- int left_shift = accum_shift > 0 ? accum_shift : 0;
- int right_shift = accum_shift > 0 ? 0 : -accum_shift;
reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
// Multiply by the fixed-point multiplier.
reduced = vqrdmulhq_n_s32(reduced, accum_multiplier);
@@ -962,7 +1135,7 @@ inline void FullyConnected(
#ifdef GEMMLOWP_NEON
if (batches == 1 && input_offset == -128 && output_activation_min == -32768 &&
output_activation_max == 32767) {
- if (filter_offset == -128 && !(output_depth % 4) && !(accum_depth % 16)) {
+ if (filter_offset == -128 && !(output_depth % 4) && !(accum_depth % 64)) {
GEMVForLstmCellWithSymmetricRange(input_data, input_dims, filter_data,
filter_dims, bias_data_int32, bias_dims,
output_multiplier, -output_shift,
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 8045052452..f919517e93 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -17,10 +17,9 @@
Usage:
-generate_examples <output directory> zipped
+generate_examples <output directory>
bazel run //tensorflow/contrib/lite/testing:generate_examples
- third_party/tensorflow/contrib/lite/testing/generated_examples zipped
"""
from __future__ import absolute_import
from __future__ import division
@@ -52,8 +51,6 @@ from tensorflow.python.ops import rnn
parser = argparse.ArgumentParser(description="Script to generate TFLite tests.")
parser.add_argument("output_path",
help="Directory where the outputs will be go.")
-# TODO(ahentz): remove this flag
-parser.add_argument("type", help="zipped")
parser.add_argument("--zip_to_output",
type=str,
help="Particular zip to output.",
@@ -543,6 +540,18 @@ def make_pool_tests(pool_op_in):
return f
+def make_l2_pool_tests(zip_path):
+ make_pool_tests(make_l2_pool)(zip_path)
+
+
+def make_avg_pool_tests(zip_path):
+ make_pool_tests(tf.nn.avg_pool)(zip_path)
+
+
+def make_max_pool_tests(zip_path):
+ make_pool_tests(tf.nn.max_pool)(zip_path)
+
+
def make_relu_tests(zip_path):
"""Make a set of tests to do relu."""
@@ -902,6 +911,22 @@ def make_binary_op_tests_func(binary_operator):
return lambda zip_path: make_binary_op_tests(zip_path, binary_operator)
+def make_add_tests(zip_path):
+ make_binary_op_tests(zip_path, tf.add)
+
+
+def make_div_tests(zip_path):
+ make_binary_op_tests(zip_path, tf.div)
+
+
+def make_sub_tests(zip_path):
+ make_binary_op_tests(zip_path, tf.subtract)
+
+
+def make_mul_tests(zip_path):
+ make_binary_op_tests(zip_path, tf.multiply)
+
+
def make_gather_tests(zip_path):
"""Make a set of tests to do gather."""
@@ -1169,7 +1194,7 @@ def make_split_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
-def make_concatenation_tests(zip_path):
+def make_concat_tests(zip_path):
"""Make a set of tests to do concatenation."""
test_parameters = [{
@@ -1966,69 +1991,26 @@ def main(unused_args):
if not os.path.isdir(x):
raise RuntimeError("Failed to create dir %r" % x)
- if FLAGS.type == "zipped":
- opstest_path = os.path.join(FLAGS.output_path)
- mkdir_if_not_exist(opstest_path)
- def _path(filename):
- return os.path.join(opstest_path, filename)
-
- dispatch = {
- "control_dep.zip": make_control_dep_tests,
- "add.zip": make_binary_op_tests_func(tf.add),
- "space_to_batch_nd.zip": make_space_to_batch_nd_tests,
- "div.zip": make_binary_op_tests_func(tf.div),
- "sub.zip": make_binary_op_tests_func(tf.subtract),
- "batch_to_space_nd.zip": make_batch_to_space_nd_tests,
- "conv.zip": make_conv_tests,
- "constant.zip": make_constant_tests,
- "depthwiseconv.zip": make_depthwiseconv_tests,
- "concat.zip": make_concatenation_tests,
- "fully_connected.zip": make_fully_connected_tests,
- "global_batch_norm.zip": make_global_batch_norm_tests,
- "gather.zip": make_gather_tests,
- "fused_batch_norm.zip": make_fused_batch_norm_tests,
- "l2norm.zip": make_l2norm_tests,
- "local_response_norm.zip": make_local_response_norm_tests,
- "mul.zip": make_binary_op_tests_func(tf.multiply),
- "relu.zip": make_relu_tests,
- "relu1.zip": make_relu1_tests,
- "relu6.zip": make_relu6_tests,
- "prelu.zip": make_prelu_tests,
- "l2_pool.zip": make_pool_tests(make_l2_pool),
- "avg_pool.zip": make_pool_tests(tf.nn.avg_pool),
- "max_pool.zip": make_pool_tests(tf.nn.max_pool),
- "pad.zip": make_pad_tests,
- "reshape.zip": make_reshape_tests,
- "resize_bilinear.zip": make_resize_bilinear_tests,
- "sigmoid.zip": make_sigmoid_tests,
- "softmax.zip": make_softmax_tests,
- "space_to_depth.zip": make_space_to_depth_tests,
- "topk.zip": make_topk_tests,
- "split.zip": make_split_tests,
- "transpose.zip": make_transpose_tests,
- "mean.zip": make_mean_tests,
- "squeeze.zip": make_squeeze_tests,
- "strided_slice.zip": make_strided_slice_tests,
- "exp.zip": make_exp_tests,
- "log_softmax.zip": make_log_softmax_tests,
- "lstm.zip": make_lstm_tests,
- "maximum.zip": make_maximum_tests,
- }
- out = FLAGS.zip_to_output
- bin_path = FLAGS.toco
- if out in dispatch:
- dispatch[out](_path(out))
- else:
- raise RuntimeError("Invalid zip to output %r" % out)
+ opstest_path = os.path.join(FLAGS.output_path)
+ mkdir_if_not_exist(opstest_path)
- else:
- raise RuntimeError("Invalid argument for type of generation.")
+ out = FLAGS.zip_to_output
+ bin_path = FLAGS.toco
+ test_function = ("make_%s_tests" % out.replace(".zip", ""))
+ if test_function not in globals():
+ raise RuntimeError("Can't find a test function to create %r. Tried %r" %
+ (out, test_function))
+
+ # TODO(ahentz): accessing globals() is not very elegant. We should either
+ # break this file into multiple tests or use decorator-based registration to
+ # avoid using globals().
+ globals()[test_function](os.path.join(opstest_path, out))
if __name__ == "__main__":
FLAGS, unparsed = parser.parse_known_args()
if unparsed:
- print("Usage: %s <path out> zipped <zip file to generate>")
+ print("Usage: %s <path out> <zip file to generate>")
else:
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 4a77196aab..4a85f3c5a4 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -704,6 +704,15 @@ void ConvertRelu6Operator(const Relu6Operator& src_op,
(*relu_op->mutable_attr())["T"].set_type(DT_FLOAT);
}
+void ConvertLogOperator(const LogOperator& src_op, GraphDef* tensorflow_graph) {
+ auto* op = tensorflow_graph->add_node();
+ op->set_op("Log");
+ op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 1);
+ *op->add_input() = src_op.inputs[0];
+ (*op->mutable_attr())["T"].set_type(DT_FLOAT);
+}
+
void ConvertLogisticOperator(const LogisticOperator& src_op,
GraphDef* tensorflow_graph) {
auto* relu_op = tensorflow_graph->add_node();
@@ -1703,6 +1712,9 @@ void ConvertOperator(const Model& model, const Operator& src_op,
} else if (src_op.type == OperatorType::kRelu6) {
ConvertRelu6Operator(static_cast<const Relu6Operator&>(src_op),
tensorflow_graph);
+ } else if (src_op.type == OperatorType::kLog) {
+ ConvertLogOperator(static_cast<const LogOperator&>(src_op),
+ tensorflow_graph);
} else if (src_op.type == OperatorType::kLogistic) {
ConvertLogisticOperator(static_cast<const LogisticOperator&>(src_op),
tensorflow_graph);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 68d6f21cf8..a648b770f8 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1479,6 +1479,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kPRelu:
case OperatorType::kSoftmax:
case OperatorType::kLogSoftmax:
+ case OperatorType::kLog:
case OperatorType::kLogistic:
case OperatorType::kTanh:
case OperatorType::kLocalResponseNormalization:
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
index d4db6f1c00..f6c8f79d8d 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
@@ -51,6 +51,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
// Test for unary ops of types that we know how to resolve.
switch (unary_op->type) {
case OperatorType::kCast:
+ case OperatorType::kLog:
case OperatorType::kNeg:
case OperatorType::kTensorFlowRsqrt:
case OperatorType::kTensorFlowSqrt:
@@ -218,6 +219,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
}
output_float_data[0] = max;
} else if (unary_op->type == OperatorType::kNeg ||
+ unary_op->type == OperatorType::kLog ||
unary_op->type == OperatorType::kTensorFlowRsqrt ||
unary_op->type == OperatorType::kTensorFlowSqrt ||
unary_op->type == OperatorType::kTensorFlowSquare) {
@@ -231,6 +233,8 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
float outval = 0.f;
if (unary_op->type == OperatorType::kNeg) {
outval = -val;
+ } else if (unary_op->type == OperatorType::kLog) {
+ outval = std::log(val);
} else if (unary_op->type == OperatorType::kTensorFlowRsqrt) {
outval = 1.0f / std::sqrt(val);
} else if (unary_op->type == OperatorType::kTensorFlowSqrt) {
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 876479079b..6b62eeb638 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -611,6 +611,18 @@ void ConvertRelu6Operator(const NodeDef& node,
model->operators.emplace_back(op);
}
+void ConvertLogOperator(const NodeDef& node,
+ const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), "Log");
+ CheckInputsCount(node, tf_import_flags, 1);
+
+ auto op = absl::make_unique<LogOperator>();
+ op->inputs.push_back(node.input(0));
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(std::move(op));
+}
+
void ConvertLogisticOperator(const NodeDef& node,
const TensorFlowImportFlags& tf_import_flags,
Model* model) {
@@ -2091,6 +2103,8 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
ConvertLRNOperator(node, tf_import_flags, model);
} else if (node.op() == "Softmax") {
ConvertSoftmaxOperator(node, tf_import_flags, model);
+ } else if (node.op() == "Log") {
+ ConvertLogOperator(node, tf_import_flags, model);
} else if (node.op() == "LogSoftmax") {
ConvertLogSoftmaxOperator(node, tf_import_flags, model);
} else if (node.op() == "All") {
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 9bd72e7de1..56ef9fe2a8 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -56,6 +56,7 @@ enum class OperatorType {
kL2Pool,
kLstmCell,
kLocalResponseNormalization,
+ kLog,
kLogistic,
kMaxPool,
kFakeQuant,
@@ -591,6 +592,17 @@ struct LogisticOperator : Operator {
LogisticOperator() : Operator(OperatorType::kLogistic) {}
};
+// Element-wise natural log operator:
+// x -> ln(x)
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: Log
+struct LogOperator : Operator {
+ LogOperator() : Operator(OperatorType::kLog) {}
+};
+
// Element-wise Tanh operator:
// x -> Tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
//
diff --git a/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py b/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py
index c35b6f9925..3761e0095e 100644
--- a/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py
+++ b/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py
@@ -50,6 +50,7 @@ class TocoFromProtosTest(googletest.TestCase):
toco_flags.output_format = toco_flags_pb2.TFLITE
toco_flags.inference_input_type = types_pb2.FLOAT
toco_flags.inference_type = types_pb2.FLOAT
+ toco_flags.allow_custom_ops = True;
model_flags = model_flags_pb2.ModelFlags()
input_array = model_flags.input_arrays.add()
input_array.name = TensorName(in_tensor)
diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.cc b/tensorflow/contrib/lite/toco/python/toco_python_api.cc
index 8a5e483f3f..153c117d17 100644
--- a/tensorflow/contrib/lite/toco/python/toco_python_api.cc
+++ b/tensorflow/contrib/lite/toco/python/toco_python_api.cc
@@ -75,7 +75,8 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
toco::Import(toco_flags, model_flags, input_contents_txt);
toco::Transform(toco_flags, model.get());
string output_file_contents_txt;
- Export(toco_flags, *model, &output_file_contents_txt);
+ Export(toco_flags, *model, toco_flags.allow_custom_ops(),
+ &output_file_contents_txt);
// Convert arguments back to byte (py3) or str (py2)
return TOCO_FROM_CPPSTRING_TO_PY(output_file_contents_txt.data(),
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index b72f5fa2a7..bd2d5f7df0 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -291,6 +291,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(Dequantize)
HANDLE_OPERATORTYPENAME_CASE(L2Normalization)
HANDLE_OPERATORTYPENAME_CASE(LocalResponseNormalization)
+ HANDLE_OPERATORTYPENAME_CASE(Log)
HANDLE_OPERATORTYPENAME_CASE(Logistic)
HANDLE_OPERATORTYPENAME_CASE(LstmCell)
HANDLE_OPERATORTYPENAME_CASE(MaxPool)
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index b6acf71b9d..d4c3f2eda8 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -151,6 +151,7 @@ tensorflow/core/kernels/decode_bmp_op.cc
tensorflow/core/kernels/depthtospace_op.cc
tensorflow/core/kernels/data_format_ops.cc
tensorflow/core/kernels/spacetodepth_op.cc
+tensorflow/core/kernels/dense_update_functor.cc
tensorflow/core/kernels/dense_update_ops.cc
tensorflow/core/kernels/deep_conv2d.cc
tensorflow/core/kernels/decode_wav_op.cc
@@ -301,3 +302,5 @@ tensorflow/core/kernels/warn_about_ints.cc
tensorflow/core/kernels/segment_reduction_ops.cc
tensorflow/core/kernels/batch_util.cc
tensorflow/core/ops/audio_ops.cc
+tensorflow/core/kernels/decode_proto_op.cc
+tensorflow/core/kernels/encode_proto_op.cc
diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.cc b/tensorflow/contrib/nccl/kernels/nccl_manager.cc
index 913935b382..b9b482a698 100644
--- a/tensorflow/contrib/nccl/kernels/nccl_manager.cc
+++ b/tensorflow/contrib/nccl/kernels/nccl_manager.cc
@@ -76,6 +76,8 @@ struct NcclManager::Communicator {
namespace {
ncclDataType_t ToNcclType(DataType t) {
switch (t) {
+ case DT_HALF:
+ return ncclHalf;
case DT_FLOAT:
return ncclFloat;
case DT_DOUBLE:
diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc
index 985b2bae25..06ca65e33a 100644
--- a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc
+++ b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc
@@ -48,35 +48,9 @@ static std::vector<BaseGPUDevice*> GetGPUDevices() {
return gpus;
}
+template <typename Scalar>
class NcclManagerTest : public ::testing::Test {
- protected:
- static void SetUpTestCase() {
- setenv("NCCL_DEBUG", "INFO", 1 /* replace */);
- devices = new std::vector<BaseGPUDevice*>(GetGPUDevices());
- CHECK(!devices->empty());
- LOG(ERROR) << "Running test with " << devices->size() << " gpus";
- }
- static void TearDownTestCase() {
- for (auto device : *devices) delete device;
- delete devices;
- }
-
- static Allocator* gpu_allocator(BaseGPUDevice* device) {
- return device->GetStepAllocator(AllocatorAttributes(),
- nullptr /* step_resource_manager */);
- }
-
- static std::vector<BaseGPUDevice*>* devices;
-
- template <typename Scalar>
- perftools::gputools::DeviceMemory<Scalar> AsDeviceMemory(
- const Scalar* cuda_memory) {
- perftools::gputools::DeviceMemoryBase wrapped(
- const_cast<Scalar*>(cuda_memory));
- perftools::gputools::DeviceMemory<Scalar> typed(wrapped);
- return typed;
- }
-
+ public:
// A single all-reduce to apply.
struct TestCase {
string key;
@@ -89,42 +63,52 @@ class NcclManagerTest : public ::testing::Test {
int num_completed = 0;
};
+ static void SetUpTestCase() {
+ setenv("NCCL_DEBUG", "INFO", 1 /* replace */);
+ devices_ = new std::vector<BaseGPUDevice*>(GetGPUDevices());
+ CHECK(!devices_->empty());
+ LOG(ERROR) << "Running test with " << devices_->size() << " gpus";
+ }
+
+ static void TearDownTestCase() {
+ for (auto device : *devices_) delete device;
+ delete devices_;
+ }
+
TestCase* MakeTestCase(int num_ranks, ncclRedOp_t reduction_op,
TensorShape shape, float value_offset) {
TestCase* test_case = new TestCase();
- test_case->expected = Tensor(DT_FLOAT, shape);
+ test_case->expected = Tensor(data_type_, shape);
if (reduction_op == ncclProd) {
- test::FillFn<float>(&test_case->expected, [](int) { return 1; });
+ test::FillFn<Scalar>(&test_case->expected,
+ [](int) { return static_cast<Scalar>(1); });
} else if (reduction_op == ncclSum) {
- test::FillFn<float>(&test_case->expected, [](int) { return 0; });
+ test::FillFn<Scalar>(&test_case->expected,
+ [](int) { return static_cast<Scalar>(0); });
} else if (reduction_op == ncclMax) {
- test::FillFn<float>(&test_case->expected, [](int) {
- return -1 * std::numeric_limits<float>::max();
- });
+ test::FillFn<Scalar>(&test_case->expected, [](int) { return -max_; });
} else if (reduction_op == ncclMin) {
- test::FillFn<float>(&test_case->expected, [](int) {
- return std::numeric_limits<float>::max();
- });
+ test::FillFn<Scalar>(&test_case->expected, [](int) { return max_; });
} else {
LOG(FATAL) << "Invalid reduction_op " << reduction_op;
}
- int mult = 1;
- for (int i = 0; i < num_ranks; ++i) {
- auto* device = devices->at(i % devices->size());
+ float value_scale = 0.01; // Small scale to avoid fp16 overflow.
+ for (int rank = 0; rank < num_ranks; ++rank) {
+ auto* device = GetDevice(rank);
auto* stream = device->tensorflow_gpu_device_info()->stream;
- Tensor in_cpu(DT_FLOAT, shape);
- test::FillFn<float>(&in_cpu, [mult, value_offset](int index) {
- return value_offset + (index + 1) * mult;
+ Tensor in_cpu(data_type_, shape);
+ test::FillFn<Scalar>(&in_cpu, [&](int index) {
+ return static_cast<Scalar>((index + 1) * value_scale + value_offset);
});
for (int j = 0; j < shape.num_elements(); ++j) {
- auto in_val = in_cpu.flat<float>()(j);
- auto out_expr = test_case->expected.flat<float>();
+ auto in_val = in_cpu.flat<Scalar>()(j);
+ auto out_expr = test_case->expected.template flat<Scalar>();
if (reduction_op == ncclProd) {
- out_expr(j) *= in_val;
+ out_expr(j) = out_expr(j) * in_val;
} else if (reduction_op == ncclSum) {
- out_expr(j) += in_val;
+ out_expr(j) = out_expr(j) + in_val;
} else if (reduction_op == ncclMax) {
if (in_val > out_expr(j)) {
out_expr(j) = in_val;
@@ -136,26 +120,18 @@ class NcclManagerTest : public ::testing::Test {
}
}
- mult *= 10;
- test_case->ins.emplace_back(gpu_allocator(device), DT_FLOAT, shape);
- test_case->outs.emplace_back(gpu_allocator(device), DT_FLOAT, shape);
+ value_scale *= 10;
+ test_case->ins.emplace_back(GpuAllocator(device), data_type_, shape);
+ test_case->outs.emplace_back(GpuAllocator(device), data_type_, shape);
const Tensor& in_gpu = test_case->ins.back();
- auto in_gpu_mem = AsDeviceMemory(in_gpu.flat<float>().data());
- stream->ThenMemcpy(&in_gpu_mem, in_cpu.flat<float>().data(),
+ auto in_gpu_mem = AsDeviceMemory(in_gpu.flat<Scalar>().data());
+ stream->ThenMemcpy(&in_gpu_mem, in_cpu.flat<Scalar>().data(),
in_cpu.TotalBytes());
}
return test_case;
}
- NcclManager::DoneCallback CreateDoneCallback(TestCase* test_case) {
- return [this, test_case](Status s) {
- mutex_lock l(test_case->mu);
- ++test_case->num_completed;
- test_case->final_status.Update(s);
- };
- }
-
void VerifyResults(const string& case_label, TestCase* test_case) {
// Wait for the done callback to be called.
{
@@ -168,41 +144,84 @@ class NcclManagerTest : public ::testing::Test {
test_case->mu.unlock();
}
// Copy memory to host and verify.
- for (int i = 0; i < test_case->outs.size(); ++i) {
- auto* device = devices->at(i % devices->size());
+ for (int rank = 0; rank < test_case->outs.size(); ++rank) {
+ auto* device = GetDevice(rank);
auto* stream = device->tensorflow_gpu_device_info()->stream;
- const Tensor& out_gpu = test_case->outs[i];
- Tensor out_cpu(DT_FLOAT, out_gpu.shape());
- auto out_gpu_mem = AsDeviceMemory(out_gpu.flat<float>().data());
- stream->ThenMemcpy(out_cpu.flat<float>().data(), out_gpu_mem,
+ const Tensor& out_gpu = test_case->outs[rank];
+ Tensor out_cpu(data_type_, out_gpu.shape());
+ auto out_gpu_mem = AsDeviceMemory(out_gpu.flat<Scalar>().data());
+ stream->ThenMemcpy(out_cpu.flat<Scalar>().data(), out_gpu_mem,
out_cpu.TotalBytes());
SE_ASSERT_OK(stream->BlockHostUntilDone());
- test::ExpectTensorEqual<float>(test_case->expected, out_cpu);
+ test::ExpectTensorNear<Scalar>(test_case->expected, out_cpu, 0.01);
}
}
+
+ NcclManager::DoneCallback CreateDoneCallback(TestCase* test_case) {
+ return [this, test_case](Status s) {
+ mutex_lock l(test_case->mu);
+ ++test_case->num_completed;
+ test_case->final_status.Update(s);
+ };
+ }
+
+ static BaseGPUDevice* GetDevice(size_t rank) {
+ return devices_->at(rank % devices_->size());
+ }
+
+ private:
+ static Allocator* GpuAllocator(BaseGPUDevice* device) {
+ return device->GetStepAllocator(AllocatorAttributes(),
+ nullptr /* step_resource_manager */);
+ }
+
+ static perftools::gputools::DeviceMemory<Scalar> AsDeviceMemory(
+ const Scalar* cuda_memory) {
+ perftools::gputools::DeviceMemoryBase wrapped(
+ const_cast<Scalar*>(cuda_memory));
+ perftools::gputools::DeviceMemory<Scalar> typed(wrapped);
+ return typed;
+ }
+
+ private:
+ static std::vector<BaseGPUDevice*>* devices_;
+ static const DataType data_type_;
+ static const Scalar max_;
};
-std::vector<BaseGPUDevice*>* NcclManagerTest::devices = nullptr;
+
+template <typename Scalar>
+std::vector<BaseGPUDevice*>* NcclManagerTest<Scalar>::devices_ = nullptr;
+template <typename Scalar>
+const DataType NcclManagerTest<Scalar>::data_type_ =
+ DataTypeToEnum<Scalar>::value;
+template <typename Scalar>
+const Scalar NcclManagerTest<Scalar>::max_ =
+ Eigen::NumTraits<Scalar>::highest();
+
+// Instantiate tests for float and half.
+using TypeList = ::testing::Types<float, Eigen::half>;
+TYPED_TEST_CASE(NcclManagerTest, TypeList);
// Test basic sum reduction.
-TEST_F(NcclManagerTest, BasicSumReduction) {
+TYPED_TEST(NcclManagerTest, BasicSumReduction) {
const int num_ranks = 3;
for (int op = 0; op < 4; ++op) {
ncclRedOp_t reduction_op = static_cast<ncclRedOp_t>(op);
- std::unique_ptr<TestCase> test_case(
- MakeTestCase(num_ranks, reduction_op, TensorShape({2, 3}), 0));
- for (int device_num = 0; device_num < num_ranks; ++device_num) {
- auto* device = devices->at(device_num % devices->size());
+ std::unique_ptr<typename TestFixture::TestCase> test_case(
+ this->MakeTestCase(num_ranks, reduction_op, TensorShape({2, 3}), 0.0f));
+ for (int rank = 0; rank < num_ranks; ++rank) {
+ auto* device = this->GetDevice(rank);
auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr;
auto* stream = device->tensorflow_gpu_device_info()->stream;
NcclManager::instance()->AddToAllReduce(
num_ranks, "allreduce", reduction_op, device->executor(),
- device->gpu_id(), event_mgr, stream, &test_case->ins[device_num],
- &test_case->outs[device_num], CreateDoneCallback(test_case.get()));
+ device->gpu_id(), event_mgr, stream, &test_case->ins[rank],
+ &test_case->outs[rank], this->CreateDoneCallback(test_case.get()));
}
LOG(ERROR) << "Verifying results";
- VerifyResults("test_case", test_case.get());
+ this->VerifyResults("test_case", test_case.get());
}
}
@@ -213,7 +232,7 @@ TEST_F(NcclManagerTest, BasicSumReduction) {
// with num_ranks > devices->size(), for some GPUs (e.g. K20m).
// To test the higher settings, increase num_ranks,
// num_collectives_per_iteration and time_limit_micros.
-TEST_F(NcclManagerTest, MultipleCallers) {
+TYPED_TEST(NcclManagerTest, MultipleCallers) {
const int num_ranks = 1; // 2;
const int num_collectives_per_iteration = 1; // 1000;
const int num_threads = 3;
@@ -223,49 +242,49 @@ TEST_F(NcclManagerTest, MultipleCallers) {
srand(Env::Default()->NowMicros());
for (;;) {
- std::vector<std::pair<int, int>> case_and_device_num;
- std::vector<std::unique_ptr<TestCase>> test_cases;
+ std::vector<std::pair<int, int>> case_and_rank;
+ std::vector<std::unique_ptr<typename TestFixture::TestCase>> test_cases;
for (int i = 0; i < num_collectives_per_iteration; ++i) {
- test_cases.emplace_back(
- MakeTestCase(num_ranks, ncclSum,
- TensorShape({100, i % 5 + 1, i % 3 + 1}), i + 0.1 * i));
+ test_cases.emplace_back(this->MakeTestCase(
+ num_ranks, ncclSum, TensorShape({100, i % 5 + 1, i % 3 + 1}),
+ 1.1f * i));
for (int j = 0; j < num_ranks; ++j) {
- case_and_device_num.emplace_back(i, j);
+ case_and_rank.emplace_back(i, j);
}
}
- for (int i = 0; i < num_ranks; ++i) {
- auto* device = devices->at(i % devices->size());
+ for (int rank = 0; rank < num_ranks; ++rank) {
+ auto* device = this->GetDevice(rank);
auto* stream = device->tensorflow_gpu_device_info()->stream;
SE_ASSERT_OK(stream->BlockHostUntilDone());
}
- std::shuffle(case_and_device_num.begin(), case_and_device_num.end(),
+ std::shuffle(case_and_rank.begin(), case_and_rank.end(),
std::mt19937(std::random_device()()));
- mutex mu; // guards case_and_device_num.
+ mutex mu; // guards case_and_rank.
std::unique_ptr<thread::ThreadPool> pool(
new thread::ThreadPool(Env::Default(), "test", num_threads));
- const int to_schedule = case_and_device_num.size();
+ const int to_schedule = case_and_rank.size();
for (int i = 0; i < to_schedule; ++i) {
auto fn = [&]() {
- int device_num;
+ int rank;
int test_num;
{
mutex_lock l(mu);
- test_num = case_and_device_num.back().first;
- device_num = case_and_device_num.back().second;
- case_and_device_num.pop_back();
+ test_num = case_and_rank.back().first;
+ rank = case_and_rank.back().second;
+ case_and_rank.pop_back();
}
- auto* device = devices->at(device_num % devices->size());
+ auto* device = this->GetDevice(rank);
auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr;
auto* stream = device->tensorflow_gpu_device_info()->stream;
- TestCase* test_case = test_cases[test_num].get();
+ typename TestFixture::TestCase* test_case = test_cases[test_num].get();
NcclManager::instance()->AddToAllReduce(
num_ranks, strings::StrCat("allreduce", test_num), ncclSum,
device->executor(), device->gpu_id(), event_mgr, stream,
- &test_case->ins[device_num], &test_case->outs[device_num],
- CreateDoneCallback(test_case));
+ &test_case->ins[rank], &test_case->outs[rank],
+ this->CreateDoneCallback(test_case));
};
pool->Schedule(fn);
}
@@ -274,7 +293,8 @@ TEST_F(NcclManagerTest, MultipleCallers) {
LOG(ERROR) << "Verifying results for " << num_collectives_per_iteration
<< " collectives";
for (int i = 0; i < test_cases.size(); ++i) {
- VerifyResults(strings::StrCat("collective", i), test_cases[i].get());
+ this->VerifyResults(strings::StrCat("collective", i),
+ test_cases[i].get());
}
int64 delta = Env::Default()->NowMicros() - start;
diff --git a/tensorflow/contrib/nccl/ops/nccl_ops.cc b/tensorflow/contrib/nccl/ops/nccl_ops.cc
index 8eb804c2e9..a353a34b80 100644
--- a/tensorflow/contrib/nccl/ops/nccl_ops.cc
+++ b/tensorflow/contrib/nccl/ops/nccl_ops.cc
@@ -25,7 +25,7 @@ REGISTER_OP("NcclAllReduce")
.Input("input: T")
.Output("data: T")
.Attr("reduction: {'min', 'max', 'prod', 'sum'}")
- .Attr("T: {float, float64, int32, int64}")
+ .Attr("T: {half, float, float64, int32, int64}")
.Attr("num_devices: int")
.Attr("shared_name: string")
.SetIsStateful()
@@ -51,7 +51,7 @@ REGISTER_OP("NcclReduce")
.Input("input: num_devices * T")
.Output("data: T")
.Attr("reduction: {'min', 'max', 'prod', 'sum'}")
- .Attr("T: {float, float64, int32, int64}")
+ .Attr("T: {half, float, float64, int32, int64}")
.Attr("num_devices: int")
.SetIsStateful()
.SetShapeFn(shape_inference::UnchangedShape)
@@ -69,7 +69,7 @@ reduction: the reduction operation to perform.
REGISTER_OP("_NcclReduceSend")
.Input("input: T")
.Attr("reduction: {'min', 'max', 'prod', 'sum'}")
- .Attr("T: {float, float64, int32, int64}")
+ .Attr("T: {half, float, float64, int32, int64}")
.Attr("num_devices: int")
.Attr("shared_name: string")
.SetIsStateful()
@@ -92,7 +92,7 @@ REGISTER_OP("_NcclReduceRecv")
.Input("input: T")
.Output("data: T")
.Attr("reduction: {'min', 'max', 'prod', 'sum'}")
- .Attr("T: {float, float64, int32, int64}")
+ .Attr("T: {half, float, float64, int32, int64}")
.Attr("num_devices: int")
.Attr("shared_name: string")
.SetIsStateful()
@@ -118,7 +118,7 @@ shared_name: Identifier that is shared between ops of the same reduce.
REGISTER_OP("NcclBroadcast")
.Input("input: T")
.Output("output: T")
- .Attr("T: {float, float64, int32, int64}")
+ .Attr("T: {half, float, float64, int32, int64}")
.Attr("shape: shape")
.SetIsStateful()
.SetShapeFn(shape_inference::UnchangedShape)
@@ -135,7 +135,7 @@ shape: The shape of the input tensor.
REGISTER_OP("_NcclBroadcastSend")
.Input("input: T")
- .Attr("T: {float, float64, int32, int64}")
+ .Attr("T: {half, float, float64, int32, int64}")
.Attr("num_devices: int")
.Attr("shared_name: string")
.SetIsStateful()
@@ -157,7 +157,7 @@ shared_name: Identifier that is shared between ops of the same broadcast.
REGISTER_OP("_NcclBroadcastRecv")
.Input("shape: int32")
.Output("output: T")
- .Attr("T: {float, float64, int32, int64}")
+ .Attr("T: {half, float, float64, int32, int64}")
.Attr("num_devices: int")
.Attr("shared_name: string")
.SetIsStateful()
diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py
index 98fe394c5b..423a8689ae 100644
--- a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py
+++ b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py
@@ -72,7 +72,7 @@ class NcclTestCase(test.TestCase):
two.
device_sets: Tuple of virtual devices to run test on.
"""
- for dtype in [np.float32, np.int32, np.int64, np.float64]:
+ for dtype in [np.float16, np.float32, np.int32, np.int64, np.float64]:
# Create session inside outer loop to test use of
# same communicator across multiple sessions.
with self.test_session(use_gpu=True) as sess:
diff --git a/tensorflow/contrib/proto/BUILD b/tensorflow/contrib/proto/BUILD
new file mode 100644
index 0000000000..046652cbc5
--- /dev/null
+++ b/tensorflow/contrib/proto/BUILD
@@ -0,0 +1,16 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+ name = "proto",
+ srcs = [
+ "__init__.py",
+ ],
+ deps = [
+ "//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
+ "//tensorflow/contrib/proto/python/ops:encode_proto_op_py",
+ ],
+)
diff --git a/tensorflow/contrib/proto/__init__.py b/tensorflow/contrib/proto/__init__.py
new file mode 100644
index 0000000000..bc5a49de78
--- /dev/null
+++ b/tensorflow/contrib/proto/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ops and modules related to proto.
+
+@@decode_proto
+@@encode_proto
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.proto.python.ops.decode_proto_op import decode_proto
+from tensorflow.contrib.proto.python.ops.encode_proto_op import encode_proto
+
+from tensorflow.python.util.all_util import remove_undocumented
+remove_undocumented(__name__)
diff --git a/tensorflow/contrib/proto/python/ops/BUILD b/tensorflow/contrib/proto/python/ops/BUILD
new file mode 100644
index 0000000000..f17065477e
--- /dev/null
+++ b/tensorflow/contrib/proto/python/ops/BUILD
@@ -0,0 +1,44 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_gen_op_wrapper_py",
+)
+
+py_library(
+ name = "decode_proto_op_py",
+ srcs = ["decode_proto_op.py"],
+ deps = [
+ ":gen_decode_proto_op_py",
+ "//tensorflow/python:framework_ops",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "gen_decode_proto_op_py",
+ out = "gen_decode_proto_op.py",
+ deps = [
+ "//tensorflow/core:decode_proto_ops_op_lib",
+ ],
+)
+
+py_library(
+ name = "encode_proto_op_py",
+ srcs = ["encode_proto_op.py"],
+ deps = [
+ ":gen_encode_proto_op_py",
+ "//tensorflow/python:framework_ops",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "gen_encode_proto_op_py",
+ out = "gen_encode_proto_op.py",
+ deps = [
+ "//tensorflow/core:encode_proto_ops_op_lib",
+ ],
+)
diff --git a/tensorflow/contrib/proto/python/ops/decode_proto_op.py b/tensorflow/contrib/proto/python/ops/decode_proto_op.py
new file mode 100644
index 0000000000..7dc000ebe4
--- /dev/null
+++ b/tensorflow/contrib/proto/python/ops/decode_proto_op.py
@@ -0,0 +1,25 @@
+# =============================================================================
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+# pylint: disable=wildcard-import,unused-import
+"""Protocol Buffer decoding from tensors."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.proto.python.ops.gen_decode_proto_op import decode_proto_v2 as decode_proto
+from tensorflow.python.framework import ops
+ops.NotDifferentiable("DecodeProtoV2")
diff --git a/tensorflow/contrib/proto/python/ops/encode_proto_op.py b/tensorflow/contrib/proto/python/ops/encode_proto_op.py
new file mode 100644
index 0000000000..ac12198b2e
--- /dev/null
+++ b/tensorflow/contrib/proto/python/ops/encode_proto_op.py
@@ -0,0 +1,25 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+# pylint: disable=wildcard-import,unused-import
+"""Protocol Buffer encoding from tensors."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.proto.python.ops.gen_encode_proto_op import encode_proto
+from tensorflow.python.framework import ops
+
+ops.NotDifferentiable("EncodeProto")
diff --git a/tensorflow/contrib/quantize/README.md b/tensorflow/contrib/quantize/README.md
index 348c824a40..c83623ec94 100644
--- a/tensorflow/contrib/quantize/README.md
+++ b/tensorflow/contrib/quantize/README.md
@@ -2,14 +2,17 @@
tf.contrib.quantize provides tools for transforming graphs to include ops to
model quantization of weights, biases and activations during both training and
-inference. This is done using the
+inference. The details of the transformation implemented in this package is
+described here [1].
+
+This is done using the
[fake quantization op](https://www.tensorflow.org/versions/r0.12/api_docs/python/array_ops/fake_quantization).
-Recent literature has shown that fixed point networks provide comparable
-performance to floating point networks [1]. This is achieved by modeling the
-quantization operation during training in both the forward and backward passes.
+Literature has shown that fixed point networks provide comparable performance to
+floating point networks [2]. This is achieved by modeling the quantization
+operation during training in both the forward and backward passes.
The fake quantization operator achieves this by modeling the quantizer as a pass
-through estimator [2]. Note that during back propagation, the parameters are
+through estimator [3]. Note that during back propagation, the parameters are
updated at high precision as this is needed to ensure sufficient precision in
accumulating tiny adjustments to the parameters. However, for the forward pass,
the parameters and activations are quantized to the desired lower precision.
@@ -61,9 +64,11 @@ These rewrites are an active area of research and experimentation, so the
rewrites and quantized training will likely not work across all models, though
we hope to work towards generalizing these techniques.
+[1] B.Jacob et al., "Quantization and Training of Neural Networks for Efficient
+Integer-Arithmetic-Only Inference", https://arxiv.org/abs/1712.05877
-[1] P.Gysel, "HARDWARE-ORIENTED APPROXIMATION OF CONVOLUTIONAL
+[2] P.Gysel et al., "HARDWARE-ORIENTED APPROXIMATION OF CONVOLUTIONAL
NEURAL NETWORKS", https://arxiv.org/pdf/1604.03168.pdf
-[2] Y.Bengio, "Estimating or Propagating Gradients Through Stochastic Neurons
-for Conditional Computation", https://arxiv.org/abs/1308.3432
+[3] Y.Bengio et al., "Estimating or Propagating Gradients Through Stochastic
+Neurons for Conditional Computation", https://arxiv.org/abs/1308.3432
diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py
index a4f7b1b221..5c0e17dc86 100644
--- a/tensorflow/contrib/quantize/python/quant_ops.py
+++ b/tensorflow/contrib/quantize/python/quant_ops.py
@@ -51,7 +51,6 @@ def LastValueQuantize(inputs,
per_channel=False,
init_min=-6.0,
init_max=6.0,
- updates_collection=ops.GraphKeys.UPDATE_OPS,
vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
name_prefix='LastValueQuant',
reuse=None,
@@ -69,8 +68,6 @@ def LastValueQuantize(inputs,
quantization ranges per output channel.
init_min: a float scalar, the initial value for variable min.
init_max: a float scalar, the initial value for variable max.
- updates_collection: (Optional) collections to collect the update ops for
- computation.
vars_collection: (Optional) collection where to store variables for
quantization interval ends.
name_prefix: name_prefix for created nodes.
@@ -133,7 +130,6 @@ def LastValueQuantize(inputs,
# TFLite requires that 0.0 if always in the [min; max] range.
batch_min = math_ops.minimum(batch_min, 0.0)
assign_min = state_ops.assign(min_var, batch_min, name='AssignMinLast')
- ops.add_to_collection(updates_collection, assign_min.op)
if per_channel:
if input_dim >= 2:
@@ -146,7 +142,6 @@ def LastValueQuantize(inputs,
# TFLite requires that 0.0 if always in the [min; max] range.
batch_max = math_ops.maximum(batch_max, 0.0)
assign_max = state_ops.assign(max_var, batch_max, name='AssignMaxLast')
- ops.add_to_collection(updates_collection, assign_max.op)
return _FakeQuantWithMinMaxVars(
inputs,
@@ -163,7 +158,6 @@ def MovingAvgQuantize(inputs,
init_min=-6.0,
init_max=6.0,
ema_decay=0.999,
- updates_collection=ops.GraphKeys.UPDATE_OPS,
vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
name_prefix='MovingAvgQuantize',
reuse=None,
@@ -182,8 +176,6 @@ def MovingAvgQuantize(inputs,
init_min: a float scalar, the initial value for variable min.
init_max: a float scalar, the initial value for variable max.
ema_decay: EMA decay parameter.
- updates_collection: (Optional) collections to collect the update ops for
- computation.
vars_collection: (Optional) collection where to store variables for
quantization interval ends.
name_prefix: name_prefix for created nodes.
@@ -246,7 +238,6 @@ def MovingAvgQuantize(inputs,
batch_min = math_ops.minimum(batch_min, 0.0)
assign_min = moving_averages.assign_moving_average(
min_var, batch_min, ema_decay, name='AssignMinEma')
- ops.add_to_collection(updates_collection, assign_min.op)
if per_channel:
if input_dim >= 2:
@@ -260,7 +251,6 @@ def MovingAvgQuantize(inputs,
batch_max = math_ops.maximum(batch_max, 0.0)
assign_max = moving_averages.assign_moving_average(
max_var, batch_max, ema_decay, name='AssignMaxEma')
- ops.add_to_collection(updates_collection, assign_max.op)
return _FakeQuantWithMinMaxVars(
inputs,
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index d53d4d7b10..d2d0426d23 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -27,6 +27,7 @@ from tensorflow.contrib.quantize.python import quant_ops
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import tf_logging as logging
# Quantizable operation types that are supported by the quantization rewrite.
_QUANTIZABLE_TYPES = {'Conv2D', 'MatMul', 'DepthwiseConv2dNative'}
@@ -41,9 +42,16 @@ def Quantize(graph,
activation_bits=8,
ema_decay=0.999,
quant_delay=None,
- vars_collection=ops.GraphKeys.GLOBAL_VARIABLES):
+ vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
+ scope=None):
"""Updates graph with quantization operations.
+ Currently we quantize the following tensors:
+ * Conv/MatMul: Quantize the weights if it matches.
+ * Activation: Quantize the output if it matches.
+ * Bypass/Post-activation Bypass: Quantize both input and output
+ if it matches.
+
Args:
graph: Graph to modify.
is_training: Whether quantizing training graph or eval graph.
@@ -57,13 +65,21 @@ def Quantize(graph,
training.
vars_collection: (Optional) Collection where to store the variables for
quantization interval ends.
+ scope: The scope to be transformed. If it's not None, only the ops which
+ are in this scope will be transformed.
Raises:
ValueError: When quantization fails.
"""
+ if scope and not scope.endswith('/'):
+ scope += '/'
+
input_to_ops_map = input_to_ops.InputToOps(graph)
for layer_match in _FindLayersToQuantize(graph):
# Quantize the weights.
context = _GetContextFromOp(layer_match.layer_op)
+
+ # If `scope` is given, only quantize it if the consumer of weights
+ # (the layer op) is in the right scope.
_InsertQuantOp(
context,
'weights_quant',
@@ -74,7 +90,8 @@ def Quantize(graph,
quant_delay=quant_delay,
narrow_range=True,
vars_collection=vars_collection,
- bits=weight_bits)
+ bits=weight_bits,
+ consumer_scope=scope)
# Quantize the activations.
consumer_ops = input_to_ops_map.ConsumerOperations(
@@ -82,6 +99,9 @@ def Quantize(graph,
add_context = context
if layer_match.bypass_op:
add_context = re.search(r'^(.*)/([^/]+)', context).group(1)
+
+ # If `scope` is given, only quantize it if the producer of weights
+ # (usually it's the layer op) is in the right scope.
_InsertQuantOp(
add_context,
'act_quant',
@@ -93,11 +113,14 @@ def Quantize(graph,
quant_delay=quant_delay,
vars_collection=vars_collection,
bits=activation_bits,
- init_min=0.0)
+ init_min=0.0,
+ producer_scope=scope)
# Quantize the inputs and output to the bypass (if it exists). The input to
# the bypass is the bias add, and the output is the activation.
if layer_match.bypass_op is not None:
+ # If `scope` is given, only quantize it if the both the producer and the
+ # consumer are in the right scope.
_InsertQuantOp(
context,
'conv_quant',
@@ -107,7 +130,9 @@ def Quantize(graph,
ema_decay=ema_decay,
quant_delay=quant_delay,
vars_collection=vars_collection,
- bits=activation_bits)
+ bits=activation_bits,
+ producer_scope=scope,
+ consumer_scope=scope)
_InsertQuantOp(
add_context,
'add_quant',
@@ -118,12 +143,16 @@ def Quantize(graph,
ema_decay=ema_decay,
quant_delay=quant_delay,
vars_collection=vars_collection,
- bits=activation_bits)
+ bits=activation_bits,
+ producer_scope=scope,
+ consumer_scope=scope)
# Quantize bypass ops that occur after the activation.
if layer_match.post_activation_bypass_op is not None:
post_activation_bypass_context = re.search(
r'^(.*)/([^/]+)', layer_match.post_activation_bypass_op.name).group(1)
+ # If `scope` is given, only quantize it if the producer is in the right
+ # scope.
_InsertQuantOp(
post_activation_bypass_context,
'post_activation_bypass_quant',
@@ -135,7 +164,8 @@ def Quantize(graph,
ema_decay=ema_decay,
quant_delay=quant_delay,
vars_collection=vars_collection,
- bits=activation_bits)
+ bits=activation_bits,
+ producer_scope=scope)
def _FindLayersToQuantize(graph):
@@ -382,7 +412,9 @@ def _InsertQuantOp(context,
ema_decay=0.999,
quant_delay=None,
vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
- narrow_range=False):
+ narrow_range=False,
+ producer_scope=None,
+ consumer_scope=None):
"""Inserts a quant op between a producer op and (multiple) consumer ops.
Args:
@@ -407,10 +439,34 @@ def _InsertQuantOp(context,
quantization interval ends.
narrow_range: Whether to use the narrow quantization range
[1; 2^bits - 1] or wide range [0; 2^bits - 1].
+ producer_scope: The restriction of producer scope. If not None, the new op
+ will be inserted only when the producer is in this scope.
+ consumer_scope: The restriction of producer scope. If not None, the new op
+ will be inserted only when all the consumers are in this scope.
Raises:
ValueError: When producer operation is not directly connected to the
consumer operation.
"""
+ if producer_scope and not producer.name.startswith(producer_scope):
+ logging.info(
+ '_InsertQuantOp ignores context="%s" name="%s" '
+ 'because producer "%s" is not in scope "%s"',
+ context, name, producer.name, producer_scope)
+ return
+
+ if consumer_scope:
+ consumers_in_scope = []
+ for consumer in consumers:
+ if consumer.name.startswith(consumer_scope):
+ consumers_in_scope.append(consumer)
+ else:
+ logging.info(
+ '_InsertQuantOp context="%s" name="%s" ignores '
+ 'consumer "%s" because it is not in scope "%s"',
+ context, name, consumer.name, consumer_scope)
+ return
+ consumers = consumers_in_scope
+
name_prefix = _AddContextToName(context, name)
# This is needed on TPU where name_scope == 'TPUReplicate/loop', and
# name_prefix starts with 'TPUReplicate/loop/'; without dropping it
diff --git a/tensorflow/contrib/quantize/python/quantize_graph.py b/tensorflow/contrib/quantize/python/quantize_graph.py
index 0b74b438ac..11d052d7f4 100644
--- a/tensorflow/contrib/quantize/python/quantize_graph.py
+++ b/tensorflow/contrib/quantize/python/quantize_graph.py
@@ -28,7 +28,8 @@ def _create_graph(input_graph=None,
weight_bits=8,
activation_bits=8,
quant_delay=None,
- freeze_bn_delay=None):
+ freeze_bn_delay=None,
+ scope=None):
"""Rewrites an input_graph in place for simulated quantization.
The graph has fake quantization ops inserted to simulate the error
@@ -48,6 +49,8 @@ def _create_graph(input_graph=None,
frozen and used instead of batch statistics during training.
freeze_bn_delay should be greater than quant_delay and should correspond
to the number of steps when training has almost converged
+ scope: The scope to be transformed. If it's not None, only the ops which
+ are in this scope will be transformed.
Raises:
ValueError: If elements contains an element that isn't a tf.Tensor or
@@ -66,7 +69,8 @@ def _create_graph(input_graph=None,
is_training,
quant_delay=quant_delay,
weight_bits=weight_bits,
- activation_bits=activation_bits)
+ activation_bits=activation_bits,
+ scope=scope)
def create_training_graph(input_graph=None, quant_delay=0):
@@ -133,7 +137,8 @@ def experimental_create_training_graph(input_graph=None,
weight_bits=8,
activation_bits=8,
quant_delay=0,
- freeze_bn_delay=None):
+ freeze_bn_delay=None,
+ scope=None):
"""Rewrites a training input_graph in place for simulated quantization.
Variables added by the rewrite get added to the global variables collection.
@@ -165,6 +170,8 @@ def experimental_create_training_graph(input_graph=None,
frozen and used instead of batch statistics during training.
freeze_bn_delay should be greater than quant_delay and should correspond
to when training has almost converged
+ scope: The scope to be transformed. If it's not None, only the ops which
+ are in this scope will be transformed.
Raises:
ValueError: If elements contains an element that isn't a tf.Tensor or
@@ -177,12 +184,14 @@ def experimental_create_training_graph(input_graph=None,
weight_bits=weight_bits,
activation_bits=activation_bits,
quant_delay=quant_delay,
- freeze_bn_delay=freeze_bn_delay)
+ freeze_bn_delay=freeze_bn_delay,
+ scope=scope)
def experimental_create_eval_graph(input_graph=None,
weight_bits=8,
- activation_bits=8):
+ activation_bits=8,
+ scope=None):
"""Rewrites an eval input_graph in place for simulated quantization.
Variables added by the rewrite get added to the global variables collection.
@@ -200,8 +209,8 @@ def experimental_create_eval_graph(input_graph=None,
default graph.
weight_bits: Number of bits to use for quantizing weights.
activation_bits: Number of bits to use for quantizing activations.
-
-
+ scope: The scope to be transformed. If it's not None, only the ops which
+ are in this scope will be transformed.
Raises:
ValueError: If elements contains an element that isn't a tf.Tensor or
@@ -211,4 +220,5 @@ def experimental_create_eval_graph(input_graph=None,
input_graph=input_graph,
is_training=False,
weight_bits=weight_bits,
- activation_bits=activation_bits)
+ activation_bits=activation_bits,
+ scope=scope)
diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py
index b9d03c1bc0..caf8ff28d5 100644
--- a/tensorflow/contrib/quantize/python/quantize_graph_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py
@@ -66,6 +66,20 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase):
for fn in rewrite_fns:
test_fn(fn)
+ def _RunTestOverExperimentalRewritesWithScope(self, test_fn, scope):
+ def with_absent_scope(fn):
+ def fn_with_absent_scope(*args):
+ fn(*args, scope=scope)
+ return fn_with_absent_scope
+ rewrite_fns = [
+ with_absent_scope(
+ quantize_graph.experimental_create_training_graph),
+ with_absent_scope(
+ quantize_graph.experimental_create_eval_graph),
+ ]
+ for fn in rewrite_fns:
+ test_fn(fn)
+
def testRewrite(self):
self._RunTestOverAllRewrites(self._TestRewrite)
@@ -99,6 +113,34 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase):
# Ensure that variables were added.
self.assertTrue(len(orig_variable_names) < len(q_variables))
+ def testWithPreActivationBypass(self):
+ self._RunTestOverAllRewrites(self._TestWithPreActivationBypass)
+
+ def _TestWithPreActivationBypass(self, rewrite_fn):
+ # Tests that the default graph is correctly used when no args are provided
+ # to rewrite_fn.
+ with ops.Graph().as_default() as g:
+ self._ConvLayer(pre_activation_bypass=True, scope='scope1')
+ rewrite_fn()
+
+ op_names = [op.name for op in g.get_operations()]
+ self.assertTrue(
+ any('scope1/add_quant/' in name for name in op_names))
+
+ def testWithPostActivationBypass(self):
+ self._RunTestOverAllRewrites(self._TestWithPostActivationBypass)
+
+ def _TestWithPostActivationBypass(self, rewrite_fn):
+ # Tests that the default graph is correctly used when no args are provided
+ # to rewrite_fn.
+ with ops.Graph().as_default() as g:
+ self._ConvLayer(post_activation_bypass=True, scope='scope1')
+ rewrite_fn()
+
+ op_names = [op.name for op in g.get_operations()]
+ self.assertTrue(any(
+ 'scope1/post_activation_bypass_quant/' in name for name in op_names))
+
def testQuantDelay(self):
self._RunTestOverTrainingRewrites(self._TestQuantDelay)
@@ -224,20 +266,66 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase):
graph_def_after = str(g.as_graph_def())
self.assertEqual(graph_def_before, graph_def_after)
- def _ConvLayer(self):
+ def testRewriteWithScope(self):
+ self._RunTestOverExperimentalRewritesWithScope(
+ self._TestRewriteWithScope, 'scope1')
+
+ def _TestRewriteWithScope(self, rewrite_fn):
+ graph = ops.Graph()
+ with graph.as_default():
+ scope1_output = self._ConvLayer(scope='scope1')
+ self._ConvLayer(input_tensor=scope1_output, scope='scope2')
+
+ rewrite_fn(graph)
+
+ op_names = [op.name for op in graph.get_operations()]
+ # The weights and activation of scope1 is quantized, but not scope2.
+ self.assertTrue(
+ any('scope1/Conv/act_quant' in name for name in op_names))
+ self.assertTrue(
+ any('scope1/Conv/weights_quant' in name for name in op_names))
+ self.assertFalse(
+ any('scope2/Conv/act_quant' in name for name in op_names))
+ self.assertFalse(
+ any('scope2/Conv/weights_quant' in name for name in op_names))
+
+ def testRewriteWithNonMatchingScope(self):
+ self._RunTestOverExperimentalRewritesWithScope(
+ self._TestRewriteWithNonMatchingScope, 'NonExistingScope')
+
+ def _TestRewriteWithNonMatchingScope(self, rewrite_fn):
+ graph = ops.Graph()
+ with graph.as_default():
+ self._ConvLayer()
+
+ op_names_before_rewrite = set([op.name for op in graph.get_operations()])
+ rewrite_fn(graph)
+ op_names_after_rewrite = set([op.name for op in graph.get_operations()])
+
+ # No ops should be inserted or removed.
+ self.assertEqual(op_names_before_rewrite, op_names_after_rewrite)
+
+ def _ConvLayer(
+ self, input_tensor=None, scope='test', pre_activation_bypass=False,
+ post_activation_bypass=False):
"""Add a basic convolution layer to the default graph."""
batch_size, height, width, depth = 5, 128, 128, 3
- inputs = array_ops.zeros((batch_size, height, width, depth))
+ if input_tensor is None:
+ input_tensor = array_ops.zeros((batch_size, height, width, depth))
weight_init = init_ops.truncated_normal_initializer
- conv = layers.conv2d(
- inputs,
- 32, [5, 5],
- stride=2,
- padding='SAME',
- weights_initializer=weight_init(0.09),
- activation_fn=None,
- scope='test')
- _ = nn_ops.relu6(conv)
+ with ops.name_scope(scope):
+ output = layers.conv2d(
+ input_tensor,
+ depth, [5, 5],
+ padding='SAME',
+ weights_initializer=weight_init(0.09),
+ activation_fn=None)
+ if pre_activation_bypass:
+ output += input_tensor
+ output = nn_ops.relu6(output)
+ if post_activation_bypass:
+ output += input_tensor
+ return output
if __name__ == '__main__':
diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py
index 8d057d3710..d37c83d683 100644
--- a/tensorflow/contrib/quantize/python/quantize_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_test.py
@@ -254,12 +254,11 @@ class QuantizeTest(test_util.TensorFlowTestCase):
graph = ops.Graph()
with graph.as_default():
with graph.name_scope(None):
- batch_size, height, width, depth = 5, 128, 128, 3
+ batch_size, height, width, depth = 5, 128, 128, 32
input1 = array_ops.zeros((batch_size, height, width, depth))
_ = conv2d(
input1,
32, [5, 5],
- stride=2,
padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=None,
@@ -268,6 +267,33 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
# Passes if Quantize() does not crash.
+ def testWithNonMatchingNameScope(self):
+ self._RunTestOverParameters(self._testWithNonMatchingNameScope)
+
+ def _testWithNonMatchingNameScope(self, is_training):
+ graph = ops.Graph()
+ with graph.as_default():
+ with graph.name_scope('name_scope'):
+ batch_size, height, width, depth = 5, 128, 128, 3
+ input1 = array_ops.zeros((batch_size, height, width, depth))
+ _ = conv2d(
+ input1,
+ 32, [5, 5],
+ stride=2,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=None,
+ scope='test')
+
+ op_names_before_quantize = set([op.name for op in graph.get_operations()])
+ quantize.Quantize(
+ graph, is_training, weight_bits=8, activation_bits=8,
+ scope='NonExisting/')
+ op_names_after_quantize = set([op.name for op in graph.get_operations()])
+
+ # No ops should be inserted or removed.
+ self.assertEqual(op_names_before_quantize, op_names_after_quantize)
+
def _WeightInit(self, stddev):
"""Returns truncated normal variable initializer.
diff --git a/tensorflow/contrib/recurrent/BUILD b/tensorflow/contrib/recurrent/BUILD
new file mode 100644
index 0000000000..b3cb04ce26
--- /dev/null
+++ b/tensorflow/contrib/recurrent/BUILD
@@ -0,0 +1,106 @@
+# Recurrent library.
+
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
+
+py_library(
+ name = "recurrent_py",
+ srcs = ["python/recurrent_api.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":functional_rnn_ops_py",
+ ":recurrent_ops_py",
+ ],
+)
+
+py_library(
+ name = "recurrent_ops_py",
+ srcs = ["python/ops/recurrent.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/framework:framework_py",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:function",
+ "//tensorflow/python:functional_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ ],
+)
+
+py_library(
+ name = "functional_rnn_ops_py",
+ srcs = ["python/ops/functional_rnn.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":recurrent_ops_py",
+ "//tensorflow/contrib/framework:framework_py",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:function",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:standard_ops",
+ ],
+)
+
+cuda_py_tests(
+ name = "recurrent_ops_test",
+ size = "small",
+ srcs = ["python/kernel_tests/recurrent_test.py"],
+ additional_deps = [
+ ":recurrent_ops_py",
+ "//third_party/py/numpy",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:random_seed",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python:variables",
+ ],
+ tags = ["nopip"],
+)
+
+cuda_py_tests(
+ name = "functional_rnn_ops_test",
+ size = "small",
+ srcs = ["python/kernel_tests/functional_rnn_test.py"],
+ additional_deps = [
+ ":functional_rnn_ops_py",
+ "//third_party/py/numpy",
+ "//tensorflow/contrib/layers:layers_py",
+ "//tensorflow/contrib/tpu:tpu",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:rnn",
+ "//tensorflow/python:rnn_cell",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ ],
+ tags = ["nopip"],
+)
diff --git a/tensorflow/contrib/recurrent/README.md b/tensorflow/contrib/recurrent/README.md
new file mode 100644
index 0000000000..86e10eee51
--- /dev/null
+++ b/tensorflow/contrib/recurrent/README.md
@@ -0,0 +1,13 @@
+# Recurrent computation library
+
+The recurrent computation library contains code to perform recurrent
+computations.
+
+Its chief application is to implement recurrent neural networks (RNNs, LSTMs,
+etc), which is implemented in `functional_rnn.py`. Similar techniques may be
+used to implement deep networks.
+
+The computation saves the activations in the forward pass, and computes the
+gradients in the backward pass using a single accumulator.
+
+The `functional_rnn` interface is compatible with the `dynamic_rnn` API.
diff --git a/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py b/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py
new file mode 100644
index 0000000000..0f19ac7dbe
--- /dev/null
+++ b/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py
@@ -0,0 +1,163 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Functional RNN."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+
+from tensorflow.contrib.recurrent.python.ops import functional_rnn
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import rnn as rnn_lib
+from tensorflow.python.ops import rnn_cell_impl
+from tensorflow.python.ops import variables
+import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
+import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import
+from tensorflow.python.platform import test as test_lib
+from tensorflow.python.platform import tf_logging as logging
+
+
+def _CreateStackedLstmCell(*cell_sizes):
+ subcells = [rnn_cell_impl.LSTMCell(cell_size) for cell_size in cell_sizes]
+ return rnn_cell_impl.MultiRNNCell(subcells)
+
+
+class FunctionalRnnTest(test_util.TensorFlowTestCase):
+
+ _BATCH_SIZE = 3
+ _TOTAL_TIME = 5
+ _INPUT_SIZE = 11
+ _NUM_UNITS = 7
+
+ # Set this to some output if you want to use it.
+ _LSTM_GRAPH_DEF_FILEPATH = None
+
+ _CELLDEFS = {
+ 'gru': (rnn_cell_impl.GRUCell, [_NUM_UNITS]),
+ 'lstm': (rnn_cell_impl.LSTMCell, [_NUM_UNITS]),
+ 'stacked_lstm': (_CreateStackedLstmCell, [_NUM_UNITS] * 3)
+ }
+
+ def _CreateCell(self, celldef_name):
+ func, args = self._CELLDEFS[celldef_name]
+ return func(*args)
+
+ def _CreateInputs(self):
+ inputs = np.random.random([FunctionalRnnTest._BATCH_SIZE,
+ FunctionalRnnTest._TOTAL_TIME,
+ FunctionalRnnTest._INPUT_SIZE])
+ # Always leave one time slot empty, to check max_length behavior.
+ sequence_length = np.random.randint(
+ 0, high=FunctionalRnnTest._TOTAL_TIME - 1,
+ size=FunctionalRnnTest._BATCH_SIZE,
+ dtype=np.int)
+ return (inputs, sequence_length)
+
+ def _CreateRnnGraph(self, create_rnn_computation_func, cell, tf_inputs,
+ tf_sequence_length, initial_state=None,
+ time_major=None, scope=None):
+ tf_result = create_rnn_computation_func(cell=cell, inputs=tf_inputs,
+ sequence_length=tf_sequence_length,
+ initial_state=initial_state,
+ dtype=dtypes.float32,
+ time_major=time_major,
+ scope=scope)
+ grad = gradients_impl.gradients(tf_result, variables.trainable_variables())
+ return {'inference': tf_result, 'grad': grad}
+
+ def _MaybeResetVariables(self, variable_cache, sess, var_list):
+ """Possibly resets the variables to a previously seen value."""
+ reset_ops = []
+ fetches = []
+ for var in var_list:
+ if var.name in variable_cache:
+ reset_ops += [var.assign(variable_cache[var.name])]
+ else:
+ fetches += [(var.name, var)]
+ if reset_ops:
+ sess.run(reset_ops)
+ if fetches:
+ val = sess.run(dict(fetches))
+ for n, v in val.items():
+ assert n not in variable_cache
+ variable_cache[n] = v
+
+ def _RunRnn(self, numpy_inputs, numpy_slen, cell_name, variable_cache,
+ is_dynamic):
+ with ops.Graph().as_default() as graph:
+ tf_inputs = array_ops.placeholder(
+ dtypes.float32, shape=numpy_inputs.shape)
+ tf_slen = array_ops.placeholder(dtypes.int32)
+ feeds = {tf_inputs: numpy_inputs, tf_slen: numpy_slen}
+ cell = self._CreateCell(cell_name)
+ fn = rnn_lib.dynamic_rnn if is_dynamic else functional_rnn.functional_rnn
+ fetches = self._CreateRnnGraph(fn, cell, tf_inputs, tf_slen)
+ with self.test_session(graph=graph) as sess:
+ sess.run(variables.global_variables_initializer())
+ # Note that cell.trainable_variables it not always set.
+ self._MaybeResetVariables(variable_cache, sess,
+ variables.trainable_variables())
+ val = sess.run(fetches, feed_dict=feeds)
+ graph_def = graph.as_graph_def()
+ return graph_def, val
+
+ def testRunLstm(self):
+ """Runs a simple LSTM. Does not check output."""
+ np_inputs, np_slen = self._CreateInputs()
+ var_cache = {}
+ graphdef, _ = self._RunRnn(np_inputs, np_slen, 'lstm', var_cache, False)
+ logging.info('graphdef: %s', graphdef)
+ if self._LSTM_GRAPH_DEF_FILEPATH:
+ with open(self._LSTM_GRAPH_DEF_FILEPATH, 'w') as f:
+ f.write(str(graphdef))
+
+ def testLstm(self):
+ """Checks an LSTM against the reference implementation."""
+ np_inputs, np_slen = self._CreateInputs()
+ var_cache = {}
+ _, func_rnn = self._RunRnn(np_inputs, np_slen, 'lstm', var_cache, False)
+ _, dyn_rnn = self._RunRnn(np_inputs, np_slen, 'lstm', var_cache, True)
+ self.assertAllClose(dyn_rnn['inference'], func_rnn['inference'])
+ self.assertAllClose(dyn_rnn['grad'], func_rnn['grad'])
+
+ def testGru(self):
+ """Checks a GRU cell against the reference implementation."""
+ np_inputs, np_slen = self._CreateInputs()
+ var_cache = {}
+ _, func_rnn = self._RunRnn(np_inputs, np_slen, 'gru', var_cache, False)
+ _, dyn_rnn = self._RunRnn(np_inputs, np_slen, 'gru', var_cache, True)
+ self.assertAllClose(dyn_rnn['inference'], func_rnn['inference'])
+ self.assertAllClose(dyn_rnn['grad'], func_rnn['grad'])
+
+ def testStackedLstm(self):
+ """Checks a stacked LSTM cell against the reference implementation."""
+ np_inputs, np_slen = self._CreateInputs()
+ var_cache = {}
+ args = [np_inputs, np_slen, 'stacked_lstm', var_cache]
+ _, func_rnn = self._RunRnn(*(args + [False]))
+ _, dyn_rnn = self._RunRnn(*(args + [True]))
+ self.assertAllClose(dyn_rnn['inference'], func_rnn['inference'])
+ self.assertAllClose(dyn_rnn['grad'], func_rnn['grad'])
+
+
+if __name__ == '__main__':
+ test_lib.main()
diff --git a/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py b/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py
new file mode 100644
index 0000000000..00fbd4fbb8
--- /dev/null
+++ b/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py
@@ -0,0 +1,192 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Recurrent ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.contrib.recurrent.python.ops import recurrent
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.framework import random_seed
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test as test_lib
+from tensorflow.python.platform import tf_logging as logging
+
+
+_ElmanState = collections.namedtuple('ElmanState', ('h'))
+_ElmanTheta = collections.namedtuple('ElmanTheta', ('w', 'b'))
+_ElmanInputs = collections.namedtuple('ElmanInputs', ('x'))
+
+
+# TODO(drpng): add test for max length computation.
+class RecurrentTest(test_util.TensorFlowTestCase):
+
+ def testBasic(self):
+ # pylint:disable=invalid-name
+ _PolyState = collections.namedtuple('PolyState', ('value', 'x_power'))
+ _PolyTheta = collections.namedtuple('PolyTheta', ('x'))
+ _PolyInputs = collections.namedtuple('PolyInputs', ('coeff'))
+ # pylint:enable=invalid-name
+
+ def Poly(theta, state, inputs):
+ next_state = _PolyState(
+ value=state.value + inputs.coeff * state.x_power,
+ x_power=state.x_power * theta.x)
+ return next_state, []
+
+ with self.test_session() as sess:
+ theta = _PolyTheta(x=array_ops.constant(2.0))
+ state = _PolyState(
+ value=array_ops.constant(0.0),
+ x_power=array_ops.constant(1.0))
+ inputs = _PolyInputs(coeff=array_ops.constant([1., 2., 3.]))
+
+ # x = 2
+ # 1 + 2*x + 3*x^2
+ ret = recurrent.Recurrent(theta, state, inputs, Poly)
+
+ acc, state = sess.run(ret)
+ self.assertAllClose(acc.value, [1., 5., 17.])
+ self.assertAllClose(acc.x_power, [2., 4., 8.])
+ self.assertAllClose(state.value, 17.)
+ self.assertAllClose(state.x_power, 8.)
+
+ y = ret[1].value
+ dx, d_coeff = gradients_impl.gradients(ys=[y], xs=[theta.x, inputs.coeff])
+ dx_val, d_coeff_val = sess.run([dx, d_coeff])
+
+ # 2 + 6*x
+ self.assertAllClose(dx_val, 14.)
+ self.assertAllClose(d_coeff_val, [1., 2., 4.])
+
+ # acc = [1, 1+2x, 1+2x+3x^2]
+ # sum(acc) = 3 + 4x + 3x^2
+ acc = ret[0].value
+ dx, d_coeff = gradients_impl.gradients(
+ ys=[math_ops.reduce_sum(acc)], xs=[theta.x, inputs.coeff])
+ dx_val, d_coeff_val = sess.run([dx, d_coeff])
+ # 4 + 6*x
+ self.assertAllClose(dx_val, 16.)
+ self.assertAllClose(d_coeff_val, [3., 4., 4.])
+
+ @staticmethod
+ def Rand(shape):
+ return random_ops.random_uniform(
+ shape, minval=-0.2, maxval=0.2, dtype=dtypes.float64)
+
+ @staticmethod
+ def Elman(theta, state0, inputs):
+ h0, w, b, x = state0.h, theta.w, theta.b, inputs.x
+ xw = math_ops.matmul(array_ops.concat([x, h0], axis=1), w)
+ h1 = math_ops.sigmoid(xw + b)
+ state1 = _ElmanState(h=h1)
+ return (state1, state1)
+
+ @staticmethod
+ def ElmanGrad(theta, state0, inputs, extras, dstate1):
+
+ @function.Defun()
+ def Grad(h0, w, b, x, h1, dh1):
+ del b
+ # We hand-roll the gradient for the 2nd half of the cell as a demo.
+ dxwb = (dh1 * (1 - h1) * h1)
+ dxw, db = dxwb, math_ops.reduce_sum(dxwb, axis=0)
+
+ # Uses tf.gradient for the 1nd half of the cell as a demo.
+ xw = math_ops.matmul(array_ops.concat([x, h0], axis=1), w)
+ dh0, dx, dw = gradients_impl.gradients(
+ ys=[xw], xs=[h0, x, w], grad_ys=[dxw])
+
+ return dh0, dx, dw, db
+
+ dh0, dx, dw, db = Grad(state0.h, theta.w, theta.b, inputs.x,
+ extras.h, dstate1.h)
+ dstate0 = _ElmanState(h=dh0)
+ dinputs = _ElmanInputs(x=dx)
+ return (_ElmanTheta(w=dw, b=db), dstate0, dinputs)
+
+ @staticmethod
+ def ElmanOut(state1):
+ return _ElmanState(x=state1.h)
+
+ @staticmethod
+ def ElmanOutGrad(dout):
+ return _ElmanState(h=dout.x)
+
+ def testElman(self):
+ for seqlen, use_grad in [(1, False), (1, True), (7, False), (7, True)]:
+ logging.info('== Elman: seqlen=%s, use_grad=%s', seqlen, use_grad)
+ self._ParameterizedTestElman(seqlen, use_grad)
+
+ def _ParameterizedTestElman(self, seqlen, use_grad):
+
+ with self.test_session() as sess:
+ random_seed.set_random_seed(342462)
+
+ batch = 3
+ dims = 4
+ theta = _ElmanTheta(w=RecurrentTest.Rand([2 * dims, dims]),
+ b=RecurrentTest.Rand([dims]))
+ state0 = _ElmanState(h=RecurrentTest.Rand([batch, dims]))
+ inputs = _ElmanInputs(x=RecurrentTest.Rand([seqlen, batch, dims]))
+
+ # Statically unrolled.
+ s = state0
+ out = []
+ for i in xrange(seqlen):
+ inp = _ElmanInputs(x=inputs.x[i, :])
+ s, _ = RecurrentTest.Elman(theta, s, inp)
+ out += [s.h]
+ acc0, final0 = array_ops.stack(out), s.h
+ loss0 = math_ops.reduce_sum(acc0) + math_ops.reduce_sum(final0)
+ (dw0, db0, dh0, di0) = gradients_impl.gradients(
+ loss0, [theta.w, theta.b, state0.h, inputs.x])
+
+ acc1, final1 = recurrent.Recurrent(
+ theta=theta,
+ state0=state0,
+ inputs=inputs,
+ cell_fn=RecurrentTest.Elman,
+ cell_grad=RecurrentTest.ElmanGrad if use_grad else None)
+ assert isinstance(acc1, _ElmanState)
+ assert isinstance(final1, _ElmanState)
+ acc1, final1 = acc1.h, final1.h
+ loss1 = math_ops.reduce_sum(acc1) + math_ops.reduce_sum(final1)
+ (dw1, db1, dh1, di1) = gradients_impl.gradients(
+ loss1, [theta.w, theta.b, state0.h, inputs.x])
+
+ # Fetches a few values and compare them.
+ (acc0, acc1, final0, final1, dw0, dw1, db0, db1, dh0, dh1, di0,
+ di1) = sess.run(
+ [acc0, acc1, final0, final1, dw0, dw1, db0, db1, dh0, dh1, di0, di1])
+ self.assertAllClose(acc0, acc1)
+ self.assertAllClose(final0, final1)
+ self.assertAllClose(dw0, dw1)
+ self.assertAllClose(db0, db1)
+ self.assertAllClose(dh0, dh1)
+ self.assertAllClose(di0, di1)
+
+if __name__ == '__main__':
+ test_lib.main()
diff --git a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
new file mode 100644
index 0000000000..a085474c1b
--- /dev/null
+++ b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
@@ -0,0 +1,396 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A tf.nn.dynamic_rnn variant, built on the Recurrent class.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+
+from tensorflow.contrib.recurrent.python.ops import recurrent
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.util import nest
+
+
+def _GetDTypesFromStructure(struct):
+ dtypes_list = []
+ for x in nest.flatten(struct):
+ x = ops.convert_to_tensor(x)
+ dtypes_list.append(x.dtype)
+ return dtypes_list
+
+
+def _SetShapeFromTemplate(struct, struct_template):
+ as_list = nest.flatten(struct)
+ template_as_list = nest.flatten(struct_template)
+ for element, template in zip(as_list, template_as_list):
+ element.set_shape(template.shape)
+
+
+class _FunctionalRnnCell(object):
+ """Wrapper around RNNCell which separates state from computation.
+
+ This class accomplishes the following:
+ * Turn the cell's `__call__` function into a pure function. The global
+ side effects are separated as `theta`. They are the variables created
+ for the weights of the computation.
+ * Unless the output is aliased as part of the state, extend the state to
+ contain the output so that we store the history in `Recurrent`.
+ * Set static shapes as required.
+ """
+
+ def __init__(self, rnn_cell, seq_inputs, initial_state):
+ assert initial_state is not None
+
+ # TODO(drpng): Dtype needs to be configurable.
+ input_dtypes = [dtypes.float32] + _GetDTypesFromStructure(initial_state)
+ # See _index.
+ like_inputs_t = nest.map_structure(
+ lambda x: array_ops.stop_gradient(array_ops.gather(x, 0)), seq_inputs)
+ input_structure = (like_inputs_t, initial_state)
+
+ @function.Defun(*input_dtypes)
+ def FlatCellStep(*flat_inputs):
+ """The flattened version of `rnn_cell`."""
+ inputs_t, state0 = nest.pack_sequence_as(input_structure, flat_inputs)
+ _SetShapeFromTemplate(state0, initial_state)
+ _SetShapeFromTemplate(inputs_t, like_inputs_t)
+ outputs_t, state1 = rnn_cell(inputs_t, state0)
+ state_list = nest.flatten(state1)
+ self._output_shape = outputs_t.shape
+
+ if outputs_t in state_list:
+ output_index_in_state = state_list.index(outputs_t)
+ else:
+ output_index_in_state = None
+
+ if output_index_in_state is None:
+ self._prepend_output = True
+ self._output_state_idx = 0
+ return [outputs_t] + state_list
+ else:
+ self._output_state_idx = output_index_in_state
+ self._prepend_output = False
+ # To save memory, we don't store return the output separately
+ # from the state list, since we know it's the same.
+ return state_list
+
+ def _ToPureFunction(func):
+ # NOTE: This forces the creating of the function.
+ if func.captured_inputs:
+ pure_func = copy.copy(func)
+ # pylint: disable=protected-access
+ pure_func._extra_inputs = []
+ return pure_func
+ return func
+
+ pure_flat_cell_step = _ToPureFunction(FlatCellStep)
+
+ def CellStep(theta, extended_state0, inputs_t):
+ """Performs one time steps on structured inputs.
+
+ The purpose of this function is to turn the parameters into flattened
+ versions, and to resolve the parameter order difference between
+ `Recurrent` and `RNNCell`.
+
+ In the event the cell returns a transformed output that is not aliased
+ within its state, the `extended_state0` also contains the output as its
+ first element.
+
+ Args:
+ theta: Weights required for the computation. A structure of tensors.
+ extended_state0: the state0, and possibly the output at the previous
+ time step. A structure of tensors.
+ inputs_t: the inputs at time t.
+
+ Returns:
+ A pair of the next state (inclusive of the output), and an empty list
+ (unused `extras`).
+ The next state is congruent to state0.
+ """
+ extended_state0_flat = nest.flatten(extended_state0)
+ state0_flat = self.MaybeRemoveOutputFromState(extended_state0_flat)
+ full_inputs = [inputs_t] + state0_flat + theta
+ # Note that the thetas are additional inputs appeneded as extra
+ # parameters.
+ cell_out = pure_flat_cell_step(*full_inputs)
+ return cell_out, []
+
+ self._cell_step = CellStep
+ self._theta = FlatCellStep.captured_inputs
+ self._zero_state = rnn_cell.zero_state
+ self._state_template = initial_state
+ self._output_size = rnn_cell.output_size
+
+ @property
+ def extended_initial_state(self):
+ if self._prepend_output:
+ return [array_ops.zeros(self._output_shape), self._state_template]
+ else:
+ # The base case, where the output is just the hidden state.
+ return self._state_template
+
+ @property
+ def cell_step(self):
+ return self._cell_step
+
+ @property
+ def theta(self):
+ return self._theta
+
+ @property
+ def state_template(self):
+ return self._state_template
+
+ @property
+ def output_shape(self):
+ return self._output_shape
+
+ def GetOutputFromState(self, state):
+ return nest.flatten(state)[self._output_state_idx]
+
+ def MaybeRemoveOutputFromState(self, flat_state):
+ if self._prepend_output:
+ return flat_state[1:]
+ return flat_state
+
+
+def _ApplyLengthsToBatch(sequence_lengths, tf_output):
+ # TODO(drpng): just use Update so that we don't carry over the gradients?
+ """Sets the output to be zero at the end of the sequence."""
+ # output is batch major.
+ batch_size, max_time, vector_size = tf_output.shape
+ output_time = array_ops.tile(math_ops.range(0, max_time), [batch_size])
+ output_time = array_ops.reshape(output_time, [batch_size, max_time])
+ lengths = array_ops.tile(
+ array_ops.reshape(sequence_lengths, [-1, 1]), [1, max_time])
+ is_less = math_ops.cast(
+ math_ops.less(output_time, lengths), dtype=dtypes.float32)
+ keep_mask = array_ops.tile(
+ array_ops.expand_dims(is_less, -1),
+ [1, 1, vector_size])
+ final_output = keep_mask * tf_output
+ return final_output
+
+
+def _PickFinalStateFromHistory(acc_state, sequence_length):
+ """Implements acc_state[sequence_length - 1]."""
+ # This will work on all platforms, unlike the regular slice.
+ last_value = []
+ for state_var in nest.flatten(acc_state):
+ # We compute the following with matrix operations:
+ # last_var = state_var[sequence_length - 1]
+ shape = array_ops.shape(state_var)
+ max_time, batch_size = shape[0], shape[1]
+ output_time = array_ops.tile(math_ops.range(0, max_time), [batch_size])
+ output_time = array_ops.reshape(output_time, [batch_size, max_time])
+ lengths = array_ops.tile(array_ops.reshape(sequence_length,
+ [-1, 1]), [1, max_time])
+ last_idx = math_ops.cast(math_ops.equal(output_time, lengths - 1),
+ dtype=dtypes.float32)
+ last_idx = array_ops.transpose(last_idx)
+ last_idx_for_bcast = array_ops.expand_dims(last_idx, -1)
+ sliced = math_ops.multiply(last_idx_for_bcast, state_var)
+ last_var = math_ops.reduce_sum(sliced, 0)
+ last_value += [last_var]
+ return nest.pack_sequence_as(acc_state, last_value)
+
+
+def _PostProcessOutput(extended_acc_state, extended_final_state, func_cell,
+ total_time, inputs_lengths):
+ """Post-process output of recurrent.
+
+ This function takes the accumulated extended state and extracts the requested
+ state and output.
+
+ When `inputs_lengths` has been set, it extracts the output from the
+ accumulated state. It also sets outputs past.
+
+ It also sets the static shape information.
+
+ Args:
+ extended_acc_state: A structure containing the accumulated state at each
+ time. It may contain the output at each time as well.
+ extended_final_state: A structure containing the final state. It may
+ contain the output at the final time.
+ func_cell: The functional wrapper around the cell.
+ total_time: A scalar integer tensor.
+ inputs_lengths: An integer tensor with one entry per input.
+
+ Returns:
+ A tuple with the outputs at each time, and the final state.
+ """
+ if inputs_lengths is None:
+ flat_final_state = func_cell.MaybeRemoveOutputFromState(
+ nest.flatten(extended_final_state))
+ tf_state = nest.pack_sequence_as(func_cell.state_template, flat_final_state)
+ else:
+ # The accumulated state is over the entire sequence, so we pick it
+ # out from the acc_state sequence.
+ flat_acc_state = func_cell.MaybeRemoveOutputFromState(
+ nest.flatten(extended_acc_state))
+ acc_state = nest.pack_sequence_as(
+ func_cell.state_template, flat_acc_state)
+ tf_state = _PickFinalStateFromHistory(acc_state, inputs_lengths)
+
+ output_from_state = func_cell.GetOutputFromState(extended_acc_state)
+ tf_output = array_ops.transpose(output_from_state, [1, 0, 2])
+ tf_output.set_shape(
+ [func_cell.output_shape[0], total_time, func_cell.output_shape[1]])
+ if inputs_lengths is not None:
+ # Need set the outputs to zero.
+ tf_output = _ApplyLengthsToBatch(inputs_lengths, tf_output)
+ # tf_output = array_ops.zeros([4, 3, 5])
+ _SetShapeFromTemplate(tf_state, func_cell.state_template)
+ return tf_output, tf_state
+
+
+# pylint: disable=invalid-name
+def functional_rnn(cell, inputs, sequence_length=None,
+ initial_state=None, dtype=None, time_major=False,
+ scope=None, use_tpu=False):
+ """Same interface as `tf.nn.dynamic_rnn`."""
+ with variable_scope.variable_scope(scope or 'rnn'):
+ if not time_major:
+ inputs = nest.map_structure(
+ lambda t: array_ops.transpose(t, [1, 0, 2]), inputs)
+ inputs_flat = nest.flatten(inputs)
+ batch_size = array_ops.shape(inputs_flat[0])[1]
+ if initial_state is None:
+ initial_state = cell.zero_state(batch_size, dtype)
+ func_cell = _FunctionalRnnCell(cell, inputs, initial_state)
+ extended_acc_state, extended_final_state = recurrent.Recurrent(
+ theta=func_cell.theta,
+ state0=func_cell.extended_initial_state,
+ inputs=inputs,
+ cell_fn=func_cell.cell_step,
+ use_tpu=use_tpu)
+ return _PostProcessOutput(extended_acc_state, extended_final_state,
+ func_cell, inputs_flat[0].shape[0], sequence_length)
+
+
+def bidirectional_functional_rnn(
+ cell_fw,
+ cell_bw,
+ inputs,
+ initial_state_fw=None,
+ initial_state_bw=None,
+ dtype=None,
+ sequence_length=None,
+ time_major=False,
+ use_tpu=False,
+ scope=None):
+ """Creates a bidirectional recurrent neural network.
+
+ Performs fully dynamic unrolling of inputs in both directions. Built to be API
+ compatible with `tf.nn.bidirectional_dynamic_rnn`, but implemented with
+ functional control flow for TPU compatibility.
+
+ Args:
+ cell_fw: An instance of `tf.contrib.rnn.RNNCell`.
+ cell_bw: An instance of `tf.contrib.rnn.RNNCell`.
+ inputs: The RNN inputs. If time_major == False (default), this must be a
+ Tensor (or hierarchical structure of Tensors) of shape
+ [batch_size, max_time, ...]. If time_major == True, this must be a Tensor
+ (or hierarchical structure of Tensors) of shape:
+ [max_time, batch_size, ...]. The first two dimensions must match across
+ all the inputs, but otherwise the ranks and other shape components may
+ differ.
+ initial_state_fw: An optional initial state for `cell_fw`. Should match
+ `cell_fw.zero_state` in structure and type.
+ initial_state_bw: An optional initial state for `cell_bw`. Should match
+ `cell_bw.zero_state` in structure and type.
+ dtype: (optional) The data type for the initial state and expected output.
+ Required if initial_states are not provided or RNN state has a
+ heterogeneous dtype.
+ sequence_length: An optional int32/int64 vector sized [batch_size]. Used to
+ copy-through state and zero-out outputs when past a batch element's
+ sequence length. So it's more for correctness than performance.
+ time_major: Whether the `inputs` tensor is in "time major" format.
+ use_tpu: Whether to enable TPU-compatible operation. If True, does not truly
+ reverse `inputs` in the backwards RNN. Once b/69305369 is fixed, we can
+ remove this flag.
+ scope: An optional scope name for the dynamic RNN.
+
+ Returns:
+ outputs: A tuple of `(output_fw, output_bw)`. The output of the forward and
+ backward RNN. If time_major == False (default), these will
+ be Tensors shaped: [batch_size, max_time, cell.output_size]. If
+ time_major == True, these will be Tensors shaped:
+ [max_time, batch_size, cell.output_size]. Note, if cell.output_size is a
+ (possibly nested) tuple of integers or TensorShape objects, then the
+ output for that direction will be a tuple having the same structure as
+ cell.output_size, containing Tensors having shapes corresponding to the
+ shape data in cell.output_size.
+ final_states: A tuple of `(final_state_fw, final_state_bw)`. A Tensor or
+ hierarchical structure of Tensors indicating the final cell state in each
+ direction. Must have the same structure and shape as cell.zero_state.
+
+ Raises:
+ ValueError: If `initial_state_fw` is None or `initial_state_bw` is None and
+ `dtype` is not provided.
+ """
+ # Keep this code in sync with tf.nn.dynamic_rnn for compatibility.
+ with variable_scope.variable_scope(scope or 'bidirectional_rnn'):
+ # Forward direction
+ with variable_scope.variable_scope('fw') as fw_scope:
+ output_fw, output_state_fw = functional_rnn(
+ cell=cell_fw, inputs=inputs, sequence_length=sequence_length,
+ initial_state=initial_state_fw, dtype=dtype,
+ time_major=time_major, scope=fw_scope, use_tpu=use_tpu)
+ # Backward direction
+ if not time_major:
+ time_dim = 1
+ batch_dim = 0
+ else:
+ time_dim = 0
+ batch_dim = 1
+
+ def _reverse(input_, seq_lengths, seq_dim, batch_dim):
+ if seq_lengths is not None:
+ return array_ops.reverse_sequence(
+ input=input_, seq_lengths=seq_lengths,
+ seq_dim=seq_dim, batch_dim=batch_dim)
+ else:
+ # See b/69305369.
+ assert not use_tpu, (
+ 'Bidirectional with variable sequence lengths unsupported on TPU')
+ return array_ops.reverse(input_, axis=[seq_dim])
+
+ with variable_scope.variable_scope('bw') as bw_scope:
+ inputs_reverse = _reverse(
+ inputs, seq_lengths=sequence_length,
+ seq_dim=time_dim, batch_dim=batch_dim)
+ tmp, output_state_bw = functional_rnn(
+ cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length,
+ initial_state=initial_state_bw, dtype=dtype,
+ time_major=time_major, scope=bw_scope, use_tpu=use_tpu)
+
+ output_bw = _reverse(
+ tmp, seq_lengths=sequence_length,
+ seq_dim=time_dim, batch_dim=batch_dim)
+
+ outputs = (output_fw, output_bw)
+ output_states = (output_state_fw, output_state_bw)
+
+ return (outputs, output_states)
+# pylint: enable=invalid-name
diff --git a/tensorflow/contrib/recurrent/python/ops/recurrent.py b/tensorflow/contrib/recurrent/python/ops/recurrent.py
new file mode 100644
index 0000000000..fa16b82ab6
--- /dev/null
+++ b/tensorflow/contrib/recurrent/python/ops/recurrent.py
@@ -0,0 +1,720 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Recurrent computation.
+
+The main interface of this module is Recurrent().
+A recurrent computation describes an auto-regressive process, where outputs
+of one time step are fed to the output of the next time step.
+
+This module uses:
+ theta: the "weights" each RNN uses.
+ state0: the initial state of each RNN.
+ cell_fn: A python function describing RNN cell. It must has the following
+ signature:
+ cell_fn: (theta, state0, inputs) -> (state1, extras)
+ state1 is the next RNN state, extras are computed by cell_fn
+ and the library forwards extras to cell_fn's gradient function.
+ cell_grad: A python function describing the backprop gradient function
+ for the RNN cell. It must has the following signature:
+ cell_grad: (theta, state0, inputs, extras, dstate1) -> (
+ dtheta, dstate0, dinputs)
+ dstate1 is what the backprop algorithm provides representing
+ gradients of state1 w.r.t. the final loss.
+
+In this module, we handle structures of tensors for theta, state0, inputs,
+and extras. The structure is an arbitrarily nested python structure, such
+as a dictionary of named tuples.
+
+Because the computation is a left-to-right chain, a single in-place accumulator
+can be used rather than a stack. Thus a special gradient was written to reduce
+unnecessary memory usage.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import inplace_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.inplace_ops import alias_inplace_update
+from tensorflow.python.util import nest
+
+
+def _AssertIsCompatible(a, b):
+ """Checks that `a` and `b` are nested structures of the same type."""
+ # TODO(drpng): implement.
+ del a
+ del b
+
+
+def _Index(struct, index):
+ """Returns a structure with `x[index]` for each tensor `x` in the structure.
+
+ Args:
+ struct: A structure of tensors.
+ index: A scalar integer tensor. Performance is better if `index` is
+ on the host memory.
+
+ Returns:
+ A structure of tensors congruent to `struct`.
+ For each key in `ret`, `rets[key] = struct[key][index]`.
+ """
+ index = ops.convert_to_tensor(index)
+ index.get_shape().assert_has_rank(0)
+ return nest.map_structure(lambda x: x[index], struct)
+
+
+def _Update(struct_acc, struct_x, t):
+ """Updates t-th row in accumulators.
+
+ Args:
+ struct_acc: The accumulators. A structure of tensors.
+ struct_x: The new values. A structure of tensors congruent to `struct_acc`.
+ t: A scalar integer. Performance is better if `t` is on the device
+ memory.
+
+ Returns:
+ A structure of tensors. Say, ret is a returned dictionary. Then, for
+ each key, we have:
+ ret[key] = struct_acc[key];
+ ret[key][t, :] = struct_x[key]
+ """
+ to_skip_update = set()
+ acc_lst = nest.flatten(struct_acc)
+ x_lst = nest.flatten(struct_x)
+ t = math_ops.to_int32([t]) # tf.to_int32 casts on-device tensors.
+ lst = []
+ for acc, x in zip(acc_lst, x_lst):
+ if acc in to_skip_update:
+ # Until b/62105730 is fixed, we need to avoid inplace update for tensors
+ # of rank 1. could reshape to handle it, but we don't really need the
+ # values applied to these, so just skip their modification.
+ lst += [acc]
+ else:
+ lst += [alias_inplace_update(acc, t, array_ops.expand_dims(x, 0))]
+ return nest.pack_sequence_as(struct_acc, lst)
+
+
+def _SeqLenDim(struct):
+ """Returns the 0-th dim size of tensors in a structure of tensors.
+
+ This is the max sequence length according to the shape of the inputs.
+
+ Args:
+ struct: A structure of tensors. Every tensor's 0-th dim has the same size.
+
+ Returns:
+ A scalar tensor which is the size of 0-th dim of every tensors in struct.
+ """
+ xs = nest.flatten(struct)
+ assert xs
+ dim0 = array_ops.shape(xs[0])[0]
+ return dim0
+
+
+def _Flatten(struct):
+ """Flattens a structure."""
+ return nest.flatten(struct)
+
+
+def _Pack(elements, struct_template):
+ """Packs the list of tensors according to the structure.
+
+ In the event that `elements` should be a scalar, `struct_template` must
+ contain exactly one non-trivial element (for instance, `[[], {'x':elt}]`).
+
+ Args:
+ elements: Elements to be packed. A list of tensor, or a single tensor.
+ struct_template: The container structure in which to pack them.
+ Returns:
+ A python structure of the same type as `struct_template`, containing
+ `elements` as its contained elements.
+ """
+ if not nest.is_sequence(elements):
+ return nest.pack_sequence_as(struct_template, [elements])
+ return nest.pack_sequence_as(struct_template, elements)
+
+
+def _EmptyAcc(slen, struct_template):
+ """Creates a set of accumulators for tensors in structure.
+
+ Args:
+ slen: The sequence length. A scalar tensor.
+ struct_template: A structure of tensors.
+
+ Returns:
+ A structure congruent to `struct_template`. Say ret is a returned
+ dictionary. Then, `ret.key`, a tensor, has the same dtype as
+ `struct_template.key`. The tensor's shape has 1 more dimension
+ than the tensor `struct_template.key`. The extra 0-th dimension is of size
+ `slen`. E.g., if `slen=10` and `struct_template.key`'s shape is `[3, 5]`,
+ then, `ret.key`'s shape is `[10, 3, 5]`.
+ """
+
+ def _EmptyAccForTensor(tensor):
+ return inplace_ops.empty(
+ array_ops.concat([[slen], array_ops.shape(tensor)], axis=0),
+ tensor.dtype,
+ init=True)
+
+ return nest.map_structure(_EmptyAccForTensor, struct_template)
+
+
+def _EmptyLike(struct):
+ """Creates a set of empty initialized tensors.
+
+ Args:
+ struct: A structure of tensors.
+
+ Returns:
+ A struct of tensors. Each tensor has the same shape and dtype as
+ its corresponding tensor in `struct`. And each tensor is initialized.
+ """
+ return nest.map_structure(
+ lambda x: inplace_ops.empty_like(x, init=True), struct)
+
+
+def _Add(struct_x, struct_y):
+ """Adds tensors in `struct_x` with respective tensors in `struct_y`.
+
+ Args:
+ struct_x: A struct of tensors.
+ struct_y: A struct of tensors congruent to `struct_x`.
+
+ Returns:
+ A struct of tensors. Each element of the returned value
+ equals `x + y`, with corresponding values in `struct_x` and `struct_y`.
+ """
+ list_x = nest.flatten(struct_x)
+ list_y = nest.flatten(struct_y)
+ z = []
+ for x, y in zip(list_x, list_y):
+ z += [math_ops.add(x, y)]
+ return nest.pack_sequence_as(struct_x, z)
+
+
+def _Dtypes(struct):
+ """Returns all tensors' data types in a list."""
+ return [x.dtype for x in nest.flatten(struct)]
+
+
+def _ConvertNoneGradientToZeros(xs, dxs):
+ """Sanitize dxs so that None becomes zeros appropriately.
+
+ Args:
+ xs: A list of tensors.
+ dxs: A list of tensors. dxs[i] corresponds to xs[i]'s gradient.
+
+ Returns:
+ A structure same as `dxs` with `None` replaced by a zero tensor.
+ """
+ list_xs = nest.flatten(xs)
+ list_dxs = nest.flatten(dxs)
+
+ # If x does not get any backprop-ed gradient, propagate zeros.
+ rets = []
+ for (x, dx) in zip(list_xs, list_dxs):
+ if dx is None:
+ rets.append(array_ops.zeros_like(x))
+ else:
+ rets.append(dx)
+
+ return nest.pack_sequence_as(dxs, rets)
+
+
+# All structures are flattened for use internally. This is for simplicity
+# and also to use the Defun construct.
+# In the forward pass (inference), the computation is structured as follows.
+# Forward: [gradient = _Recurrent.Grad]
+# Flatten structures, create accumulators.
+# for t = 0..max_input_length:
+# Defun ForwardLoopBody:
+# Defun Fwd: flatten/pack around cell_fn
+# state1 = Fwd(inputs[t], state0)
+# acc_state += [state1]
+# Pack structures.
+# During the backward pass (backpropping the gradient from the last time
+# step to the first, through the structure), the computation is structured
+# as follows.
+# Grad:
+# Flatten structures.
+# Defun Backward:
+# Create create accumulated derivatives: d_theta, d_inputs, d_acc_state.
+# Regarding the note at the top of the file, there is only one accumulator
+# for d_theta accumulated over the whole sequence.
+# for t = max_input_length -1..0:
+# Defun BackwardLoopBody:
+# Retrieve acc_state[t] computed in the forward pass.
+# Defun Bak: flatten/back around cell_fn_grad.
+# d_state1 is d_state0 from previous step (ie next time).
+# d_acc_state[dev_t] += d_state1
+# d_theta_t, d_state0, d_inputs_t, = Bak()
+# d_inputs[dev_t] += d_inputs
+# d_theta += d_theta_t
+# d_acc_state[t] += d_state1
+# Pack structures and return.
+class _Recurrent(object):
+ """A helper class to construct a recurrent neural net."""
+
+ def __init__(self, cell_fn, cell_grad, theta, state0, inputs,
+ max_input_length, extras, use_tpu):
+ """RNN helper class.
+
+ Args:
+ cell_fn: A python function, which computes:
+ state1, extras = cell_fn(theta, state0, inputs[t, :])
+ cell_grad: A python function which computes:
+ dtheta, dstate0, dinputs[t, :] = cell_grad(
+ theta, state0, inputs[t, :], extras, dstate1)
+ theta: weights. A structure of tensors.
+ state0: initial state. A structure of tensors.
+ inputs: inputs. A structure of tensors.
+ max_input_length: None, or the maximum effective length of the input over
+ all batches. A scalar tensor.
+ extras: A structure of tensors. The 2nd return value of every
+ invocation of cell_fn is a structure of tensors with matching keys
+ and shapes of this `extras`.
+ use_tpu: A boolean indicating whether the computation is mean to
+ run on a TPU.
+ """
+ self._theta = theta
+ self._state = state0
+ self._inputs = inputs
+ self._max_input_length = self._MaybeComputeMaxInputLength(
+ inputs, max_input_length)
+ self._cell_fn = cell_fn
+ self._cell_grad = cell_grad
+ self._extras = extras
+
+ # pylint: disable=unbalanced-tuple-unpacking
+
+ # NOTE: TF Function (Fwd, Bak, ForwardLoopBody, BackwardLoopBody,
+ # Forward and Backward defined below) simply takes a list of
+ # Tensors and returns a list of Tensors. When we pass in a
+ # structure (a list of structures of Tensors), we use _Flatten to
+ # convert the structure into a list of tensor. Conversely, the
+ # following code often uses _Pack to formulate a structure from a
+ # list of tensors based on a "template".
+
+ # Wraps cell_fn in a TF Function:
+ # state1 = cell_fn(theta, state0, inputs)
+ fwd_sig = [self._theta, self._state, self._inputs]
+
+ compiled = use_tpu
+ noinline = not compiled
+ dev_t_type = dtypes.int32 if use_tpu else dtypes.int64
+
+ @function.Defun(*_Dtypes(fwd_sig))
+ def Fwd(*args):
+ (theta, state0, inputs) = _Pack(args, fwd_sig)
+ state1, extras = self._cell_fn(theta, state0, inputs)
+ assert not function.get_extra_args(), (
+ 'cell_fn is not pure with extra args: %s.' %
+ (function.get_extra_args()))
+ _AssertIsCompatible(state1, self._state)
+ _AssertIsCompatible(extras, self._extras)
+ return _Flatten([state1, extras])
+
+ # Wraps cell_fn in a TF Function as a for-loop's body.
+ #
+ # The loop state is composed of:
+ # t: The loop variable. Timestep id.
+ # dev_t: The loop variable mirrored on the device.
+ # theta: the recurrent net's weights.
+ # state0: the previous recurrent state.
+ # inputs: inputs to the recurrent net. inputs[t, :] are for the timestep t.
+ # acc_state: Each timestep's computed new state is also stashed into
+ # acc_state.
+ # acc_extras: Each timestep's computed extras is stashed into acc_extras
+ fwdloop_sig = [
+ self._theta, self._state, self._inputs, self._state, self._extras
+ ]
+
+ @function.Defun(dtypes.int32, dev_t_type, *_Dtypes(fwdloop_sig))
+ def ForwardLoopBody(*args):
+ """The body of forward loop."""
+ t, dev_t = args[0], args[1]
+ (theta, state0, inputs, acc_state, acc_extras) = _Pack(
+ args[2:], fwdloop_sig)
+ inputs_t = _Index(inputs, t) # external input at time step t.
+ fwd = Fwd(*_Flatten([theta, state0, inputs_t]))
+ state1, extras = _Pack(fwd, [self._state, self._extras])
+ # Saves state1 and extras in their accumulators.
+ acc_state = _Update(acc_state, state1, dev_t)
+ acc_extras = _Update(acc_extras, extras, dev_t)
+
+ return [math_ops.add(dev_t, 1)] + _Flatten(
+ [theta, state1, inputs, acc_state, acc_extras])
+
+ def Grad(op, *args):
+ """The python grad function for the Forward function."""
+
+ # NOTE: tf.gradient backprops None for int32/int64 while zeros
+ # for float32/float64. For consistency, we always backprop
+ # zeros.
+ args = list(args)
+ for i, dy in enumerate(args):
+ if dy is None:
+ args[i] = array_ops.zeros_like(op.outputs[i])
+ # TODO(drpng): getting the extra state here?
+ op_inputs = [x for x in op.inputs]
+ op_struct = [
+ self._theta, self._state, self._inputs, self._max_input_length,
+ self._extras
+ ]
+ (theta, state0, inputs, max_input_length, _) = _Pack(op_inputs, op_struct)
+ # acc_state and acc_extras are computed by the Forward pass and
+ # needed by the Backward pass.
+ acc_state, _, acc_extras = _Pack([x for x in op.outputs],
+ [self._state, self._state, self._extras])
+
+ # Forward computes acc_state, the final state and
+ # acc_extras. tf.gradients gives us their gradients w.r.t. the
+ # final loss. Because acc_extras are not exposed by Compute(),
+ # it has no gradients w.r.t. the final loss (i.e., by
+ # construction, it must be zeros).
+ d_acc_state, d_state1, _ = _Pack(args,
+ [self._state, self._state, self._extras])
+ return Backward(*_Flatten([
+ theta, state0, inputs, max_input_length, acc_state, acc_extras,
+ d_acc_state, d_state1
+ ]))
+
+ # Forward calls ForwardLoopBody n times. Each time computes one
+ # time step of the recurrent net.
+ forward_sig = [
+ self._theta, self._state, self._inputs, self._max_input_length,
+ self._extras
+ ]
+
+ @function.Defun(
+ *_Dtypes(forward_sig), python_grad_func=Grad, noinline=noinline)
+ def Forward(*args):
+ """Forward pass of the recurrent net."""
+ theta, state0, inputs, max_input_length, extras = _Pack(args, forward_sig)
+
+ slen_dim = _SeqLenDim(inputs)
+
+ # Creates accumulators for state0 and extras.
+ acc_state = _EmptyAcc(slen_dim, state0)
+ acc_extras = _EmptyAcc(slen_dim, extras)
+
+ dev_t = array_ops.constant(0, dtype=dev_t_type)
+ run = functional_ops.For(
+ start=0,
+ limit=max_input_length,
+ delta=1,
+ inputs=[dev_t] + _Flatten(
+ [theta, state0, inputs, acc_state, acc_extras]),
+ body=ForwardLoopBody,
+ rewrite_with_while=compiled)
+ _, state1, _, acc_state, acc_extras = _Pack(
+ run[1:],
+ [self._theta, self._state, self._inputs, self._state, self._extras])
+
+ return _Flatten([acc_state, state1, acc_extras])
+
+ # The per-step backward computes:
+ # d_theta, d_state0, d_inputs = cell_grad(
+ # theta, state0, inputs, extras, d_state1)
+ # where d_state1 is the backprop-ed gradient for state1, and
+ # extras is the computed by the forward step to facilitate the
+ # backward step.
+ bak_sig = [
+ self._theta, self._state, self._inputs, self._extras, self._state
+ ]
+
+ @function.Defun(*_Dtypes(bak_sig))
+ def Bak(*args):
+ """Backward step."""
+ (theta, state0, inputs, extras, d_state1) = _Pack(args, bak_sig)
+ (dtheta, dstate0, dinputs) = self._cell_grad(theta, state0, inputs,
+ extras, d_state1)
+ assert not function.get_extra_args(), (
+ 'cell_grad is not pure with extra args: %s.' %
+ (function.get_extra_args()))
+ _AssertIsCompatible(dtheta, self._theta)
+ _AssertIsCompatible(dstate0, self._state)
+ _AssertIsCompatible(dinputs, self._inputs)
+ return _Flatten(
+ _ConvertNoneGradientToZeros([theta, state0, inputs],
+ [dtheta, dstate0, dinputs]))
+
+ # Define defuns used by a functional_ops.If in BackwardLoopBody.
+ state_if_sig = [self._state, self._state]
+
+ @function.Defun(*_Dtypes(state_if_sig))
+ def ReturnOrigState0(*args):
+ """Returns original state0 from inputs."""
+ (_, orig_state0) = _Pack(args, state_if_sig)
+ return nest.flatten(orig_state0)
+
+ @function.Defun(*_Dtypes(state_if_sig))
+ def ReturnAccState(*args):
+ """Returns acc_state[t-1] from inputs."""
+ (acc_state, _) = _Pack(args, state_if_sig)
+ return nest.flatten(acc_state)
+
+ # Wraps cell_grad gradient function in a TF Function as a
+ # for-loop's body for the Backward pass.
+ #
+ # The loop state is composed of:
+ # t: The loop variable. Timestep id.
+ # state0: the initial state for the entire backward loop.
+ # dev_t: The loop variable mirrored on the device.
+ # theta: the recurrent net's weights.
+ # inputs: inputs to the recurrent net. inputs[t, :] are for the timestep t.
+ # acc_state: Each timestep's computed new state was stashed into
+ # acc_state by the Forward pass.
+ # acc_extras: Each timestep's computed extras was stashed into
+ # acc_extras by the Forward pass.
+ # d_theta: All timestep's gradient for theta is accumulated (added) into
+ # d_theta.
+ # d_state1: The backprop-ed gradient for the new stated computed by
+ # timestep t.
+ # d_inputs: d_inputs[t, :] is populated by the backward time step t.
+ # d_acc_state: The backprop-ed gradient for acc_state.
+ bakloop_sig = [
+ self._theta, self._state, self._inputs, self._state, self._extras,
+ self._theta, self._state, self._inputs, self._state
+ ]
+
+ @function.Defun(dtypes.int32, dev_t_type, *_Dtypes(bakloop_sig))
+ def BackwardLoopBody(*args):
+ """Backward loop body function."""
+ t, dev_t = args[0], args[1]
+ (theta, orig_state0, inputs, acc_state, acc_extras, d_theta, d_state1,
+ d_inputs, d_acc_state) = _Pack(args[2:], bakloop_sig)
+
+ # The input recurrent state for time step t is previous time step's
+ # output, or the original state0 when on time step 0.
+ state_from_acc = _Index(acc_state, math_ops.maximum(0, t - 1))
+ state0 = functional_ops.If(
+ math_ops.equal(t, array_ops.constant(0, dtypes.int32)),
+ _Flatten([state_from_acc, orig_state0]), ReturnOrigState0,
+ ReturnAccState)
+ state0 = nest.pack_sequence_as(orig_state0, state0)
+
+ # The external inputs for time step t.
+ inputs_t = _Index(inputs, t)
+ # The extras for time step t.
+ extras_t = _Index(acc_extras, t)
+
+ d_state1 = _Add(_Index(d_acc_state, t), d_state1)
+ (d_theta_t, d_state0, d_inputs_t) = _Pack(
+ Bak(*_Flatten([theta, state0, inputs_t, extras_t, d_state1])),
+ [self._theta, self._state, self._inputs])
+ d_theta = _Add(d_theta, d_theta_t)
+ d_inputs = _Update(d_inputs, d_inputs_t, dev_t)
+ return [math_ops.subtract(dev_t, 1)] + _Flatten([
+ theta, orig_state0, inputs, acc_state, acc_extras, d_theta, d_state0,
+ d_inputs, d_acc_state
+ ])
+
+ # Backward calls BackwardLoopBody n times. Each time computes the backprop
+ # for one time step of the recurrent net.
+ backward_sig = [
+ self._theta, self._state, self._inputs, self._max_input_length,
+ self._state, self._extras, self._state, self._state
+ ]
+
+ @function.Defun(*_Dtypes(backward_sig), noinline=noinline)
+ def Backward(*args):
+ """Backward pass for the recurrent net."""
+ # theta, state0, inputs are Forward's inputs.
+ # acc_state is the accumulated 1st output of Forward.
+ # acc_extras is the accumulated 2nd output of Forward.
+ # d_acc_state is the gradient for acc_state.
+ # d_state1 is the gradient for the final state computed by Forward.
+ (theta, state0, inputs, max_input_length, acc_state, acc_extras,
+ d_acc_state, d_state1) = _Pack(args, backward_sig)
+
+ # Accumulators for gradients.
+ d_theta = _EmptyLike(theta)
+ d_inputs = _EmptyLike(inputs)
+
+ # Loop backwards. Note the loop's limit is open-ended, so goes through
+ # t=0.
+ t = max_input_length - 1
+ dev_t = math_ops.to_int32(t) if use_tpu else math_ops.to_int64(t)
+ run = functional_ops.For(
+ start=t,
+ limit=-1,
+ delta=-1,
+ inputs=[dev_t] + _Flatten([
+ theta, state0, inputs, acc_state, acc_extras, d_theta, d_state1,
+ d_inputs, d_acc_state
+ ]),
+ body=BackwardLoopBody,
+ rewrite_with_while=compiled)
+
+ (theta, state0, inputs, acc_state, acc_extras, d_theta, d_state0,
+ d_inputs, d_acc_state) = _Pack(run[1:], bakloop_sig)
+
+ d_max_input_length = array_ops.constant(0, dtype=max_input_length.dtype)
+ return _Flatten(
+ [d_theta, d_state0, d_inputs, d_max_input_length, acc_extras])
+
+ self._forward = Forward
+
+ def _MaybeComputeMaxInputLength(self, inputs, max_input_length):
+ if max_input_length is not None:
+ return max_input_length
+ return math_ops.reduce_max(array_ops.shape(nest.flatten(inputs)[0])[0])
+
+ def Compute(self):
+ return _Pack(
+ self._forward(*_Flatten([
+ self._theta, self._state, self._inputs, self._max_input_length,
+ self._extras
+ ])), [self._state, self._state, self._extras])[:2]
+
+
+def _GetCellGrad(cell_fn, cell_grad):
+ """Returns the gradient function for cell_fn.
+
+ Args:
+ cell_fn: The recurrent neural net's cell function.
+ cell_grad: If not None, cell_fn's gradient function.
+
+ Returns:
+ Returns cell_grad if not None. Otherwise, assume cell_fn is a python
+ function representing the recurrent neural net's cell function, i.e.,
+ cell_fn: (theta, state0, inputs) -> (state1, extra)
+ returns its default gradient python function, i.e.,
+ cell_grad: (theta, state0, inputs, extras, dstate1) -> (
+ dtheta, dstate0, dinputs)
+ """
+
+ if cell_grad:
+ return cell_grad
+
+ def CellGrad(theta, state0, inputs, extras, dstate1):
+ """Default gradient function for cell_fn."""
+ # NOTE: The default grad function recomputes the forward
+ # function and does not take advantage of 'extras' returned by
+ # the forward function.
+ del extras
+ state1, extras = cell_fn(theta, state0, inputs)
+ ys = _Flatten([state1])
+ xs = _Flatten([theta, state0, inputs])
+ grad_ys = _Flatten([dstate1])
+ grads = gradients_impl.gradients(ys=ys, xs=xs, grad_ys=grad_ys)
+ return _ConvertNoneGradientToZeros([theta, state0, inputs],
+ _Pack(grads, [theta, state0, inputs]))
+
+ return CellGrad
+
+
+def _IsSingleTimeStep(inputs, max_input_length):
+ """Returns True only if the time dimension of inputs is 1."""
+ if not isinstance(max_input_length, ops.Tensor):
+ return max_input_length == 1
+ for x in nest.flatten(inputs):
+ if x.shape.dims is None or x.shape[0].value != 1:
+ return False
+ return True
+
+
+def Recurrent(theta,
+ state0,
+ inputs,
+ cell_fn,
+ cell_grad=None,
+ extras=None,
+ max_input_length=None,
+ use_tpu=False):
+ """Compute a recurrent neural net.
+
+ Roughly, Recurrent() computes the following:
+ state = state0
+ for t in inputs' sequence length:
+ state = cell_fn(theta, state, inputs[t, :])
+ accumulate_state[t, :] = state
+ return accumulate_state, state
+
+ theta, state, inputs are all structures of tensors.
+
+ inputs[t, :] means taking a slice out from every tensor in the inputs.
+
+ accumulate_state[t, :] = state means that we stash every tensor in
+ 'state' into a slice of the corresponding tensor in
+ accumulate_state.
+
+ cell_fn is a python callable computing (building up a TensorFlow
+ graph) the recurrent neural network's one forward step. Two calls of
+ cell_fn must describe two identical computations.
+
+ By construction, Recurrent()'s backward computation does not access
+ any intermediate values computed by cell_fn during forward
+ computation. We may extend Recurrent() to support that by taking a
+ customized backward function of cell_fn.
+
+ Args:
+ theta: weights. A structure of tensors.
+ state0: initial state. A structure of tensors.
+ inputs: inputs. A structure of tensors.
+ cell_fn: A python function, which computes:
+ state1, extras = cell_fn(theta, state0, inputs[t, :])
+ cell_grad: A python function which computes:
+ dtheta, dstate0, dinputs[t, :] = cell_grad(
+ theta, state0, inputs[t, :], extras, dstate1)
+ extras: A structure of tensors. The 2nd return value of every
+ invocation of cell_fn is a structure of tensors with matching keys
+ and shapes of this `extras`.
+ max_input_length: maximum length of effective input. This is used to
+ truncate the computation if the inputs have been allocated to a
+ larger size. A scalar tensor.
+ use_tpu: whether or not we are on TPU.
+
+ Returns:
+ accumulate_state and the final state.
+ """
+ if cell_grad is None and _IsSingleTimeStep(inputs, max_input_length):
+ # The seqlen length is staticly known as 1. Hence, we just need to
+ # call cell_fn once without putting it into a loop.
+ inputs = nest.map_structure(lambda x: array_ops.squeeze(x, axis=0), inputs)
+ state1, _ = cell_fn(theta, state0, inputs)
+ acc_state = nest.map_structure(lambda x: array_ops.expand_dims(x, axis=0),
+ state1)
+ return acc_state, state1
+
+ # If cell_grad is not given, derives the gradient function from
+ # cell_fn.
+ cell_grad = _GetCellGrad(cell_fn, cell_grad)
+
+ if extras is None:
+ # Derives 'extras' so that we can allocate extras' accumulator.
+ _, extras = cell_fn(theta, state0, _Index(inputs, 0))
+ extras = nest.map_structure(array_ops.zeros_like, extras)
+ else:
+ _, actual = cell_fn(theta, state0, _Index(inputs, 0))
+ _AssertIsCompatible(extras, actual)
+
+ return _Recurrent(
+ cell_fn=cell_fn,
+ cell_grad=cell_grad,
+ theta=theta,
+ state0=state0,
+ inputs=inputs,
+ max_input_length=max_input_length,
+ extras=extras,
+ use_tpu=use_tpu).Compute()
diff --git a/tensorflow/experimental_api.py b/tensorflow/contrib/recurrent/python/recurrent_api.py
index 63a8aa9cb1..ffe1dcf7dc 100644
--- a/tensorflow/experimental_api.py
+++ b/tensorflow/contrib/recurrent/python/recurrent_api.py
@@ -1,4 +1,4 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,26 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
-# Bring in all of the public TensorFlow interface into this
-# module.
+"""Recurrent computations library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-# pylint: disable=g-bad-import-order
-from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
-# pylint: disable=wildcard-import
-from tensorflow.tools.api.generator.api import * # pylint: disable=redefined-builtin
-# pylint: enable=wildcard-import
-
-from tensorflow.python.util.lazy_loader import LazyLoader
-contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
-del LazyLoader
-
-from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
-app.flags = flags # pylint: disable=undefined-variable
+# pylint: disable=unused-import
+from tensorflow.contrib.recurrent.python.ops import functional_bidirectional_rnn
+from tensorflow.contrib.recurrent.python.ops import functional_rnn
+from tensorflow.contrib.recurrent.python.ops import Recurrent
+# pylint: enable=unused-import
del absolute_import
del division
diff --git a/tensorflow/contrib/rpc/BUILD b/tensorflow/contrib/rpc/BUILD
new file mode 100644
index 0000000000..597f18c771
--- /dev/null
+++ b/tensorflow/contrib/rpc/BUILD
@@ -0,0 +1,13 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+ name = "rpc",
+ srcs = [
+ "__init__.py",
+ ],
+ deps = ["//tensorflow/contrib/rpc/python/ops:rpc_op_py"],
+)
diff --git a/tensorflow/contrib/rpc/__init__.py b/tensorflow/contrib/rpc/__init__.py
new file mode 100644
index 0000000000..c65c1a05de
--- /dev/null
+++ b/tensorflow/contrib/rpc/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ops and modules related to RPC.
+
+@@rpc
+@@try_rpc
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.rpc.python.ops.rpc_op import rpc
+from tensorflow.contrib.rpc.python.ops.rpc_op import try_rpc
+
+from tensorflow.python.util.all_util import remove_undocumented
+remove_undocumented(__name__)
diff --git a/tensorflow/contrib/rpc/python/ops/BUILD b/tensorflow/contrib/rpc/python/ops/BUILD
new file mode 100644
index 0000000000..84d2a1832f
--- /dev/null
+++ b/tensorflow/contrib/rpc/python/ops/BUILD
@@ -0,0 +1,24 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
+
+py_library(
+ name = "rpc_op_py",
+ srcs = ["rpc_op.py"],
+ deps = [
+ ":gen_rpc_op_py",
+ "//tensorflow/python:framework_ops",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "gen_rpc_op_py",
+ out = "gen_rpc_op.py",
+ deps = [
+ "//tensorflow/core:rpc_ops_op_lib",
+ ],
+)
diff --git a/tensorflow/contrib/rpc/python/ops/rpc_op.py b/tensorflow/contrib/rpc/python/ops/rpc_op.py
new file mode 100644
index 0000000000..e1b6c41137
--- /dev/null
+++ b/tensorflow/contrib/rpc/python/ops/rpc_op.py
@@ -0,0 +1,26 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+# pylint: disable=wildcard-import,unused-import
+"""RPC communication."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.rpc.python.ops.gen_rpc_op import rpc
+from tensorflow.contrib.rpc.python.ops.gen_rpc_op import try_rpc
+from tensorflow.python.framework import ops
+ops.NotDifferentiable("Rpc")
+ops.NotDifferentiable("TryRpc")
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index 4de09dd988..2f4a76720d 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -47,6 +47,7 @@ py_library(
":tpu_lib",
":tpu_py",
"//tensorflow/contrib/summary:summary_ops",
+ "//tensorflow/contrib/training:training_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 3f2db548ac..a1690dadff 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -25,6 +25,8 @@ from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.framework import device as pydev
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -56,6 +58,7 @@ _NOT_IMPLEMENTED_OPS = set([
_MAX_WARNING_LINES = 5
_TPU_REPLICATE_ATTR = "_tpu_replicate"
+_OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation"
def _tpu_system_device_name(job):
@@ -121,8 +124,16 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
outside the replicated computation.
"""
- def __init__(self, name):
+ def __init__(self, name, num_replicas):
super(TPUReplicateContext, self).__init__()
+ self._num_replicas = num_replicas
+ self._outer_device_function_stack = None
+ self._oc_dev_fn_stack = None
+ self._outside_compilation_cluster = None
+ self._outside_compilation_counter = 0
+ self._in_gradient_colocation = None
+ self._gradient_colocation_stack = []
+ self._host_compute_core = []
self._name = name
self._unsupported_ops = []
@@ -136,6 +147,143 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
logging.warning("... and %d more" %
(len(self._unsupported_ops) - _MAX_WARNING_LINES))
+ def EnterGradientColocation(self, op, gradient_uid):
+ if op is not None:
+ self._gradient_colocation_stack.append(op)
+ if not self._outside_compilation_cluster:
+ try:
+ outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR)
+ if self._in_gradient_colocation:
+ raise NotImplementedError(
+ "Cannot nest gradient colocation operations outside compilation"
+ )
+ if gradient_uid == "__unsupported__":
+ raise NotImplementedError(
+ "No gradient_uid calling gradient within outside_compilation")
+ # When we take the gradient of an op X in an
+ # outside_compilation cluster C in a forward computation we
+ # would like to put the ops corresponding to the gradient of
+ # X into a new outside_compilation cluster C'. However, if
+ # we take the gradient of X twice, the second one should get
+ # yet another new outside_compilation cluster C''.
+ #
+ # The mechanism we adopt is to use a 'root_cluster' which is
+ # the cluster that X was in before we took gradients, and a
+ # 'gradient_uid' which is different for every invocation of
+ # gradients, and put the gradient of X in cluster
+ # 'root_cluster.gradient_uid'.
+ #
+ # When the gradient code adds multiple Ops, it asks them to
+ # be colocated either with the original Op X, or with one of
+ # the preceding Ops that was added to the gradient. In other
+ # words, we want to detect the case where we are colocating
+ # with an Op that is in cluster root_cluster.gradient_uid
+ # and put the new Op in that same cluster if the
+ # gradient_uid is the same (the case that we are in the same
+ # invocation of gradients, and just adding new Ops to the
+ # cluster); and in a different cluster if the gradient_uids
+ # are different (the case that we are in a new invocation of
+ # gradients, taking the gradient of a previously-computed
+ # gradient).
+ self._in_gradient_colocation = op
+ parts = outside_attr.split(".")
+ if len(parts) > 1:
+ uid = parts[-1]
+ if uid == gradient_uid:
+ # Keep using the same cluster
+ cluster = outside_attr
+ else:
+ # We're taking the gradient of a gradient so make a new
+ # cluster attr, adding a new '.uid' on the end to
+ # preserve the invariant that the gradient_uid is the
+ # suffix after the last '.' in the attr.
+ cluster = outside_attr + "." + gradient_uid
+ else:
+ # We're taking the gradient of an Op in the forward pass, so
+ # make a new cluster combining the Op's cluster and the
+ # gradient id.
+ cluster = outside_attr + "." + gradient_uid
+ self._EnterOutsideCompilationScope(cluster=cluster)
+ except ValueError:
+ # The attr was not present: do nothing.
+ pass
+
+ def ExitGradientColocation(self, op, gradient_uid):
+ if op is not None:
+ if not self._gradient_colocation_stack:
+ raise errors.InternalError(
+ op.node_def, op,
+ "Badly nested gradient colocation: empty stack when popping Op " +
+ op.name)
+ last_op = self._gradient_colocation_stack.pop()
+ if op is last_op:
+ if op is self._in_gradient_colocation:
+ self._in_gradient_colocation = None
+ self._ExitOutsideCompilationScope()
+ else:
+ raise errors.InternalError(
+ op.node_def, op, "Badly nested gradient colocation, expected " +
+ last_op + ", got " + op.name)
+
+ def _EnterOutsideCompilationScope(self, cluster=None):
+
+ class FakeOp(object):
+ """A helper class to determine the current device.
+
+ Supports only the device set/get methods needed to run the
+ graph's _apply_device_function method.
+ """
+
+ def __init__(self):
+ self._device = ""
+
+ @property
+ def device(self):
+ return self._device
+
+ def _set_device(self, device):
+ self._device = device.to_string()
+
+ if self._outside_compilation_cluster:
+ raise NotImplementedError("Cannot nest outside_compilation clusters")
+ if cluster:
+ self._outside_compilation_cluster = cluster
+ else:
+ self._outside_compilation_cluster = str(self._outside_compilation_counter)
+ self._outside_compilation_counter += 1
+ graph = ops.get_default_graph()
+ fake_op = FakeOp()
+ graph._apply_device_functions(fake_op) # pylint: disable=protected-access
+ device = pydev.DeviceSpec.from_string(fake_op.device)
+ if (device.device_type == "TPU_REPLICATED_CORE" and
+ device.device_index is not None):
+ self._host_compute_core.append(self._outside_compilation_cluster + ":" +
+ str(device.device_index))
+ self._oc_dev_fn_stack = graph._device_function_stack # pylint: disable=protected-access
+ graph._device_function_stack = self._outer_device_function_stack # pylint: disable=protected-access
+
+ def _ExitOutsideCompilationScope(self):
+ if not self._outside_compilation_cluster:
+ raise NotImplementedError(
+ "Attempted to exit outside_compilation scope when not in scope")
+ self._outside_compilation_cluster = None
+ graph = ops.get_default_graph()
+ graph._device_function_stack = self._oc_dev_fn_stack # pylint: disable=protected-access
+
+ def Enter(self):
+ if not self._outer_device_function_stack:
+ # Capture the device function stack at the time of first entry
+ # since that is the stack that will be used outside_compilation.
+ graph = ops.get_default_graph()
+ self._outer_device_function_stack = list(graph._device_function_stack) # pylint: disable=protected-access
+ super(TPUReplicateContext, self).Enter()
+
+ def Exit(self):
+ super(TPUReplicateContext, self).Exit()
+
+ def HostComputeCore(self):
+ return self._host_compute_core
+
def AddOp(self, op):
self._AddOpInternal(op)
@@ -157,9 +305,16 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
raise ValueError("TPU computations cannot be nested")
op._set_attr(_TPU_REPLICATE_ATTR,
attr_value_pb2.AttrValue(s=compat.as_bytes(self._name)))
- # pylint: enable=protected-access
- op.graph.prevent_feeding(op)
- op.graph.prevent_fetching(op)
+ if self._outside_compilation_cluster:
+ op._set_attr(
+ _OUTSIDE_COMPILATION_ATTR,
+ attr_value_pb2.AttrValue(
+ s=compat.as_bytes(self._outside_compilation_cluster)))
+ if self._num_replicas > 1 or not self._outside_compilation_cluster:
+ # Prevent feeding or fetching anything that is being compiled,
+ # and any replicated outside_compilation Op.
+ op.graph.prevent_feeding(op)
+ op.graph.prevent_fetching(op)
def AddValue(self, val):
result = val
@@ -181,6 +336,45 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
return None
+def outside_compilation(computation, args=None):
+ """Builds part of a computation outside any current TPU replicate scope.
+
+ Args:
+ computation: A Python function that builds the computation to
+ place on the host.
+ args: Inputs to pass to computation.
+ Returns:
+ The Tensors returned by computation.
+ """
+ graph = ops.get_default_graph()
+
+ # If we are in a TPUReplicateContext, signal that we are now
+ # outside_compilation
+ initial_context = graph._get_control_flow_context() # pylint: disable=protected-access
+ context = initial_context
+ while context:
+ if isinstance(context, TPUReplicateContext):
+ context._EnterOutsideCompilationScope() # pylint: disable=protected-access
+ context = context.outer_context
+
+ retval = computation(*args)
+
+ # If we are in a TPUReplicateContext, signal that we are no longer
+ # outside_compilation
+ final_context = graph._get_control_flow_context() # pylint: disable=protected-access
+ if initial_context is not final_context:
+ raise NotImplementedError(
+ "Control-flow context cannot be different at start and end of an "
+ "outside_compilation scope")
+ context = initial_context
+ while context:
+ if isinstance(context, TPUReplicateContext):
+ context._ExitOutsideCompilationScope() # pylint: disable=protected-access
+ context = context.outer_context
+
+ return retval
+
+
def replicate(computation,
inputs=None,
infeed_queue=None,
@@ -280,7 +474,8 @@ def replicate(computation,
computation_inputs.append(
tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i)))
- context = TPUReplicateContext(name=graph.unique_name("cluster"))
+ context = TPUReplicateContext(
+ name=graph.unique_name("cluster"), num_replicas=num_replicas)
try:
context.Enter()
@@ -361,6 +556,12 @@ def replicate(computation,
finally:
context.report_unsupported_operations()
context.Exit()
+ host_compute_core = context.HostComputeCore()
+
+ if host_compute_core:
+ attr_value = attr_value_pb2.AttrValue()
+ attr_value.list.s.extend([compat.as_bytes(x) for x in host_compute_core])
+ metadata._set_attr("host_compute_core", attr_value) # pylint: disable=protected-access
# Fan-out: Builds a TPUReplicatedOutput node for each output.
outputs = [tpu_ops.tpu_replicated_output(output_tensors[i], num_replicas,
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 6834600b79..1332108d04 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -38,6 +38,8 @@ from tensorflow.contrib.tpu.python.tpu import tpu_context
from tensorflow.contrib.tpu.python.tpu import tpu_feed
from tensorflow.contrib.tpu.python.tpu import training_loop
from tensorflow.contrib.tpu.python.tpu import util as util_lib
+from tensorflow.contrib.training.python.training import hparam
+from tensorflow.core.framework import variable_pb2
from tensorflow.core.framework.summary_pb2 import Summary
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
@@ -53,6 +55,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
@@ -73,6 +76,8 @@ _ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop'
_BATCH_SIZE_KEY = 'batch_size'
_CROSS_REPLICA_SUM_OP = 'CrossReplicaSum'
_ONE_GIGABYTE = 1024 * 1024 * 1024
+_TPU_ENQUEUE_OPS = '_tpu_enqueue_ops'
+_TPU_TRAIN_OP = '_tpu_train_op'
_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY]
@@ -85,6 +90,13 @@ _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY]
_WRAP_INPUT_FN_INTO_WHILE_LOOP = False
+ops.register_proto_function(
+ '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR),
+ proto_type=variable_pb2.VariableDef,
+ to_proto=resource_variable_ops._to_proto_fn, # pylint: disable=protected-access
+ from_proto=resource_variable_ops._from_proto_fn) # pylint: disable=protected-access
+
+
def _create_global_step(graph):
graph = graph or ops.get_default_graph()
if training.get_global_step(graph) is not None:
@@ -1297,7 +1309,10 @@ class _ModelFnWrapper(object):
batch_size_for_model_fn = self._ctx.batch_size_for_model_fn
if batch_size_for_model_fn is not None:
- params[_BATCH_SIZE_KEY] = batch_size_for_model_fn
+ if isinstance(params, hparam.HParams):
+ params.add_hparam(_BATCH_SIZE_KEY, batch_size_for_model_fn)
+ else:
+ params[_BATCH_SIZE_KEY] = batch_size_for_model_fn
estimator_spec = self._model_fn(features=features, **kwargs)
if (self._ctx.is_running_on_cpu(is_export_mode) and
@@ -1936,7 +1951,10 @@ class TPUEstimator(estimator_lib.Estimator):
# input_fn for use_tpu=True/False.
batch_size_for_input_fn = ctx.batch_size_for_input_fn
if batch_size_for_input_fn is not None:
- kwargs['params'][_BATCH_SIZE_KEY] = batch_size_for_input_fn
+ if isinstance(kwargs['params'], hparam.HParams):
+ kwargs['params'].add_hparam(_BATCH_SIZE_KEY, batch_size_for_input_fn)
+ else:
+ kwargs['params'][_BATCH_SIZE_KEY] = batch_size_for_input_fn
# For export_savedmodel, input_fn is never passed to Estimator. So,
# `is_export_mode` must be False.
@@ -2006,6 +2024,13 @@ class TPUEstimator(estimator_lib.Estimator):
enqueue_ops, dequeue_fn, input_hooks, run_infeed_loop_on_coordinator = (
input_holders.generate_infeed_enqueue_ops_and_dequeue_fn())
+ graph = ops.get_default_graph()
+ for enqueue_op in enqueue_ops:
+ if isinstance(enqueue_op, list):
+ graph.get_collection_ref(_TPU_ENQUEUE_OPS).extend(enqueue_op)
+ else:
+ graph.add_to_collection(_TPU_ENQUEUE_OPS, enqueue_op)
+
if mode == model_fn_lib.ModeKeys.TRAIN:
loss, host_call, scaffold = (
_train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn))
@@ -2036,11 +2061,14 @@ class TPUEstimator(estimator_lib.Estimator):
# Validate the TPU training graph to catch basic errors
_validate_tpu_training_graph()
+ train_op = control_flow_ops.group(*update_ops)
+ graph.add_to_collection(_TPU_TRAIN_OP, train_op)
+
return model_fn_lib.EstimatorSpec(
mode,
loss=loss,
training_hooks=hooks,
- train_op=control_flow_ops.group(*update_ops),
+ train_op=train_op,
scaffold=scaffold)
if mode == model_fn_lib.ModeKeys.EVAL:
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_test.py
index 336d8260c3..c3882b8a27 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_test.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_test.py
@@ -37,7 +37,7 @@ class TPUContextTest(test.TestCase):
def testIsInContext(self):
"""Test that control_flow_util can check that we're in a TPU context."""
z1 = array_ops.identity(1)
- context = tpu.TPUReplicateContext(b"context")
+ context = tpu.TPUReplicateContext(b"context", 1)
context.Enter()
z2 = array_ops.identity(1)
context.Exit()
diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py
index 95e051e3b5..185f70a86d 100644
--- a/tensorflow/contrib/training/python/training/hparam.py
+++ b/tensorflow/contrib/training/python/training/hparam.py
@@ -630,6 +630,9 @@ class HParams(object):
def __str__(self):
return str(sorted(self.values().items()))
+ def __repr__(self):
+ return '%s(%s)' % (type(self).__name__, self.__str__())
+
@staticmethod
def _get_kind_name(param_type, is_list):
"""Returns the field name given parameter type and is_list.
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 7d5ae1c5b5..c5ca421ced 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -256,7 +256,7 @@ tf_proto_library(
visibility = ["//visibility:public"],
)
-# Minimal lib to detect plafrom
+# Minimal lib to detect platform
cc_library(
name = "lib_platform",
hdrs = [
@@ -264,6 +264,55 @@ cc_library(
],
)
+PLATFORM_BASE_HDRS = [
+ "platform/logging.h",
+ "platform/macros.h",
+ "platform/types.h",
+ "platform/cpu_info.h",
+]
+
+PLATFORM_OTHER_HDRS = [
+ "platform/abi.h",
+ "platform/stacktrace.h",
+ "platform/stacktrace_handler.h",
+ "platform/context.h",
+ "platform/cpu_feature_guard.h",
+ "platform/dynamic_annotations.h",
+ "platform/env.h",
+ "platform/env_time.h",
+ "platform/file_system.h",
+ "platform/file_system_helper.h",
+ "platform/fingerprint.h",
+ "platform/init_main.h",
+ "platform/mem.h",
+ "platform/mutex.h",
+ "platform/net.h",
+ "platform/notification.h",
+ "platform/null_file_system.h",
+ "platform/prefetch.h",
+ "platform/profile_utils/clock_cycle_profiler.h",
+ "platform/profile_utils/cpu_utils.h",
+ "platform/protobuf.h",
+ "platform/strong_hash.h",
+ "platform/subprocess.h",
+ "platform/thread_annotations.h",
+]
+
+# Smaller platform libraries that don't depend on "lib" or "lib_internal".
+cc_library(
+ name = "platform_base",
+ srcs = glob([
+ "platform/*/integral_types.h",
+ "platform/*/logging.h",
+ "platform/*/cpu_info.h",
+ ]),
+ hdrs = PLATFORM_BASE_HDRS,
+ deps = [
+ ":lib_platform",
+ "//tensorflow/core/platform/default/build_config:base",
+ ],
+)
+
# Minimal lib so that tools used for mobile compilation
# don't have to depend on lib/platformlib.
cc_library(
@@ -294,7 +343,8 @@ cc_library(
# tf_cc_test and tf_cc_binary will include the necessary symbols.
cc_library(
name = "lib",
- hdrs = [
+ hdrs = PLATFORM_BASE_HDRS +
+ PLATFORM_OTHER_HDRS + [
"lib/bfloat16/bfloat16.h",
"lib/core/arena.h",
"lib/core/bitmap.h",
@@ -341,34 +391,6 @@ cc_library(
"lib/strings/str_util.h",
"lib/strings/strcat.h",
"lib/strings/stringprintf.h",
- "platform/abi.h",
- "platform/context.h",
- "platform/cpu_feature_guard.h",
- "platform/cpu_info.h",
- "platform/dynamic_annotations.h",
- "platform/env.h",
- "platform/env_time.h",
- "platform/file_system.h",
- "platform/file_system_helper.h",
- "platform/fingerprint.h",
- "platform/init_main.h",
- "platform/logging.h",
- "platform/macros.h",
- "platform/mem.h",
- "platform/mutex.h",
- "platform/net.h",
- "platform/notification.h",
- "platform/null_file_system.h",
- "platform/prefetch.h",
- "platform/profile_utils/clock_cycle_profiler.h",
- "platform/profile_utils/cpu_utils.h",
- "platform/protobuf.h",
- "platform/stacktrace.h",
- "platform/strong_hash.h",
- "platform/subprocess.h",
- "platform/thread_annotations.h",
- "platform/types.h",
- "platform/windows/cpu_info.h",
],
visibility = ["//visibility:public"],
deps = [
@@ -415,6 +437,17 @@ cc_library(
],
)
+# Libraries that will eventually be moved into lib/core
+# Note that stringpiece_test can't be place here yet, because we are
+# required to use tf_cc_test, and that rule will change / into _
+cc_library(
+ name = "core_stringpiece",
+ srcs = ["lib/core/stringpiece.cc"],
+ hdrs = ["lib/core/stringpiece.h"],
+ copts = tf_copts(),
+ deps = [":platform_base"],
+)
+
# Test support library needed for all tests
# This is currently public, but may be made internal in the
# future. Try to avoid depending on it.
@@ -442,6 +475,27 @@ cc_library(
] + tf_additional_test_deps(),
)
+# Testing libraries - lite versions that don't depend on all of "lib" or
+# "lib_internal". Instead, they only need a much smaller set of support
+# libraries such as ":platform_base" and ":core_stringpiece".
+cc_library(
+ name = "test_lite",
+ testonly = 1,
+ srcs = [
+ "platform/test.cc",
+ ],
+ hdrs = [
+ "platform/test.h",
+ "platform/test_benchmark.h",
+ ],
+ copts = tf_copts(),
+ deps = [
+ ":lib_platform",
+ ":platform_base",
+ "//tensorflow/core/platform/default/build_config:gtest",
+ ],
+)
+
# This build rule (along with :framework_internal, :lib, and :lib_internal)
# purposefully omits the definitions of many declared symbols, which are
# included in //tensorflow:libtensorflow_framework.so. Using tf_cc_test and tf_cc_binary
@@ -499,7 +553,6 @@ tf_cuda_library(
"framework/type_index.h",
"framework/type_traits.h",
"framework/types.h",
- "framework/visitable_allocator.h",
"public/version.h",
"util/activation_mode.h",
"util/bcast.h",
@@ -633,10 +686,13 @@ tf_gen_op_libs(
"boosted_trees_ops",
"candidate_sampling_ops",
"checkpoint_ops",
+ "collective_ops",
"control_flow_ops",
"ctc_ops",
"data_flow_ops",
"dataset_ops",
+ "decode_proto_ops",
+ "encode_proto_ops",
"function_ops",
"functional_ops",
"image_ops",
@@ -653,6 +709,7 @@ tf_gen_op_libs(
"random_ops",
"remote_fused_graph_ops",
"resource_variable_ops",
+ "rpc_ops",
"scoped_allocator_ops",
"sdca_ops",
"set_ops",
@@ -746,11 +803,14 @@ cc_library(
":boosted_trees_ops_op_lib",
":candidate_sampling_ops_op_lib",
":checkpoint_ops_op_lib",
+ ":collective_ops_op_lib",
":control_flow_ops_op_lib",
":ctc_ops_op_lib",
":cudnn_rnn_ops_op_lib",
":data_flow_ops_op_lib",
":dataset_ops_op_lib",
+ ":decode_proto_ops_op_lib",
+ ":encode_proto_ops_op_lib",
":function_ops_op_lib",
":functional_ops_op_lib",
":image_ops_op_lib",
@@ -767,6 +827,7 @@ cc_library(
":random_ops_op_lib",
":remote_fused_graph_ops_op_lib",
":resource_variable_ops_op_lib",
+ ":rpc_ops_op_lib",
":scoped_allocator_ops_op_lib",
":script_ops_op_lib",
":sdca_ops_op_lib",
@@ -888,11 +949,14 @@ cc_library(
"//tensorflow/core/kernels:boosted_trees_ops",
"//tensorflow/core/kernels:candidate_sampler_ops",
"//tensorflow/core/kernels:checkpoint_ops",
+ "//tensorflow/core/kernels:collective_ops",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:ctc_ops",
"//tensorflow/core/kernels:cudnn_rnn_kernels",
"//tensorflow/core/kernels:data_flow",
"//tensorflow/core/kernels:dataset_ops",
+ "//tensorflow/core/kernels:decode_proto_op",
+ "//tensorflow/core/kernels:encode_proto_op",
"//tensorflow/core/kernels:fake_quant_ops",
"//tensorflow/core/kernels:function_ops",
"//tensorflow/core/kernels:functional_ops",
@@ -914,6 +978,7 @@ cc_library(
"//tensorflow/core/kernels:remote_fused_graph_ops",
"//tensorflow/core/kernels:required",
"//tensorflow/core/kernels:resource_variable_ops",
+ "//tensorflow/core/kernels:rpc_op",
"//tensorflow/core/kernels:scoped_allocator_ops",
"//tensorflow/core/kernels:sdca_ops",
"//tensorflow/core/kernels:set_kernels",
@@ -1641,6 +1706,7 @@ cc_library(
exclude = [
"**/*test*",
"framework/variant.cc",
+ "lib/core/stringpiece.cc",
"lib/hash/crc32c_accelerate.cc",
"lib/gif/**/*",
"lib/jpeg/**/*",
@@ -1654,6 +1720,7 @@ cc_library(
) + tf_additional_lib_srcs(
exclude = [
"**/*test*",
+ "lib/core/stringpiece.cc",
"platform/**/cuda.h",
"platform/**/cuda_libdevice_path.cc",
"platform/**/stream_executor.h",
@@ -1674,6 +1741,7 @@ cc_library(
":lib_hash_crc32c_accelerate_internal",
":lib_proto_parsing",
":abi",
+ ":core_stringpiece",
"//third_party/eigen3",
"//tensorflow/core/platform/default/build_config:platformlib",
"@snappy",
@@ -1905,7 +1973,6 @@ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [
"framework/tracking_allocator.h", # only needed for tests
"framework/unique_tensor_references.h",
"framework/variant.h",
- "framework/visitable_allocator.h",
"platform/variant_coding.h",
"util/command_line_flags.h",
"util/env_var.h",
@@ -2183,17 +2250,17 @@ tf_cuda_library(
CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/allocator_retry.h",
"common_runtime/bfc_allocator.h",
+ "common_runtime/buf_rendezvous.h",
+ "common_runtime/build_graph_options.h",
"common_runtime/collective_executor_mgr.h",
"common_runtime/collective_param_resolver_local.h",
"common_runtime/collective_rma_local.h",
- "common_runtime/device_resolver_local.h",
- "common_runtime/buf_rendezvous.h",
- "common_runtime/build_graph_options.h",
"common_runtime/constant_folding.h",
"common_runtime/copy_tensor.h",
"common_runtime/costmodel_manager.h",
"common_runtime/debugger_state_interface.h",
"common_runtime/device_factory.h",
+ "common_runtime/device_resolver_local.h",
"common_runtime/device_set.h",
"common_runtime/dma_helper.h",
"common_runtime/eigen_thread_pool.h",
@@ -2204,6 +2271,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/mkl_cpu_allocator.h",
"common_runtime/optimization_registry.h",
"common_runtime/pending_counts.h",
+ "common_runtime/placer.h",
"common_runtime/process_util.h",
"common_runtime/profile_handler.h",
"common_runtime/renamed_device.h",
@@ -2212,10 +2280,11 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/scoped_allocator.h",
"common_runtime/scoped_allocator_mgr.h",
"common_runtime/session_factory.h",
- "common_runtime/placer.h",
+ "common_runtime/single_threaded_cpu_device.h",
"common_runtime/stats_publisher_interface.h",
"common_runtime/step_stats_collector.h",
"common_runtime/threadpool_device.h",
+ "common_runtime/visitable_allocator.h",
"graph/gradients.h",
"graph/quantize_training.h",
] + if_mkl(["graph/mkl_graph_util.h"])
@@ -2617,6 +2686,23 @@ cc_library(
alwayslink = 1,
)
+# This is the lite version of a main() for tests. It does not include any
+# support for reporting benchmark results when running on TPUs.
+cc_library(
+ name = "test_lite_main",
+ testonly = 1,
+ srcs = ["platform/test_main.cc"],
+ copts = tf_copts(),
+ deps = [
+ ":core_stringpiece",
+ ":lib_platform",
+ ":stacktrace_handler",
+ ":test_lite",
+ "//tensorflow/core/platform/default/build_config:test_lite_main",
+ ],
+ alwayslink = 1,
+)
+
tf_cc_tests(
name = "low_level_library_tests",
size = "small",
diff --git a/tensorflow/core/api_def/base_api/api_def_CollectiveBcastRecv.pbtxt b/tensorflow/core/api_def/base_api/api_def_CollectiveBcastRecv.pbtxt
new file mode 100644
index 0000000000..88049bca36
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_CollectiveBcastRecv.pbtxt
@@ -0,0 +1,5 @@
+op {
+ graph_op_name: "CollectiveBcastRecv"
+ visibility: SKIP
+ summary: "Receives a tensor value broadcast from another device."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_CollectiveBcastSend.pbtxt b/tensorflow/core/api_def/base_api/api_def_CollectiveBcastSend.pbtxt
new file mode 100644
index 0000000000..7ff70f5b17
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_CollectiveBcastSend.pbtxt
@@ -0,0 +1,5 @@
+op {
+ graph_op_name: "CollectiveBcastSend"
+ visibility: SKIP
+ summary: "Broadcasts a tensor value to one or more other devices."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_CollectiveReduce.pbtxt b/tensorflow/core/api_def/base_api/api_def_CollectiveReduce.pbtxt
new file mode 100644
index 0000000000..10d9771d46
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_CollectiveReduce.pbtxt
@@ -0,0 +1,5 @@
+op {
+ graph_op_name: "CollectiveReduce"
+ visibility: SKIP
+ summary: "Mutually reduces multiple tensors of identical type and shape."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt
new file mode 100644
index 0000000000..c8152f53c4
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt
@@ -0,0 +1,116 @@
+op {
+ graph_op_name: "DecodeProtoV2"
+ in_arg {
+ name: "bytes"
+ description: <<END
+Tensor of serialized protos with shape `batch_shape`.
+END
+ }
+ out_arg {
+ name: "sizes"
+ description: <<END
+Tensor of int32 with shape `[batch_shape, len(field_names)]`.
+Each entry is the number of values found for the corresponding field.
+Optional fields may have 0 or 1 values.
+END
+ }
+ out_arg {
+ name: "values"
+ description: <<END
+List of tensors containing values for the corresponding field.
+`values[i]` has datatype `output_types[i]`
+and shape `[batch_shape, max(sizes[...,i])]`.
+END
+ }
+ attr {
+ name: "message_type"
+ description: <<END
+Name of the proto message type to decode.
+END
+ }
+ attr {
+ name: "field_names"
+ description: <<END
+List of strings containing proto field names.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+List of TF types to use for the respective field in field_names.
+END
+ }
+ attr {
+ name: "descriptor_source"
+ description: <<END
+Either the special value `local://` or a path to a file containing
+a serialized `FileDescriptorSet`.
+END
+ }
+ attr {
+ name: "message_format"
+ description: <<END
+Either `binary` or `text`.
+END
+ }
+ attr {
+ name: "sanitize"
+ description: <<END
+Whether to sanitize the result or not.
+END
+ }
+ summary: <<END
+The op extracts fields from a serialized protocol buffers message into tensors.
+END
+ description: <<END
+The `decode_proto` op extracts fields from a serialized protocol buffers
+message into tensors. The fields in `field_names` are decoded and converted
+to the corresponding `output_types` if possible.
+
+A `message_type` name must be provided to give context for the field
+names. The actual message descriptor can be looked up either in the
+linked-in descriptor pool or a filename provided by the caller using
+the `descriptor_source` attribute.
+
+Each output tensor is a dense tensor. This means that it is padded to
+hold the largest number of repeated elements seen in the input
+minibatch. (The shape is also padded by one to prevent zero-sized
+dimensions). The actual repeat counts for each example in the
+minibatch can be found in the `sizes` output. In many cases the output
+of `decode_proto` is fed immediately into tf.squeeze if missing values
+are not a concern. When using tf.squeeze, always pass the squeeze
+dimension explicitly to avoid surprises.
+
+For the most part, the mapping between Proto field types and
+TensorFlow dtypes is straightforward. However, there are a few
+special cases:
+
+- A proto field that contains a submessage or group can only be converted
+to `DT_STRING` (the serialized submessage). This is to reduce the
+complexity of the API. The resulting string can be used as input
+to another instance of the decode_proto op.
+
+- TensorFlow lacks support for unsigned integers. The ops represent uint64
+types as a `DT_INT64` with the same twos-complement bit pattern
+(the obvious way). Unsigned int32 values can be represented exactly by
+specifying type `DT_INT64`, or using twos-complement if the caller
+specifies `DT_INT32` in the `output_types` attribute.
+
+The `descriptor_source` attribute selects a source of protocol
+descriptors to consult when looking up `message_type`. This may be a
+filename containing a serialized `FileDescriptorSet` message,
+or the special value `local://`, in which case only descriptors linked
+into the code will be searched; the filename can be on any filesystem
+accessible to TensorFlow.
+
+You can build a `descriptor_source` file using the `--descriptor_set_out`
+and `--include_imports` options to the protocol compiler `protoc`.
+
+The `local://` database only covers descriptors linked into the
+code via C++ libraries, not Python imports. You can link in a proto descriptor
+by creating a cc_library target with alwayslink=1.
+
+Both binary and text proto serializations are supported, and can be
+chosen using the `format` attribute.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_DeepCopy.pbtxt b/tensorflow/core/api_def/base_api/api_def_DeepCopy.pbtxt
new file mode 100644
index 0000000000..fe0fc3823f
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_DeepCopy.pbtxt
@@ -0,0 +1,15 @@
+op {
+ graph_op_name: "DeepCopy"
+ in_arg {
+ name: "x"
+ description: "The source tensor of type `T`."
+ }
+ out_arg {
+ name: "y"
+ description: <<END
+ y: A `Tensor` of type `T`. A copy of `x`. Guaranteed that `y`
+ is not an alias of `x`.
+END
+ }
+ summary: "Makes a copy of `x`."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_Empty.pbtxt b/tensorflow/core/api_def/base_api/api_def_Empty.pbtxt
new file mode 100644
index 0000000000..746f561e92
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_Empty.pbtxt
@@ -0,0 +1,23 @@
+op {
+ graph_op_name: "Empty"
+ in_arg {
+ name: "shape"
+ description: "1-D. Represents the shape of the output tensor."
+ }
+ attr {
+ name: "init"
+ description:
+ "If True, initialize the returned tensor with the default value "
+ "of dtype. Otherwise, the implementation is free not to initialize"
+ "the tensor's content."
+ }
+ out_arg {
+ name: "output"
+ description: "A `Tensor` of type `T`."
+ }
+ summary: <<END
+Creates a tensor with the given shape.
+
+This operation creates a tensor of `shape` and `dtype`.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_EncodeProto.pbtxt b/tensorflow/core/api_def/base_api/api_def_EncodeProto.pbtxt
new file mode 100644
index 0000000000..fdbe47f236
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_EncodeProto.pbtxt
@@ -0,0 +1,81 @@
+op {
+ graph_op_name: "EncodeProto"
+ in_arg {
+ name: "sizes"
+ description: <<END
+Tensor of int32 with shape `[batch_shape, len(field_names)]`.
+END
+ }
+ in_arg {
+ name: "values"
+ description: <<END
+List of tensors containing values for the corresponding field.
+END
+ }
+ out_arg {
+ name: "bytes"
+ description: <<END
+Tensor of serialized protos with shape `batch_shape`.
+END
+ }
+ attr {
+ name: "message_type"
+ description: <<END
+Name of the proto message type to decode.
+END
+ }
+ attr {
+ name: "field_names"
+ description: <<END
+List of strings containing proto field names.
+END
+ }
+ attr {
+ name: "Tinput_types"
+ description: <<END
+The input types.
+END
+ }
+ summary: <<END
+The op serializes protobuf messages provided in the input tensors.
+END
+ description: <<END
+The types of the tensors in `values` must match the schema for the
+fields specified in `field_names`. All the tensors in `values` must
+have a common shape prefix, *batch_shape*.
+
+The `sizes` tensor specifies repeat counts for each field. The repeat
+count (last dimension) of a each tensor in `values` must be greater
+than or equal to corresponding repeat count in `sizes`.
+
+A `message_type` name must be provided to give context for the field
+names. The actual message descriptor can be looked up either in the
+linked-in descriptor pool or a filename provided by the caller using
+the `descriptor_source` attribute.
+
+The `descriptor_source` attribute selects a source of protocol
+descriptors to consult when looking up `message_type`. This may be a
+filename containing a serialized `FileDescriptorSet` message,
+or the special value `local://`, in which case only descriptors linked
+into the code will be searched; the filename can be on any filesystem
+accessible to TensorFlow.
+
+You can build a `descriptor_source` file using the `--descriptor_set_out`
+and `--include_imports` options to the protocol compiler `protoc`.
+
+The `local://` database only covers descriptors linked into the
+code via C++ libraries, not Python imports. You can link in a proto descriptor
+by creating a cc_library target with alwayslink=1.
+
+There are a few special cases in the value mapping:
+
+Submessage and group fields must be pre-serialized as TensorFlow strings.
+
+TensorFlow lacks support for unsigned int64s, so they must be
+represented as `tf.int64` with the same twos-complement bit pattern
+(the obvious way).
+
+Unsigned int32 values can be represented exactly with `tf.int64`, or
+with sign wrapping if the input is of type `tf.int32`.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_InplaceAdd.pbtxt b/tensorflow/core/api_def/base_api/api_def_InplaceAdd.pbtxt
new file mode 100644
index 0000000000..3654286cc3
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_InplaceAdd.pbtxt
@@ -0,0 +1,28 @@
+op {
+ graph_op_name: "InplaceAdd"
+ in_arg {
+ name: "x"
+ description: "A `Tensor` of type T."
+ }
+ in_arg {
+ name: "i"
+ description: "A vector. Indices into the left-most dimension of `x`."
+ }
+ in_arg {
+ name: "v"
+ description:
+ "A `Tensor` of type T. Same dimension sizes as x except "
+ "the first dimension, which must be the same as i's size."
+ }
+ out_arg {
+ name: "y"
+ description:
+ "A `Tensor` of type T. An alias of `x`. The content "
+ "of `y` is undefined if there are duplicates in `i`."
+ }
+ summary: <<END
+ Adds v into specified rows of x.
+
+ Computes y = x; y[i, :] += v; return y.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_InplaceSub.pbtxt b/tensorflow/core/api_def/base_api/api_def_InplaceSub.pbtxt
new file mode 100644
index 0000000000..a9480b4a38
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_InplaceSub.pbtxt
@@ -0,0 +1,28 @@
+op {
+ graph_op_name: "InplaceSub"
+ in_arg {
+ name: "x"
+ description: "A `Tensor` of type T."
+ }
+ in_arg {
+ name: "i"
+ description: "A vector. Indices into the left-most dimension of `x`."
+ }
+ in_arg {
+ name: "v"
+ description:
+ "A `Tensor` of type T. Same dimension sizes as x except "
+ "the first dimension, which must be the same as i's size."
+ }
+ out_arg {
+ name: "y"
+ description:
+ "A `Tensor` of type T. An alias of `x`. The content "
+ "of `y` is undefined if there are duplicates in `i`."
+ }
+ summary: <<END
+ Subtracts `v` into specified rows of `x`.
+
+ Computes y = x; y[i, :] -= v; return y.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_InplaceUpdate.pbtxt b/tensorflow/core/api_def/base_api/api_def_InplaceUpdate.pbtxt
new file mode 100644
index 0000000000..2fcd3659dc
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_InplaceUpdate.pbtxt
@@ -0,0 +1,28 @@
+op {
+ graph_op_name: "InplaceUpdate"
+ in_arg {
+ name: "x"
+ description: "A tensor of type `T`."
+ }
+ in_arg {
+ name: "i"
+ description: "A vector. Indices into the left-most dimension of `x`."
+ }
+ in_arg {
+ name: "v"
+ description:
+ "A `Tensor` of type T. Same dimension sizes as x except "
+ "the first dimension, which must be the same as i's size."
+ }
+ out_arg {
+ name: "y"
+ description:
+ "A `Tensor` of type T. An alias of `x`. The content "
+ "of `y` is undefined if there are duplicates in `i`."
+ }
+ summary: <<END
+ Updates specified rows with values in `v`.
+
+ Computes `x[i, :] = v; return x`.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_Rpc.pbtxt b/tensorflow/core/api_def/base_api/api_def_Rpc.pbtxt
new file mode 100644
index 0000000000..344ef191fd
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_Rpc.pbtxt
@@ -0,0 +1,108 @@
+op {
+ graph_op_name: "Rpc"
+ in_arg {
+ name: "address"
+ description: <<END
+`0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server.
+If this tensor has more than 1 element, then multiple parallel rpc requests
+are sent. This argument broadcasts with `method` and `request`.
+END
+ }
+ in_arg {
+ name: "method"
+ description: <<END
+`0-D` or `1-D`. The method address on the RPC server.
+If this tensor has more than 1 element, then multiple parallel rpc requests
+are sent. This argument broadcasts with `address` and `request`.
+END
+ }
+ in_arg {
+ name: "request"
+ description: <<END
+`0-D` or `1-D`. Serialized proto strings: the rpc request argument.
+If this tensor has more than 1 element, then multiple parallel rpc requests
+are sent. This argument broadcasts with `address` and `method`.
+END
+ }
+ out_arg {
+ name: "response"
+ description: <<END
+Same shape as `request`. Serialized proto strings: the rpc responses.
+END
+ }
+ attr {
+ name: "protocol"
+ description: <<END
+RPC protocol to use. Empty string means use the default protocol.
+Options include 'grpc'.
+END
+ }
+ attr {
+ name: "fail_fast"
+ description: <<END
+`boolean`. If `true` (default), then failures to connect
+(i.e., the server does not immediately respond) cause an RPC failure.
+END
+ }
+ attr {
+ name: "timeout_in_ms"
+ description: <<END
+`int`. If `0` (default), then the kernel will run the RPC
+request and only time out if the RPC deadline passes or the session times out.
+If this value is greater than `0`, then the op will raise an exception if
+the RPC takes longer than `timeout_in_ms`.
+END
+ }
+ summary: <<END
+Perform batches of RPC requests.
+END
+ description: <<END
+This op asynchronously performs either a single RPC request, or a batch
+of requests. RPC requests are defined by three main parameters:
+
+ - `address` (the host+port or BNS address of the request)
+ - `method` (the RPC method name for the request)
+ - `request` (the serialized proto string, or vector of strings,
+ of the RPC request argument).
+
+For example, if you have an RPC service running on port localhost:2345,
+and its interface is configured with the following proto declaration:
+
+```
+service MyService {
+ rpc MyMethod(MyRequestProto) returns (MyResponseProto) {
+ }
+};
+```
+
+then call this op with arguments:
+
+```
+address = "localhost:2345"
+method = "MyService/MyMethod"
+```
+
+The `request` tensor is a string tensor representing serialized `MyRequestProto`
+strings; and the output string tensor `response` will have the same shape
+and contain (upon successful completion) corresponding serialized
+`MyResponseProto` strings.
+
+For example, to send a single, empty, `MyRequestProto`, call
+this op with `request = ""`. To send 5 **parallel** empty requests,
+call this op with `request = ["", "", "", "", ""]`.
+
+More generally, one can create a batch of `MyRequestProto` serialized protos
+from regular batched tensors using the `encode_proto` op, and convert
+the response `MyResponseProto` serialized protos to batched tensors
+using the `decode_proto` op.
+
+**NOTE** Working with serialized proto strings is faster than instantiating
+actual proto objects in memory, so no performance degradation is expected
+compared to writing custom kernels for this workflow.
+
+If the connection fails or the remote worker returns an error
+status, the op reraises this exception locally.
+
+See the `TryRpc` op if you prefer to handle RPC failures manually in the graph.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_TryRpc.pbtxt b/tensorflow/core/api_def/base_api/api_def_TryRpc.pbtxt
new file mode 100644
index 0000000000..bded00e83c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_TryRpc.pbtxt
@@ -0,0 +1,123 @@
+op {
+ graph_op_name: "TryRpc"
+ in_arg {
+ name: "address"
+ description: <<END
+`0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server.
+If this tensor has more than 1 element, then multiple parallel rpc requests
+are sent. This argument broadcasts with `method` and `request`.
+END
+ }
+ in_arg {
+ name: "method"
+ description: <<END
+`0-D` or `1-D`. The method address on the RPC server.
+If this tensor has more than 1 element, then multiple parallel rpc requests
+are sent. This argument broadcasts with `address` and `request`.
+END
+ }
+ in_arg {
+ name: "request"
+ description: <<END
+`0-D` or `1-D`. Serialized proto strings: the rpc request argument.
+If this tensor has more than 1 element, then multiple parallel rpc requests
+are sent. This argument broadcasts with `address` and `method`.
+END
+ }
+ out_arg {
+ name: "response"
+ description: <<END
+Same shape as `request`. Serialized proto strings: the rpc responses.
+END
+ }
+ out_arg {
+ name: "status_code"
+ description: <<END
+Same shape as `request`. Values correspond to tensorflow Status enum codes.
+END
+ }
+ out_arg {
+ name: "status_message"
+ description: <<END
+Same shape as `request`. Values correspond to Status messages
+returned from the RPC calls.
+END
+ }
+ attr {
+ name: "protocol"
+ description: <<END
+RPC protocol to use. Empty string means use the default protocol.
+Options include 'grpc'.
+END
+ }
+ attr {
+ name: "fail_fast"
+ description: <<END
+`boolean`. If `true` (default), then failures to connect
+(i.e., the server does not immediately respond) cause an RPC failure.
+END
+ }
+ attr {
+ name: "timeout_in_ms"
+ description: <<END
+`int`. If `0` (default), then the kernel will run the RPC
+request and only time out if the RPC deadline passes or the session times out.
+If this value is greater than `0`, then the op will raise an exception if
+the RPC takes longer than `timeout_in_ms`.
+END
+ }
+ summary: <<END
+Perform batches of RPC requests.
+END
+ description: <<END
+This op asynchronously performs either a single RPC request, or a batch
+of requests. RPC requests are defined by three main parameters:
+
+ - `address` (the host+port or BNS address of the request)
+ - `method` (the method name for the request)
+ - `request` (the serialized proto string, or vector of strings,
+ of the RPC request argument).
+
+For example, if you have an RPC service running on port localhost:2345,
+and its interface is configured with the following proto declaration:
+
+```
+service MyService {
+ rpc MyMethod(MyRequestProto) returns (MyResponseProto) {
+ }
+};
+```
+
+then call this op with arguments:
+
+```
+address = "localhost:2345"
+method = "MyService/MyMethod"
+```
+
+The `request` tensor is a string tensor representing serialized `MyRequestProto`
+strings; and the output string tensor `response` will have the same shape
+and contain (upon successful completion) corresponding serialized
+`MyResponseProto` strings.
+
+For example, to send a single, empty, `MyRequestProto`, call
+this op with `request = ""`. To send 5 **parallel** empty requests,
+call this op with `request = ["", "", "", "", ""]`.
+
+More generally, one can create a batch of `MyRequestProto` serialized protos
+from regular batched tensors using the `encode_proto` op, and convert
+the response `MyResponseProto` serialized protos to batched tensors
+using the `decode_proto` op.
+
+**NOTE** Working with serialized proto strings is faster than instantiating
+actual proto objects in memory, so no performance degradation is expected
+compared to writing custom kernels for this workflow.
+
+Unlike the standard `Rpc` op, if the connection fails or the remote worker
+returns an error status, this op does **not** reraise the exception.
+Instead, the `status_code` and `status_message` entry for the corresponding RPC
+call is set with the error returned from the RPC call. The `response` tensor
+will contain valid response values for those minibatch entries whose RPCs did
+not fail; the rest of the entries will have empty strings.
+END
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_DeepCopy.pbtxt b/tensorflow/core/api_def/python_api/api_def_DeepCopy.pbtxt
new file mode 100644
index 0000000000..2d5ed2b432
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_DeepCopy.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "DeepCopy"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Empty.pbtxt b/tensorflow/core/api_def/python_api/api_def_Empty.pbtxt
new file mode 100644
index 0000000000..0b863520e9
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Empty.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "Empty"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_InplaceAdd.pbtxt b/tensorflow/core/api_def/python_api/api_def_InplaceAdd.pbtxt
new file mode 100644
index 0000000000..390e3bbf97
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_InplaceAdd.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "InplaceAdd"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_InplaceSub.pbtxt b/tensorflow/core/api_def/python_api/api_def_InplaceSub.pbtxt
new file mode 100644
index 0000000000..af9634f9b2
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_InplaceSub.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "InplaceSub"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_InplaceUpdate.pbtxt b/tensorflow/core/api_def/python_api/api_def_InplaceUpdate.pbtxt
new file mode 100644
index 0000000000..5fa9d778ea
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_InplaceUpdate.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "InplaceUpdate"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_SlideDataset.pbtxt b/tensorflow/core/api_def/python_api/api_def_SlideDataset.pbtxt
new file mode 100644
index 0000000000..867116c5da
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_SlideDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "SlideDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/common_runtime/bfc_allocator.h b/tensorflow/core/common_runtime/bfc_allocator.h
index e34945dd48..b8e773503c 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.h
+++ b/tensorflow/core/common_runtime/bfc_allocator.h
@@ -23,7 +23,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/allocator_retry.h"
-#include "tensorflow/core/framework/visitable_allocator.h"
+#include "tensorflow/core/common_runtime/visitable_allocator.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/macros.h"
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index f95cecfc66..8ddc9958b2 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -81,6 +81,7 @@ class DirectSessionMinusAXTest : public ::testing::Test {
test::FillValues<float>(&a_tensor, a_values);
Node* a = test::graph::Constant(&graph, a_tensor);
a->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
+ a_ = a->name();
Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
test::FillValues<float>(&x_tensor, {1, 1});
@@ -97,12 +98,18 @@ class DirectSessionMinusAXTest : public ::testing::Test {
y_neg_ = y_neg->name();
y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
+ Node* z = test::graph::Unary(&graph, "Identity", y_neg);
+ z_ = z->name();
+ z->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
+
test::graph::ToGraphDef(&graph, &def_);
}
+ string a_;
string x_;
string y_;
string y_neg_;
+ string z_;
GraphDef def_;
};
@@ -133,7 +140,6 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_Callable) {
auto session = CreateSession();
ASSERT_TRUE(session != nullptr);
TF_ASSERT_OK(session->Create(def_));
- std::vector<std::pair<string, Tensor>> inputs;
// Run the test twice to ensure that the Make/Run/Release cycle is hermetic.
for (int i = 0; i < 2; ++i) {
@@ -175,6 +181,159 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_Callable) {
}
}
+TEST_F(DirectSessionMinusAXTest, TestTensorConnection) {
+ Initialize({3, 2, -1, 0});
+ auto session = CreateSession();
+ ASSERT_TRUE(session != nullptr);
+ TF_ASSERT_OK(session->Create(def_));
+
+ {
+ // Directly wire the output of node a to the output of node y, making the
+ // callable graph into "Neg(a);".
+ CallableOptions callable_options;
+ TensorConnection* c = callable_options.add_tensor_connection();
+ c->set_from_tensor(a_ + ":0");
+ c->set_to_tensor(y_ + ":0");
+ callable_options.add_fetch(y_neg_ + ":0");
+
+ Session::CallableHandle handle;
+ TF_ASSERT_OK(session->MakeCallable(callable_options, &handle));
+ std::vector<Tensor> outputs;
+ TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
+ ASSERT_EQ(1, outputs.size());
+ auto mat = outputs[0].matrix<float>();
+ ASSERT_TRUE(outputs[0].IsInitialized());
+ EXPECT_FLOAT_EQ(-3.0, mat(0, 0));
+ EXPECT_FLOAT_EQ(-2.0, mat(0, 1));
+ EXPECT_FLOAT_EQ(1.0, mat(1, 0));
+ EXPECT_FLOAT_EQ(0.0, mat(1, 1));
+ TF_ASSERT_OK(session->ReleaseCallable(handle));
+ }
+
+ {
+ // Directly wire the output of node a to the output of node y, making the
+ // callable graph into "Neg(a);"; also fetch the result of a.
+ CallableOptions callable_options;
+ TensorConnection* c = callable_options.add_tensor_connection();
+ c->set_from_tensor(a_ + ":0");
+ c->set_to_tensor(y_ + ":0");
+ callable_options.add_fetch(a_ + ":0");
+ callable_options.add_fetch(y_neg_ + ":0");
+
+ Session::CallableHandle handle;
+ TF_ASSERT_OK(session->MakeCallable(callable_options, &handle));
+ std::vector<Tensor> outputs;
+ TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
+ ASSERT_EQ(2, outputs.size());
+ auto mat_a = outputs[0].matrix<float>();
+ ASSERT_TRUE(outputs[0].IsInitialized());
+ EXPECT_FLOAT_EQ(3.0, mat_a(0, 0));
+ EXPECT_FLOAT_EQ(2.0, mat_a(0, 1));
+ EXPECT_FLOAT_EQ(-1.0, mat_a(1, 0));
+ EXPECT_FLOAT_EQ(0.0, mat_a(1, 1));
+
+ auto mat_y_neg = outputs[1].matrix<float>();
+ ASSERT_TRUE(outputs[1].IsInitialized());
+ EXPECT_FLOAT_EQ(-3.0, mat_y_neg(0, 0));
+ EXPECT_FLOAT_EQ(-2.0, mat_y_neg(0, 1));
+ EXPECT_FLOAT_EQ(1.0, mat_y_neg(1, 0));
+ EXPECT_FLOAT_EQ(0.0, mat_y_neg(1, 1));
+ TF_ASSERT_OK(session->ReleaseCallable(handle));
+ }
+
+ {
+ // Wire the output of "Neg(Matmul(a, x))" to the output of "a",
+ // creating an invalid cycle.
+ CallableOptions callable_options;
+ TensorConnection* c = callable_options.add_tensor_connection();
+ c->set_from_tensor(y_ + ":0");
+ c->set_to_tensor(a_ + ":0");
+ callable_options.add_fetch(y_ + ":0");
+
+ Session::CallableHandle handle;
+ Status s = session->MakeCallable(callable_options, &handle);
+ EXPECT_TRUE(errors::IsInvalidArgument(s));
+ EXPECT_TRUE(
+ str_util::StrContains(s.error_message(), "would create a cycle"));
+ }
+
+ {
+ // Attempt to wire a non-existent node to a node that does exist.
+ CallableOptions callable_options;
+ TensorConnection* c = callable_options.add_tensor_connection();
+ c->set_from_tensor("unknown_node:0");
+ c->set_to_tensor(y_ + ":0");
+ callable_options.add_fetch(y_ + ":0");
+
+ Session::CallableHandle handle;
+ Status s = session->MakeCallable(callable_options, &handle);
+ EXPECT_TRUE(errors::IsInvalidArgument(s));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(), "unknown node"));
+ }
+
+ {
+ // Attempt to wire a non-existent output from a node that does
+ // exist to another node.
+ CallableOptions callable_options;
+ TensorConnection* c = callable_options.add_tensor_connection();
+ c->set_from_tensor(a_ + ":17");
+ c->set_to_tensor(y_ + ":0");
+ callable_options.add_fetch(y_ + ":0");
+
+ Session::CallableHandle handle;
+ Status s = session->MakeCallable(callable_options, &handle);
+ EXPECT_TRUE(errors::IsInvalidArgument(s));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(), "unknown edge"));
+ }
+
+ {
+ // Attempt to wire a tensor to a node that doesn't exist.
+ CallableOptions callable_options;
+ TensorConnection* c = callable_options.add_tensor_connection();
+ c->set_from_tensor(a_ + ":0");
+ c->set_to_tensor("unknown_node:0");
+ callable_options.add_fetch(y_ + ":0");
+
+ Session::CallableHandle handle;
+ Status s = session->MakeCallable(callable_options, &handle);
+ EXPECT_TRUE(errors::IsNotFound(s));
+ EXPECT_TRUE(
+ str_util::StrContains(s.error_message(), "unable to find feed output"));
+ }
+
+ {
+ // Attempt to wire two tensors to the same tensor.
+ CallableOptions callable_options;
+ TensorConnection* c1 = callable_options.add_tensor_connection();
+ c1->set_from_tensor(a_ + ":0");
+ c1->set_to_tensor(y_neg_ + ":0");
+ TensorConnection* c2 = callable_options.add_tensor_connection();
+ c2->set_from_tensor(x_ + ":0");
+ c2->set_to_tensor(y_neg_ + ":0");
+ callable_options.add_fetch(z_ + ":0");
+
+ Session::CallableHandle handle;
+ Status s = session->MakeCallable(callable_options, &handle);
+ EXPECT_TRUE(errors::IsInvalidArgument(s));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(), "fed more than once"));
+ }
+
+ {
+ // Attempt to wire a tensor to a tensor that is also being fed.
+ CallableOptions callable_options;
+ TensorConnection* c = callable_options.add_tensor_connection();
+ c->set_from_tensor(a_ + ":0");
+ c->set_to_tensor(y_ + ":0");
+ callable_options.add_feed(y_ + ":0");
+ callable_options.add_fetch(y_neg_ + ":0");
+
+ Session::CallableHandle handle;
+ Status s = session->MakeCallable(callable_options, &handle);
+ EXPECT_TRUE(errors::IsInvalidArgument(s));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(), "fed more than once"));
+ }
+}
+
TEST_F(DirectSessionMinusAXTest, TestFeed) {
Initialize({1, 2, 3, 4});
auto session = CreateSession();
@@ -654,6 +813,55 @@ TEST(DirectSessionTest, MultipleFeedTest_Callable) {
EXPECT_TRUE(str_util::StrContains(s.error_message(), "fed more than once"));
}
+TEST(DirectSessionTest, TestTensorConnectionUseTwice) {
+ Graph graph(OpRegistry::Global());
+
+ Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
+ test::FillValues<float>(&a_tensor, {1.0, 2.0, 3.0, 4.0});
+ Node* a = test::graph::Constant(&graph, a_tensor);
+
+ Tensor dummy_tensor(DT_FLOAT, TensorShape({1}));
+ test::FillValues<float>(&dummy_tensor, {-1.0});
+
+ Node* left = test::graph::Constant(&graph, dummy_tensor);
+ Node* right = test::graph::Constant(&graph, dummy_tensor);
+
+ // y = A * x
+ Node* y = test::graph::Add(&graph, left, right);
+
+ GraphDef def;
+ test::graph::ToGraphDef(&graph, &def);
+
+ auto session = CreateSession();
+ ASSERT_TRUE(session != nullptr);
+ TF_ASSERT_OK(session->Create(def));
+
+ CallableOptions callable_options;
+ // Directly wire the output of node a to the outputs of nodes left
+ // and right, making the callable graph into "a + a;".
+ TensorConnection* c_left = callable_options.add_tensor_connection();
+ c_left->set_from_tensor(a->name() + ":0");
+ c_left->set_to_tensor(left->name() + ":0");
+ TensorConnection* c_right = callable_options.add_tensor_connection();
+ c_right->set_from_tensor(a->name() + ":0");
+ c_right->set_to_tensor(right->name() + ":0");
+
+ callable_options.add_fetch(y->name() + ":0");
+
+ Session::CallableHandle handle;
+ TF_ASSERT_OK(session->MakeCallable(callable_options, &handle));
+ std::vector<Tensor> outputs;
+ TF_ASSERT_OK(session->RunCallable(handle, {}, &outputs, nullptr));
+ ASSERT_EQ(1, outputs.size());
+ auto mat = outputs[0].matrix<float>();
+ ASSERT_TRUE(outputs[0].IsInitialized());
+ EXPECT_FLOAT_EQ(2.0, mat(0, 0));
+ EXPECT_FLOAT_EQ(4.0, mat(0, 1));
+ EXPECT_FLOAT_EQ(6.0, mat(1, 0));
+ EXPECT_FLOAT_EQ(8.0, mat(1, 1));
+ TF_ASSERT_OK(session->ReleaseCallable(handle));
+}
+
TEST(DirectSessionTest, FetchMultipleTimes) {
Graph g(OpRegistry::Global());
Tensor seven_tensor(DT_INT32, TensorShape());
diff --git a/tensorflow/core/common_runtime/eigen_thread_pool.h b/tensorflow/core/common_runtime/eigen_thread_pool.h
index c6f13c6a11..ddd627fb20 100644
--- a/tensorflow/core/common_runtime/eigen_thread_pool.h
+++ b/tensorflow/core/common_runtime/eigen_thread_pool.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMMON_RUNTIME_EIGEN_THREAD_POOL_H_
#define TENSORFLOW_COMMON_RUNTIME_EIGEN_THREAD_POOL_H_
+#define EIGEN_USE_THREADS
+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/lib/core/threadpool.h"
diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
index 0a586344cc..208697361d 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
@@ -19,7 +19,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
-#include "tensorflow/core/framework/visitable_allocator.h"
+#include "tensorflow/core/common_runtime/visitable_allocator.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
index 63ed0b8be1..b0ca7e3109 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
@@ -85,8 +85,8 @@ GPUDebugAllocator::~GPUDebugAllocator() { delete base_allocator_; }
void* GPUDebugAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
num_bytes += (2 * MASK_BYTES);
-
void* allocated_ptr = base_allocator_->AllocateRaw(alignment, num_bytes);
+ if (allocated_ptr == nullptr) return allocated_ptr;
// Return the pointer after the header
void* rv = static_cast<char*>(allocated_ptr) + MASK_BYTES;
@@ -102,11 +102,13 @@ void* GPUDebugAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
return rv;
}
void GPUDebugAllocator::DeallocateRaw(void* ptr) {
- CHECK(CheckHeader(ptr)) << "before_mask has been overwritten";
- CHECK(CheckFooter(ptr)) << "after_mask has been overwritten";
+ if (ptr != nullptr) {
+ CHECK(CheckHeader(ptr)) << "before_mask has been overwritten";
+ CHECK(CheckFooter(ptr)) << "after_mask has been overwritten";
- // Backtrack to the beginning of the header.
- ptr = static_cast<void*>(static_cast<char*>(ptr) - MASK_BYTES);
+ // Backtrack to the beginning of the header.
+ ptr = static_cast<void*>(static_cast<char*>(ptr) - MASK_BYTES);
+ }
// Deallocate the memory
base_allocator_->DeallocateRaw(ptr);
}
@@ -168,10 +170,12 @@ GPUNanResetAllocator::~GPUNanResetAllocator() { delete base_allocator_; }
void* GPUNanResetAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
void* allocated_ptr = base_allocator_->AllocateRaw(alignment, num_bytes);
+ if (allocated_ptr == nullptr) return allocated_ptr;
// Initialize the buffer to Nans
size_t req_size = base_allocator_->RequestedSize(allocated_ptr);
- std::vector<float> nans(req_size / sizeof(float), std::nanf(""));
+ std::vector<float> nans((req_size + sizeof(float) - 1) / sizeof(float),
+ std::nanf(""));
gpu::DeviceMemory<float> nan_ptr{
gpu::DeviceMemoryBase{static_cast<float*>(allocated_ptr), req_size}};
@@ -182,13 +186,16 @@ void* GPUNanResetAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
return allocated_ptr;
}
void GPUNanResetAllocator::DeallocateRaw(void* ptr) {
- // Reset the buffer to Nans
- size_t req_size = base_allocator_->RequestedSize(ptr);
- std::vector<float> nans(req_size / sizeof(float), std::nanf(""));
- gpu::DeviceMemory<float> nan_ptr{
- gpu::DeviceMemoryBase{static_cast<float*>(ptr), req_size}};
- if (!stream_exec_->SynchronousMemcpy(&nan_ptr, &nans[0], req_size)) {
- LOG(ERROR) << "Could not initialize to NaNs";
+ if (ptr != nullptr) {
+ // Reset the buffer to Nans
+ size_t req_size = base_allocator_->RequestedSize(ptr);
+ std::vector<float> nans((req_size + sizeof(float) - 1) / sizeof(float),
+ std::nanf(""));
+ gpu::DeviceMemory<float> nan_ptr{
+ gpu::DeviceMemoryBase{static_cast<float*>(ptr), req_size}};
+ if (!stream_exec_->SynchronousMemcpy(&nan_ptr, &nans[0], req_size)) {
+ LOG(ERROR) << "Could not initialize to NaNs";
+ }
}
// Deallocate the memory
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
index 0db08dc975..adce3a8436 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
@@ -21,7 +21,7 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
-#include "tensorflow/core/framework/visitable_allocator.h"
+#include "tensorflow/core/common_runtime/visitable_allocator.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator.h b/tensorflow/core/common_runtime/gpu/pool_allocator.h
index 38d669ea07..91ce830df8 100644
--- a/tensorflow/core/common_runtime/gpu/pool_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/pool_allocator.h
@@ -24,7 +24,7 @@ limitations under the License.
#include <map>
#include <memory>
#include <vector>
-#include "tensorflow/core/framework/visitable_allocator.h"
+#include "tensorflow/core/common_runtime/visitable_allocator.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc
index 2f17af273f..6a3e6906a3 100644
--- a/tensorflow/core/common_runtime/graph_execution_state.cc
+++ b/tensorflow/core/common_runtime/graph_execution_state.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <string>
#include <unordered_set>
+#include <utility>
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
@@ -27,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
+#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/subgraph.h"
@@ -237,6 +239,50 @@ void GraphExecutionState::RestoreStatefulNodes(Graph* graph) {
}
}
+namespace {
+
+class TensorConnectionPruneRewrite : public subgraph::PruneRewrite {
+ public:
+ TensorConnectionPruneRewrite(const string* endpoint_name,
+ NodeBuilder::NodeOut from_tensor)
+ : subgraph::PruneRewrite(endpoint_name, nullptr /* device_info */),
+ from_tensor_(std::move(from_tensor)) {}
+
+ Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor,
+ Node** out_node) override {
+ Status s;
+ auto check_no_cycle_fn = [this, feed_tensor, &s](Node* n) {
+ if (n == feed_tensor.node) {
+ s.Update(errors::InvalidArgument(
+ "Requested Tensor connection between nodes \"",
+ feed_tensor.node->name(), "\" and \"", from_tensor_.node->name(),
+ "\" would create a cycle."));
+ }
+ };
+ ReverseDFSFrom(*g, {from_tensor_.node}, std::move(check_no_cycle_fn),
+ nullptr);
+ TF_RETURN_IF_ERROR(s);
+
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(strings::StrCat("_identity_", feed_tensor.node->name(), "_",
+ feed_tensor.index),
+ "Identity")
+ .Input(from_tensor_)
+ .Attr("T",
+ BaseType(from_tensor_.node->output_type(from_tensor_.index)))
+ .Finalize(g, out_node));
+
+ (*out_node)->set_assigned_device_name(
+ feed_tensor.node->assigned_device_name());
+ return Status::OK();
+ }
+
+ private:
+ NodeBuilder::NodeOut from_tensor_;
+};
+
+} // namespace
+
Status GraphExecutionState::PruneGraph(
const BuildGraphOptions& options, Graph* graph,
subgraph::RewriteGraphMetadata* out_rewrite_metadata) {
@@ -265,12 +311,48 @@ Status GraphExecutionState::PruneGraph(
new subgraph::SendFetchRewrite(&fetch, device_info));
}
}
+
+ for (const TensorConnection& tensor_connection :
+ options.callable_options.tensor_connection()) {
+ Node* from_node = nullptr;
+ TensorId from_id(ParseTensorName(tensor_connection.from_tensor()));
+
+ for (Node* n : graph->nodes()) {
+ if (n->name() == from_id.first) {
+ from_node = n;
+ break;
+ }
+ }
+ if (from_node == nullptr) {
+ return errors::InvalidArgument(
+ "Requested tensor connection from unknown node: \"",
+ tensor_connection.to_tensor(), "\".");
+ }
+ if (from_id.second >= from_node->num_outputs()) {
+ return errors::InvalidArgument(
+ "Requested tensor connection from unknown edge: \"",
+ tensor_connection.to_tensor(),
+ "\" (actual number of outputs = ", from_node->num_outputs(), ").");
+ }
+
+ feed_rewrites.emplace_back(new TensorConnectionPruneRewrite(
+ &tensor_connection.to_tensor(), {from_node, from_id.second}));
+ }
+
std::vector<string> target_node_names(
options.callable_options.target().begin(),
options.callable_options.target().end());
- return subgraph::RewriteGraphForExecution(graph, feed_rewrites,
- fetch_rewrites, target_node_names,
- out_rewrite_metadata);
+ TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
+ graph, feed_rewrites, fetch_rewrites, target_node_names,
+ out_rewrite_metadata));
+
+ CHECK_EQ(out_rewrite_metadata->feed_types.size(),
+ options.callable_options.feed_size() +
+ options.callable_options.tensor_connection_size());
+ for (int i = 0; i < options.callable_options.tensor_connection_size(); ++i) {
+ out_rewrite_metadata->feed_types.pop_back();
+ }
+ return Status::OK();
}
Status GraphExecutionState::InitBaseGraph(const BuildGraphOptions& options) {
@@ -340,7 +422,13 @@ Status GraphExecutionState::OptimizeGraph(
options.callable_options.target().begin(),
options.callable_options.target().end());
- if (!options.callable_options.feed().empty()) {
+ for (const TensorConnection& tensor_connection :
+ options.callable_options.tensor_connection()) {
+ item.fetch.push_back(tensor_connection.from_tensor());
+ }
+
+ if (!(options.callable_options.feed().empty() &&
+ options.callable_options.tensor_connection().empty())) {
std::unordered_set<string> feeds;
for (const string& feed : options.callable_options.feed()) {
TensorId id = ParseTensorName(feed);
@@ -349,6 +437,15 @@ Status GraphExecutionState::OptimizeGraph(
}
feeds.insert(id.first.ToString());
}
+ for (const TensorConnection& tensor_connection :
+ options.callable_options.tensor_connection()) {
+ TensorId id = ParseTensorName(tensor_connection.to_tensor());
+ if (id.second != 0) {
+ return errors::InvalidArgument("Unsupported feed: ",
+ tensor_connection.to_tensor());
+ }
+ feeds.insert(id.first.ToString());
+ }
for (const NodeDef& node : original_graph_def_.node()) {
if (feeds.find(node.name()) == feeds.end()) {
continue;
diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc
index 1125d2a34a..790f2eaa1e 100644
--- a/tensorflow/core/common_runtime/graph_runner.cc
+++ b/tensorflow/core/common_runtime/graph_runner.cc
@@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// TODO(skyewm): this is necessary to make the single_threaded_cpu_device.h
+// include work. Some other include must be including eigen without defining
+// this. Consider defining in this in a BUILD rule.
+#define EIGEN_USE_THREADS
+
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/device_factory.h"
@@ -20,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/memory_types.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/common_runtime/single_threaded_cpu_device.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_util.h"
@@ -36,18 +42,6 @@ namespace tensorflow {
namespace {
-std::unique_ptr<Device> GetCPUDevice(Env* env) {
- std::vector<Device*> devices;
- SessionOptions session_options;
- session_options.env = env;
- Status s = DeviceFactory::GetFactory(DEVICE_CPU)
- ->CreateDevices(session_options, "", &devices);
- if (s.ok() && !devices.empty()) {
- return std::unique_ptr<Device>(devices[0]);
- }
- return nullptr;
-}
-
// A simple rendezvous class.
// Assumes a single sender and a single receiver, no duplicate sends, and no
// sends of dead tensors.
@@ -98,7 +92,8 @@ class SimpleRendezvous : public Rendezvous {
} // namespace
GraphRunner::GraphRunner(Env* env)
- : device_deleter_(GetCPUDevice(env)), device_(device_deleter_.get()) {}
+ : device_deleter_(new SingleThreadedCpuDevice(env)),
+ device_(device_deleter_.get()) {}
GraphRunner::GraphRunner(Device* device) : device_(device) {}
GraphRunner::~GraphRunner() {}
diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
index 55c8411ad0..b2ef51d10b 100644
--- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h
+++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
@@ -24,7 +24,7 @@ limitations under the License.
#include <cstdlib>
#include <string>
#include "tensorflow/core/common_runtime/bfc_allocator.h"
-#include "tensorflow/core/framework/visitable_allocator.h"
+#include "tensorflow/core/common_runtime/visitable_allocator.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/mem.h"
diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc
index 1b7e3138ee..06dbe04986 100644
--- a/tensorflow/core/common_runtime/shape_refiner.cc
+++ b/tensorflow/core/common_runtime/shape_refiner.cc
@@ -431,6 +431,32 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
InferenceContext* src_context = GetContext(input_edge->src());
if (src_context == nullptr) return errors::Internal("Missing src context");
ShapeHandle src_shape = src_context->output(input_edge->src_output());
+
+ if (src_context->Value(src_context->Rank(src_shape)) == 0) {
+ Tensor t;
+ bool evaluated = false;
+ TF_RETURN_IF_ERROR(
+ EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t));
+ if (!evaluated) {
+ return errors::InvalidArgument(
+ "Received a shape scalar with unknown static value. A static value "
+ "of '-1' is required to represent an unknown shape.");
+ }
+ if (t.dims() == 0) {
+ if (t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) {
+ *result = target_context->UnknownShape();
+ return Status::OK();
+ } else if (t.dtype() == DT_INT64 && t.scalar<int64>()() == -1) {
+ *result = target_context->UnknownShape();
+ return Status::OK();
+ }
+ }
+ return errors::InvalidArgument(
+ "Received an invalid shape scalar with a static value that is not "
+ "'-1': ",
+ t.DebugString());
+ }
+
TF_RETURN_IF_ERROR(src_context->WithRank(src_shape, 1, &src_shape));
const string& src_op = input_edge->src()->type_string();
diff --git a/tensorflow/core/common_runtime/single_threaded_cpu_device.h b/tensorflow/core/common_runtime/single_threaded_cpu_device.h
new file mode 100644
index 0000000000..04d5af9087
--- /dev/null
+++ b/tensorflow/core/common_runtime/single_threaded_cpu_device.h
@@ -0,0 +1,82 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SINGLE_THREADED_CPU_DEVICE_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_SINGLE_THREADED_CPU_DEVICE_H_
+
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+
+namespace tensorflow {
+
+class Env;
+
+// A simple single-threaded CPU device. This can be used to run inexpensive
+// computations. In particular, using this avoids initializing the global thread
+// pools in LocalDevice.
+class SingleThreadedCpuDevice : public Device {
+ public:
+ SingleThreadedCpuDevice(Env* env)
+ : Device(env, Device::BuildDeviceAttributes("/device:CPU:0", DEVICE_CPU,
+ Bytes(256 << 20),
+ DeviceLocality())) {
+ eigen_worker_threads_.num_threads = 1;
+ eigen_worker_threads_.workers = new thread::ThreadPool(
+ env, "graph_runner", eigen_worker_threads_.num_threads);
+ eigen_threadpool_wrapper_.reset(
+ new EigenThreadPoolWrapper(eigen_worker_threads_.workers));
+ eigen_device_.reset(new Eigen::ThreadPoolDevice(
+ eigen_threadpool_wrapper_.get(), eigen_worker_threads_.num_threads));
+ set_tensorflow_cpu_worker_threads(&eigen_worker_threads_);
+ set_eigen_cpu_device(eigen_device_.get());
+ }
+
+ ~SingleThreadedCpuDevice() override {
+ eigen_threadpool_wrapper_.reset();
+ eigen_device_.reset();
+ delete eigen_worker_threads_.workers;
+ }
+
+ Status Sync() override { return Status::OK(); }
+
+ Status MakeTensorFromProto(const TensorProto& tensor_proto,
+ const AllocatorAttributes alloc_attrs,
+ Tensor* tensor) override {
+ Tensor parsed(tensor_proto.dtype());
+ if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
+ return errors::InvalidArgument("Cannot parse tensor from tensor_proto.");
+ }
+ *tensor = parsed;
+ return Status::OK();
+ }
+
+ Allocator* GetAllocator(AllocatorAttributes attr) override {
+ return cpu_allocator();
+ }
+
+ private:
+ DeviceBase::CpuWorkerThreads eigen_worker_threads_;
+ std::unique_ptr<Eigen::ThreadPoolInterface> eigen_threadpool_wrapper_;
+ std::unique_ptr<Eigen::ThreadPoolDevice> eigen_device_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SINGLE_THREADED_CPU_DEVICE_H_
diff --git a/tensorflow/core/framework/visitable_allocator.h b/tensorflow/core/common_runtime/visitable_allocator.h
index ed41b05531..8edf922d11 100644
--- a/tensorflow/core/framework/visitable_allocator.h
+++ b/tensorflow/core/common_runtime/visitable_allocator.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_FRAMEWORK_VISITABLE_ALLOCATOR_H_
-#define TENSORFLOW_CORE_FRAMEWORK_VISITABLE_ALLOCATOR_H_
+#ifndef TENSORFLOW_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
+#define TENSORFLOW_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
#include <functional>
#include "tensorflow/core/framework/allocator.h"
@@ -76,4 +76,4 @@ class TrackingVisitableAllocator : public TrackingAllocator,
VisitableAllocator* allocator_;
};
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_FRAMEWORK_VISITABLE_ALLOCATOR_H_
+#endif // TENSORFLOW_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
diff --git a/tensorflow/core/distributed_runtime/local_master.cc b/tensorflow/core/distributed_runtime/local_master.cc
index aaa4cfa734..76315462a7 100644
--- a/tensorflow/core/distributed_runtime/local_master.cc
+++ b/tensorflow/core/distributed_runtime/local_master.cc
@@ -157,6 +157,47 @@ Status LocalMaster::Reset(CallOptions* call_options,
return ret;
}
+Status LocalMaster::MakeCallable(CallOptions* call_options,
+ const MakeCallableRequest* request,
+ MakeCallableResponse* response) {
+ Notification n;
+ Status ret;
+ master_impl_->MakeCallable(request, response, [&n, &ret](const Status& s) {
+ ret.Update(s);
+ n.Notify();
+ });
+ TF_RETURN_IF_ERROR(
+ WaitForNotification(call_options, default_timeout_in_ms_, &n));
+ return ret;
+}
+Status LocalMaster::RunCallable(CallOptions* call_options,
+ const RunCallableRequest* request,
+ RunCallableResponse* response) {
+ Notification n;
+ Status ret;
+ master_impl_->RunCallable(call_options, request, response,
+ [&n, &ret](const Status& s) {
+ ret.Update(s);
+ n.Notify();
+ });
+ TF_RETURN_IF_ERROR(
+ WaitForNotification(call_options, default_timeout_in_ms_, &n));
+ return ret;
+}
+Status LocalMaster::ReleaseCallable(CallOptions* call_options,
+ const ReleaseCallableRequest* request,
+ ReleaseCallableResponse* response) {
+ Notification n;
+ Status ret;
+ master_impl_->ReleaseCallable(request, response, [&n, &ret](const Status& s) {
+ ret.Update(s);
+ n.Notify();
+ });
+ TF_RETURN_IF_ERROR(
+ WaitForNotification(call_options, default_timeout_in_ms_, &n));
+ return ret;
+}
+
namespace {
mutex* get_local_master_registry_lock() {
static mutex local_master_registry_lock(LINKER_INITIALIZED);
diff --git a/tensorflow/core/distributed_runtime/local_master.h b/tensorflow/core/distributed_runtime/local_master.h
index c20b40329a..cad6babad8 100644
--- a/tensorflow/core/distributed_runtime/local_master.h
+++ b/tensorflow/core/distributed_runtime/local_master.h
@@ -71,6 +71,16 @@ class LocalMaster : public MasterInterface {
Status Reset(CallOptions* call_options, const ResetRequest* request,
ResetResponse* response) override;
+ Status MakeCallable(CallOptions* call_options,
+ const MakeCallableRequest* request,
+ MakeCallableResponse* response) override;
+ Status RunCallable(CallOptions* call_options,
+ const RunCallableRequest* request,
+ RunCallableResponse* response) override;
+ Status ReleaseCallable(CallOptions* call_options,
+ const ReleaseCallableRequest* request,
+ ReleaseCallableResponse* response);
+
// Registers the mapping from the given `target` to the given `master`.
//
// WARNING: The `master` pointer remains owned by the caller. It is
diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc
index 1a488303ac..f47502e844 100644
--- a/tensorflow/core/distributed_runtime/master.cc
+++ b/tensorflow/core/distributed_runtime/master.cc
@@ -611,4 +611,55 @@ void Master::Reset(const ResetRequest* req, ResetResponse* resp,
});
}
+void Master::MakeCallable(const MakeCallableRequest* req,
+ MakeCallableResponse* resp, MyClosure done) {
+ auto session = FindMasterSession(req->session_handle());
+ if (session == nullptr) {
+ done(errors::Aborted("Session ", req->session_handle(), " is not found."));
+ return;
+ }
+
+ SchedClosure(std::bind(
+ [this, session, req, resp](MyClosure done) {
+ Status s = session->MakeCallable(*req, resp);
+ session->Unref();
+ done(s);
+ },
+ std::move(done)));
+}
+
+void Master::RunCallable(CallOptions* opts, const RunCallableRequest* req,
+ RunCallableResponse* resp, MyClosure done) {
+ auto session = FindMasterSession(req->session_handle());
+ if (session == nullptr) {
+ done(errors::Aborted("Session ", req->session_handle(), " is not found."));
+ return;
+ }
+
+ SchedClosure(std::bind(
+ [this, session, opts, req, resp](MyClosure done) {
+ Status s = session->RunCallable(opts, *req, resp);
+ session->Unref();
+ done(s);
+ },
+ std::move(done)));
+}
+
+void Master::ReleaseCallable(const ReleaseCallableRequest* req,
+ ReleaseCallableResponse* resp, MyClosure done) {
+ auto session = FindMasterSession(req->session_handle());
+ if (session == nullptr) {
+ done(errors::Aborted("Session ", req->session_handle(), " is not found."));
+ return;
+ }
+
+ SchedClosure(std::bind(
+ [this, session, req, resp](MyClosure done) {
+ Status s = session->ReleaseCallable(*req, resp);
+ session->Unref();
+ done(s);
+ },
+ std::move(done)));
+}
+
} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/master.h b/tensorflow/core/distributed_runtime/master.h
index 678fc46bd7..dbb337fd48 100644
--- a/tensorflow/core/distributed_runtime/master.h
+++ b/tensorflow/core/distributed_runtime/master.h
@@ -61,6 +61,13 @@ class Master {
// See tensorflow::Reset() and the comment on ResetRequest.
void Reset(const ResetRequest* req, ResetResponse* resp, MyClosure done);
+ void MakeCallable(const MakeCallableRequest* req, MakeCallableResponse* resp,
+ MyClosure done);
+ void RunCallable(CallOptions* opts, const RunCallableRequest* req,
+ RunCallableResponse* resp, MyClosure done);
+ void ReleaseCallable(const ReleaseCallableRequest* req,
+ ReleaseCallableResponse* resp, MyClosure done);
+
private:
typedef Master ME;
diff --git a/tensorflow/core/distributed_runtime/master_interface.h b/tensorflow/core/distributed_runtime/master_interface.h
index bf6a2db3e2..a8ae3cba3c 100644
--- a/tensorflow/core/distributed_runtime/master_interface.h
+++ b/tensorflow/core/distributed_runtime/master_interface.h
@@ -89,6 +89,16 @@ class MasterInterface {
virtual Status Reset(CallOptions* call_options, const ResetRequest* request,
ResetResponse* response) = 0;
+ virtual Status MakeCallable(CallOptions* call_options,
+ const MakeCallableRequest* request,
+ MakeCallableResponse* response) = 0;
+ virtual Status RunCallable(CallOptions* call_options,
+ const RunCallableRequest* request,
+ RunCallableResponse* response) = 0;
+ virtual Status ReleaseCallable(CallOptions* call_options,
+ const ReleaseCallableRequest* request,
+ ReleaseCallableResponse* response) = 0;
+
protected:
// NOTE: This should only be called by implementations of this
// interface whose CreateRunStepResponse() method returns a
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index 64adf35c5e..e0a5bb4c53 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -72,7 +72,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
client_graph_(std::move(cg)),
session_opts_(session_opts),
is_partial_(is_partial),
- debug_opts_(bopts.callable_options.run_options().debug_options()),
+ callable_opts_(bopts.callable_options),
worker_cache_(worker_cache),
should_deregister_(should_deregister) {
VLOG(1) << "Created ReffedClientGraph for node with "
@@ -94,12 +94,18 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
const ClientGraph* client_graph() { return client_graph_.get(); }
+ const CallableOptions& callable_options() { return callable_opts_; }
+
std::unique_ptr<ProfileHandler> GetProfileHandler(uint64 step,
int64 execution_count,
const RunOptions& ropts) {
return stats_publisher_->GetProfileHandler(step, execution_count, ropts);
}
+ int64 get_and_increment_execution_count() {
+ return execution_count_.fetch_add(1);
+ }
+
// Turn RPC logging on or off, both at the WorkerCache used by this
// master process, and at each remote worker in use for the current
// partitions.
@@ -178,6 +184,10 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
CallOptions* opts, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp,
CancellationManager* cm, const bool is_last_partial_run);
+ Status RunPartitions(const MasterEnv* env, int64 step_id,
+ int64 execution_count, PerStepState* pss,
+ CallOptions* call_opts, const RunCallableRequest& req,
+ RunCallableResponse* resp, CancellationManager* cm);
// Calls workers to cleanup states for the step "step_id". Calls
// `done` when all cleanup RPCs have completed.
@@ -211,10 +221,11 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
const std::unique_ptr<ClientGraph> client_graph_;
const SessionOptions session_opts_;
const bool is_partial_;
- const DebugOptions& debug_opts_;
+ const CallableOptions callable_opts_;
WorkerCacheInterface* const worker_cache_; // Not owned.
std::unordered_map<StringPiece, Node*, StringPieceHasher> name_to_node_;
const bool should_deregister_;
+ std::atomic<int64> execution_count_ = {0};
// Graph partitioned into per-location subgraphs.
struct Part {
@@ -269,6 +280,17 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
const PartitionOptions& popts,
std::unordered_map<string, GraphDef> graph_partitions);
+ // Prepares a number of calls to workers. One call per partition.
+ // This is a generic method that handles Run, PartialRun, and RunCallable.
+ template <class FetchListType, class ClientRequestType,
+ class ClientResponseType>
+ Status RunPartitionsHelper(
+ const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
+ const FetchListType& fetches, const MasterEnv* env, int64 step_id,
+ int64 execution_count, PerStepState* pss, CallOptions* call_opts,
+ const ClientRequestType& req, ClientResponseType* resp,
+ CancellationManager* cm, bool is_last_partial_run);
+
// Deregisters the partitions on the workers. Called in the
// destructor and does not wait for the rpc completion.
void DeregisterPartitions();
@@ -411,7 +433,8 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
c->req.set_session_handle(session_handle_);
c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
*c->req.mutable_graph_options() = session_opts_.config.graph_options();
- *c->req.mutable_debug_options() = debug_opts_;
+ *c->req.mutable_debug_options() =
+ callable_opts_.run_options().debug_options();
VLOG(2) << "Register " << c->req.graph_def().DebugString();
auto cb = [c, &done](const Status& s) {
c->status = s;
@@ -490,24 +513,46 @@ class RunManyGraphs {
TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs);
};
-Status MasterSession::ReffedClientGraph::RunPartitions(
- const MasterEnv* env, int64 step_id, int64 execution_count,
- PerStepState* pss, CallOptions* call_opts, const RunStepRequestWrapper& req,
- MutableRunStepResponseWrapper* resp, CancellationManager* cm,
- const bool is_last_partial_run) {
- VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
- << execution_count;
- // Maps the names of fed tensors to their index in `req`.
- std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
+namespace {
+Status AddSendFromClientRequest(const RunStepRequestWrapper& client_req,
+ MutableRunGraphRequestWrapper* worker_req,
+ size_t index, const string& send_key) {
+ return worker_req->AddSendFromRunStepRequest(client_req, index, send_key);
+}
- for (size_t i = 0; i < req.num_feeds(); ++i) {
- if (!feeds.insert({req.feed_name(i), i}).second) {
- return errors::InvalidArgument("Duplicated feeds: ", req.feed_name(i));
- }
- }
+Status AddSendFromClientRequest(const RunCallableRequest& client_req,
+ MutableRunGraphRequestWrapper* worker_req,
+ size_t index, const string& send_key) {
+ return worker_req->AddSendFromRunCallableRequest(client_req, index, send_key);
+}
- // Prepares a number of calls to workers. One call per partition.
+// TODO(mrry): Add a full-fledged wrapper that avoids TensorProto copies for
+// in-process messages.
+struct RunCallableResponseWrapper {
+ RunCallableResponse* resp; // Not owned.
+ std::unordered_map<string, TensorProto> fetch_key_to_protos;
+
+ RunMetadata* mutable_metadata() { return resp->mutable_metadata(); }
+ Status AddTensorFromRunGraphResponse(
+ const string& tensor_name, MutableRunGraphResponseWrapper* worker_resp,
+ size_t index) {
+ // TODO(b/74355905): Add a specialized implementation that avoids
+ // copying the tensor into the RunCallableResponse when at least
+ // two of the {client, master, worker} are in the same process.
+ return worker_resp->RecvValue(index, &fetch_key_to_protos[tensor_name]);
+ }
+};
+} // namespace
+
+template <class FetchListType, class ClientRequestType,
+ class ClientResponseType>
+Status MasterSession::ReffedClientGraph::RunPartitionsHelper(
+ const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
+ const FetchListType& fetches, const MasterEnv* env, int64 step_id,
+ int64 execution_count, PerStepState* pss, CallOptions* call_opts,
+ const ClientRequestType& req, ClientResponseType* resp,
+ CancellationManager* cm, bool is_last_partial_run) {
// Collect execution cost stats on a smoothly decreasing frequency.
ExecutorOpts exec_opts;
if (pss->report_tensor_allocations_upon_oom) {
@@ -553,28 +598,19 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
// We keep these as separate paths for now, to ensure we aren't
// inadvertently slowing down the normal run path.
if (is_partial_) {
- for (size_t i = 0; i < req.num_feeds(); ++i) {
- const string& name = req.feed_name(i);
- const auto iter = part.feed_key.find(name);
+ for (const auto& name_index : feeds) {
+ const auto iter = part.feed_key.find(name_index.first.ToString());
if (iter == part.feed_key.end()) {
// The provided feed must be for a different partition.
continue;
}
const string& key = iter->second;
- auto feeds_iter = feeds.find(name);
- if (feeds_iter == feeds.end()) {
- return errors::InvalidArgument("No feed is provided for feed=", name,
- ", key=", key);
- } else if (feeds_iter->second != static_cast<size_t>(i)) {
- return errors::Internal("Cannot find feed named \"", name,
- " in request.");
- }
- TF_RETURN_IF_ERROR(c->req->AddSendFromRunStepRequest(req, i, key));
+ TF_RETURN_IF_ERROR(AddSendFromClientRequest(req, c->req.get(),
+ name_index.second, key));
}
// TODO(suharshs): Make a map from feed to fetch_key to make this faster.
// For now, we just iterate through partitions to find the matching key.
- for (int i = 0; static_cast<size_t>(i) < req.num_fetches(); ++i) {
- const string& req_fetch = req.fetch_name(i);
+ for (const string& req_fetch : fetches) {
for (const auto& key_fetch : part.key_fetch) {
if (key_fetch.second == req_fetch) {
c->req->add_recv_key(key_fetch.first);
@@ -586,9 +622,13 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
for (const auto& feed_key : part.feed_key) {
const string& feed = feed_key.first;
const string& key = feed_key.second;
- const int64 feed_index = feeds[feed];
+ auto iter = feeds.find(feed);
+ if (iter == feeds.end()) {
+ return errors::Internal("No feed index found for feed: ", feed);
+ }
+ const int64 feed_index = iter->second;
TF_RETURN_IF_ERROR(
- c->req->AddSendFromRunStepRequest(req, feed_index, key));
+ AddSendFromClientRequest(req, c->req.get(), feed_index, key));
}
for (const auto& key_fetch : part.key_fetch) {
const string& key = key_fetch.first;
@@ -622,50 +662,115 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
} else {
return errors::Cancelled("Step was cancelled");
}
+ TF_RETURN_IF_ERROR(calls.status());
- // Collects fetches.
- Status status = calls.status();
- if (status.ok()) {
- for (int i = 0; i < num; ++i) {
- const Part& part = partitions_[i];
- MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get();
- for (size_t j = 0; j < run_graph_resp->num_recvs(); ++j) {
- auto iter = part.key_fetch.find(run_graph_resp->recv_key(j));
- if (iter == part.key_fetch.end()) {
- status.Update(errors::Internal("Unexpected fetch key: ",
- run_graph_resp->recv_key(j)));
- break;
- }
- const string& fetch = iter->second;
- status.Update(
- resp->AddTensorFromRunGraphResponse(fetch, run_graph_resp, j));
- if (!status.ok()) {
- break;
- }
+ // Collects fetches and metadata.
+ Status status;
+ for (int i = 0; i < num; ++i) {
+ const Part& part = partitions_[i];
+ MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get();
+ for (size_t j = 0; j < run_graph_resp->num_recvs(); ++j) {
+ auto iter = part.key_fetch.find(run_graph_resp->recv_key(j));
+ if (iter == part.key_fetch.end()) {
+ status.Update(errors::Internal("Unexpected fetch key: ",
+ run_graph_resp->recv_key(j)));
+ break;
}
- if (pss->collect_timeline) {
- pss->step_stats[i].Swap(run_graph_resp->mutable_step_stats());
+ const string& fetch = iter->second;
+ status.Update(
+ resp->AddTensorFromRunGraphResponse(fetch, run_graph_resp, j));
+ if (!status.ok()) {
+ break;
}
- if (pss->collect_costs) {
- CostGraphDef* cost_graph = run_graph_resp->mutable_cost_graph();
- for (int j = 0; j < cost_graph->node_size(); ++j) {
- resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap(
- cost_graph->mutable_node(j));
- }
+ }
+ if (pss->collect_timeline) {
+ pss->step_stats[i].Swap(run_graph_resp->mutable_step_stats());
+ }
+ if (pss->collect_costs) {
+ CostGraphDef* cost_graph = run_graph_resp->mutable_cost_graph();
+ for (int j = 0; j < cost_graph->node_size(); ++j) {
+ resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap(
+ cost_graph->mutable_node(j));
}
- if (pss->collect_partition_graphs) {
- protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
- resp->mutable_metadata()->mutable_partition_graphs();
- for (size_t i = 0; i < run_graph_resp->num_partition_graphs(); i++) {
- partition_graph_defs->Add()->Swap(
- run_graph_resp->mutable_partition_graph(i));
- }
+ }
+ if (pss->collect_partition_graphs) {
+ protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
+ resp->mutable_metadata()->mutable_partition_graphs();
+ for (size_t i = 0; i < run_graph_resp->num_partition_graphs(); i++) {
+ partition_graph_defs->Add()->Swap(
+ run_graph_resp->mutable_partition_graph(i));
}
}
}
return status;
}
+Status MasterSession::ReffedClientGraph::RunPartitions(
+ const MasterEnv* env, int64 step_id, int64 execution_count,
+ PerStepState* pss, CallOptions* call_opts, const RunStepRequestWrapper& req,
+ MutableRunStepResponseWrapper* resp, CancellationManager* cm,
+ const bool is_last_partial_run) {
+ VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
+ << execution_count;
+ // Maps the names of fed tensors to their index in `req`.
+ std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
+ for (size_t i = 0; i < req.num_feeds(); ++i) {
+ if (!feeds.insert({req.feed_name(i), i}).second) {
+ return errors::InvalidArgument("Duplicated feeds: ", req.feed_name(i));
+ }
+ }
+
+ std::vector<string> fetches;
+ fetches.reserve(req.num_fetches());
+ for (size_t i = 0; i < req.num_fetches(); ++i) {
+ fetches.push_back(req.fetch_name(i));
+ }
+
+ return RunPartitionsHelper(feeds, fetches, env, step_id, execution_count, pss,
+ call_opts, req, resp, cm, is_last_partial_run);
+}
+
+Status MasterSession::ReffedClientGraph::RunPartitions(
+ const MasterEnv* env, int64 step_id, int64 execution_count,
+ PerStepState* pss, CallOptions* call_opts, const RunCallableRequest& req,
+ RunCallableResponse* resp, CancellationManager* cm) {
+ VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
+ << execution_count;
+ // Maps the names of fed tensors to their index in `req`.
+ std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
+ for (size_t i = 0; i < callable_opts_.feed_size(); ++i) {
+ if (!feeds.insert({callable_opts_.feed(i), i}).second) {
+ // MakeCallable will fail if there are two feeds with the same name.
+ return errors::Internal("Duplicated feeds in callable: ",
+ callable_opts_.feed(i));
+ }
+ }
+
+ // Create a wrapped response object to collect the fetched values and
+ // rearrange them for the RunCallableResponse.
+ RunCallableResponseWrapper wrapped_resp;
+ wrapped_resp.resp = resp;
+
+ TF_RETURN_IF_ERROR(RunPartitionsHelper(
+ feeds, callable_opts_.fetch(), env, step_id, execution_count, pss,
+ call_opts, req, &wrapped_resp, cm, false /* is_last_partial_run */));
+
+ // Collects fetches.
+ // TODO(b/74355905): Add a specialized implementation that avoids
+ // copying the tensor into the RunCallableResponse when at least
+ // two of the {client, master, worker} are in the same process.
+ for (const string& fetch : callable_opts_.fetch()) {
+ TensorProto* fetch_proto = resp->mutable_fetch()->Add();
+ auto iter = wrapped_resp.fetch_key_to_protos.find(fetch);
+ if (iter == wrapped_resp.fetch_key_to_protos.end()) {
+ return errors::Internal("Worker did not return a value for fetch: ",
+ fetch);
+ }
+ fetch_proto->Swap(&iter->second);
+ }
+ return Status::OK();
+}
+
namespace {
class CleanupBroadcastHelper {
@@ -1266,15 +1371,11 @@ WorkerCacheInterface* MasterSession::get_worker_cache() const {
return env_->worker_cache;
}
-Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
- ReffedClientGraph** rcg, bool is_partial) {
+Status MasterSession::StartStep(const BuildGraphOptions& opts, bool is_partial,
+ ReffedClientGraph** out_rcg, int64* out_count) {
const uint64 hash = HashBuildGraphOptions(opts);
{
mutex_lock l(mu_);
- // Keep track of how many times this subgraph has been executed in
- // this session.
- int64* c = &subgraph_execution_counts_[hash];
- *count = (*c)++;
// TODO(suharshs): We cache partial run graphs and run graphs separately
// because there is preprocessing that needs to only be run for partial
// run calls.
@@ -1296,8 +1397,9 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
iter = m->insert({hash, entry}).first;
VLOG(1) << "Preparing to execute new graph";
}
- *rcg = iter->second;
- (*rcg)->Ref();
+ *out_rcg = iter->second;
+ (*out_rcg)->Ref();
+ *out_count = (*out_rcg)->get_and_increment_execution_count();
}
return Status::OK();
}
@@ -1316,6 +1418,12 @@ void MasterSession::ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
rcg_map->clear();
}
+namespace {
+uint64 MakeStepId() {
+ return (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
+}
+} // namespace
+
Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
PartialRunSetupResponse* resp) {
std::vector<string> inputs, outputs, targets;
@@ -1332,15 +1440,15 @@ Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
string handle = std::to_string(partial_run_handle_counter_.fetch_add(1));
ReffedClientGraph* rcg = nullptr;
- int64 count = 0;
// Prepare.
BuildGraphOptions opts;
BuildBuildGraphOptions(*req, &opts);
- TF_RETURN_IF_ERROR(StartStep(opts, &count, &rcg, true));
+ int64 count;
+ TF_RETURN_IF_ERROR(StartStep(opts, true, &rcg, &count));
// Keeps the highest 8 bits 0x01: we reserve some bits of the
// step_id for future use.
- uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
+ const uint64 step_id = MakeStepId();
TRACEPRINTF("stepid %llu", step_id);
rcg->Ref();
@@ -1585,6 +1693,73 @@ Status MasterSession::CreateDebuggerState(
return Status::OK();
}
+void MasterSession::FillPerStepState(MasterSession::ReffedClientGraph* rcg,
+ const RunOptions& run_options,
+ uint64 step_id, int64 count,
+ PerStepState* out_pss,
+ std::unique_ptr<ProfileHandler>* out_ph) {
+ out_pss->collect_timeline =
+ run_options.trace_level() == RunOptions::FULL_TRACE;
+ out_pss->collect_rpcs = run_options.trace_level() == RunOptions::FULL_TRACE;
+ out_pss->report_tensor_allocations_upon_oom =
+ run_options.report_tensor_allocations_upon_oom();
+ // Build the cost model every 'build_cost_model_every' steps after skipping an
+ // initial 'build_cost_model_after' steps.
+ const int64 build_cost_model_after =
+ session_opts_.config.graph_options().build_cost_model_after();
+ const int64 build_cost_model_every =
+ session_opts_.config.graph_options().build_cost_model();
+ out_pss->collect_costs =
+ build_cost_model_every > 0 &&
+ ((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
+ out_pss->collect_partition_graphs = run_options.output_partition_graphs();
+
+ *out_ph = rcg->GetProfileHandler(step_id, count, run_options);
+ if (*out_ph) {
+ out_pss->collect_timeline = true;
+ out_pss->collect_rpcs = (*out_ph)->should_collect_rpcs();
+ }
+}
+
+Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg,
+ uint64 step_id,
+ const RunOptions& run_options,
+ PerStepState* pss,
+ const std::unique_ptr<ProfileHandler>& ph,
+ const Status& run_status,
+ RunMetadata* out_run_metadata) {
+ Status s = run_status;
+ if (s.ok()) {
+ pss->end_micros = Env::Default()->NowMicros();
+
+ // Schedule post-processing and cleanup to be done asynchronously.
+ rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata);
+ } else if (errors::IsCancelled(s)) {
+ mutex_lock l(mu_);
+ if (closed_) {
+ if (garbage_collected_) {
+ s = errors::Cancelled(
+ "Step was cancelled because the session was garbage collected due "
+ "to inactivity.");
+ } else {
+ s = errors::Cancelled(
+ "Step was cancelled by an explicit call to `Session::Close()`.");
+ }
+ }
+ }
+ Ref();
+ rcg->Ref();
+ rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) {
+ if (!s.ok()) {
+ LOG(ERROR) << "Cleanup partition error: " << s;
+ }
+ rcg->Unref();
+ MarkRunCompletion();
+ Unref();
+ });
+ return s;
+}
+
Status MasterSession::DoRunWithLocalExecution(
CallOptions* opts, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp) {
@@ -1597,8 +1772,8 @@ Status MasterSession::DoRunWithLocalExecution(
BuildGraphOptions bgopts;
BuildBuildGraphOptions(req, &bgopts);
ReffedClientGraph* rcg = nullptr;
- int64 count = 0;
- TF_RETURN_IF_ERROR(StartStep(bgopts, &count, &rcg, false));
+ int64 count;
+ TF_RETURN_IF_ERROR(StartStep(bgopts, false, &rcg, &count));
// Unref "rcg" when out of scope.
core::ScopedUnref unref(rcg);
@@ -1614,64 +1789,133 @@ Status MasterSession::DoRunWithLocalExecution(
// Keeps the highest 8 bits 0x01: we reserve some bits of the
// step_id for future use.
- const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
+ const uint64 step_id = MakeStepId();
TRACEPRINTF("stepid %llu", step_id);
- pss.collect_timeline = req.options().trace_level() == RunOptions::FULL_TRACE;
- pss.collect_rpcs = req.options().trace_level() == RunOptions::FULL_TRACE;
- pss.report_tensor_allocations_upon_oom =
- req.options().report_tensor_allocations_upon_oom();
- // Build the cost model every 'build_cost_model_every' steps after skipping an
- // initial 'build_cost_model_after' steps.
- const int64 build_cost_model_after =
- session_opts_.config.graph_options().build_cost_model_after();
- const int64 build_cost_model_every =
- session_opts_.config.graph_options().build_cost_model();
- pss.collect_costs =
- build_cost_model_every > 0 &&
- ((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
- pss.collect_partition_graphs = req.options().output_partition_graphs();
+ std::unique_ptr<ProfileHandler> ph;
+ FillPerStepState(rcg, req.options(), step_id, count, &pss, &ph);
- std::unique_ptr<ProfileHandler> ph =
- rcg->GetProfileHandler(step_id, count, req.options());
- if (ph) {
- pss.collect_timeline = true;
- pss.collect_rpcs = ph->should_collect_rpcs();
+ Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
+ &cancellation_manager_, false);
+ cleanup.release(); // MarkRunCompletion called in PostRunCleanup().
+ return PostRunCleanup(rcg, step_id, req.options(), &pss, ph, s,
+ resp->mutable_metadata());
+}
+
+Status MasterSession::MakeCallable(const MakeCallableRequest& req,
+ MakeCallableResponse* resp) {
+ UpdateLastAccessTime();
+
+ BuildGraphOptions opts;
+ opts.callable_options = req.options();
+ opts.use_function_convention = false;
+
+ ReffedClientGraph* callable;
+
+ {
+ mutex_lock l(mu_);
+ if (closed_) {
+ return errors::FailedPrecondition("Session is closed.");
+ }
+ std::unique_ptr<ClientGraph> client_graph;
+ TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
+ callable = new ReffedClientGraph(handle_, opts, std::move(client_graph),
+ session_opts_, stats_publisher_factory_,
+ false /* is_partial */, get_worker_cache(),
+ !should_delete_worker_sessions_);
+ }
+
+ Status s = BuildAndRegisterPartitions(callable);
+ if (!s.ok()) {
+ callable->Unref();
+ return s;
}
+ uint64 handle;
+ {
+ mutex_lock l(mu_);
+ handle = next_callable_handle_++;
+ callables_[handle] = callable;
+ }
+
+ resp->set_handle(handle);
+ return Status::OK();
+}
+
+Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
+ const RunCallableRequest& req,
+ RunCallableResponse* resp) {
+ VLOG(2) << "DoRunCallable req: " << req.DebugString();
+ PerStepState pss;
+ pss.start_micros = Env::Default()->NowMicros();
+ auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
+
+ // Prepare.
+ int64 count = rcg->get_and_increment_execution_count();
+
+ // Keeps the highest 8 bits 0x01: we reserve some bits of the
+ // step_id for future use.
+ const uint64 step_id = MakeStepId();
+ TRACEPRINTF("stepid %llu", step_id);
+
+ const RunOptions& run_options = rcg->callable_options().run_options();
+
+ if (run_options.timeout_in_ms() != 0) {
+ opts->SetTimeout(run_options.timeout_in_ms());
+ }
+
+ std::unique_ptr<ProfileHandler> ph;
+ FillPerStepState(rcg, run_options, step_id, count, &pss, &ph);
Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
- &cancellation_manager_, false);
- if (s.ok()) {
- pss.end_micros = Env::Default()->NowMicros();
+ &cancellation_manager_);
+ cleanup.release(); // MarkRunCompletion called in PostRunCleanup().
+ return PostRunCleanup(rcg, step_id, run_options, &pss, ph, s,
+ resp->mutable_metadata());
+}
- // Schedule post-processing and cleanup to be done asynchronously.
- rcg->ProcessStats(step_id, &pss, ph.get(), req.options(),
- resp->mutable_metadata());
- } else if (errors::IsCancelled(s)) {
+Status MasterSession::RunCallable(CallOptions* opts,
+ const RunCallableRequest& req,
+ RunCallableResponse* resp) {
+ UpdateLastAccessTime();
+ ReffedClientGraph* callable;
+ {
mutex_lock l(mu_);
if (closed_) {
- if (garbage_collected_) {
- s = errors::Cancelled(
- "Step was cancelled because the session was garbage collected due "
- "to inactivity.");
- } else {
- s = errors::Cancelled(
- "Step was cancelled by an explicit call to `Session::Close()`.");
- }
+ return errors::FailedPrecondition("Session is closed.");
+ }
+ int64 handle = req.handle();
+ if (handle >= next_callable_handle_) {
+ return errors::InvalidArgument("No such callable handle: ", handle);
+ }
+ auto iter = callables_.find(req.handle());
+ if (iter == callables_.end()) {
+ return errors::InvalidArgument(
+ "Attempted to run callable after handle was released: ", handle);
}
+ callable = iter->second;
+ callable->Ref();
+ ++num_running_;
}
- Ref();
- rcg->Ref();
- cleanup.release(); // MarkRunCompletion called in done closure.
- rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) {
- if (!s.ok()) {
- LOG(ERROR) << "Cleanup partition error: " << s;
+ core::ScopedUnref unref_callable(callable);
+ return DoRunCallable(opts, callable, req, resp);
+}
+
+Status MasterSession::ReleaseCallable(const ReleaseCallableRequest& req,
+ ReleaseCallableResponse* resp) {
+ UpdateLastAccessTime();
+ ReffedClientGraph* to_unref = nullptr;
+ {
+ mutex_lock l(mu_);
+ auto iter = callables_.find(req.handle());
+ if (iter != callables_.end()) {
+ to_unref = iter->second;
+ callables_.erase(iter);
}
- rcg->Unref();
- MarkRunCompletion();
- Unref();
- });
- return s;
+ }
+ if (to_unref != nullptr) {
+ to_unref->Unref();
+ }
+ return Status::OK();
}
Status MasterSession::Close() {
@@ -1688,6 +1932,7 @@ Status MasterSession::Close() {
}
ClearRunsTable(&to_unref, &run_graphs_);
ClearRunsTable(&to_unref, &partial_run_graphs_);
+ ClearRunsTable(&to_unref, &callables_);
}
for (ReffedClientGraph* rcg : to_unref) rcg->Unref();
if (should_delete_worker_sessions_) {
diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h
index 4bd4e1367a..a05419904f 100644
--- a/tensorflow/core/distributed_runtime/master_session.h
+++ b/tensorflow/core/distributed_runtime/master_session.h
@@ -89,6 +89,15 @@ class MasterSession : public core::RefCounted {
Status ListDevices(ListDevicesResponse* resp) const;
+ Status MakeCallable(const MakeCallableRequest& req,
+ MakeCallableResponse* resp);
+
+ Status RunCallable(CallOptions* opts, const RunCallableRequest& req,
+ RunCallableResponse* resp);
+
+ Status ReleaseCallable(const ReleaseCallableRequest& req,
+ ReleaseCallableResponse* resp);
+
// Close this session and delete "*this". Returns OK if all known
// states are cleanup successfully.
//
@@ -140,6 +149,8 @@ class MasterSession : public core::RefCounted {
typedef std::unordered_map<uint64, ReffedClientGraph*> RCGMap;
RCGMap run_graphs_ GUARDED_BY(mu_);
RCGMap partial_run_graphs_ GUARDED_BY(mu_);
+ int64 next_callable_handle_ GUARDED_BY(mu_) = 0;
+ RCGMap callables_ GUARDED_BY(mu_);
struct PerStepState {
bool collect_costs = false;
@@ -205,15 +216,28 @@ class MasterSession : public core::RefCounted {
bool should_delete_worker_sessions_ = false;
Status DeleteWorkerSessions();
- Status StartStep(const BuildGraphOptions& opts, int64* count,
- ReffedClientGraph** graph, bool is_partial);
+ Status StartStep(const BuildGraphOptions& opts, bool is_partial,
+ ReffedClientGraph** out_rcg, int64* out_count);
void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
RCGMap* rcg_map) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ void FillPerStepState(MasterSession::ReffedClientGraph* rcg,
+ const RunOptions& run_options, uint64 step_id,
+ int64 count, PerStepState* out_pss,
+ std::unique_ptr<ProfileHandler>* out_ph);
Status DoRunWithLocalExecution(CallOptions* opts,
const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp);
Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp);
+ Status DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
+ const RunCallableRequest& req,
+ RunCallableResponse* resp);
+ Status PostRunCleanup(MasterSession::ReffedClientGraph* rcg, uint64 step_id,
+ const RunOptions& run_options, PerStepState* pss,
+ const std::unique_ptr<ProfileHandler>& ph,
+ const Status& run_status,
+ RunMetadata* out_run_metadata);
+
void MarkRunCompletion();
void UpdateLastAccessTime();
diff --git a/tensorflow/core/distributed_runtime/message_wrappers.cc b/tensorflow/core/distributed_runtime/message_wrappers.cc
index 66ebb3080a..18668b44d3 100644
--- a/tensorflow/core/distributed_runtime/message_wrappers.cc
+++ b/tensorflow/core/distributed_runtime/message_wrappers.cc
@@ -326,6 +326,20 @@ Status InMemoryRunGraphRequest::AddSendFromRunStepRequest(
return Status::OK();
}
+// TODO(b/74355905): Add a specialized implementation that avoids
+// copying the tensor when at least two of the {client, master,
+// worker} are in the same process.
+Status InMemoryRunGraphRequest::AddSendFromRunCallableRequest(
+ const RunCallableRequest& run_callable_request, size_t i,
+ const string& send_key) {
+ Tensor tensor;
+ if (!ParseTensorProtoToTensor(run_callable_request.feed(i), &tensor)) {
+ return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
+ }
+ sends_.emplace_back(send_key, std::move(tensor));
+ return Status::OK();
+}
+
size_t InMemoryRunGraphRequest::num_recvs() const { return recvs_.size(); }
const string& InMemoryRunGraphRequest::recv_key(size_t i) const {
@@ -439,6 +453,18 @@ Status MutableProtoRunGraphRequest::AddSendFromRunStepRequest(
return Status::OK();
}
+// TODO(b/74355905): Add a specialized implementation that avoids
+// copying the tensor when at least two of the {client, master,
+// worker} are in the same process.
+Status MutableProtoRunGraphRequest::AddSendFromRunCallableRequest(
+ const RunCallableRequest& run_callable_request, size_t i,
+ const string& send_key) {
+ NamedTensorProto* send = request_.add_send();
+ send->set_name(send_key);
+ *send->mutable_tensor() = run_callable_request.feed(i);
+ return Status::OK();
+}
+
size_t MutableProtoRunGraphRequest::num_recvs() const {
return request_.recv_key_size();
}
diff --git a/tensorflow/core/distributed_runtime/message_wrappers.h b/tensorflow/core/distributed_runtime/message_wrappers.h
index 79fa6f926e..1f7cdb98a4 100644
--- a/tensorflow/core/distributed_runtime/message_wrappers.h
+++ b/tensorflow/core/distributed_runtime/message_wrappers.h
@@ -302,6 +302,9 @@ class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
virtual Status AddSendFromRunStepRequest(
const RunStepRequestWrapper& run_step_request, size_t i,
const string& send_key) = 0;
+ virtual Status AddSendFromRunCallableRequest(
+ const RunCallableRequest& run_callable_request, size_t i,
+ const string& send_key) = 0;
virtual void add_recv_key(const string& recv_key) = 0;
virtual void set_is_partial(bool is_partial) = 0;
@@ -334,6 +337,9 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
Status AddSendFromRunStepRequest(
const RunStepRequestWrapper& run_step_request, size_t i,
const string& send_key) override;
+ Status AddSendFromRunCallableRequest(
+ const RunCallableRequest& run_callable_request, size_t i,
+ const string& send_key) override;
void add_recv_key(const string& recv_key) override;
void set_is_partial(bool is_partial) override;
void set_is_last_partial_run(bool is_last_partial_run) override;
@@ -385,6 +391,9 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
Status AddSendFromRunStepRequest(
const RunStepRequestWrapper& run_step_request, size_t i,
const string& send_key) override;
+ Status AddSendFromRunCallableRequest(
+ const RunCallableRequest& run_callable_request, size_t i,
+ const string& send_key) override;
void add_recv_key(const string& recv_key) override;
void set_is_partial(bool is_partial) override;
void set_is_last_partial_run(bool is_last_partial_run) override;
diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD
index 9c655bfa31..d3478dfc38 100644
--- a/tensorflow/core/distributed_runtime/rpc/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/BUILD
@@ -499,3 +499,33 @@ tf_cuda_cc_test(
"//tensorflow/core/kernels:variable_ops",
],
)
+
+cc_library(
+ name = "grpc_rpc_factory",
+ srcs = [
+ "grpc_rpc_factory.cc",
+ ],
+ hdrs = ["grpc_rpc_factory.h"],
+ deps = [
+ ":grpc_state",
+ ":grpc_util",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/util/rpc:call_container",
+ "//tensorflow/core/util/rpc:rpc_factory",
+ ],
+)
+
+cc_library(
+ name = "grpc_rpc_factory_registration",
+ srcs = [
+ "grpc_rpc_factory_registration.cc",
+ ],
+ deps = [
+ ":grpc_rpc_factory",
+ "//tensorflow/core/util/rpc:rpc_factory",
+ "//tensorflow/core/util/rpc:rpc_factory_registry",
+ ],
+ alwayslink = 1,
+)
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
index 63745e8ebd..23968e24c8 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
@@ -111,6 +111,11 @@ class GrpcMasterService : public AsyncServiceInterface {
ENQUEUE_REQUEST(CloseSession, false);
ENQUEUE_REQUEST(ListDevices, false);
ENQUEUE_REQUEST(Reset, false);
+ ENQUEUE_REQUEST(MakeCallable, false);
+ for (int i = 0; i < 100; ++i) {
+ ENQUEUE_REQUEST(RunCallable, true);
+ }
+ ENQUEUE_REQUEST(ReleaseCallable, false);
void* tag;
bool ok;
@@ -236,6 +241,47 @@ class GrpcMasterService : public AsyncServiceInterface {
});
ENQUEUE_REQUEST(Reset, false);
}
+
+ // RPC handler for making a callable.
+ void MakeCallableHandler(
+ MasterCall<MakeCallableRequest, MakeCallableResponse>* call) {
+ master_impl_->MakeCallable(&call->request, &call->response,
+ [call](const Status& status) {
+ call->SendResponse(ToGrpcStatus(status));
+ });
+ ENQUEUE_REQUEST(MakeCallable, false);
+ }
+
+ // RPC handler for running a callable.
+ void RunCallableHandler(
+ MasterCall<RunCallableRequest, RunCallableResponse>* call) {
+ auto* trace = TraceRpc("RunCallable/Server", call->client_metadata());
+ CallOptions* call_opts = new CallOptions;
+ // The timeout may be overridden by a non-zero timeout in the
+ // callable's `RunOptions`; this overriding will happen inside the
+ // `MasterSession` implementation.
+ call_opts->SetTimeout(default_session_config_.operation_timeout_in_ms());
+ call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
+ master_impl_->RunCallable(call_opts, &call->request, &call->response,
+ [call, call_opts, trace](const Status& status) {
+ call->ClearCancelCallback();
+ delete call_opts;
+ delete trace;
+ call->SendResponse(ToGrpcStatus(status));
+ });
+ ENQUEUE_REQUEST(RunCallable, false);
+ }
+
+ // RPC handler for making a callable.
+ void ReleaseCallableHandler(
+ MasterCall<ReleaseCallableRequest, ReleaseCallableResponse>* call) {
+ master_impl_->ReleaseCallable(&call->request, &call->response,
+ [call](const Status& status) {
+ call->SendResponse(ToGrpcStatus(status));
+ });
+ ENQUEUE_REQUEST(ReleaseCallable, false);
+ }
+
#undef ENQUEUE_REQUEST
// Start tracing, including the ID attached to the RPC.
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc
index e2016e824c..c832adbbbf 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc
@@ -36,6 +36,9 @@ static const char* grpcMasterService_method_names[] = {
"/tensorflow.MasterService/CloseSession",
"/tensorflow.MasterService/ListDevices",
"/tensorflow.MasterService/Reset",
+ "/tensorflow.MasterService/MakeCallable",
+ "/tensorflow.MasterService/RunCallable",
+ "/tensorflow.MasterService/ReleaseCallable",
};
std::unique_ptr<MasterService::Stub> MasterService::NewStub(
@@ -64,7 +67,14 @@ MasterService::Stub::Stub(
rpcmethod_ListDevices_(grpcMasterService_method_names[5],
::grpc::internal::RpcMethod::NORMAL_RPC, channel),
rpcmethod_Reset_(grpcMasterService_method_names[6],
- ::grpc::internal::RpcMethod::NORMAL_RPC, channel) {}
+ ::grpc::internal::RpcMethod::NORMAL_RPC, channel),
+ rpcmethod_MakeCallable_(grpcMasterService_method_names[7],
+ ::grpc::internal::RpcMethod::NORMAL_RPC, channel),
+ rpcmethod_RunCallable_(grpcMasterService_method_names[8],
+ ::grpc::internal::RpcMethod::NORMAL_RPC, channel),
+ rpcmethod_ReleaseCallable_(grpcMasterService_method_names[9],
+ ::grpc::internal::RpcMethod::NORMAL_RPC,
+ channel) {}
::grpc::Status MasterService::Stub::CreateSession(
::grpc::ClientContext* context, const CreateSessionRequest& request,
@@ -115,8 +125,29 @@ MasterService::Stub::Stub(
context, request, response);
}
+::grpc::Status MasterService::Stub::MakeCallable(
+ ::grpc::ClientContext* context, const MakeCallableRequest& request,
+ MakeCallableResponse* response) {
+ return ::grpc::internal::BlockingUnaryCall(
+ channel_.get(), rpcmethod_MakeCallable_, context, request, response);
+}
+
+::grpc::Status MasterService::Stub::RunCallable(
+ ::grpc::ClientContext* context, const RunCallableRequest& request,
+ RunCallableResponse* response) {
+ return ::grpc::internal::BlockingUnaryCall(
+ channel_.get(), rpcmethod_RunCallable_, context, request, response);
+}
+
+::grpc::Status MasterService::Stub::ReleaseCallable(
+ ::grpc::ClientContext* context, const ReleaseCallableRequest& request,
+ ReleaseCallableResponse* response) {
+ return ::grpc::internal::BlockingUnaryCall(
+ channel_.get(), rpcmethod_ReleaseCallable_, context, request, response);
+}
+
MasterService::AsyncService::AsyncService() {
- for (int i = 0; i < 7; ++i) {
+ for (int i = 0; i < 10; ++i) {
AddMethod(new ::grpc::internal::RpcServiceMethod(
grpcMasterService_method_names[i],
::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h
index 6ae94b7441..3c382738c4 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h
@@ -79,6 +79,15 @@ class MasterService final {
virtual ::grpc::Status Reset(::grpc::ClientContext* context,
const ResetRequest& request,
ResetResponse* response) = 0;
+ virtual ::grpc::Status MakeCallable(::grpc::ClientContext* context,
+ const MakeCallableRequest& request,
+ MakeCallableResponse* response) = 0;
+ virtual ::grpc::Status RunCallable(::grpc::ClientContext* context,
+ const RunCallableRequest& request,
+ RunCallableResponse* response) = 0;
+ virtual ::grpc::Status ReleaseCallable(
+ ::grpc::ClientContext* context, const ReleaseCallableRequest& request,
+ ReleaseCallableResponse* response) = 0;
};
class Stub final : public StubInterface {
public:
@@ -104,6 +113,15 @@ class MasterService final {
::grpc::Status Reset(::grpc::ClientContext* context,
const ResetRequest& request,
ResetResponse* response) override;
+ ::grpc::Status MakeCallable(::grpc::ClientContext* context,
+ const MakeCallableRequest& request,
+ MakeCallableResponse* response) override;
+ ::grpc::Status RunCallable(::grpc::ClientContext* context,
+ const RunCallableRequest& request,
+ RunCallableResponse* response) override;
+ ::grpc::Status ReleaseCallable(::grpc::ClientContext* context,
+ const ReleaseCallableRequest& request,
+ ReleaseCallableResponse* response) override;
private:
std::shared_ptr< ::grpc::ChannelInterface> channel_;
@@ -114,6 +132,9 @@ class MasterService final {
const ::grpc::internal::RpcMethod rpcmethod_CloseSession_;
const ::grpc::internal::RpcMethod rpcmethod_ListDevices_;
const ::grpc::internal::RpcMethod rpcmethod_Reset_;
+ const ::grpc::internal::RpcMethod rpcmethod_MakeCallable_;
+ const ::grpc::internal::RpcMethod rpcmethod_RunCallable_;
+ const ::grpc::internal::RpcMethod rpcmethod_ReleaseCallable_;
};
static std::unique_ptr<Stub> NewStub(
const std::shared_ptr< ::grpc::ChannelInterface>& channel,
@@ -179,6 +200,30 @@ class MasterService final {
::grpc::Service::RequestAsyncUnary(6, context, request, response,
new_call_cq, notification_cq, tag);
}
+ void RequestMakeCallable(
+ ::grpc::ServerContext* context, MakeCallableRequest* request,
+ ::grpc::ServerAsyncResponseWriter<MakeCallableResponse>* response,
+ ::grpc::CompletionQueue* new_call_cq,
+ ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
+ ::grpc::Service::RequestAsyncUnary(7, context, request, response,
+ new_call_cq, notification_cq, tag);
+ }
+ void RequestRunCallable(
+ ::grpc::ServerContext* context, RunCallableRequest* request,
+ ::grpc::ServerAsyncResponseWriter<RunCallableResponse>* response,
+ ::grpc::CompletionQueue* new_call_cq,
+ ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
+ ::grpc::Service::RequestAsyncUnary(8, context, request, response,
+ new_call_cq, notification_cq, tag);
+ }
+ void RequestReleaseCallable(
+ ::grpc::ServerContext* context, ReleaseCallableRequest* request,
+ ::grpc::ServerAsyncResponseWriter<ReleaseCallableResponse>* response,
+ ::grpc::CompletionQueue* new_call_cq,
+ ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
+ ::grpc::Service::RequestAsyncUnary(9, context, request, response,
+ new_call_cq, notification_cq, tag);
+ }
};
};
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc
index 1088e9be66..1b92a79a67 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc
@@ -95,6 +95,28 @@ class GrpcRemoteMaster : public MasterInterface {
&MasterServiceStub::Reset);
}
+ Status MakeCallable(CallOptions* call_options,
+ const MakeCallableRequest* request,
+ MakeCallableResponse* response) override {
+ ::grpc::ClientContext ctx;
+ return Call(&ctx, call_options, request, response,
+ &MasterServiceStub::MakeCallable);
+ }
+ Status RunCallable(CallOptions* call_options,
+ const RunCallableRequest* request,
+ RunCallableResponse* response) override {
+ ::grpc::ClientContext ctx;
+ return Call(&ctx, call_options, request, response,
+ &MasterServiceStub::RunCallable);
+ }
+ Status ReleaseCallable(CallOptions* call_options,
+ const ReleaseCallableRequest* request,
+ ReleaseCallableResponse* response) override {
+ ::grpc::ClientContext ctx;
+ return Call(&ctx, call_options, request, response,
+ &MasterServiceStub::ReleaseCallable);
+ }
+
private:
// Start tracing, attaching a unique ID to both the trace and the RPC.
port::Tracing::TraceMe TraceRpc(StringPiece name,
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc
new file mode 100644
index 0000000000..d004abd1c1
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc
@@ -0,0 +1,213 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/util/rpc/call_container.h"
+#include "tensorflow/core/util/rpc/rpc_factory.h"
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h"
+
+namespace tensorflow {
+
+namespace {
+class GrpcCall {
+ public:
+ explicit GrpcCall(CallContainer<GrpcCall>* container, int index, bool try_rpc,
+ const string* request_msg, string* response_msg,
+ int32* status_code, string* status_message)
+ : container_(container),
+ index_(index),
+ try_rpc_(try_rpc),
+ request_msg_(request_msg),
+ response_msg_(response_msg),
+ status_code_(status_code),
+ status_message_(status_message) {}
+
+ void StartCancel() { call_opts_.StartCancel(); }
+
+ void Done(const Status& s) {
+ DCHECK(container_ != nullptr);
+ if (!s.ok() && try_rpc_) {
+ DCHECK(status_code_ != nullptr);
+ DCHECK(status_message_ != nullptr);
+ *status_code_ = s.code();
+ *status_message_ = s.error_message();
+ }
+ container_->Done(s, index_);
+ }
+
+ const string& request() const { return *request_msg_; }
+ string* response() const { return response_msg_; }
+ CallOptions* call_opts() { return &call_opts_; }
+
+ private:
+ CallContainer<GrpcCall>* const container_;
+ const int index_;
+ bool try_rpc_;
+ CallOptions call_opts_;
+ const string* request_msg_;
+ string* response_msg_;
+ int* status_code_;
+ string* status_message_;
+};
+
+} // namespace
+
+GrpcRPCFactory::GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast,
+ int64 timeout_in_ms)
+ : RPCFactory(), fail_fast_(fail_fast), timeout_in_ms_(timeout_in_ms) {
+ // TODO(ebrevdo): Investigate possible performance improvements by
+ // replacing this thread with a threadpool.
+ polling_thread_ =
+ ctx->env()->StartThread(ThreadOptions(), "rpc_op_grpc_factory", [this]() {
+ void* tag;
+ bool ok;
+ while (completion_queue_.Next(&tag, &ok)) {
+ GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
+ callback_tag->OnCompleted(ok);
+ }
+ });
+}
+
+GrpcRPCFactory::~GrpcRPCFactory() {
+ // The amount of time we wait depends on several parameters, including:
+ // - the value of the fail_fast attribute.
+ // - the timeout option of the rpc call in the proto declaration.
+ // - the network roundtrip time and service's execution time.
+ //
+ // If a connection is made but the service doesn't ever respond, and
+ // there is no timeout option set for this rpc call, then it is
+ // possible the RPC request will wait forever.
+ //
+ completion_queue_.Shutdown();
+ delete polling_thread_;
+}
+
+void GrpcRPCFactory::Call(OpKernelContext* ctx, int64 num_elements,
+ const Tensor& address_t, const Tensor& method_t,
+ const Tensor& request_t, const bool try_rpc,
+ Tensor* response_t, Tensor* status_code_t,
+ Tensor* status_message_t,
+ AsyncOpKernel::DoneCallback done) {
+ auto address = address_t.flat<string>();
+ auto method = method_t.flat<string>();
+ auto request = request_t.flat<string>();
+
+ // Stubs are maintained by the GrpcRPCFactory class and will be
+ // deleted when the class is destroyed.
+ ::grpc::GenericStub* singleton_stub = nullptr;
+ if (address.size() == 1) {
+ singleton_stub = GetOrCreateStubForAddress(address(0));
+ }
+ auto get_stub = [&address, this,
+ singleton_stub](int64 ix) -> ::grpc::GenericStub* {
+ return (address.size() > 1) ? GetOrCreateStubForAddress(address(ix))
+ : singleton_stub;
+ };
+ auto get_method_ptr = [&method](int64 ix) -> const string* {
+ return (method.size() > 1) ? &(method(ix)) : &(method(0));
+ };
+ auto get_request_ptr = [&request](int64 ix) -> const string* {
+ return (request.size() > 1) ? &(request(ix)) : &(request(0));
+ };
+
+ if (try_rpc) {
+ // In this case status_code will never be set in the response,
+ // so we just set it to OK.
+ DCHECK(status_code_t != nullptr);
+ status_code_t->flat<int32>().setConstant(
+ static_cast<int>(errors::Code::OK));
+ }
+
+ CancellationManager* cm = ctx->cancellation_manager();
+ CancellationToken cancellation_token = cm->get_cancellation_token();
+
+ // This object will delete itself when done.
+ auto* container =
+ new CallContainer<GrpcCall>(ctx, num_elements, fail_fast_, try_rpc,
+ std::move(done), cancellation_token);
+
+ auto response = response_t->flat<string>();
+ int32* status_code_ptr = nullptr;
+ string* status_message_ptr = nullptr;
+ if (try_rpc) {
+ status_code_ptr = status_code_t->flat<int32>().data();
+ status_message_ptr = status_message_t->flat<string>().data();
+ }
+ for (int i = 0; i < num_elements; ++i) {
+ container->calls()->emplace_back(
+ container, i, try_rpc, get_request_ptr(i), &response(i),
+ (try_rpc) ? &status_code_ptr[i] : nullptr,
+ (try_rpc) ? &status_message_ptr[i] : nullptr);
+ }
+
+ int i = 0;
+ for (GrpcCall& call : *(container->calls())) {
+ // This object will delete itself when done.
+ new RPCState<string>(get_stub(i), &completion_queue_, *get_method_ptr(i),
+ call.request(), call.response(),
+ /*done=*/[&call](const Status& s) { call.Done(s); },
+ call.call_opts(), fail_fast_, timeout_in_ms_);
+ ++i;
+ }
+
+ // Need to register this callback after all the RPCs are in
+ // flight; otherwise we may try to cancel an RPC *before* it
+ // launches, which is a no-op, and then fall into a deadlock.
+ bool is_cancelled = !cm->RegisterCallback(
+ cancellation_token, [container]() { container->StartCancel(); });
+
+ if (is_cancelled) {
+ ctx->SetStatus(errors::Cancelled("Operation has been cancelled."));
+ // container's reference counter will take care of calling done().
+ container->StartCancel();
+ }
+}
+
+::grpc::GenericStub* GrpcRPCFactory::GetOrCreateStubForAddress(
+ const string& address) {
+ mutex_lock lock(mu_);
+
+ auto stub = stubs_.find(address);
+ if (stub != stubs_.end()) return stub->second.get();
+
+ ChannelPtr channel = CreateChannelForAddress(address);
+ auto* created = new ::grpc::GenericStub(channel);
+ stubs_[address].reset(created);
+ return created;
+}
+
+GrpcRPCFactory::ChannelPtr GrpcRPCFactory::CreateChannelForAddress(
+ const string& address) {
+ ::grpc::ChannelArguments args;
+ args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
+
+ // Set a standard backoff timeout of 1s instead of the
+ // (sometimes default) 20s.
+ args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 1000);
+ return ::grpc::CreateCustomChannel(
+ /*target=*/address, ::grpc::InsecureChannelCredentials(), args);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h
new file mode 100644
index 0000000000..34ec235aaf
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h
@@ -0,0 +1,59 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/util/rpc/rpc_factory.h"
+
+namespace tensorflow {
+
+class GrpcRPCFactory : public RPCFactory {
+ public:
+ explicit GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast,
+ int64 timeout_in_ms);
+
+ // Explicit destructor to control destruction order.
+ ~GrpcRPCFactory() override;
+
+ void Call(OpKernelContext* ctx, int64 num_elements, const Tensor& address_t,
+ const Tensor& method_t, const Tensor& request_t, const bool try_rpc,
+ Tensor* response_t, Tensor* status_code_t, Tensor* status_message_t,
+ AsyncOpKernel::DoneCallback done) override;
+
+ protected:
+ typedef std::shared_ptr<::grpc::Channel> ChannelPtr;
+ virtual ChannelPtr CreateChannelForAddress(const string& address);
+
+ private:
+ ::grpc::GenericStub* GetOrCreateStubForAddress(const string& address);
+
+ bool fail_fast_;
+ int64 timeout_in_ms_;
+ ::grpc::CompletionQueue completion_queue_;
+ Thread* polling_thread_; // Owned.
+
+ mutex mu_;
+ typedef std::unique_ptr<::grpc::GenericStub> StubPtr;
+ std::unordered_map<string, StubPtr> stubs_ GUARDED_BY(mu_);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory_registration.cc b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory_registration.cc
new file mode 100644
index 0000000000..b884489378
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory_registration.cc
@@ -0,0 +1,34 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h"
+#include "tensorflow/core/util/rpc/rpc_factory.h"
+#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
+
+namespace tensorflow {
+namespace {
+
+// Used for adding the grpc factory to the RPC factory registry.
+struct Value {
+ static RPCFactory* Function(OpKernelConstruction* ctx, bool fail_fast,
+ int64 timeout_in_ms) {
+ return new GrpcRPCFactory(ctx, fail_fast, timeout_in_ms);
+ }
+};
+
+REGISTER_RPC_FACTORY("grpc", Value::Function);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
index 3e79a40683..fd1c150fa7 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
@@ -91,6 +91,15 @@ void ReEncodeConsts(GraphDef* gdef) {
}
} // namespace
+Status GrpcSession::Handle(string* out_handle) {
+ mutex_lock l(mu_);
+ if (handle_.empty()) {
+ return errors::InvalidArgument("A session is not created yet....");
+ }
+ *out_handle = handle_;
+ return Status::OK();
+}
+
Status GrpcSession::CreateImpl(CallOptions* call_options,
const GraphDef& graph) {
{
@@ -274,14 +283,9 @@ Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
Status GrpcSession::RunProto(CallOptions* call_options,
MutableRunStepRequestWrapper* req,
MutableRunStepResponseWrapper* resp) {
- {
- mutex_lock l(mu_);
- if (handle_.empty()) {
- return errors::InvalidArgument("A session is not created yet....");
- }
-
- req->set_session_handle(handle_);
- }
+ string handle;
+ TF_RETURN_IF_ERROR(Handle(&handle));
+ req->set_session_handle(handle);
return master_->RunStep(call_options, req, resp);
}
@@ -293,14 +297,7 @@ Status GrpcSession::PRunSetup(const std::vector<string>& input_names,
PartialRunSetupRequest req;
PartialRunSetupResponse resp;
CallOptions call_options;
- {
- mutex_lock l(mu_);
- if (handle_.empty()) {
- return errors::InvalidArgument("A session is not created yet....");
- }
-
- req.set_session_handle(handle_);
- }
+ TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
for (const string& feed : input_names) {
req.add_feed(feed);
}
@@ -400,6 +397,55 @@ Status GrpcSession::Reset(const SessionOptions& options,
return ret;
}
+Status GrpcSession::MakeCallable(const CallableOptions& callable_options,
+ CallableHandle* out_handle) {
+ MakeCallableRequest req;
+ TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
+ *req.mutable_options() = callable_options;
+ MakeCallableResponse resp;
+ CallOptions call_options;
+ call_options.SetTimeout(options_.config.operation_timeout_in_ms());
+ TF_RETURN_IF_ERROR(master_->MakeCallable(&call_options, &req, &resp));
+ *out_handle = resp.handle();
+ return Status::OK();
+}
+
+Status GrpcSession::RunCallable(CallableHandle handle,
+ const std::vector<Tensor>& feed_tensors,
+ std::vector<Tensor>* fetch_tensors,
+ RunMetadata* run_metadata) {
+ RunCallableRequest req;
+ TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
+ req.set_handle(handle);
+ for (const Tensor& feed : feed_tensors) {
+ feed.AsProtoTensorContent(req.mutable_feed()->Add());
+ }
+
+ RunCallableResponse resp;
+ CallOptions call_options;
+ call_options.SetTimeout(options_.config.operation_timeout_in_ms());
+ TF_RETURN_IF_ERROR(master_->RunCallable(&call_options, &req, &resp));
+ for (const TensorProto& fetch : resp.fetch()) {
+ Tensor fetch_tensor;
+ if (!fetch_tensor.FromProto(cpu_allocator(), fetch)) {
+ return errors::Internal(
+ "Could not parse fetched tensor data in response from master.");
+ }
+ fetch_tensors->push_back(std::move(fetch_tensor));
+ }
+ return Status::OK();
+}
+
+Status GrpcSession::ReleaseCallable(CallableHandle handle) {
+ ReleaseCallableRequest req;
+ TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
+ req.set_handle(handle);
+ ReleaseCallableResponse resp;
+ CallOptions call_options;
+ call_options.SetTimeout(options_.config.operation_timeout_in_ms());
+ return master_->ReleaseCallable(&call_options, &req, &resp);
+}
+
class GrpcSessionFactory : public SessionFactory {
public:
bool AcceptsOptions(const SessionOptions& options) override {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.h b/tensorflow/core/distributed_runtime/rpc/grpc_session.h
index d87956a135..63795117f9 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.h
@@ -82,20 +82,27 @@ class GrpcSession : public Session {
Status Close() override;
// NOTE: This API is still experimental and may change.
- ::tensorflow::Status PRunSetup(const std::vector<string>& input_names,
- const std::vector<string>& output_names,
- const std::vector<string>& target_nodes,
- string* handle) override;
+ Status PRunSetup(const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ string* handle) override;
// NOTE: This API is still experimental and may change.
- ::tensorflow::Status PRun(
- const string& handle,
- const std::vector<std::pair<string, Tensor> >& inputs,
- const std::vector<string>& output_names,
- std::vector<Tensor>* outputs) override;
+ Status PRun(const string& handle,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_names,
+ std::vector<Tensor>* outputs) override;
Status ListDevices(std::vector<DeviceAttributes>* response) override;
+ Status MakeCallable(const CallableOptions& callable_options,
+ CallableHandle* out_handle) override;
+ Status RunCallable(CallableHandle handle,
+ const std::vector<Tensor>& feed_tensors,
+ std::vector<Tensor>* fetch_tensors,
+ RunMetadata* run_metadata) override;
+ Status ReleaseCallable(CallableHandle handle) override;
+
protected:
// Takes ownership of `*master`.
void SetRemoteMaster(std::unique_ptr<MasterInterface> master);
@@ -111,6 +118,8 @@ class GrpcSession : public Session {
// The current version of the graph.
int64 current_graph_version_ GUARDED_BY(mu_);
+ Status Handle(string* out_handle) LOCKS_EXCLUDED(mu_);
+
Status RunHelper(const RunOptions& run_options,
const std::vector<std::pair<string, Tensor> >& inputs,
const std::vector<string>& output_tensor_names,
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
index 335c3febe2..45b15a54a2 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
@@ -120,6 +120,49 @@ TEST(GrpcSessionTest, BasicNonProtoAPI) {
}
}
+TEST(GrpcSessionTest, BasicCallable) {
+ GraphDef graph;
+ string node_names[3];
+ // c = a * b
+ CreateGraphDef(&graph, node_names);
+
+ std::unique_ptr<test::TestCluster> cluster;
+ TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
+
+ std::unique_ptr<Session> session(
+ NewRemote(Options(cluster->targets()[0], 1)));
+ ASSERT_TRUE(session != nullptr);
+
+ for (int iters = 0; iters < 25; ++iters) {
+ TF_CHECK_OK(session->Create(graph));
+ {
+ // Just run to target node
+ CallableOptions opts;
+ opts.add_target(node_names[2]);
+ Session::CallableHandle handle;
+ TF_CHECK_OK(session->MakeCallable(opts, &handle));
+ TF_CHECK_OK(session->RunCallable(handle, {}, nullptr, nullptr));
+ TF_CHECK_OK(session->ReleaseCallable(handle));
+ }
+ {
+ // Run to a target node and a real tensor
+ CallableOptions opts;
+ opts.add_target(node_names[1]);
+ opts.add_fetch(node_names[2] + ":0");
+ Session::CallableHandle handle;
+ TF_CHECK_OK(session->MakeCallable(opts, &handle));
+ std::vector<Tensor> outputs;
+ TF_CHECK_OK(session->RunCallable(handle, {}, &outputs, nullptr));
+ ASSERT_EQ(1, outputs.size());
+ ASSERT_TRUE(outputs[0].IsInitialized());
+ ASSERT_EQ(4.0, outputs[0].flat<float>()(0));
+ TF_CHECK_OK(session->ReleaseCallable(handle));
+ }
+
+ TF_CHECK_OK(session->Close());
+ }
+}
+
TEST(GrpcSessionTest, BasicNonProtoAPIConsistentOrder) {
GraphDef graph;
string node_names[3];
diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc
index 6182f95f28..1a7e5219cd 100644
--- a/tensorflow/core/framework/allocator.cc
+++ b/tensorflow/core/framework/allocator.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/framework/visitable_allocator.h"
+#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/allocator_registry.h"
#include "tensorflow/core/framework/log_memory.h"
@@ -88,20 +88,15 @@ void EnableCPUAllocatorFullStats(bool enable) {
cpu_allocator_collect_full_stats = enable;
}
-class CPUAllocator : public VisitableAllocator {
+class CPUAllocator : public Allocator {
public:
- CPUAllocator()
- : total_allocation_warning_triggered_(false), allocation_begun_(false) {}
+ CPUAllocator() : total_allocation_warning_triggered_(false) {}
~CPUAllocator() override {}
string Name() override { return "cpu"; }
void* AllocateRaw(size_t alignment, size_t num_bytes) override {
- if (!allocation_begun_) {
- allocation_begun_ = true;
- }
-
if (num_bytes > LargeAllocationWarningBytes()) {
LOG(WARNING) << "Allocation of " << num_bytes << " exceeds "
<< 100 * kLargeAllocationWarningThreshold
@@ -127,38 +122,16 @@ class CPUAllocator : public VisitableAllocator {
total_allocation_warning_triggered_ = true;
}
}
-
- // visit each Visitor in alloc_visitors_
- if (p != nullptr) {
- for (const Visitor& v : alloc_visitors_) {
- v(p, num_bytes);
- }
- }
-
return p;
}
void DeallocateRaw(void* ptr) override {
- std::size_t alloc_size;
- bool init_alloc_size = false;
if (cpu_allocator_collect_stats) {
- alloc_size = port::MallocExtension_GetAllocatedSize(ptr);
- init_alloc_size = true;
+ const std::size_t alloc_size =
+ port::MallocExtension_GetAllocatedSize(ptr);
mutex_lock l(mu_);
stats_.bytes_in_use -= alloc_size;
}
-
- // visit each Visitor in free_visitors_
- if (ptr != nullptr) {
- if (!init_alloc_size) {
- alloc_size = port::MallocExtension_GetAllocatedSize(ptr);
- init_alloc_size = true;
- }
- for (const Visitor& v : free_visitors_) {
- v(ptr, alloc_size);
- }
- }
-
port::AlignedFree(ptr);
}
@@ -178,37 +151,11 @@ class CPUAllocator : public VisitableAllocator {
return port::MallocExtension_GetAllocatedSize(ptr);
}
- // REQUIRES: can only add visitors before the first Allocate call
-
- void AddAllocVisitor(Visitor visitor) override {
- mutex_lock lock(visitor_mutex_);
- CHECK(!allocation_begun_)
- << "AddAllocVisitor may not be called after allocation has begun.";
- alloc_visitors_.push_back(visitor);
- }
-
- void AddFreeVisitor(Visitor visitor) override {
- mutex_lock lock(visitor_mutex_);
- CHECK(!allocation_begun_)
- << "AddFreeVisitor may not be called after allocation has begun.";
- free_visitors_.push_back(visitor);
- }
-
private:
mutex mu_;
AllocatorStats stats_ GUARDED_BY(mu_);
bool total_allocation_warning_triggered_ GUARDED_BY(mu_);
- // visitor_mutex_ protects write access to alloc_visitors_ and free_visitors_.
- // While write access is mutually exclusive, reads may happen concurrently.
- // This is okay because we may only append to alloc_visitors_ and
- // free_visitors_ before first allocation, and subsequently we only read these
- // vectors.
- mutex visitor_mutex_;
- std::vector<Visitor> alloc_visitors_;
- std::vector<Visitor> free_visitors_;
- std::atomic<bool> allocation_begun_;
-
TF_DISALLOW_COPY_AND_ASSIGN(CPUAllocator);
};
diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h
index 362d345133..5810c7fa54 100644
--- a/tensorflow/core/framework/collective.h
+++ b/tensorflow/core/framework/collective.h
@@ -103,11 +103,8 @@ struct CollectiveParams {
// Rank of this device in each subdivision permutation.
std::vector<int> subdiv_rank;
std::vector<int> subdiv_source_rank;
- const Tensor* in_tensor; // kernel input
- Tensor* out_tensor; // kernel output
std::unique_ptr<OpKernel> merge_op; // reduction only
std::unique_ptr<OpKernel> final_op; // reduction only
- OpKernelContext* op_context;
string ToString() const;
};
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index cfde1e8ea3..05171006b0 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -96,7 +96,7 @@ OpKernel::OpKernel(OpKernelConstruction* context,
output_memory_types_(context->output_memory_types().begin(),
context->output_memory_types().end()),
graph_def_version_(context->graph_def_version()),
- is_internal_(StringPiece(type_string()).starts_with("_")),
+ is_internal_(str_util::StartsWith(type_string(), "_")),
input_name_map_(context->num_inputs()),
output_name_map_(context->num_outputs()) {
OP_REQUIRES_OK(context,
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 54ecaa5dd4..cc1ec47a83 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -726,6 +726,24 @@ ShapeHandle InferenceContext::Matrix(DimensionOrConstant dim1,
return MakeShape({dim1, dim2});
}
+Status InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
+ int input_idx, ShapeHandle* out) {
+ ShapeHandle input_shape;
+ TF_RETURN_IF_ERROR(WithRankAtMost(input(input_idx), 1, &input_shape));
+
+ requested_input_tensor_as_partial_shape_[input_idx] = true;
+ if (input_idx < input_tensors_as_shapes_.size() &&
+ input_tensors_as_shapes_[input_idx].IsSet() &&
+ RankKnown(input_tensors_as_shapes_[input_idx])) {
+ *out = input_tensors_as_shapes_[input_idx];
+ return Status::OK();
+ }
+
+ return InternalMakeShapeFromTensor(
+ true /* treat_unknown_scalar_tensor_as_unknown_shape */,
+ input_tensor(input_idx), input_shape, out);
+}
+
Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
ShapeHandle* out) {
ShapeHandle input_shape;
@@ -739,13 +757,31 @@ Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
return Status::OK();
}
- return MakeShapeFromTensor(input_tensor(input_idx), input_shape, out);
+ return InternalMakeShapeFromTensor(
+ false /* treat_unknown_scalar_tensor_as_unknown_shape */,
+ input_tensor(input_idx), input_shape, out);
}
Status InferenceContext::MakeShapeFromTensor(const Tensor* t,
ShapeHandle tensor_shape,
ShapeHandle* out) {
+ return InternalMakeShapeFromTensor(
+ false /* treat_unknown_scalar_tensor_as_unknown_shape */, t, tensor_shape,
+ out);
+}
+
+Status InferenceContext::InternalMakeShapeFromTensor(
+ bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
+ ShapeHandle tensor_shape, ShapeHandle* out) {
+ // Only callers who have set
+ if (!treat_unknown_scalar_tensor_as_unknown_shape) {
+ TF_RETURN_IF_ERROR(WithRank(tensor_shape, 1, &tensor_shape));
+ }
if (t == nullptr) {
+ // This is guarded by the check above.
+ if (Rank(tensor_shape) == 0) {
+ return ReturnUnknownShape(out);
+ }
// Shape tensor is not known, but if the shape of the shape tensor is then
// the right number of unknown dims can be created.
DimensionHandle shape_dim = Dim(tensor_shape, 0);
@@ -759,10 +795,46 @@ Status InferenceContext::MakeShapeFromTensor(const Tensor* t,
return ReturnCreatedShape(dims, out);
}
+ if (t->shape().dims() == 0) {
+ if (t->dtype() == DataType::DT_INT32) {
+ auto flat_t = t->scalar<int32>();
+ if (flat_t() != -1) {
+ *out = nullptr;
+ return errors::InvalidArgument(
+ "Input tensor must be rank 1, or if its rank 0 it must have value "
+ "-1 "
+ "(representing an unknown shape). Saw value: ",
+ flat_t());
+ }
+ return ReturnUnknownShape(out);
+ } else if (t->dtype() == DataType::DT_INT64) {
+ auto flat_t = t->scalar<int64>();
+ if (flat_t() != -1) {
+ *out = nullptr;
+ return errors::InvalidArgument(
+ "Input tensor must be rank 1, or if its rank 0 it must have value "
+ "-1 "
+ "(representing an unknown shape). Saw value: ",
+ flat_t());
+ }
+ return ReturnUnknownShape(out);
+ } else {
+ *out = nullptr;
+ return errors::InvalidArgument(
+ "Input tensor must be int32 or int64, but was ",
+ DataTypeString(t->dtype()));
+ }
+ }
+
if (t->shape().dims() != 1) {
*out = nullptr;
- return errors::InvalidArgument("Input tensor must be rank 1, but was rank ",
- t->shape().dims());
+ return errors::InvalidArgument(
+ "Input tensor must be rank 1, but was rank ", t->shape().dims(), ".",
+ ((t->shape().dims() == 0)
+ ? "If it is rank 0 rank 0 it must have statically known value -1 "
+ "(representing an unknown shape). "
+ : " "),
+ "Saw tensor shape ", t->shape().DebugString());
}
std::vector<DimensionHandle> dims;
if (t->dtype() == DataType::DT_INT32) {
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index accc587000..cdb4bd79bb 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -463,6 +463,12 @@ class InferenceContext {
// the input tensor is NULL, then an unknown shape is returned.
Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out);
+ // Like the function above, but treats scalar values as unknown
+ // shapes. **NOTE** If the scalar is statically known, its value
+ // must be -1 or an error is returned.
+ Status MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx,
+ ShapeHandle* out);
+
// Returns in <out> a new shape corresponding to <proto>.
Status MakeShapeFromShapeProto(const TensorShapeProto& proto,
ShapeHandle* out);
@@ -708,6 +714,11 @@ class InferenceContext {
merged_dims_.clear();
}
+ // Helper method for MakeShapeFromTensor and MakeShapeFromShapeTensor.
+ Status InternalMakeShapeFromTensor(
+ bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
+ ShapeHandle tensor_shape, ShapeHandle* out);
+
ShapeManager shape_manager_;
// inputs_, outputs_, and input_tensors_as_shapes_ refer to values from
diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc
index da103bfec9..586c38e43b 100644
--- a/tensorflow/core/framework/shape_inference_test.cc
+++ b/tensorflow/core/framework/shape_inference_test.cc
@@ -1081,17 +1081,26 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
t = ::tensorflow::test::AsTensor<int64>({});
EXPECT_EQ("[]", create(&t));
+ // Test negative scalar
+ t = ::tensorflow::test::AsScalar<int32>(-1);
+ EXPECT_EQ("?", create(&t));
+
t = ::tensorflow::test::AsTensor<float>({1, 2, 3});
EXPECT_TRUE(str_util::StrContains(
create(&t), "Input tensor must be int32 or int64, but was float"));
t = ::tensorflow::test::AsScalar<int32>(1);
+ auto s_scalar = create(&t);
EXPECT_TRUE(str_util::StrContains(
- create(&t), "Input tensor must be rank 1, but was rank 0"));
+ s_scalar,
+ "Input tensor must be rank 1, or if its rank 0 it must have value -1"))
+ << s_scalar;
t = ::tensorflow::test::AsTensor<int32>({1, 2}, TensorShape{2, 1});
+ auto s_matrix = create(&t);
EXPECT_TRUE(str_util::StrContains(
- create(&t), "Input tensor must be rank 1, but was rank 2"));
+ s_matrix, "Input tensor must be rank 1, but was rank 2"))
+ << s_matrix;
// Test negative values for the dims.
t = ::tensorflow::test::AsTensor<int64>({3, -2, 1});
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 79735e6cc2..087190ad2a 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -30,6 +30,7 @@ constexpr char kConst[] = "Const";
constexpr char kConv2d[] = "Conv2D";
constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter";
constexpr char kConv2dBackpropInput[] = "Conv2DBackpropInput";
+constexpr char kFusedConv2dBiasActivation[] = "FusedConv2DBiasActivation";
constexpr char kMatMul[] = "MatMul";
constexpr char kSparseMatMul[] = "SparseMatMul";
constexpr char kPlaceholder[] = "Placeholder";
@@ -196,6 +197,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
wrap(&OpLevelCostEstimator::PredictConv2DBackpropFilter)},
{kConv2dBackpropInput,
wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput)},
+ {kFusedConv2dBiasActivation,
+ wrap(&OpLevelCostEstimator::PredictFusedConv2DBiasActivation)},
{kMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
{kSparseMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
{kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)},
@@ -545,7 +548,6 @@ int64 OpLevelCostEstimator::CountConv2DOperations(
ops *= conv_dims.kx * conv_dims.ky;
ops *= conv_dims.iz * conv_dims.oz;
ops *= kOpsPerMac;
- VLOG(1) << "Operations for Conv2D " << ops;
if (conv_info != nullptr) {
*conv_info = conv_dims;
@@ -983,6 +985,91 @@ Costs OpLevelCostEstimator::PredictConv2DBackpropFilter(
return costs;
}
+Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
+ const OpContext& op_context) const {
+ // FusedConv2DBiasActivation computes a fused kernel which implements:
+ // 2D convolution, adds side input with separate scaling on convolution and
+ // side inputs, then adds bias, and finally applies the ReLU activation
+ // function to the result:
+ //
+ // Input -> Conv2D -> Add -> BiasAdd -> ReLU
+ // ^ ^ ^
+ // Filter Side Input Bias
+ //
+ // Note that when adding the side input, the operation multiplies the output
+ // of Conv2D by conv_input_scale, confusingly, and the side_input by
+ // side_input_scale.
+ //
+ // Note that in the special case that side_input_scale is 0, which we infer
+ // from side_input having dimensions [], we skip that addition operation.
+ //
+ // For more information, see
+ // contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+ auto& conv_input = op_context.op_info.inputs(0);
+ auto& filter = op_context.op_info.inputs(1);
+ auto& bias = op_context.op_info.inputs(2);
+ auto& side_input = op_context.op_info.inputs(3);
+ auto& conv_input_scale = op_context.op_info.inputs(4);
+ auto& side_input_scale = op_context.op_info.inputs(5);
+
+ // Manually compute our convolution dimensions.
+ bool found_unknown_shapes = false;
+ auto dims = ConvolutionDimensionsFromInputs(
+ conv_input.shape(), filter.shape(), op_context.op_info,
+ &found_unknown_shapes);
+
+ // Construct the shape of our output tensor from our convolution dimensions
+ // and format, as it may not be available yet.
+ //
+ // TODO(varomodt): should we centralize the Conv2D input/output shapes?
+ bool unknown_conv_format = false;
+ OpInfo::TensorProperties output;
+ switch (GetConvolutionFormat(op_context)) {
+ case NCHW:
+ output =
+ DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.ox, dims.oy});
+ break;
+ case NHWC:
+ output =
+ DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz});
+ break;
+ default:
+ // TODO(b/77722245): support cost estimation for NCHW_VECT_C.
+ LOG(WARNING) << "unsupported data format: "
+ << GetDataFormat(op_context.op_info)
+ << " Defaulting to NHWC.";
+ output =
+ DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz});
+ unknown_conv_format = true;
+ break;
+ }
+
+ // Add the operations the fused op always computes.
+ std::vector<OpContext> component_ops = {
+ FusedChildContext(op_context, "Conv2D", output, {conv_input, filter}),
+ FusedChildContext(op_context, "Mul", output, {output, conv_input_scale}),
+ FusedChildContext(op_context, "BiasAdd", output, {output, bias}),
+ FusedChildContext(op_context, "Relu", output, {output})};
+
+ // Add our side_input iff it's non-empty.
+ if (side_input.shape().dim_size() > 0) {
+ component_ops.push_back(FusedChildContext(op_context, "Mul", side_input,
+ {side_input, side_input_scale}));
+ component_ops.push_back(
+ FusedChildContext(op_context, "Add", output, {side_input, output}));
+ }
+
+ // Construct an op_context which definitely has our output shape.
+ auto op_context_with_output = op_context;
+ op_context_with_output.op_info.mutable_outputs()->Clear();
+ *op_context_with_output.op_info.mutable_outputs()->Add() = output;
+
+ // Construct component operations and run the cost computation.
+ auto costs = PredictFusedOp(op_context_with_output, component_ops);
+ costs.inaccurate |= found_unknown_shapes || unknown_conv_format;
+ return costs;
+}
+
Costs OpLevelCostEstimator::PredictMatMul(const OpContext& op_context) const {
const auto& op_features = op_context.op_info;
bool found_unknown_shapes = false;
@@ -1086,6 +1173,66 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice(
return costs;
}
+Costs OpLevelCostEstimator::PredictFusedOp(
+ const OpContext& op_context,
+ const std::vector<OpContext>& fused_op_contexts) const {
+ // Note that PredictOpCountBasedCost will get the correct memory_time from
+ // the node's inputs and outputs; but we don't want to have to re-implement
+ // the logic for computing the operation count of each of our component
+ // operations here; so we simply add the compute times of each component
+ // operation, then update the execution time.
+ Costs fused_cost = PredictOpCountBasedCost(0, op_context.op_info);
+ fused_cost.compute_time = 0;
+ fused_cost.inaccurate = false;
+ for (auto& fused_op : fused_op_contexts) {
+ auto op_cost = PredictCosts(fused_op);
+ fused_cost.compute_time += op_cost.compute_time;
+ fused_cost.inaccurate |= op_cost.inaccurate;
+ }
+
+ CombineCostsAndUpdateExecutionTime(&fused_cost);
+ return fused_cost;
+}
+
+/* static */
+OpContext OpLevelCostEstimator::FusedChildContext(
+ const OpContext& parent, const string& op_name,
+ const OpInfo::TensorProperties& output,
+ const std::vector<OpInfo::TensorProperties>& inputs) {
+ // Setup the base parameters of our new context.
+ OpContext new_context;
+ new_context.name = op_name;
+ new_context.device_name = parent.device_name;
+ new_context.op_info = parent.op_info;
+ new_context.op_info.set_op(op_name);
+
+ // Setup the inputs of our new context.
+ new_context.op_info.mutable_inputs()->Clear();
+ for (const auto& input : inputs) {
+ *new_context.op_info.mutable_inputs()->Add() = input;
+ }
+
+ // Setup the output of our new context.
+ new_context.op_info.mutable_outputs()->Clear();
+ *new_context.op_info.mutable_outputs()->Add() = output;
+
+ return new_context;
+}
+
+/* static */
+OpInfo::TensorProperties OpLevelCostEstimator::DescribeTensor(
+ DataType type, const std::vector<int64>& dims) {
+ OpInfo::TensorProperties ret;
+ ret.set_dtype(type);
+
+ auto shape = ret.mutable_shape();
+ for (const int dim : dims) {
+ shape->add_dim()->set_size(dim);
+ }
+
+ return ret;
+}
+
/* static */
OpLevelCostEstimator::ConvolutionDimensions
OpLevelCostEstimator::OpDimensionsFromInputs(
@@ -1371,6 +1518,21 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad(
return costs;
}
+/* static */
+OpLevelCostEstimator::ConvolutionFormat
+OpLevelCostEstimator::GetConvolutionFormat(const OpContext& op_context) {
+ auto data_format = GetDataFormat(op_context.op_info);
+ if (data_format == "NCHW") {
+ return NCHW;
+ } else if (data_format == "NHWC") {
+ return NHWC;
+ } else if (data_format == "NCHW_VECT_C") {
+ return NCHW_VECT_C;
+ }
+
+ return UNKNOWN_CONVOLUTION_FORMAT;
+}
+
void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime(
Costs* costs) const {
if (compute_memory_overlap_) {
@@ -1379,6 +1541,5 @@ void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime(
costs->execution_time = costs->compute_time + costs->memory_time;
}
}
-
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index 7080264698..35649f7ee9 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -82,6 +82,13 @@ class OpLevelCostEstimator {
int64 sy; // Stride y.
Padding padding; // SAME or VALID.
};
+ enum ConvolutionFormat {
+ UNKNOWN_CONVOLUTION_FORMAT,
+ NHWC,
+ NCHW,
+ NCHW_VECT_C,
+ NCHW_VECT_W,
+ };
int64 CountConv2DOperations(const OpInfo& op_features,
bool* found_unknown_shapes) const;
int64 CountConv2DOperations(const OpInfo& op_features,
@@ -138,6 +145,7 @@ class OpLevelCostEstimator {
Costs PredictCwiseOp(const OpContext& op_context) const;
Costs PredictConv2DBackpropInput(const OpContext& op_context) const;
Costs PredictConv2DBackpropFilter(const OpContext& op_context) const;
+ Costs PredictFusedConv2DBiasActivation(const OpContext& op_context) const;
Costs PredictMatMul(const OpContext& op_context) const;
Costs PredictNoOp(const OpContext& op_context) const;
Costs PredictIdentity(const OpContext& op_context) const;
@@ -152,6 +160,10 @@ class OpLevelCostEstimator {
Costs PredictFusedBatchNorm(const OpContext& op_context) const;
Costs PredictFusedBatchNormGrad(const OpContext& op_context) const;
+ // Generic cost prediction method for fused operations.
+ Costs PredictFusedOp(const OpContext& op_context,
+ const std::vector<OpContext>& fused_op_contexts) const;
+
// Utility function for safe division. Returns 0
// if rhs is 0 or negative.
static double SafeDiv(const double lhs, const double rhs) {
@@ -173,6 +185,20 @@ class OpLevelCostEstimator {
const TensorShapeProto& original_image_shape, const OpInfo& op_info,
bool* found_unknown_shapes);
+ // Helper to construct child operation contexts for the component operations
+ // of fused ops.
+ static OpContext FusedChildContext(
+ const OpContext& parent, const string& op_name,
+ const OpInfo::TensorProperties& output,
+ const std::vector<OpInfo::TensorProperties>& inputs);
+
+ // Helper to construct tensor shapes.
+ static OpInfo::TensorProperties DescribeTensor(
+ DataType type, const std::vector<int64>& dims);
+
+ // Returns the Conv2D format for this operation.
+ static ConvolutionFormat GetConvolutionFormat(const OpContext& op_context);
+
// This method calculates the execution time depending on whether IO can
// overlap with computation. It assumes the memory and the compute times have
// already been calculated.
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
index d797a8a8c1..13ea43bed6 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
@@ -93,6 +93,14 @@ OpContext DescribeBatchMatMul(const std::vector<int>& dims_a,
return op_context;
}
+// Wrangles the minimum number of proto fields to set up a 1D Tensor for cost
+// estimation purposes.
+void DescribeTensor1D(int dim0, OpInfo::TensorProperties* tensor) {
+ auto shape = tensor->mutable_shape();
+ shape->add_dim()->set_size(dim0);
+ tensor->set_dtype(DT_FLOAT);
+}
+
// Wrangles the minimum number of proto fields to set up a 4D Tensor for cost
// estimation purposes.
void DescribeTensor4D(int dim0, int dim1, int dim2, int dim3,
@@ -120,6 +128,38 @@ OpContext DescribeConvolution(int batch, int ix, int iy, int iz1, int iz2,
return op_context;
}
+// DescribeFusedConv2DBiasActivation constructs an OpContext for a
+// FusedConv2DBiasActivation applied to a convolution input tensor with shape
+// (batch, ix, iy, iz1), a kernel tensor with shape (kx, ky, iz2, oz), a
+// bias tensor with shape (oz), a side input tensor with shape
+// (batch, ox, oy, oz) if has_side_input is set, and two scaling tensors with
+// shape (1).
+//
+// Note that this assumes the NHWC data format.
+OpContext DescribeFusedConv2DBiasActivation(int batch, int ix, int iy, int iz1,
+ int iz2, int kx, int ky, int ox,
+ int oy, int oz,
+ bool has_side_input) {
+ OpContext op_context;
+ SetCpuDevice(&op_context.op_info);
+ op_context.op_info.set_op("FusedConv2DBiasActivation");
+ DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
+ DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs());
+ DescribeTensor1D(oz, op_context.op_info.add_inputs());
+
+ // Add the side_input, if any.
+ auto side_input = op_context.op_info.add_inputs();
+ if (has_side_input) {
+ DescribeTensor4D(batch, ox, oy, oz, side_input);
+ }
+
+ // Add the scaling tensors.
+ DescribeTensor1D(1, op_context.op_info.add_inputs());
+ DescribeTensor1D(1, op_context.op_info.add_inputs());
+
+ return op_context;
+}
+
// DescribeUnaryOp constructs an OpContext for the given operation applied to
// a 4-tensor with shape (size1, 1, 1, 1).
OpContext DescribeUnaryOp(const string& op, int size1) {
@@ -162,12 +202,9 @@ OpContext DescribeBiasAdd(int size1, int size2) {
op_context.op_info.set_op("BiasAdd");
DescribeTensor4D(1, 1, size2, size1, op_context.op_info.add_inputs());
+ DescribeTensor1D(size1, op_context.op_info.add_inputs());
DescribeTensor4D(1, 1, size2, size1, op_context.op_info.add_outputs());
- auto bias = op_context.op_info.add_inputs();
- bias->mutable_shape()->add_dim()->set_size(size1);
- bias->set_dtype(DT_FLOAT);
-
return op_context;
}
@@ -486,6 +523,25 @@ TEST_F(OpLevelCostEstimatorTest, ExecutionTimeSumOrMax) {
SetComputeMemoryOverlap(false); // Set it back to default.
}
+TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationExecutionTime) {
+ auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
+ 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true));
+ EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
+ EXPECT_FALSE(cost.inaccurate);
+}
+
+TEST_F(OpLevelCostEstimatorTest,
+ FusedConv2DBiasActivationNoSideInputExecutionTime) {
+ auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
+ 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ false));
+ EXPECT_EQ(Costs::Duration(825345), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(355321038), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(356146383), cost.execution_time);
+ EXPECT_FALSE(cost.inaccurate);
+}
+
TEST_F(OpLevelCostEstimatorTest, MulExecutionTime) {
auto cost = PredictCosts(DescribeBinaryOp("Mul", 1000, 1));
EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index a24d2dbd9f..1fb1711f54 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -245,6 +245,8 @@ bool IsPolygamma(const NodeDef& node) { return node.op() == "Polygamma"; }
bool IsPow(const NodeDef& node) { return node.op() == "Pow"; }
+bool IsPrint(const NodeDef& node) { return node.op() == "Print"; }
+
bool IsProd(const NodeDef& node) { return node.op() == "Prod"; }
bool IsReal(const NodeDef& node) { return node.op() == "Real"; }
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 8667f72c7e..d516baebf3 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -95,6 +95,7 @@ bool IsNoOp(const NodeDef& node);
bool IsNotEqual(const NodeDef& node);
bool IsPlaceholder(const NodeDef& node);
bool IsPolygamma(const NodeDef& node);
+bool IsPrint(const NodeDef& node);
bool IsProd(const NodeDef& node);
bool IsPow(const NodeDef& node);
bool IsReal(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 122fd48584..e4bc030885 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -480,6 +480,7 @@ tf_cuda_cc_test(
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:virtual_cluster",
+ "//tensorflow/core/grappler/costs:virtual_placer",
],
)
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 7bf264ba30..fa0f7c1c6e 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -279,6 +279,7 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
ctx_ext_(ctx_ext) {}
virtual ~ArithmeticOptimizerStage() = default;
+ protected:
// Simplification graph rewrite can create additional nodes that are inputs
// to final simplified node, they can be also added to the arithmetic
// optimizer queue for further optimization.
@@ -304,10 +305,176 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
}
private:
- // extened context required for ArithmeticOptimizer
+ // Extended context required for ArithmeticOptimizer.
const ArithmeticOptimizerContext ctx_ext_;
};
+// Subtype of ArithmeticOptimizerStage that does optimization by rewriting a
+// group of nodes from the optimized graph.
+//
+// * AddOpsRewrite:
+// Rewrite a group of Add/AddN with compact Add/AddN tree
+//
+// * MinimizeBroadcasts:
+// Rewrite a group of binary associative ops, reordering
+// inputs, to minimize the cost of broadcast
+class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage {
+ public:
+ explicit ArithmeticNodesGroupOptimizerStage(
+ const string& name, const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext ctx_ext)
+ : ArithmeticOptimizerStage(name, ctx, ctx_ext), optimized_nodes_{} {}
+ ~ArithmeticNodesGroupOptimizerStage() override = default;
+
+ // Input name with a statically inferred shape from GraphProperties
+ struct InputAndShape {
+ InputAndShape(const string& input, const TensorShapeProto& shape)
+ : input(input), shape(shape) {}
+ string input;
+ TensorShapeProto shape;
+ };
+
+ // Subgraph (subtree) of nodes, that we want to optimize in "one shot" (e.g.
+ // all the Add nodes that we plan to rewrite with a single AddN). Subgraph is
+ // obtained by graph traversal, starting from a root node.
+ struct OptimizedNodesGroup {
+ NodeDef* root_node;
+ TensorShapeProto root_shape;
+ // Optimized nodes that will be updated or removed by rewrite
+ std::vector<NodeDef*> optimized_nodes;
+ // Inputs to optimized nodes
+ std::vector<InputAndShape> inputs;
+ };
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
+
+ OptimizedNodesGroup group;
+ TF_RETURN_IF_ERROR(CreateOptimizedNodesGroup(node, &group));
+
+ if (!group.optimized_nodes.empty()) {
+ *simplified_node_name = RewriteOptimizedNodesGroup(group);
+ }
+
+ return Status::OK();
+ }
+
+ protected:
+ // Modify the optimized graph after nodes group was successfully identified
+ virtual string RewriteOptimizedNodesGroup(
+ const OptimizedNodesGroup& group) = 0;
+
+ // Check if input can become a part of current optimized nodes group.
+ virtual bool IsAbsorbableByOptimizedNodesGroup(
+ const OptimizedNodesGroup& group, const string& input) const = 0;
+
+ Status AbsorbInputByOptimizedNodesGroup(const string& input,
+ OptimizedNodesGroup* group) const {
+ NodeDef* node;
+ TF_RETURN_IF_ERROR(GetInputNode(input, &node));
+
+ if (IsAbsorbableByOptimizedNodesGroup(*group, input)) {
+ for (int i = 0; i < node->input_size(); ++i) {
+ const string& input_i = node->input(i);
+ if (!IsControlInput(input)) {
+ TF_RETURN_IF_ERROR(AbsorbInputByOptimizedNodesGroup(input_i, group));
+ }
+ }
+ group->optimized_nodes.push_back(node);
+ } else {
+ // If node can't be absorbed, add it to OptimizedNodesGroup input
+ OpInfo::TensorProperties properties;
+ TF_RETURN_IF_ERROR(GetTensorProperties(input, &properties));
+ group->inputs.emplace_back(input, properties.shape());
+ }
+ return Status::OK();
+ }
+
+ Status CreateOptimizedNodesGroup(NodeDef* root_node,
+ OptimizedNodesGroup* group) const {
+ OpInfo::TensorProperties root_node_output_properties;
+ TF_RETURN_IF_ERROR(
+ GetTensorProperties(root_node->name(), &root_node_output_properties));
+
+ group->root_node = root_node;
+ group->root_shape = root_node_output_properties.shape();
+
+ group->optimized_nodes.reserve(root_node->input_size());
+ for (int i = 0; i < root_node->input_size(); ++i) {
+ const string& input_i = root_node->input(i);
+ if (!IsControlInput(input_i)) {
+ TF_RETURN_IF_ERROR(AbsorbInputByOptimizedNodesGroup(input_i, group));
+ }
+ }
+
+ return Status::OK();
+ }
+
+ // Check if all inputs can be broadcasted to the same shape
+ // TODO(ezhulenev): move to GraphOptimizerStage?
+ bool HasAllInputsBroadcastableToShape(
+ const NodeDef& node, const OpInfo::TensorProperties& properties) const {
+ auto is_broadcastable = [this, &properties](const string& input) {
+ OpInfo::TensorProperties input_props;
+ Status has_input_properties = GetTensorProperties(input, &input_props);
+ return has_input_properties.ok() &&
+ ShapesBroadcastable(properties, input_props);
+ };
+ return std::all_of(node.input().begin(), node.input().end(),
+ is_broadcastable);
+ }
+
+ // TODO(ezhulenev): move to GraphOptimizerStage?
+ bool IsDrivenByControlDependency(const NodeDef& node) const {
+ return std::any_of(node.input().begin(), node.input().end(),
+ IsControlInput);
+ }
+
+ // TODO(ezhulenev): move to GraphOptimizerStage?
+ bool DrivesControlDependency(const NodeDef& node) const {
+ int position;
+ for (const NodeDef* output : ctx_.node_map->GetOutputs(node.name())) {
+ for (int i = 0; i < output->input_size(); ++i) {
+ auto input = output->input(i);
+ string name = ParseNodeName(input, &position);
+ if (name == node.name() && /*control input*/ position < 0) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
+ string ShapeSignature(const TensorShapeProto& shape) const {
+ string signature = strings::StrCat("rank:", shape.dim_size(), ":dim");
+ for (int i = 0; i < shape.dim_size(); ++i)
+ strings::StrAppend(&signature, ":", shape.dim(i).size());
+ return signature;
+ }
+
+ void AddToOptimizedNodes(const NodeDef* node) {
+ optimized_nodes_.insert(node->name());
+ }
+
+ bool IsOnTheSameDevice(const OptimizedNodesGroup& group,
+ const NodeDef& node) const {
+ return group.root_node->device() == node.device();
+ }
+
+ bool IsInPreserveSet(const NodeDef& node) const {
+ return ctx_.nodes_to_preserve->find(node.name()) !=
+ ctx_.nodes_to_preserve->end();
+ }
+
+ bool IsAlreadyOptimized(const NodeDef& node) const {
+ return optimized_nodes_.find(node.name()) != optimized_nodes_.end();
+ }
+
+ private:
+ // set of nodes already processed by this optimizer stage
+ std::unordered_set<string> optimized_nodes_;
+};
+
// Rewrite a tree of Add/AddN with a single AddN operation, consuming all the
// original inputs of absorbed nodes.
//
@@ -335,110 +502,33 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
// x y w Add_3 AddN(x, y, q, e) z
// / \
// q e
-class AddOpsRewriteStage : public ArithmeticOptimizerStage {
+class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage {
public:
explicit AddOpsRewriteStage(const GraphOptimizerContext& ctx,
const ArithmeticOptimizerContext& ctx_ext)
- : ArithmeticOptimizerStage("AddOpsRewrite", ctx, ctx_ext),
- rewritten_nodes_() {}
-
+ : ArithmeticNodesGroupOptimizerStage("AddOpsRewrite", ctx, ctx_ext) {}
~AddOpsRewriteStage() override = default;
// Check if a node can become a root of AddOpsGroup
bool IsSupported(const NodeDef* node) const override {
- // check basic preconditions
- if (!IsRewritable(node)) {
- return false;
- }
+ if (!CanOptimize(node)) return false;
// shape must be symbolically defined and all inputs compatible with it
OpInfo::TensorProperties properties;
Status has_properties = GetTensorProperties(node->name(), &properties);
return has_properties.ok() && ShapeIsSymbolicallyDefined(properties) &&
- HasAllInputsOfBroadcastableShape(*node, properties);
- }
-
- Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
- CHECK(IsSupported(node));
- AddOpsGroup group;
- TF_RETURN_IF_ERROR(CreateAddOpsGroup(node, &group));
-
- if (!group.absorbed_nodes.empty()) {
- *simplified_node_name = RewriteAddOpsGroup(group);
- }
-
- return Status::OK();
- }
-
- private:
- // Input name with a statically inferred shape from GraphProperties
- struct InputAndShape {
- InputAndShape(const string& input, const TensorShapeProto& shape)
- : input(input), shape(shape) {}
- string input;
- TensorShapeProto shape;
- };
-
- // Holds together an add ops subgraph that we want to rewrite together.
- //
- // For the graph above the AddOpsGroup will be:
- // root_node: AddN_1
- // absorbed_nodes: [Add_1, Add_2]
- // input_nodes: [x, y, z, w, q, e]
- struct AddOpsGroup {
- const NodeDef* root_node;
- TensorShapeProto root_shape;
- // Add/AddN operations below the root level that were absorbed by this group
- std::vector<NodeDef*> absorbed_nodes;
- // Inputs of absorbed nodes that will be forwarded to optimized AddN ops
- std::vector<InputAndShape> inputs;
- };
-
- // Check if all inputs can be broadcasted to the same shape
- bool HasAllInputsOfBroadcastableShape(
- const NodeDef& node, const OpInfo::TensorProperties& properties) const {
- const AddOpsRewriteStage* self = this;
- return std::all_of(
- node.input().begin(), node.input().end(),
- [self, &properties](const string& input) {
- OpInfo::TensorProperties input_properties;
- Status has_input_properties =
- self->GetTensorProperties(input, &input_properties);
- return has_input_properties.ok() &&
- ShapesBroadcastable(properties, input_properties);
- });
- }
-
- // TODO(ezhulenev): use GraphRewriter?
- bool IsDrivenByControlDependency(const NodeDef& node) const {
- return std::any_of(node.input().begin(), node.input().end(),
- IsControlInput);
- }
-
- // TODO(ezhulenev): use GraphRewriter?
- bool DrivesControlDependency(const NodeDef& node) const {
- int position;
- for (const NodeDef* output : ctx_.node_map->GetOutputs(node.name())) {
- for (int i = 0; i < output->input_size(); ++i) {
- auto input = output->input(i);
- string name = ParseNodeName(input, &position);
- if (name == node.name() && /*control input*/ position < 0) {
- return true;
- }
- }
- }
- return false;
+ HasAllInputsBroadcastableToShape(*node, properties);
}
- // Check if a node can be absorbed by current AddOpsGroup
- bool IsAbsorbableByAddOpsGroup(const string& name, const AddOpsGroup& group) {
+ protected:
+ // Check if a node can be absorbed by current OptimizedNodesGroup
+ bool IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup& group,
+ const string& input) const override {
NodeDef* node;
- Status node_status = GetInputNode(name, &node);
- if (!node_status.ok()) {
- return false;
- }
- // check basic preconditions
- if (!IsRewritable(node)) {
+ Status node_status = GetInputNode(input, &node);
+ if (!node_status.ok() || !CanOptimize(node)) return false;
+
+ if (!IsOnTheSameDevice(group, *node)) {
return false;
}
// with a single output data consumer (presumably if we reach this node from
@@ -447,102 +537,42 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
if (NumNonControlDataOutputs(*node, *ctx_.node_map) != 1) {
return false;
}
- // must be on the same device as a root node
- if (node->device() != group.root_node->device()) {
- return false;
- }
// All input shapes must be broadcastable to the node shape
OpInfo::TensorProperties properties;
- Status has_properties = GetTensorProperties(name, &properties);
+ Status has_properties = GetTensorProperties(input, &properties);
return has_properties.ok() &&
- HasAllInputsOfBroadcastableShape(*node, properties);
+ HasAllInputsBroadcastableToShape(*node, properties);
}
// Node requirements both for a root node and an absorbed node
- bool IsRewritable(const NodeDef* node) const {
- // only Add or AddN can be a root node
+ bool CanOptimize(const NodeDef* node) const {
// TODO(ezhulenev): check if AccumulateNV2 can be supported too
if (!IsAdd(*node) && !IsAddN(*node)) {
return false;
}
- // it must not be in a preserve set
- if (ctx_.nodes_to_preserve->find(node->name()) !=
- ctx_.nodes_to_preserve->end()) {
- return false;
- }
- // it must not be a node created or absorbed by previous iteration
- if (rewritten_nodes_.find(node->name()) != rewritten_nodes_.end()) {
+ if (IsInPreserveSet(*node) || IsAlreadyOptimized(*node)) {
return false;
}
// it must not be created by this stage at any of previous optimization runs
- if (StringPiece(node->name()).contains(stage_name_)) {
+ if (str_util::StrContains(node->name(), stage_name_)) {
return false;
}
- // should not drive or be driven by control dependency
// TODO(ezhulenev): relax this condition for root node
return !(IsDrivenByControlDependency(*node) ||
DrivesControlDependency(*node));
}
- // Create an AddOpsGroup with a root in a given node
- Status CreateAddOpsGroup(const NodeDef* root_node, AddOpsGroup* group) {
- OpInfo::TensorProperties root_node_output_properties;
- TF_RETURN_IF_ERROR(
- GetTensorProperties(root_node->name(), &root_node_output_properties));
-
- group->root_node = root_node;
- group->root_shape = root_node_output_properties.shape();
-
- group->absorbed_nodes.reserve(root_node->input_size());
- for (int i = 0; i < root_node->input_size(); ++i) {
- const string& input_i = root_node->input(i);
- if (!IsControlInput(input_i)) {
- TF_RETURN_IF_ERROR(AbsorbInputByAddOpsGroup(input_i, group));
- }
- }
-
- return Status::OK();
- }
-
- Status AbsorbInputByAddOpsGroup(const string& input, AddOpsGroup* group) {
- NodeDef* node;
- TF_RETURN_IF_ERROR(GetInputNode(input, &node));
-
- if (IsAbsorbableByAddOpsGroup(input, *group)) {
- group->absorbed_nodes.push_back(node);
- for (int i = 0; i < node->input_size(); ++i) {
- const string& input_i = node->input(i);
- if (!IsControlInput(input)) {
- TF_RETURN_IF_ERROR(AbsorbInputByAddOpsGroup(input_i, group));
- }
- }
- } else {
- // If node can't be absorbed, add it to AddOpsGroup input
- OpInfo::TensorProperties properties;
- TF_RETURN_IF_ERROR(GetTensorProperties(input, &properties));
- group->inputs.emplace_back(input, properties.shape());
- }
- return Status::OK();
- }
-
- // Rewrite an add ops group into a single AddN if all input shapes are
+ // Rewrite a group of add ops into a single AddN if all input shapes are
// symbolically equal. If not, create AddN for equal shapes first, and then
// build an Add tree, minimizing the cost of broadcasts.
- string RewriteAddOpsGroup(const AddOpsGroup& group) {
+ string RewriteOptimizedNodesGroup(const OptimizedNodesGroup& group) override {
// all new nodes will be placed under the scope of a root node
auto root_scope_and_name = ParseNodeScopeAndName(group.root_node->name());
- auto shape_sig = [](const TensorShapeProto& shape) {
- string name = strings::StrCat("r:", shape.dim_size(), ":d");
- for (int i = 0; i < shape.dim_size(); ++i)
- strings::StrAppend(&name, ":", shape.dim(i).size());
- return name;
- };
-
// Find what shapes are present in the inputs of absorbed nodes
std::unordered_map<string, std::vector<InputAndShape>> shape_sig_to_inputs;
for (const auto& input : group.inputs) {
- shape_sig_to_inputs[shape_sig(input.shape)].push_back(input);
+ shape_sig_to_inputs[ShapeSignature(input.shape)].push_back(input);
}
// Collect all the shapes from representative elements
@@ -556,8 +586,6 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
string node_name = OptimizedNodeName(root_scope_and_name);
AddInputsOfSymbolicallyEqualShape(*group.root_node, node_name,
group.inputs);
- // keep track of nodes that were created or absorbed as a part of rewrite
- rewritten_nodes_.insert(node_name);
return node_name;
}
@@ -586,7 +614,7 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
// Prepare leaf AddN nodes for inputs of equal shape
for (int i = 0; i < shapes.size(); ++i) {
const auto node_name = leaf_node_name(i);
- const auto& inputs = shape_sig_to_inputs[shape_sig(shapes[i])];
+ const auto& inputs = shape_sig_to_inputs[ShapeSignature(shapes[i])];
add_ops.push_back(AddInputsOfSymbolicallyEqualShape(*group.root_node,
node_name, inputs));
}
@@ -637,7 +665,7 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
node->add_input(inputAndShape.input);
}
- rewritten_nodes_.insert(node_name);
+ AddToOptimizedNodes(node);
return InputAndShape(node_name, shape);
}
@@ -661,13 +689,10 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
node->add_input(left.input);
node->add_input(right.input);
- rewritten_nodes_.insert(node_name);
+ AddToOptimizedNodes(node);
return InputAndShape(
node_name, TensorShapeProto()); // shape is not important at this point
}
-
- // keep nodes that were added or absorbed as a part of AddOpsGroup rewrite
- std::unordered_set<string> rewritten_nodes_;
};
// Use the commutativity and (left- and right-) distributive property of
@@ -693,7 +718,7 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
- CHECK(IsSupported(node));
+ TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
std::set<string> common_factors;
std::vector<string> ctrl_deps;
@@ -839,6 +864,201 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
std::unordered_set<string> rewritten_nodes_;
};
+// Binary associative ops can be re-ordered to minimize the number of broadcasts
+// and the size of a temporary tensors.
+//
+// Example: [a, c] - scalars, [b, d] - matrices
+// @ - binary associative op (Add or Mul)
+// @* - broadcast
+//
+// @ @*
+// / \ / \
+// @* @* -> @ @
+// / \ / \ / \ / \
+// a b c d a c b d
+class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage {
+ public:
+ explicit MinimizeBroadcasts(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticNodesGroupOptimizerStage("MinimizeBroadcasts", ctx, ctx_ext) {
+ }
+ ~MinimizeBroadcasts() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ if (!IsBinaryAssociative(*node)) return false;
+
+ // has a symbolically defined shape with broadcastable inputs
+ OpInfo::TensorProperties properties;
+ Status has_properties = GetTensorProperties(node->name(), &properties);
+ return has_properties.ok() && ShapeIsSymbolicallyDefined(properties) &&
+ HasAllInputsBroadcastableToShape(*node, properties);
+ }
+
+ protected:
+ bool IsBinaryAssociative(const NodeDef& node) const {
+ return IsMul(node) || IsAdd(node);
+ }
+
+ bool IsSameOp(const OptimizedNodesGroup& group, const NodeDef& node) const {
+ return group.root_node->op() == node.op();
+ }
+
+ // Check if a node can be absorbed by current OptimizedNodesGroup
+ bool IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup& group,
+ const string& input) const override {
+ NodeDef* node;
+ Status node_status = GetInputNode(input, &node);
+ if (!node_status.ok()) return false;
+
+ if (!IsSameOp(group, *node)) {
+ return false;
+ }
+ if (IsInPreserveSet(*node) || IsAlreadyOptimized(*node)) {
+ return false;
+ }
+ if (IsDrivenByControlDependency(*node) || DrivesControlDependency(*node)) {
+ return false;
+ }
+ if (!IsOnTheSameDevice(group, *node)) {
+ return false;
+ }
+ // Optimized nodes updated in place, and that would break the graph, if the
+ // node has multiple output consumers
+ if (NumNonControlOutputs(*node, *ctx_.node_map) != 1) {
+ return false;
+ }
+ // All input shapes must be broadcastable to the node shape
+ OpInfo::TensorProperties properties;
+ Status has_properties = GetTensorProperties(input, &properties);
+ return has_properties.ok() &&
+ HasAllInputsBroadcastableToShape(*node, properties);
+ }
+
+ std::size_t CountUniqueShapes(const std::vector<InputAndShape>& inputs) {
+ std::set<string> sigs;
+ for (const auto& ias : inputs) {
+ sigs.insert(ShapeSignature(ias.shape));
+ }
+ return sigs.size();
+ }
+
+ string RewriteOptimizedNodesGroup(const OptimizedNodesGroup& group) override {
+ if (CountUniqueShapes(group.inputs) <= 1) {
+ // nothing to optimize when all shapes are the same
+ return group.root_node->name();
+ }
+
+ auto num_nodes = /*root*/ 1 + group.optimized_nodes.size();
+ auto num_inputs = group.inputs.size();
+ CHECK_EQ(num_nodes, num_inputs - 1)
+ << "Can't build a tree with " << num_inputs << " inputs, using "
+ << num_nodes << "binary op nodes.";
+
+ std::deque<InputAndShape> add_ops(group.inputs.begin(), group.inputs.end());
+ std::deque<NodeDef*> optimized_nodes(group.optimized_nodes.begin(),
+ group.optimized_nodes.end());
+
+ // sort inputs by it's shape from smallest to largest
+ std::stable_sort(add_ops.begin(), add_ops.end(),
+ [](const InputAndShape& lhs, const InputAndShape& rhs) {
+ return CompareSymbolicallyShapedTensorSizes(lhs.shape,
+ rhs.shape);
+ });
+
+ // If there is an odd number of inputs, last one is the largest, and we want
+ // to attach it to the root node, to build a well balanced tree.
+ std::deque<InputAndShape> add_ops_leftover;
+ if (add_ops.size() % 2 != 0) {
+ add_ops_leftover.push_back(add_ops.back());
+ add_ops.pop_back();
+ }
+
+ // At this point it's guaranteed that add_ops have even number of inputs.
+ do {
+ const InputAndShape lhs = add_ops.front();
+ add_ops.pop_front();
+ const InputAndShape rhs = add_ops.front();
+ add_ops.pop_front();
+
+ NodeDef* node;
+ if (!optimized_nodes.empty()) {
+ // re-purpose optimized nodes to build a new tree
+ node = optimized_nodes.front();
+ optimized_nodes.pop_front();
+ } else {
+ // or use root node if none optimized nodes left
+ node = group.root_node;
+ }
+ InputAndShape updated_node = UpdateInputs(lhs.input, rhs.input, node);
+
+ // Pushing updated node to the back of a deque will create a wide and
+ // short tree, pushing to the front will create a tall tree. We prefer to
+ // get a wide tree, it minimizes the potential number of temporary tensors
+ // required to keep in memory, though sometimes we can go up to prevent
+ // propagating a brodcast from leaves to the root. Example:
+ //
+ // inputs: [s, s, s, M] (s - scalar, M - matrix)
+ // @* - op with broadcast
+ //
+ // (only push_back) @* (push_front first op)
+ // / \
+ // @* @ M
+ // / \ / \
+ // @ @* -> @ s
+ // / \ / \ / \
+ // s s s M s s
+ if (add_ops.size() >= 2 &&
+ CompareSymbolicallyShapedTensorSizes(add_ops.at(0).shape,
+ add_ops.at(1).shape)) {
+ add_ops.push_front(updated_node);
+ } else {
+ add_ops.push_back(updated_node);
+ }
+ } while (add_ops.size() > 1);
+ CHECK_EQ(1, add_ops.size());
+
+ // attach the largest tensor to the root op
+ if (!add_ops_leftover.empty()) {
+ const InputAndShape lhs = add_ops.front();
+ add_ops.pop_front();
+ const InputAndShape rhs = add_ops_leftover.front();
+ InputAndShape updated_node =
+ UpdateInputs(lhs.input, rhs.input, group.root_node);
+ add_ops.push_back(updated_node);
+ }
+
+ return add_ops.front().input;
+ }
+
+ InputAndShape UpdateInputs(const string& input_0, const string& input_1,
+ NodeDef* node) {
+ string old_input_0 = node->input(0);
+ string old_input_1 = node->input(1);
+
+ // Update inputs only if they changed
+ if (old_input_0 != input_0 || old_input_1 != input_1) {
+ node->set_input(0, input_0);
+ node->set_input(1, input_1);
+ // Invalidate node properties (shape)
+ ctx_.graph_properties->ClearOutputProperties(node->name());
+ ctx_.graph_properties->ClearInputProperties(node->name());
+ // Update the node map
+ ctx_.node_map->RemoveOutput(NodeName(old_input_0), node->name());
+ ctx_.node_map->RemoveOutput(NodeName(old_input_1), node->name());
+ ctx_.node_map->AddOutput(NodeName(input_0), node->name());
+ ctx_.node_map->AddOutput(NodeName(input_1), node->name());
+ // Add updated node to optimization queue
+ AddToOptimizationQueue(node);
+ }
+
+ // Do not add updated node to any other group
+ AddToOptimizedNodes(node);
+
+ TensorShapeProto shape; // shape is not important at this point
+ return InputAndShape(node->name(), shape);
+ }
+};
+
// Removes inverse transpose nodes
class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
public:
@@ -854,7 +1074,7 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
// TODO(rmlarsen): Forward control dependencies on the bypassed
// transpose nodes.
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
- CHECK(IsSupported(node));
+ TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
NodeDef* input;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
@@ -943,7 +1163,7 @@ class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage {
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
- CHECK(IsSupported(node));
+ TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
// Bypass Bitcast whose source type and destination type are equal.
if (GetSourceDataType(*node) == GetDestinationDataType(*node)) {
@@ -981,7 +1201,8 @@ class RemoveRedundantCastStage : public ArithmeticOptimizerStage {
bool IsSupported(const NodeDef* node) const override { return IsCast(*node); }
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
- CHECK(IsSupported(node));
+ TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
+
// Bypass Cast whose source type and destination type are equal.
if (GetSourceDataType(*node) == GetDestinationDataType(*node)) {
*simplified_node_name = node->input(0);
@@ -1678,6 +1899,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<AddOpsRewriteStage>(ctx, ctx_ext);
if (options_.hoist_common_factor_out_of_aggregation && can_use_shapes)
pipeline.AddStage<HoistCommonFactorOutOfAggregation>(ctx, ctx_ext);
+ if (options_.minimize_broadcasts && can_use_shapes)
+ pipeline.AddStage<MinimizeBroadcasts>(ctx, ctx_ext);
if (options_.remove_identity_transpose && can_use_shapes)
pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext);
if (options_.remove_redundant_bitcast)
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 39b89dedba..c0fe8839ca 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -59,6 +59,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool enable_try_simplify_and_replace = true;
bool combine_add_to_addn = false;
bool hoist_common_factor_out_of_aggregation = true;
+ bool minimize_broadcasts = false;
bool remove_identity_transpose = true;
bool remove_redundant_bitcast = true;
bool remove_redundant_cast = true;
@@ -69,10 +70,10 @@ class ArithmeticOptimizer : public GraphOptimizer {
static ArithmeticOptimizerOptions Default(
RewriterConfig::Toggle opt_level) {
ArithmeticOptimizerOptions options;
- // TODO(ezhulenev): enable combine_add_to_addn by default after 1.8
- // release cut
+ // TODO(ezhulenev): enable by default after 1.8 release cut
if (opt_level == RewriterConfig::AGGRESSIVE) {
options.combine_add_to_addn = true;
+ options.minimize_broadcasts = true;
}
return options;
}
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index e117341ba3..9677175d2e 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -93,6 +93,7 @@ class ArithmeticOptimizerTest : public GrapplerTest {
options.enable_try_simplify_and_replace = false;
options.combine_add_to_addn = false;
options.hoist_common_factor_out_of_aggregation = false;
+ options.minimize_broadcasts = false;
options.remove_identity_transpose = false;
options.remove_redundant_bitcast = false;
options.remove_redundant_cast = false;
@@ -113,6 +114,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
optimizer->options_.hoist_common_factor_out_of_aggregation = true;
}
+ void EnableOnlyMinimizeBroadcasts(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.minimize_broadcasts = true;
+ }
+
void EnableOnlyRemoveIdentityTranspose(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_identity_transpose = true;
@@ -1841,5 +1847,160 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
EXPECT_EQ(5, found);
}
+TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
+ auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT);
+ auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
+
+ auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
+ auto mul2 = ops::Mul(s.WithOpName("mul2"), mul1, c);
+
+ auto outputs = ops::Identity(s.WithOpName("outputs"), mul2);
+
+ GrapplerItem item;
+ item.fetch = {"outputs"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyMinimizeBroadcasts(&optimizer);
+
+ OptimizeAndPrune(&optimizer, &item, &output);
+
+ // We expect the following rewrite(s) to occur:
+ //
+ // * *
+ // / \ / \
+ // * c --> * b
+ // / \ / \
+ // a b a c
+ NodeMap node_map(&output);
+
+ const NodeDef* mul1_node = node_map.GetNode("mul1");
+ ASSERT_NE(mul1_node, nullptr);
+ EXPECT_EQ("a", mul1_node->input(0));
+ EXPECT_EQ("c", mul1_node->input(1));
+
+ const NodeDef* mul2_node = node_map.GetNode("mul2");
+ ASSERT_NE(mul2_node, nullptr);
+ EXPECT_EQ("mul1", mul2_node->input(0));
+ EXPECT_EQ("b", mul2_node->input(1));
+}
+
+TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_FlattenTallGraph) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
+ auto b = ops::Variable(s.WithOpName("b"), {32, 32}, DT_FLOAT);
+ auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
+ auto d = ops::Variable(s.WithOpName("d"), {32}, DT_FLOAT);
+ auto e = ops::Variable(s.WithOpName("e"), {32}, DT_FLOAT);
+
+ auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
+ auto mul2 = ops::Mul(s.WithOpName("mul2"), mul1, c);
+ auto mul3 = ops::Mul(s.WithOpName("mul3"), mul2, d);
+ auto mul4 = ops::Mul(s.WithOpName("mul4"), mul3, e);
+
+ auto outputs = ops::Identity(s.WithOpName("outputs"), mul4);
+
+ GrapplerItem item;
+ item.fetch = {"outputs"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyMinimizeBroadcasts(&optimizer);
+
+ OptimizeAndPrune(&optimizer, &item, &output);
+
+ // We expect the following rewrite(s) to occur: Graph is "flattened" and
+ // largest shape pushed to the top.
+ //
+ // *
+ // / \
+ // * e *
+ // / \ / \
+ // * d * b
+ // / \ / \
+ // * c --> * *
+ // / \ / \ / \
+ // a b a c d e
+ NodeMap node_map(&output);
+
+ const NodeDef* mul1_node = node_map.GetNode("mul1");
+ ASSERT_NE(mul1_node, nullptr);
+ EXPECT_EQ("a", mul1_node->input(0));
+ EXPECT_EQ("c", mul1_node->input(1));
+
+ const NodeDef* mul2_node = node_map.GetNode("mul2");
+ ASSERT_NE(mul2_node, nullptr);
+ EXPECT_EQ("d", mul2_node->input(0));
+ EXPECT_EQ("e", mul2_node->input(1));
+
+ const NodeDef* mul3_node = node_map.GetNode("mul3");
+ ASSERT_NE(mul3_node, nullptr);
+ EXPECT_EQ("mul1", mul3_node->input(0));
+ EXPECT_EQ("mul2", mul3_node->input(1));
+
+ const NodeDef* mul4_node = node_map.GetNode("mul4");
+ ASSERT_NE(mul4_node, nullptr);
+ EXPECT_EQ("mul3", mul4_node->input(0));
+ EXPECT_EQ("b", mul4_node->input(1));
+}
+
+TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_BuildTreeUp) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ // [a, b, c] - scalars, [d] - matrix
+ auto a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
+ auto b = ops::Variable(s.WithOpName("b"), {32}, DT_FLOAT);
+ auto c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
+ auto d = ops::Variable(s.WithOpName("D"), {32, 32}, DT_FLOAT);
+
+ auto mul1 = ops::Mul(s.WithOpName("mul1"), a, b);
+ auto mul2 = ops::Mul(s.WithOpName("mul2"), c, d);
+ auto mul3 = ops::Mul(s.WithOpName("mul3"), mul1, mul2);
+
+ auto outputs = ops::Identity(s.WithOpName("outputs"), mul3);
+
+ GrapplerItem item;
+ item.fetch = {"outputs"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyMinimizeBroadcasts(&optimizer);
+
+ OptimizeAndPrune(&optimizer, &item, &output);
+
+ // We expect the following rewrite(s) to occur:
+ //
+ // *
+ // / \
+ // * * D
+ // / \ / \
+ // * * -> * c
+ // / \ / \ / \
+ // a b c D a b
+ NodeMap node_map(&output);
+
+ const NodeDef* mul1_node = node_map.GetNode("mul1");
+ ASSERT_NE(mul1_node, nullptr);
+ EXPECT_EQ("a", mul1_node->input(0));
+ EXPECT_EQ("b", mul1_node->input(1));
+
+ const NodeDef* mul2_node = node_map.GetNode("mul2");
+ ASSERT_NE(mul2_node, nullptr);
+ EXPECT_EQ("mul1", mul2_node->input(0));
+ EXPECT_EQ("c", mul2_node->input(1));
+
+ const NodeDef* mul3_node = node_map.GetNode("mul3");
+ ASSERT_NE(mul3_node, nullptr);
+ EXPECT_EQ("D", mul3_node->input(0));
+ EXPECT_EQ("mul2", mul3_node->input(1));
+}
+
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/debug_stripper.cc b/tensorflow/core/grappler/optimizers/debug_stripper.cc
index 8bd10171f1..9701a038d0 100644
--- a/tensorflow/core/grappler/optimizers/debug_stripper.cc
+++ b/tensorflow/core/grappler/optimizers/debug_stripper.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
namespace grappler {
@@ -40,10 +41,22 @@ Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item,
inp = AsControlDependency(inp);
}
}
- } else if (IsCheckNumerics(node)) {
+ } else if (IsCheckNumerics(node) || IsPrint(node)) {
// Replace with Identity op which will be pruned later.
node.set_op("Identity");
- node.mutable_attr()->erase("message");
+ // Only preserve T attribute.
+ protobuf::Map<string, AttrValue> new_attr;
+ if (node.attr().find("T") != node.attr().end()) {
+ new_attr.insert({"T", node.attr().at("T")});
+ }
+ node.mutable_attr()->swap(new_attr);
+ // As Identity op only takes one input, mark redundant inputs as control
+ // input.
+ for (size_t i = 1; i < node.input_size(); ++i) {
+ if (!IsControlInput(node.input(i))) {
+ *node.mutable_input(i) = AsControlDependency(node.input(i));
+ }
+ }
}
}
return Status::OK();
diff --git a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
index 3f11febc64..96ceee791f 100644
--- a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
+++ b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
@@ -164,6 +164,42 @@ TEST_F(DebugStripperTest, StripCheckNumericsFromGraph) {
test::ExpectTensorEqual<float>(expected[0], optimized[0]);
}
+TEST_F(DebugStripperTest, StripPrintFromGraph) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
+ ops::Placeholder::Shape({}));
+ Output print = ops::Print(s.WithOpName("Print"), x, {x});
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ DebugStripper optimizer;
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "x") {
+ EXPECT_EQ("Placeholder", node.op());
+ EXPECT_EQ(0, node.input_size());
+ } else if (node.name() == "Print") {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("^x", node.input(1));
+ EXPECT_EQ(1, node.attr_size());
+ }
+ }
+
+ EXPECT_EQ(2, output.node_size());
+
+ Tensor x_t(DT_FLOAT, TensorShape({}));
+ x_t.flat<float>()(0) = 1.0f;
+ std::vector<Tensor> expected =
+ EvaluateNodes(item.graph, {"Print"}, {{"x", x_t}});
+ std::vector<Tensor> optimized =
+ EvaluateNodes(output, {"Print"}, {{"x", x_t}});
+ test::ExpectTensorEqual<float>(expected[0], optimized[0]);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
index 7ed0474861..072f772946 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
@@ -134,6 +134,18 @@ class GraphOptimizerStage {
// and remove template parameter.
virtual Status TrySimplify(NodeDef* node, Result* result) = 0;
+ // Return InvalidArgumentError if node is not supported by the optimizer
+ // stage.
+ // TODO(ezhulenev): make this check part of non-virtual public API
+ // (TrySimplify), and make virtual implementation protected.
+ Status EnsureNodeIsSupported(const NodeDef* node) const {
+ return IsSupported(node)
+ ? Status::OK()
+ : errors::InvalidArgument(
+ "Node ", node->name(), " is not supported by optimizer ",
+ optimizer_name_, " and stage ", stage_name_);
+ }
+
// Get a name for a new node, created by this stage, based on one or multiple
// nodes of an original graph.
const string OptimizedNodeName(const NodeScopeAndName& node) const {
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index 308eecd420..561226f945 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -17,9 +17,13 @@ limitations under the License.
#include <unordered_set>
#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/devices.h"
#include "tensorflow/core/grappler/grappler_item.h"
@@ -363,6 +367,28 @@ std::vector<int> DataInputPos(const NodeDef& node) {
return {};
}
+bool IsHostMemory(const NodeDef& node, int output_port) {
+ DeviceNameUtils::ParsedName parsed_name;
+ if (DeviceNameUtils::ParseFullName(node.device(), &parsed_name)) {
+ DeviceType device_type(parsed_name.type);
+ Status s = FindKernelDef(device_type, node, nullptr, nullptr);
+ if (s.ok()) {
+ tensorflow::MemoryTypeVector in_mtypes;
+ tensorflow::MemoryTypeVector out_mtypes;
+ s = tensorflow::MemoryTypesForNode(OpRegistry::Global(), device_type,
+ node, &in_mtypes, &out_mtypes);
+ if (s.ok()) {
+ if (out_mtypes[output_port] == HOST_MEMORY) {
+ return true;
+ }
+ }
+ } else {
+ return true;
+ }
+ }
+ return false;
+}
+
class GraphProcessor {
public:
GraphProcessor(const GraphProperties& graph_properties,
@@ -883,6 +909,23 @@ class NodeProcessor : public GraphProcessor {
list->set_i(3, w);
}
+ string MaybeGetHostDevice(const string& input_name) const {
+ string device = node_->device();
+ DeviceNameUtils::ParsedName parsed_name;
+ if (DeviceNameUtils::ParseFullName(device, &parsed_name)) {
+ if (parsed_name.type != "CPU") {
+ NodeDef* input = node_map_->GetNode(input_name);
+ int port;
+ ParseNodeName(input_name, &port);
+ if (IsHostMemory(*input, port)) {
+ parsed_name.type = "CPU";
+ device = DeviceNameUtils::ParsedNameToString(parsed_name);
+ }
+ }
+ }
+ return device;
+ }
+
NodeDef* AddNodeDataFormatOp(const string& name, const string& input_name,
const string& op, DataType dtype,
bool nhwc_to_nchw) {
@@ -890,7 +933,9 @@ class NodeProcessor : public GraphProcessor {
added_node->set_name(name);
added_node->set_op(op);
node_map_->AddNode(added_node->name(), added_node);
- added_node->set_device(node_->device());
+ // The inputs of a DataFormat op could be in host memory for ops such as
+ // Reshape.
+ added_node->set_device(MaybeGetHostDevice(input_name));
AttrValue attr_data_type;
attr_data_type.set_type(dtype);
added_node->mutable_attr()->insert({"T", attr_data_type});
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
index 1c912fcaa2..260347b0e8 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
+#include "tensorflow/core/grappler/costs/virtual_placer.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -158,7 +159,7 @@ class LayoutOptimizerTest : public ::testing::Test {
return output.x_backprop;
}
- std::unique_ptr<VirtualCluster> virtual_cluster_;
+ std::unique_ptr<Cluster> virtual_cluster_;
};
TEST_F(LayoutOptimizerTest, Conv2DBackpropInput) {
@@ -1130,6 +1131,27 @@ TEST_F(LayoutOptimizerTest, LoopNoLiveLock) {
EXPECT_EQ(mul_node->input(0),
"Conv2D-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
}
+
+TEST_F(LayoutOptimizerTest, DevicePlacement) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2D(&s, 4, 2, "VALID");
+ auto shape = ops::Shape(s.WithOpName("s"), conv);
+ auto i = ops::Identity(s.WithOpName("i"), shape);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ VirtualPlacer virtual_placer(virtual_cluster_.get());
+ for (auto& node : *item.graph.mutable_node()) {
+ string device = virtual_placer.get_canonical_device_name(node);
+ node.set_device(device);
+ }
+ LayoutOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
+ NodeMap node_map(&output);
+ auto vec_permute =
+ node_map.GetNode("s-0-0-VecPermuteNCHWToNHWC-LayoutOptimizer");
+ EXPECT_EQ(vec_permute->device(), "/device:CPU:0");
+}
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc
index a063dc3381..fff06dd2ac 100644
--- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc
@@ -16,18 +16,17 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
#include <algorithm>
+#include <deque>
#include <limits>
#include <unordered_map>
#include <unordered_set>
#include <vector>
-#include <deque>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
@@ -46,74 +45,36 @@ namespace tensorflow {
namespace grappler {
namespace {
-std::vector<int> GetStackPushNodesToConvert(
- const SimpleGraphView& graph_view,
- const std::unordered_set<string>& nodes_to_preserve, int stack_node_idx) {
- VLOG(1) << "Stack node: " << graph_view.graph()->node(stack_node_idx).name();
- const std::unordered_set<string> op_types_to_traverse(
- {"Stack", "StackV2", "Enter", "RefEnter", "Switch", "RefSwitch",
- "Identity", "RefIdentity"});
- std::vector<int> nodes_to_convert;
- std::set<int> fanout;
- graph_view.DepthFirstSearch(op_types_to_traverse, stack_node_idx, &fanout);
- for (int fanout_idx : fanout) {
- const NodeDef& fanout_node = graph_view.graph()->node(fanout_idx);
- VLOG(1) << "Fanout " << fanout_idx << " : " << fanout_node.name();
- if (IsStackPushOp(fanout_node)) {
- nodes_to_convert.push_back(fanout_idx);
- } else if (IsStackOp(fanout_node) || IsStackCloseOp(fanout_node) ||
- op_types_to_traverse.find(fanout_node.op()) !=
- op_types_to_traverse.end()) {
- continue;
- } else if (!IsStackPopOp(fanout_node) ||
- (!graph_view.outputs(fanout_idx).empty() ||
- nodes_to_preserve.find(fanout_node.name()) !=
- nodes_to_preserve.end())) {
- // The node is either a stack pop with consumers or something unexpected
- // so we leave the graph alone.
- nodes_to_convert.clear();
- break;
- }
- }
- return nodes_to_convert;
-}
+class LoopInvariantNodeMotionOptimizer {
+ public:
+ explicit LoopInvariantNodeMotionOptimizer(GraphDef* optimized_graph)
+ : optimized_graph_(optimized_graph) {}
+ virtual ~LoopInvariantNodeMotionOptimizer() = default;
+ Status Optimize();
-Status RemoveStackOps(const GrapplerItem& item, GraphDef* optimized_graph) {
- const std::unordered_set<string> nodes_to_preserve = item.NodesToPreserve();
- const GraphDef& graph = item.graph;
- *optimized_graph = graph;
- NodeMap node_map(optimized_graph);
- SimpleGraphView graph_view;
- TF_RETURN_IF_ERROR(graph_view.Initialize(graph));
- for (int node_idx = 0; node_idx < graph.node_size(); ++node_idx) {
- if (IsStackOp(graph.node(node_idx))) {
- for (int push_node_idx : GetStackPushNodesToConvert(
- graph_view, nodes_to_preserve, node_idx)) {
- // We found push nodes without corresponding pops. Convert them to
- // Identity passing the data through and add a control dependency from
- // the op supplying the stack handle.
- NodeDef* push_node = optimized_graph->mutable_node(push_node_idx);
- VLOG(1) << "Converting " << push_node_idx << " : "
- << push_node->DebugString();
- if (push_node->attr().count("swap_memory") != 0) {
- push_node->mutable_attr()->erase("swap_memory");
- }
- push_node->set_op("Identity");
- push_node->mutable_input()->SwapElements(0, 1);
- const string ctrl_dep = ConstantFolding::AddControlDependency(
- push_node->input(1), optimized_graph, &node_map);
- push_node->set_input(1, ctrl_dep);
- VLOG(1) << "After converting: " << push_node->DebugString();
- }
- }
- }
- return Status::OK();
-}
+ private:
+ Status FindInvariantNodes(NodeDef* node);
+ Status RevertInvariantNodes();
+ Status MoveInvariantNodes(const int frame_id);
+ Status HandleInvariantNode(NodeDef* node, const int num_outputs,
+ const int frame_id);
+ Status HandleConst(NodeDef* node, const int num_outputs, const int frame_id);
+ Status HandleInvariantEnter(NodeDef* node, const int num_outputs);
-} // namespace
+ GraphDef* optimized_graph_; // Not owned.
+ std::unique_ptr<NodeMap> node_map_;
+ std::map<NodeDef*, int> invariant_nodes_;
+ std::set<int> empty_set_;
+ // TODO(rmlarsen): Use vector instead of map, since frames ids are dense.
+ std::map<int, std::set<int>> frame_children_;
+ std::map<int, int> frame_parent_;
+ std::map<int, const NodeDef*> loop_cond_;
+ std::map<int, std::vector<NodeDef*>> invariant_enters_;
+ int new_enter_id_;
+};
-Status LoopOptimizer::LINMHandleInvariantEnter(NodeDef* node,
- const int num_outputs) {
+Status LoopInvariantNodeMotionOptimizer::HandleInvariantEnter(
+ NodeDef* node, const int num_outputs) {
auto consumers = node_map_->GetOutputs(node->name());
std::vector<string> enter_control_inputs;
string enter_input;
@@ -142,9 +103,10 @@ Status LoopOptimizer::LINMHandleInvariantEnter(NodeDef* node,
return Status::OK();
}
-Status LoopOptimizer::LINMHandleConst(NodeDef* node,
- const int num_outputs, const int frame_id) {
- NodeDef* const_node;
+Status LoopInvariantNodeMotionOptimizer::HandleConst(NodeDef* node,
+ const int num_outputs,
+ const int frame_id) {
+ NodeDef* const_node = nullptr;
if (num_outputs == 0) {
// all successor nodes are invariant
// Remove the control inputs from this frame to the const node,
@@ -156,12 +118,17 @@ Status LoopOptimizer::LINMHandleConst(NodeDef* node,
// some successor nodes are variant
// Have to keep the const node in the frame,
// so create a new one outside the frame (in parent frame)
- const_node = optimized_graph_->add_node();
- const_node->set_name(AddPrefixToNodeName(node->name(), kLoopOptimizer));
- const_node->set_op("Const");
- const_node->set_device(node->device());
- *const_node->mutable_attr() = node->attr();
- node_map_->AddNode(const_node->name(), const_node);
+ const string const_node_name =
+ AddPrefixToNodeName(node->name(), kLoopOptimizer);
+ const_node = node_map_->GetNode(const_node_name);
+ if (const_node == nullptr) {
+ const_node = optimized_graph_->add_node();
+ const_node->set_name(const_node_name);
+ const_node->set_op("Const");
+ const_node->set_device(node->device());
+ *const_node->mutable_attr() = node->attr();
+ node_map_->AddNode(const_node->name(), const_node);
+ }
auto consumers = node_map_->GetOutputs(node->name());
for (auto* consumer : consumers) {
if (invariant_nodes_.count(consumer)) {
@@ -185,8 +152,8 @@ Status LoopOptimizer::LINMHandleConst(NodeDef* node,
int parent_id = parent_it->second;
auto loop_cond_it = loop_cond_.find(parent_id);
if (loop_cond_it == loop_cond_.end()) {
- return errors::InvalidArgument(
- "Frame ", frame_id, " doesn't have a LoopCond node");
+ return errors::InvalidArgument("Frame ", frame_id,
+ " doesn't have a LoopCond node");
}
auto& loop_cond_name = loop_cond_it->second->name();
NodeDef* switch_node = nullptr;
@@ -197,9 +164,8 @@ Status LoopOptimizer::LINMHandleConst(NodeDef* node,
}
}
if (!switch_node) {
- return errors::InvalidArgument(
- "LoopCond node of Frame ", frame_id,
- " doesn't connect to any Switch node");
+ return errors::InvalidArgument("LoopCond node of Frame ", frame_id,
+ " doesn't connect to any Switch node");
}
string switch_output = StrCat(switch_node->name(), ":1");
const string ctrl_dep = ConstantFolding::AddControlDependency(
@@ -210,8 +176,8 @@ Status LoopOptimizer::LINMHandleConst(NodeDef* node,
return Status::OK();
}
-Status LoopOptimizer::LINMHandleInvariantNode(NodeDef* node,
- const int num_outputs, const int frame_id) {
+Status LoopInvariantNodeMotionOptimizer::HandleInvariantNode(
+ NodeDef* node, const int num_outputs, const int frame_id) {
// have to remove control inputs to the invariant node from the same frame
// when moving this node out of this frame
for (int i = 0; i < node->input_size(); ++i) {
@@ -228,16 +194,14 @@ Status LoopOptimizer::LINMHandleInvariantNode(NodeDef* node,
DataTypeVector output_types;
OpRegistryInterface* op_registry = OpRegistry::Global();
const OpRegistrationData* op_reg_data = nullptr;
- TF_RETURN_IF_ERROR(
- op_registry->LookUp(node->op(), &op_reg_data));
- TF_RETURN_IF_ERROR(
- InOutTypesForNode(*node, op_reg_data->op_def,
- &input_types, &output_types));
+ TF_RETURN_IF_ERROR(op_registry->LookUp(node->op(), &op_reg_data));
+ TF_RETURN_IF_ERROR(InOutTypesForNode(*node, op_reg_data->op_def, &input_types,
+ &output_types));
auto consumers = node_map_->GetOutputs(node->name());
string fname = invariant_enters_[frame_id][0]->attr().at("frame_name").s();
- int piterations = invariant_enters_[frame_id][0]
- ->attr().at("parallel_iterations").i();
+ int piterations =
+ invariant_enters_[frame_id][0]->attr().at("parallel_iterations").i();
for (auto* consumer : consumers) {
if (!invariant_nodes_.count(consumer)) {
for (int i = 0; i < consumer->input_size(); ++i) {
@@ -281,28 +245,27 @@ Status LoopOptimizer::LINMHandleInvariantNode(NodeDef* node,
return Status::OK();
}
-Status LoopOptimizer::MoveInvariantNodes(const int frame_id) {
- for (auto iter = invariant_nodes_.begin();
- iter != invariant_nodes_.end(); ++iter) {
+Status LoopInvariantNodeMotionOptimizer::MoveInvariantNodes(
+ const int frame_id) {
+ for (auto iter = invariant_nodes_.begin(); iter != invariant_nodes_.end();
+ ++iter) {
auto* invariant_node = iter->first;
const int num_outputs = iter->second;
if (IsEnter(*invariant_node)) {
- TF_RETURN_IF_ERROR(
- LINMHandleInvariantEnter(invariant_node, num_outputs));
+ TF_RETURN_IF_ERROR(HandleInvariantEnter(invariant_node, num_outputs));
} else if (IsConstant(*invariant_node)) {
- TF_RETURN_IF_ERROR(
- LINMHandleConst(invariant_node, num_outputs, frame_id));
+ TF_RETURN_IF_ERROR(HandleConst(invariant_node, num_outputs, frame_id));
} else {
TF_RETURN_IF_ERROR(
- LINMHandleInvariantNode(invariant_node, num_outputs, frame_id));
+ HandleInvariantNode(invariant_node, num_outputs, frame_id));
}
}
return Status::OK();
}
-Status LoopOptimizer::RevertInvariantNodes() {
+Status LoopInvariantNodeMotionOptimizer::RevertInvariantNodes() {
std::deque<const NodeDef*> reverted_nodes;
- for (auto iter=invariant_nodes_.begin(); iter != invariant_nodes_.end();) {
+ for (auto iter = invariant_nodes_.begin(); iter != invariant_nodes_.end();) {
bool erased = false;
const auto* node = iter->first;
if (!IsConstant(*node) && !IsEnter(*node) && iter->second > 0) {
@@ -331,8 +294,8 @@ Status LoopOptimizer::RevertInvariantNodes() {
auto* producer = node_map_->GetNode(input);
auto iter = invariant_nodes_.find(producer);
if (iter != invariant_nodes_.end()) {
- if (IsControlInput(input) &&
- !IsConstant(*producer) && !IsEnter(*producer)) {
+ if (IsControlInput(input) && !IsConstant(*producer) &&
+ !IsEnter(*producer)) {
reverted_nodes.push_back(producer);
invariant_nodes_.erase(iter);
} else {
@@ -357,12 +320,11 @@ Status LoopOptimizer::RevertInvariantNodes() {
return Status::OK();
}
-Status LoopOptimizer::FindInvariantNodes(NodeDef* node) {
+Status LoopInvariantNodeMotionOptimizer::FindInvariantNodes(NodeDef* node) {
auto consumers = node_map_->GetOutputs(node->name());
invariant_nodes_.insert(std::make_pair(node, consumers.size()));
for (auto* consumer : consumers) {
- if (invariant_nodes_.count(consumer) ||
- ModifiesFrameInfo(*consumer)) {
+ if (invariant_nodes_.count(consumer) || ModifiesFrameInfo(*consumer)) {
continue;
}
bool is_invariant = true;
@@ -399,9 +361,14 @@ Status LoopOptimizer::FindInvariantNodes(NodeDef* node) {
return Status::OK();
}
-Status LoopOptimizer::LoopInvariantNodeMotion() {
+Status LoopInvariantNodeMotionOptimizer::Optimize() {
+ node_map_.reset(new NodeMap(optimized_graph_));
+ FrameMap frame_map;
+ int num_frames;
+ TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph_, *node_map_,
+ &frame_map, &num_frames));
std::deque<int> worklist;
- for (auto iter = frame_map_.begin(); iter != frame_map_.end(); ++iter) {
+ for (auto iter = frame_map.begin(); iter != frame_map.end(); ++iter) {
auto* node = iter->first;
auto& frame_ids = iter->second;
if (frame_ids.size() >= 3) {
@@ -467,19 +434,82 @@ Status LoopOptimizer::LoopInvariantNodeMotion() {
return Status::OK();
}
-Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
- GraphDef* optimized_graph) {
+std::vector<int> GetStackPushNodesToConvert(
+ const SimpleGraphView& graph_view,
+ const std::unordered_set<string>& nodes_to_preserve, int stack_node_idx) {
+ VLOG(1) << "Stack node: " << graph_view.graph()->node(stack_node_idx).name();
+ const std::unordered_set<string> op_types_to_traverse(
+ {"Stack", "StackV2", "Enter", "RefEnter", "Switch", "RefSwitch",
+ "Identity", "RefIdentity"});
+ std::vector<int> nodes_to_convert;
+ std::set<int> fanout;
+ graph_view.DepthFirstSearch(op_types_to_traverse, stack_node_idx, &fanout);
+ for (int fanout_idx : fanout) {
+ const NodeDef& fanout_node = graph_view.graph()->node(fanout_idx);
+ VLOG(1) << "Fanout " << fanout_idx << " : " << fanout_node.name();
+ if (IsStackPushOp(fanout_node)) {
+ nodes_to_convert.push_back(fanout_idx);
+ } else if (IsStackOp(fanout_node) || IsStackCloseOp(fanout_node) ||
+ op_types_to_traverse.find(fanout_node.op()) !=
+ op_types_to_traverse.end()) {
+ continue;
+ } else if (!IsStackPopOp(fanout_node) ||
+ (!graph_view.outputs(fanout_idx).empty() ||
+ nodes_to_preserve.find(fanout_node.name()) !=
+ nodes_to_preserve.end())) {
+ // The node is either a stack pop with consumers or something unexpected
+ // so we leave the graph alone.
+ nodes_to_convert.clear();
+ break;
+ }
+ }
+ return nodes_to_convert;
+}
+
+Status RemoveStackOps(const GrapplerItem& item, GraphDef* optimized_graph) {
+ const std::unordered_set<string> nodes_to_preserve = item.NodesToPreserve();
+ const GraphDef& graph = item.graph;
+ *optimized_graph = graph;
+ NodeMap node_map(optimized_graph);
+ SimpleGraphView graph_view;
+ TF_RETURN_IF_ERROR(graph_view.Initialize(graph));
+ for (int node_idx = 0; node_idx < graph.node_size(); ++node_idx) {
+ if (IsStackOp(graph.node(node_idx))) {
+ for (int push_node_idx : GetStackPushNodesToConvert(
+ graph_view, nodes_to_preserve, node_idx)) {
+ // We found push nodes without corresponding pops. Convert them to
+ // Identity passing the data through and add a control dependency from
+ // the op supplying the stack handle.
+ NodeDef* push_node = optimized_graph->mutable_node(push_node_idx);
+ VLOG(1) << "Converting " << push_node_idx << " : "
+ << push_node->DebugString();
+ if (push_node->attr().count("swap_memory") != 0) {
+ push_node->mutable_attr()->erase("swap_memory");
+ }
+ push_node->set_op("Identity");
+ push_node->mutable_input()->SwapElements(0, 1);
+ const string ctrl_dep = ConstantFolding::AddControlDependency(
+ push_node->input(1), optimized_graph, &node_map);
+ push_node->set_input(1, ctrl_dep);
+ VLOG(1) << "After converting: " << push_node->DebugString();
+ }
+ }
+ }
+ return Status::OK();
+}
- TF_RETURN_IF_ERROR(RemoveStackOps(item, optimized_graph));
+} // namespace
- if (opt_level_ == RewriterConfig::AGGRESSIVE) {
- optimized_graph_ = optimized_graph;
- // Set up helper data structures.
- node_map_.reset(new NodeMap(optimized_graph_));
- int num_frames;
- TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph_, *node_map_,
- &frame_map_, &num_frames));
- TF_RETURN_IF_ERROR(LoopInvariantNodeMotion());
+Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) {
+ *optimized_graph = item.graph;
+ // Set up helper data structures.
+ if (options_.enable_loop_invariant_node_motion) {
+ LoopInvariantNodeMotionOptimizer linm_optimizer(optimized_graph);
+ TF_RETURN_IF_ERROR(linm_optimizer.Optimize());
+ }
+ if (options_.enable_stack_push_removal) {
+ TF_RETURN_IF_ERROR(RemoveStackOps(item, optimized_graph));
}
return Status::OK();
diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.h b/tensorflow/core/grappler/optimizers/loop_optimizer.h
index c1b0321e4e..83c499bbe7 100644
--- a/tensorflow/core/grappler/optimizers/loop_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/loop_optimizer.h
@@ -30,9 +30,13 @@ constexpr char kLoopOptimizer[] = "LoopOptimizer";
class LoopOptimizer : public GraphOptimizer {
public:
- LoopOptimizer() : opt_level_(RewriterConfig::ON) {}
+ LoopOptimizer()
+ : opt_level_(RewriterConfig::ON),
+ options_(LoopOptimizerOptions::Default(RewriterConfig::ON)) {}
explicit LoopOptimizer(RewriterConfig::Toggle opt_level)
- : opt_level_(opt_level) {}
+ : opt_level_(opt_level),
+ options_(LoopOptimizerOptions::Default(RewriterConfig::ON)) {}
+
~LoopOptimizer() override {}
string name() const override { return "loop_optimizer"; };
@@ -44,29 +48,24 @@ class LoopOptimizer : public GraphOptimizer {
const GraphDef& optimized_graph, double result) override;
private:
- Status LoopInvariantNodeMotion();
- Status FindInvariantNodes(NodeDef* node);
- Status RevertInvariantNodes();
- Status MoveInvariantNodes(const int frame_id);
- Status LINMHandleInvariantNode(NodeDef* node, const int num_outputs,
- const int frame_id);
- Status LINMHandleConst(NodeDef* node, const int num_outputs,
- const int frame_id);
- Status LINMHandleInvariantEnter(NodeDef* node, const int num_outputs);
-
- std::map<NodeDef*, int> invariant_nodes_;
- std::set<int> empty_set_;
- std::map<int, std::set<int>> frame_children_;
- std::map<int, int> frame_parent_;
- std::map<int, const NodeDef*> loop_cond_;
- std::map<int, std::vector<NodeDef*>> invariant_enters_;
- int new_enter_id_;
- RewriterConfig::Toggle opt_level_;
+ friend class LoopOptimizerTest;
+
+ // Granular control for loop optimizer stages.
+ struct LoopOptimizerOptions {
+ bool enable_loop_invariant_node_motion = false;
+ bool enable_stack_push_removal = true;
+
+ static LoopOptimizerOptions Default(RewriterConfig::Toggle opt_level) {
+ LoopOptimizerOptions options;
+ if (opt_level == RewriterConfig::AGGRESSIVE) {
+ options.enable_loop_invariant_node_motion = true;
+ }
+ return options;
+ }
+ };
- std::unique_ptr<NodeMap> node_map_;
- FrameMap frame_map_;
- std::unique_ptr<GraphProperties> graph_properties_;
- GraphDef* optimized_graph_; // Not owned.
+ RewriterConfig::Toggle opt_level_;
+ LoopOptimizerOptions options_;
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc
index a0bd335197..10ec544424 100644
--- a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc
@@ -25,7 +25,6 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace {
class LoopOptimizerTest : public GrapplerTest {
protected:
@@ -57,6 +56,23 @@ class LoopOptimizerTest : public GrapplerTest {
attributes.emplace_back("T", type);
AddNode(name, op, inputs, attributes, graph);
}
+
+ void DisableAllStages(LoopOptimizer* optimizer) {
+ LoopOptimizer::LoopOptimizerOptions options;
+ options.enable_loop_invariant_node_motion = false;
+ options.enable_stack_push_removal = false;
+ optimizer->options_ = options;
+ }
+
+ void EnableOnlyLoopInvariantNodeMotion(LoopOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.enable_loop_invariant_node_motion = true;
+ }
+
+ void EnableOnlyStackPushRemoval(LoopOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.enable_stack_push_removal = true;
+ }
};
TEST_F(LoopOptimizerTest, Basic) {
@@ -81,7 +97,8 @@ TEST_F(LoopOptimizerTest, Basic) {
GrapplerItem item;
item.graph = graph;
- LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ LoopOptimizer optimizer;
+ EnableOnlyLoopInvariantNodeMotion(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -128,7 +145,8 @@ TEST_F(LoopOptimizerTest, Const) {
GrapplerItem item;
item.graph = graph;
- LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ LoopOptimizer optimizer;
+ EnableOnlyLoopInvariantNodeMotion(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -175,7 +193,8 @@ TEST_F(LoopOptimizerTest, ControlOutput) {
GrapplerItem item;
item.graph = graph;
- LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ LoopOptimizer optimizer;
+ EnableOnlyLoopInvariantNodeMotion(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -235,7 +254,8 @@ TEST_F(LoopOptimizerTest, NestedLoop1) {
GrapplerItem item;
item.graph = graph;
- LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ LoopOptimizer optimizer;
+ EnableOnlyLoopInvariantNodeMotion(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -302,7 +322,8 @@ TEST_F(LoopOptimizerTest, NestedLoop2) {
GrapplerItem item;
item.graph = graph;
- LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ LoopOptimizer optimizer;
+ EnableOnlyLoopInvariantNodeMotion(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -365,7 +386,8 @@ TEST_F(LoopOptimizerTest, NestedLoopConst1) {
GrapplerItem item;
item.graph = graph;
- LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ LoopOptimizer optimizer;
+ EnableOnlyLoopInvariantNodeMotion(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -429,7 +451,8 @@ TEST_F(LoopOptimizerTest, NestedLoopConst2) {
GrapplerItem item;
item.graph = graph;
- LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ LoopOptimizer optimizer;
+ EnableOnlyLoopInvariantNodeMotion(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -475,6 +498,7 @@ TEST_F(LoopOptimizerTest, NoOp) {
CHECK(fake_input.NextItem(&item));
LoopOptimizer optimizer;
+ EnableOnlyStackPushRemoval(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -504,6 +528,7 @@ TEST_F(LoopOptimizerTest, RemovePush_NoOp) {
AddSimpleNode("stop", "StopGradient", {"stack3"}, &graph);
LoopOptimizer optimizer;
+ EnableOnlyStackPushRemoval(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -534,6 +559,7 @@ TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) {
item.fetch.push_back("pop4");
LoopOptimizer optimizer;
+ EnableOnlyStackPushRemoval(&optimizer);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
@@ -563,6 +589,5 @@ TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) {
}
}
-} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index 5893f286ed..534fe670e0 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -430,18 +430,28 @@ Status SimpleGraphView::Initialize(const GraphDef& graph, bool dedup_inputs,
}
void SimpleGraphView::DepthFirstSearch(
- const std::unordered_set<string>& op_types_to_traverse, int node_idx,
+ const std::unordered_set<string>& op_types_to_traverse, int root_node,
std::set<int>* nodes_found) const {
- if (nodes_found->find(node_idx) != nodes_found->end()) {
- return;
- }
- nodes_found->insert(node_idx);
- const string& op_type = graph_->node(node_idx).op();
+ nodes_found->clear();
+ const string& op_type = graph_->node(root_node).op();
if (op_types_to_traverse.find(op_type) == op_types_to_traverse.end()) {
return;
}
- for (auto output_idx : this->outputs(node_idx)) {
- DepthFirstSearch(op_types_to_traverse, output_idx, nodes_found);
+ std::vector<int> stack;
+ stack.reserve(32);
+ stack.push_back(root_node);
+ while (!stack.empty()) {
+ const int node_idx = stack.back();
+ stack.pop_back();
+ nodes_found->insert(node_idx);
+ const string& op_type = graph_->node(node_idx).op();
+ if (op_types_to_traverse.find(op_type) != op_types_to_traverse.end()) {
+ for (auto output_idx : this->outputs(node_idx)) {
+ if (nodes_found->find(output_idx) == nodes_found->end()) {
+ stack.push_back(output_idx);
+ }
+ }
+ }
}
}
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 1857d8d655..1018e8d25c 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -132,6 +132,17 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "collective_ops",
+ prefix = "collective_ops",
+ deps = [
+ "//tensorflow/core:collective_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+tf_kernel_library(
name = "concat_lib",
srcs = [
"concat_lib_cpu.cc",
@@ -1395,6 +1406,7 @@ tf_kernel_library(
visibility = [":friends"],
deps = [
":bounds_check",
+ ":dense_update_functor",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//third_party/eigen3",
@@ -5121,6 +5133,9 @@ filegroup(
"summary_interface.*",
"summary_kernels.*",
"spectrogram_convert_test_data.cc",
+ "decode_proto_op.cc",
+ "encode_proto_op.cc",
+ "rpc_op.cc",
# Excluded due to experimental status:
"debug_ops.*",
"scatter_nd_op*",
@@ -6153,6 +6168,50 @@ tf_kernel_library(
],
)
+tf_kernel_library(
+ name = "decode_proto_op",
+ srcs = [
+ "decode_proto_op.cc",
+ ],
+ deps = [
+ "//tensorflow/core:decode_proto_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/util/proto:decode",
+ "//tensorflow/core/util/proto:descriptors",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "encode_proto_op",
+ srcs = ["encode_proto_op.cc"],
+ deps = [
+ "//tensorflow/core:encode_proto_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/util/proto:descriptors",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "rpc_op",
+ srcs = [
+ "rpc_op.cc",
+ ],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:rpc_ops_op_lib",
+ "//tensorflow/core/util/rpc:call_container",
+ "//tensorflow/core/util/rpc:rpc_factory",
+ "//tensorflow/core/util/rpc:rpc_factory_registry",
+ "//third_party/eigen3",
+ ],
+)
+
# -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo.
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc
new file mode 100644
index 0000000000..5de41bac72
--- /dev/null
+++ b/tensorflow/core/kernels/collective_ops.cc
@@ -0,0 +1,266 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+namespace {
+class CollectiveOpKernel : public AsyncOpKernel {
+ public:
+ explicit CollectiveOpKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {}
+
+ // A string encoding instance, frame and iter to be handed off to
+ // the implementation for use in generating RecvBuf keys.
+ string GetCollectiveKey(OpKernelContext* c) {
+ return strings::StrCat(col_params_.instance.instance_key, ":",
+ c->frame_iter().frame_id, ":",
+ c->frame_iter().iter_id);
+ }
+
+ // Returns false if calling invocation of ComputeAsync should return
+ // immediately.
+ bool CanProceedWithCompute(OpKernelContext* c, CollectiveExecutor* col_exec,
+ const DoneCallback& done) {
+ if (col_params_.group.group_size >
+ col_params_.instance.device_names.size()) {
+ // This is the first invocation: Finish initializing col_params_.
+ // Call in a blockable thread because it's not guaranteed that
+ // this call cannot block.
+ c->env()->SchedClosure([this, c, done, col_exec]() {
+ col_exec->CompleteParamsAsync(c->device()->name(), &col_params_,
+ c->cancellation_manager(),
+ [this, c, done](const Status& s) {
+ if (s.ok()) {
+ ComputeAsync(c, done);
+ } else {
+ c->SetStatus(s);
+ done();
+ }
+ });
+ });
+ return false;
+ }
+ return true;
+ }
+
+ CollectiveParams col_params_;
+};
+
+class CollectiveReduceOpKernel : public CollectiveOpKernel {
+ public:
+ explicit CollectiveReduceOpKernel(OpKernelConstruction* c)
+ : CollectiveOpKernel(c) {
+ col_params_.instance.type = REDUCTION_COLLECTIVE;
+ OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
+ OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
+ OP_REQUIRES_OK(
+ c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
+ OP_REQUIRES_OK(
+ c, c->GetAttr("subdiv_offsets",
+ &col_params_.instance.impl_details.subdiv_offsets));
+ string merge_op_name;
+ OP_REQUIRES_OK(c, c->GetAttr("merge_op", &merge_op_name));
+ OP_REQUIRES(c, merge_op_name == "Add" || merge_op_name == "Mul",
+ errors::InvalidArgument(
+ "merge_op must be one of {\"Add\", \"Mul\"} but got ",
+ merge_op_name));
+ string final_op_name;
+ OP_REQUIRES_OK(c, c->GetAttr("final_op", &final_op_name));
+ OP_REQUIRES(c, final_op_name == "Id" || final_op_name == "Div",
+ errors::InvalidArgument(
+ "final_op must be one of {\"Id\", \"Div\"} but got ",
+ final_op_name));
+ OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
+
+ const NodeDef& real_node = c->def();
+ col_params_.name = strings::StrCat(real_node.name(), ": Reduce(",
+ merge_op_name, ",", final_op_name, ")");
+ col_params_.group.device_type = c->device_type();
+
+ // Find the OpKernels by name, type and device type.
+ NodeDef sub_node;
+ // The merge_op takes two inputs
+ sub_node.add_input(real_node.input(0));
+ sub_node.add_input(real_node.input(0));
+ sub_node.set_device(real_node.device());
+ SetAttrValue(col_params_.instance.data_type,
+ &(*sub_node.mutable_attr())["T"]);
+ col_params_.merge_op = BuildOpKernel(c, merge_op_name, &sub_node);
+ col_params_.final_op = BuildOpKernel(c, final_op_name, &sub_node);
+ }
+
+ std::unique_ptr<OpKernel> BuildOpKernel(OpKernelConstruction* c,
+ const string& name,
+ NodeDef* sub_node) {
+ std::unique_ptr<OpKernel> k;
+ if (name.empty() || name == "Id") return k;
+ sub_node->set_name(name);
+ sub_node->set_op(name);
+ Status status;
+ k = CreateOpKernel(c->device_type(), c->device(),
+ c->device()->GetAllocator(AllocatorAttributes()),
+ *sub_node, c->graph_def_version(), &status);
+ if (!status.ok()) {
+ c->CtxFailureWithWarning(errors::Internal("Failed to build OpKernel for ",
+ name, " : ",
+ status.error_message()));
+ }
+ return k;
+ }
+
+ void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+ CollectiveExecutor* col_exec = c->collective_executor();
+ OP_REQUIRES_ASYNC(
+ c, col_exec,
+ errors::Internal(
+ "Failed to get CollectiveExecutor from OpKernelContext for Op ",
+ col_params_.name),
+ done);
+ if (!CanProceedWithCompute(c, col_exec, done)) return;
+ // Allocate the output tensor, trying to reuse the input.
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK_ASYNC(c,
+ c->forward_input_or_allocate_output(
+ {0}, 0, c->input(0).shape(), &output),
+ done);
+
+ auto actual_done = [c, col_exec, done](const Status& s) {
+ OP_REQUIRES_OK_ASYNC(c, s, done);
+ done();
+ };
+ col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(CollectiveReduceOpKernel);
+};
+
+REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_CPU),
+ CollectiveReduceOpKernel);
+REGISTER_KERNEL_BUILDER(Name("CollectiveReduce").Device(DEVICE_GPU),
+ CollectiveReduceOpKernel);
+
+class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
+ public:
+ explicit CollectiveBcastSendOpKernel(OpKernelConstruction* c)
+ : CollectiveOpKernel(c) {
+ col_params_.instance.type = BROADCAST_COLLECTIVE;
+ OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
+ OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
+ OP_REQUIRES_OK(
+ c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
+ OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
+ OP_REQUIRES_OK(c, c->GetAttr("shape", &shape_));
+ col_params_.is_source = true;
+ col_params_.instance.impl_details.subdiv_offsets = {0};
+
+ col_params_.name =
+ strings::StrCat(name(), ": Broadcast(", col_params_.is_source, ")");
+ col_params_.group.device_type = c->device_type();
+ }
+
+ void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+ CollectiveExecutor* col_exec = c->collective_executor();
+ OP_REQUIRES_ASYNC(
+ c, col_exec,
+ errors::Internal(
+ "Failed to get CollectiveExecutor from OpKernelContext for Op ",
+ col_params_.name),
+ done);
+ if (!CanProceedWithCompute(c, col_exec, done)) return;
+ OP_REQUIRES_ASYNC(
+ c, shape_.IsSameSize(c->input(0).shape()),
+ errors::Internal("Declared shape of op ", col_params_.name,
+ " does not match shape of input"),
+ done);
+ // Allocate the output Tensor, trying to reuse the input.
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK_ASYNC(
+ c, c->forward_input_or_allocate_output({0}, 0, shape_, &output), done);
+
+ auto actual_done = [c, col_exec, done](const Status& s) {
+ OP_REQUIRES_OK_ASYNC(c, s, done);
+ done();
+ };
+ col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
+ }
+
+ private:
+ TensorShape shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastSendOpKernel);
+};
+
+REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_CPU),
+ CollectiveBcastSendOpKernel);
+REGISTER_KERNEL_BUILDER(Name("CollectiveBcastSend").Device(DEVICE_GPU),
+ CollectiveBcastSendOpKernel);
+
+class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
+ public:
+ explicit CollectiveBcastRecvOpKernel(OpKernelConstruction* c)
+ : CollectiveOpKernel(c) {
+ col_params_.instance.type = BROADCAST_COLLECTIVE;
+ OP_REQUIRES_OK(c, c->GetAttr("group_size", &col_params_.group.group_size));
+ OP_REQUIRES_OK(c, c->GetAttr("group_key", &col_params_.group.group_key));
+ OP_REQUIRES_OK(
+ c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
+ OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
+ OP_REQUIRES_OK(c, c->GetAttr("shape", &shape_));
+ col_params_.is_source = false;
+ col_params_.instance.impl_details.subdiv_offsets = {0};
+
+ col_params_.name =
+ strings::StrCat(name(), ": Broadcast(", col_params_.is_source, ")");
+ col_params_.group.device_type = c->device_type();
+ }
+
+ void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
+ CollectiveExecutor* col_exec = c->collective_executor();
+ OP_REQUIRES_ASYNC(
+ c, col_exec,
+ errors::Internal(
+ "Failed to get CollectiveExecutor from OpKernelContext for Op ",
+ col_params_.name),
+ done);
+ if (!CanProceedWithCompute(c, col_exec, done)) return;
+ // No input, so must allocate output.
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, shape_, &output), done);
+
+ auto actual_done = [c, col_exec, done](const Status& s) {
+ OP_REQUIRES_OK_ASYNC(c, s, done);
+ done();
+ };
+ col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
+ }
+
+ private:
+ TensorShape shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastRecvOpKernel);
+};
+
+REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_CPU),
+ CollectiveBcastRecvOpKernel);
+REGISTER_KERNEL_BUILDER(Name("CollectiveBcastRecv").Device(DEVICE_GPU),
+ CollectiveBcastRecvOpKernel);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/concat_lib_gpu.cc b/tensorflow/core/kernels/concat_lib_gpu.cc
index d8643c0b2f..93e392d303 100644
--- a/tensorflow/core/kernels/concat_lib_gpu.cc
+++ b/tensorflow/core/kernels/concat_lib_gpu.cc
@@ -118,6 +118,7 @@ TF_CALL_complex128(REGISTER);
TF_CALL_int64(REGISTER);
TF_CALL_bfloat16(REGISTER);
TF_CALL_bool(REGISTER);
+TF_CALL_uint8(REGISTER);
#undef REGISTER
diff --git a/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc b/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc
index 0f7adaf24a..a561d918bd 100644
--- a/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc
+++ b/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc
@@ -202,6 +202,7 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPUCONCAT32);
TF_CALL_complex64(REGISTER_GPUCONCAT32);
TF_CALL_complex128(REGISTER_GPUCONCAT32);
TF_CALL_int64(REGISTER_GPUCONCAT32);
+TF_CALL_uint8(REGISTER_GPUCONCAT32);
REGISTER_GPUCONCAT32(bfloat16);
REGISTER_GPUCONCAT32(bool);
@@ -209,6 +210,7 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPUCONCAT64);
TF_CALL_complex64(REGISTER_GPUCONCAT64);
TF_CALL_complex128(REGISTER_GPUCONCAT64);
TF_CALL_int64(REGISTER_GPUCONCAT64);
+TF_CALL_uint8(REGISTER_GPUCONCAT64);
REGISTER_GPUCONCAT64(bfloat16);
REGISTER_GPUCONCAT64(bool);
@@ -216,6 +218,7 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU32);
TF_CALL_complex64(REGISTER_GPU32);
TF_CALL_complex128(REGISTER_GPU32);
TF_CALL_int64(REGISTER_GPU32);
+TF_CALL_uint8(REGISTER_GPU32);
REGISTER_GPU32(bfloat16);
REGISTER_GPU32(bool);
@@ -223,6 +226,7 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU64);
TF_CALL_complex64(REGISTER_GPU64);
TF_CALL_complex128(REGISTER_GPU64);
TF_CALL_int64(REGISTER_GPU64);
+TF_CALL_uint8(REGISTER_GPU64);
REGISTER_GPU64(bfloat16);
REGISTER_GPU64(bool);
diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc
index f16766315f..a87b63f913 100644
--- a/tensorflow/core/kernels/concat_op.cc
+++ b/tensorflow/core/kernels/concat_op.cc
@@ -212,6 +212,7 @@ REGISTER_CONCAT(qint32);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
REGISTER_GPU(bfloat16);
+TF_CALL_uint8(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
TF_CALL_int64(REGISTER_GPU);
diff --git a/tensorflow/core/kernels/concat_op_test.cc b/tensorflow/core/kernels/concat_op_test.cc
index e3ba8ae9f6..39b44b2fcc 100644
--- a/tensorflow/core/kernels/concat_op_test.cc
+++ b/tensorflow/core/kernels/concat_op_test.cc
@@ -78,6 +78,9 @@ static void BM_ConcatDim1Float(int iters, int dim2) {
BENCHMARK(BM_ConcatDim0Float)->Arg(1000)->Arg(100000)->Arg(1000000);
BENCHMARK(BM_ConcatDim1Float)->Arg(1000)->Arg(100000)->Arg(1000000);
+static void BM_ConcatDim1uint8(int iters, int dim2) {
+ ConcatHelper<uint8>(iters, 1, dim2);
+}
static void BM_ConcatDim1int16(int iters, int dim2) {
ConcatHelper<int16>(iters, 1, dim2);
}
@@ -85,6 +88,7 @@ static void BM_ConcatDim1bfloat16(int iters, int dim2) {
ConcatHelper<bfloat16>(iters, 1, dim2);
}
+BENCHMARK(BM_ConcatDim1uint8)->Arg(1000)->Arg(100000)->Arg(1000000);
BENCHMARK(BM_ConcatDim1int16)->Arg(1000)->Arg(100000)->Arg(1000000);
BENCHMARK(BM_ConcatDim1bfloat16)->Arg(1000)->Arg(100000)->Arg(1000000);
diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc
index 312c1a41d3..fe1a1ba5a3 100644
--- a/tensorflow/core/kernels/constant_op.cc
+++ b/tensorflow/core/kernels/constant_op.cc
@@ -258,13 +258,15 @@ REGISTER_KERNEL(GPU, Eigen::half);
REGISTER_KERNEL(GPU, bfloat16);
REGISTER_KERNEL(GPU, float);
REGISTER_KERNEL(GPU, double);
+REGISTER_KERNEL(GPU, complex64);
+REGISTER_KERNEL(GPU, complex128);
REGISTER_KERNEL(GPU, uint8);
REGISTER_KERNEL(GPU, int8);
REGISTER_KERNEL(GPU, uint16);
REGISTER_KERNEL(GPU, int16);
REGISTER_KERNEL(GPU, int64);
REGISTER_KERNEL(GPU, bool);
-// Currently we do not support filling strings and complex64 on GPU
+// Currently we do not support filling strings on GPU
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc
index 07dc786d9b..e4036ddaa9 100644
--- a/tensorflow/core/kernels/cudnn_rnn_ops.cc
+++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc
@@ -227,22 +227,43 @@ inline perftools::gputools::port::Status ToExecutorStatus(const Status& s) {
s.error_message());
}
-// A helper to allocate temporary scratch memory for Cudnn RNN models. It takes
-// the ownership of the underlying memory. The expectation is that the memory
-// should be alive for the span of the Cudnn RNN itself.
-class CudnnRNNWorkspaceAllocator : public ScratchAllocator {
+template <typename>
+struct ToTFDataType;
+
+template <>
+struct ToTFDataType<Eigen::half> : std::integral_constant<DataType, DT_HALF> {};
+
+template <>
+struct ToTFDataType<float> : std::integral_constant<DataType, DT_FLOAT> {};
+
+template <>
+struct ToTFDataType<double> : std::integral_constant<DataType, DT_DOUBLE> {};
+
+template <>
+struct ToTFDataType<uint8> : std::integral_constant<DataType, DT_UINT8> {};
+
+// A helper to allocate temporary scratch memory for Cudnn RNN models. It
+// takes the ownership of the underlying memory. The expectation is that the
+// memory should be alive for the span of the Cudnn RNN itself.
+template <typename T>
+class CudnnRnnAllocatorInTemp : public ScratchAllocator {
public:
- ~CudnnRNNWorkspaceAllocator() override {}
- explicit CudnnRNNWorkspaceAllocator(OpKernelContext* context)
+ ~CudnnRnnAllocatorInTemp() = default;
+
+ explicit CudnnRnnAllocatorInTemp(OpKernelContext* context)
: context_(context) {}
int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override {
return std::numeric_limits<int64>::max();
}
+
StatusOr<DeviceMemory<uint8>> AllocateBytes(
perftools::gputools::Stream* stream, int64 byte_size) override {
Tensor temporary_memory;
+ const DataType tf_data_type = ToTFDataType<T>::value;
+ int64 allocate_count =
+ Eigen::divup(byte_size, static_cast<int64>(sizeof(T)));
Status allocation_status(context_->allocate_temp(
- DT_UINT8, TensorShape({byte_size}), &temporary_memory));
+ tf_data_type, TensorShape({allocate_count}), &temporary_memory));
if (!allocation_status.ok()) {
return ToExecutorStatus(allocation_status);
}
@@ -250,10 +271,16 @@ class CudnnRNNWorkspaceAllocator : public ScratchAllocator {
// allocator.
allocated_tensors_.push_back(temporary_memory);
total_byte_size_ += byte_size;
- return StatusOr<DeviceMemory<uint8>>(
- AsDeviceMemory<uint8>(&temporary_memory));
+ return DeviceMemory<uint8>::MakeFromByteSize(
+ temporary_memory.template flat<T>().data(),
+ temporary_memory.template flat<T>().size() * sizeof(T));
+ }
+
+ int64 TotalByteSize() const { return total_byte_size_; }
+
+ Tensor get_allocated_tensor(int index) const {
+ return allocated_tensors_[index];
}
- int64 TotalByteSize() { return total_byte_size_; }
private:
int64 total_byte_size_ = 0;
@@ -261,15 +288,15 @@ class CudnnRNNWorkspaceAllocator : public ScratchAllocator {
std::vector<Tensor> allocated_tensors_;
};
-// A helper to allocate reserve-space memory for Cudnn RNN models. The tensors
-// are allocated as a kernel output, and will be fed into the backward pass.
+// A helper to allocate memory for Cudnn RNN models as a kernel output. It is
+// used by forward pass kernel to feed the output to the backward pass.
// The memory is expected to live long enough after the backward pass is
// finished.
template <typename T>
-class CudnnRNNReserveSpaceAllocator : public ScratchAllocator {
+class CudnnRnnAllocatorInOutput : public ScratchAllocator {
public:
- ~CudnnRNNReserveSpaceAllocator() override {}
- CudnnRNNReserveSpaceAllocator(OpKernelContext* context, int output_index)
+ ~CudnnRnnAllocatorInOutput() override {}
+ CudnnRnnAllocatorInOutput(OpKernelContext* context, int output_index)
: context_(context), output_index_(output_index) {}
int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override {
return std::numeric_limits<int64>::max();
@@ -343,13 +370,14 @@ struct CudnnModelTypes {
TFRNNInputMode rnn_input_mode;
RnnDirectionMode rnn_direction_mode;
bool HasInputC() const {
- // For Cudnn 5.0, only LSTM has input-c. All other models use only input-h.
+ // For Cudnn 5.0, only LSTM has input-c. All other models use only
+ // input-h.
return rnn_mode == RnnMode::kRnnLstm;
}
};
// A helper class that collects the shapes to describe a RNN model.
-struct CudnnModelShapes {
+struct CudnnRnnModelShapes {
int num_layers;
int input_size;
int num_units;
@@ -360,7 +388,7 @@ struct CudnnModelShapes {
TensorShape output_shape;
TensorShape hidden_state_shape;
// At present only fields related to cached RnnDescriptor are concerned.
- bool IsCompatibleWith(const CudnnModelShapes& rhs) const {
+ bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const {
return num_layers == rhs.num_layers && input_size == rhs.input_size &&
num_units == rhs.num_units && dir_count == rhs.dir_count;
}
@@ -371,9 +399,9 @@ struct CudnnModelShapes {
}
};
-// Utility class for using CudnnModelShapes as a hash table key.
-struct CudnnModelShapesHasher {
- uint64 operator()(const CudnnModelShapes& to_hash) const {
+// Utility class for using CudnnRnnModelShapes as a hash table key.
+struct CudnnRnnModelShapesHasher {
+ uint64 operator()(const CudnnRnnModelShapes& to_hash) const {
uint64 hash = static_cast<uint64>(to_hash.num_layers);
hash = tensorflow::FingerprintCat64(
hash, static_cast<uint64>(to_hash.input_size));
@@ -384,21 +412,21 @@ struct CudnnModelShapesHasher {
}
};
-// Utility class for using CudnnModelShapes as a hash table key.
-struct CudnnModelShapesComparator {
- bool operator()(const CudnnModelShapes& first,
- const CudnnModelShapes& second) const {
+// Utility class for using CudnnRnnModelShapes as a hash table key.
+struct CudnnRnnModelShapesComparator {
+ bool operator()(const CudnnRnnModelShapes& first,
+ const CudnnRnnModelShapes& second) const {
return first.IsCompatibleWith(second);
}
};
-// Extract and checks the forward input tensors, parameters, and shapes from the
-// OpKernelContext.
+// Extract and checks the forward input tensors, parameters, and shapes from
+// the OpKernelContext.
Status ExtractForwardInput(OpKernelContext* context,
const CudnnModelTypes& model_types,
const Tensor** input, const Tensor** input_h,
const Tensor** input_c, const Tensor** params,
- CudnnModelShapes* model_shapes) {
+ CudnnRnnModelShapes* model_shapes) {
TF_RETURN_IF_ERROR(context->input("input", input));
TF_RETURN_IF_ERROR(context->input("input_h", input_h));
if (model_types.HasInputC()) {
@@ -810,7 +838,7 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
const Tensor* input_h = nullptr;
const Tensor* input_c = nullptr;
const Tensor* params = nullptr;
- CudnnModelShapes model_shapes;
+ CudnnRnnModelShapes model_shapes;
OP_REQUIRES_OK(context,
ExtractForwardInput(context, model_types(), &input, &input_h,
&input_c, &params, &model_shapes));
@@ -876,7 +904,7 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
// Creates a memory callback for the reserve_space. The memory lives in the
// output of this kernel. And it will be fed into the backward pass when
// needed.
- CudnnRNNReserveSpaceAllocator<T> reserve_space_allocator(context, 3);
+ CudnnRnnAllocatorInOutput<T> reserve_space_allocator(context, 3);
if (!is_training_) {
Tensor* dummy_reserve_space = nullptr;
OP_REQUIRES_OK(context,
@@ -884,7 +912,7 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
}
// Creates a memory callback for the workspace. The memory lives to the end
// of this kernel calls.
- CudnnRNNWorkspaceAllocator workspace_allocator(context);
+ CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
bool launch_status = false;
{
mutex_lock l(mu_);
@@ -910,7 +938,7 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
input_c_data, params_data, *output_desc, &output_data,
*hidden_state_desc, &output_h_data, *hidden_state_desc,
&output_c_data, is_training_, &reserve_space_allocator,
- &workspace_allocator, /* output_result_profile */ nullptr)
+ &workspace_allocator, /*output_result_profile=*/nullptr)
.ok();
}
OP_REQUIRES(context, launch_status,
@@ -920,8 +948,8 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
private:
mutex mu_;
bool is_training_;
- std::unordered_map<CudnnModelShapes, RnnScratchSpace, CudnnModelShapesHasher,
- CudnnModelShapesComparator>
+ std::unordered_map<CudnnRnnModelShapes, RnnScratchSpace,
+ CudnnRnnModelShapesHasher, CudnnRnnModelShapesComparator>
rnn_state_cache_ GUARDED_BY(mu_);
};
@@ -949,7 +977,7 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
const Tensor* input_h = nullptr;
const Tensor* input_c = nullptr;
const Tensor* params = nullptr;
- CudnnModelShapes model_shapes;
+ CudnnRnnModelShapes model_shapes;
OP_REQUIRES_OK(context,
ExtractForwardInput(context, model_types(), &input, &input_h,
&input_c, &params, &model_shapes));
@@ -1090,7 +1118,7 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
auto reserve_space_uint8 = CastDeviceMemory<uint8, T>(reserve_space);
// Creates a memory callback for the workspace. The memory lives to the end
// of this kernel calls.
- CudnnRNNWorkspaceAllocator workspace_allocator(context);
+ CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
bool launch_status = false;
{
mutex_lock l(mu_);
@@ -1119,7 +1147,7 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
output_c_backprop_data, &input_backprop_data,
&input_h_backprop_data, &input_c_backprop_data,
&params_backprop_data, &reserve_space_uint8,
- &workspace_allocator, /* output_result_profile */ nullptr)
+ &workspace_allocator, /*output_result_profile=*/nullptr)
.ok();
}
OP_REQUIRES(context, launch_status,
@@ -1128,8 +1156,8 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
private:
mutex mu_;
- std::unordered_map<CudnnModelShapes, RnnScratchSpace, CudnnModelShapesHasher,
- CudnnModelShapesComparator>
+ std::unordered_map<CudnnRnnModelShapes, RnnScratchSpace,
+ CudnnRnnModelShapesHasher, CudnnRnnModelShapesComparator>
rnn_state_cache_ GUARDED_BY(mu_);
};
diff --git a/tensorflow/core/kernels/decode_proto_op.cc b/tensorflow/core/kernels/decode_proto_op.cc
new file mode 100644
index 0000000000..b4e5b776ed
--- /dev/null
+++ b/tensorflow/core/kernels/decode_proto_op.cc
@@ -0,0 +1,1011 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// DecodeProto is a TensorFlow Op which extracts arbitrary fields
+// from protos serialized as strings.
+//
+// See docs in ../ops/decode_proto_op.cc.
+//
+// This implementation reads the serialized format using a handful of
+// calls from the WireFormatLite API used by generated proto code.
+// WireFormatLite is marked as an "internal" proto API but is widely
+// used in practice and highly unlikely to change.
+// This will be much faster than the previous implementation based on
+// constructing a temporary dynamic message in memory and using the
+// proto reflection api to read it.
+// It can be used with any proto whose descriptors are available at
+// runtime but should be competitive in speed with approaches that
+// compile in the proto definitions.
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "third_party/eigen3/Eigen/Core"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/util/proto/decode.h"
+#include "tensorflow/core/util/proto/descriptors.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+namespace {
+
+using ::tensorflow::MakeUnique;
+using ::tensorflow::protobuf::Descriptor;
+using ::tensorflow::protobuf::DescriptorPool;
+using ::tensorflow::protobuf::DynamicMessageFactory;
+using ::tensorflow::protobuf::FieldDescriptor;
+using ::tensorflow::protobuf::Message;
+using ::tensorflow::protobuf::TextFormat;
+using ::tensorflow::protobuf::internal::WireFormatLite;
+using ::tensorflow::protobuf::io::CodedInputStream;
+
+const bool kFailOnDecodeError = true;
+
+// Returns true if the proto field type can be converted to the
+// tensorflow::DataType.
+bool CheckOutputType(FieldDescriptor::Type field_type, DataType output_type) {
+ switch (field_type) {
+ case WireFormatLite::TYPE_DOUBLE:
+ return output_type == tensorflow::DT_DOUBLE;
+ case WireFormatLite::TYPE_FLOAT:
+ return output_type == tensorflow::DT_FLOAT ||
+ output_type == tensorflow::DT_DOUBLE;
+ case WireFormatLite::TYPE_INT64:
+ return output_type == tensorflow::DT_INT64;
+ case WireFormatLite::TYPE_UINT64:
+ return output_type == tensorflow::DT_INT64;
+ case WireFormatLite::TYPE_INT32:
+ return output_type == tensorflow::DT_INT32;
+ case WireFormatLite::TYPE_FIXED64:
+ return output_type == tensorflow::DT_INT64;
+ case WireFormatLite::TYPE_FIXED32:
+ return output_type == tensorflow::DT_INT32 ||
+ output_type == tensorflow::DT_INT64;
+ case WireFormatLite::TYPE_BOOL:
+ return output_type == tensorflow::DT_BOOL;
+ case WireFormatLite::TYPE_STRING:
+ return output_type == tensorflow::DT_STRING;
+ case WireFormatLite::TYPE_GROUP:
+ return output_type == tensorflow::DT_STRING;
+ case WireFormatLite::TYPE_MESSAGE:
+ return output_type == tensorflow::DT_STRING;
+ case WireFormatLite::TYPE_BYTES:
+ return output_type == tensorflow::DT_STRING;
+ case WireFormatLite::TYPE_UINT32:
+ return output_type == tensorflow::DT_INT32 ||
+ output_type == tensorflow::DT_INT64;
+ case WireFormatLite::TYPE_ENUM:
+ return output_type == tensorflow::DT_INT32;
+ case WireFormatLite::TYPE_SFIXED32:
+ return output_type == tensorflow::DT_INT32;
+ case WireFormatLite::TYPE_SFIXED64:
+ return output_type == tensorflow::DT_INT64;
+ case WireFormatLite::TYPE_SINT32:
+ return output_type == tensorflow::DT_INT32;
+ case WireFormatLite::TYPE_SINT64:
+ return output_type == tensorflow::DT_INT64;
+ // default: intentionally omitted in order to enable static checking.
+ }
+}
+
+// A FieldInfo holds a handful of information from the FieldDescriptor
+// and user attributes.
+struct FieldInfo {
+ FieldInfo(const FieldDescriptor* field_desc, int user_index)
+ : output_index(user_index) {
+ // Without this intermediate data structure, the profile had hotspots
+ // calling methods of FieldDescriptor.
+ number = field_desc->number();
+
+ // The wire format library defines the same constants used in
+ // descriptor.proto. This static_cast is safe because they
+ // are guaranteed to stay in sync.
+ // We need the field type from the FieldDescriptor here
+ // because the wire format doesn't tell us anything about
+ // what happens inside a packed repeated field: there is
+ // enough information in the wire format to skip the
+ // whole field but not enough to know how to parse what's
+ // inside. For that we go to the schema.
+ type = static_cast<WireFormatLite::FieldType>(field_desc->type());
+ is_repeated = field_desc->is_repeated();
+ }
+
+ // Disable copy and move.
+ FieldInfo(const FieldInfo&) = delete;
+ FieldInfo& operator=(const FieldInfo&) = delete;
+
+ // Internally we sort field descriptors by wire number for
+ // fast lookup. In general this is different from the order
+ // given by the user. Output_index gives the index into
+ // the field_names and output_types attributes and into
+ // the output tensor list.
+ int output_index = -1;
+
+ // This is a cache of the relevant fields from `FieldDescriptorProto`.
+ // This was added after noticing that FieldDescriptor->type() was
+ // using 6% of the cpu profile.
+ WireFormatLite::FieldType type;
+ int number;
+ bool is_repeated;
+};
+
+// A CountCollector counts sizes of repeated and optional fields in a proto.
+//
+// Each field is tracked by a single CountCollector instance. The
+// instance manages a single count, which is stored as a pointer (it
+// is intended to be a reference to the `sizes` output which is being
+// filled in). The pointer is passed in at initialization.
+//
+// Counting is done as a separate pass in order to allocate output tensors
+// all at once. This allows the TensorFlow runtime to optimize allocation
+// for the consumer, while removing the need for copying inside this op.
+// After this pass, the DenseCollector class (below) gathers the data:
+// It is more complex and provides better motivation for the API here.
+class CountCollector {
+ public:
+ // Default constructor allows the collector to be a vector element.
+ CountCollector() = default;
+
+ // The count may be stored inside an Eigen Tensor to eliminate copying.
+ explicit CountCollector(int32* count) : count_ptr_(count) {}
+
+ // Reads (in this case counts) a single value.
+ Status ReadValue(CodedInputStream* input, const FieldInfo& field) {
+ // Only repeated fields can have count > 1.
+ if (*count_ptr_ == 0 || field.is_repeated) {
+ (*count_ptr_)++;
+ }
+ // We expect a wire type based on the schema field_type, to allow
+ // a little more checking.
+ if (!SkipValue(input, field)) {
+ return errors::DataLoss("ReadValue: Failed skipping field when counting");
+ }
+ return Status::OK();
+ }
+
+ // Reads (in this case counts) a length-delimited list of values.
+ Status ReadPackedValues(CodedInputStream* input, const FieldInfo& field,
+ size_t buf_size) {
+ if (buf_size == 0) {
+ return Status::OK();
+ }
+
+ const void* tmpbuf;
+ int unused_max_buf_size;
+
+ input->GetDirectBufferPointerInline(&tmpbuf, &unused_max_buf_size);
+ // This is safe because the underlying storage for the CodedInputStream is
+ // owned by the input tensor. If it were a Cord or file-backed stream this
+ // pointer would go stale after the bytes were skipped.
+ const uint8* buf = reinterpret_cast<const uint8*>(tmpbuf);
+
+ // Important: we skipped the input->{Push,Pop}Limit() calls for speed,
+ // so the bounds check on buf_size inside Skip() is critical, and
+ // must be done before scanning the contents.
+ if (!input->Skip(buf_size)) {
+ return errors::DataLoss("ReadPackedValues: Skipping packed field failed");
+ }
+
+ // Dispatch to the appropriately typed field reader based on the
+ // schema type.
+ Status st;
+ switch (field.type) {
+ case WireFormatLite::TYPE_DOUBLE:
+ st = CountPackedFixed<double>(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_FLOAT:
+ st = CountPackedFixed<float>(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_INT64:
+ st = CountPackedVarint(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_UINT64:
+ st = CountPackedVarint(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_INT32:
+ st = CountPackedVarint(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_FIXED64:
+ st = CountPackedFixed<uint64>(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_FIXED32:
+ st = CountPackedFixed<uint32>(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_BOOL:
+ st = CountPackedVarint(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_STRING:
+ st = errors::DataLoss("TYPE_STRING encountered as packed");
+ break;
+ case WireFormatLite::TYPE_GROUP:
+ st = errors::DataLoss("TYPE_GROUP encountered as packed");
+ break;
+ case WireFormatLite::TYPE_MESSAGE:
+ st = errors::DataLoss("TYPE_MESSAGE encountered as packed");
+ break;
+ case WireFormatLite::TYPE_BYTES:
+ st = errors::DataLoss("TYPE_BYTES encountered as packed");
+ break;
+ case WireFormatLite::TYPE_UINT32:
+ st = CountPackedVarint(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_ENUM:
+ st = CountPackedVarint(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_SFIXED32:
+ st = CountPackedFixed<int32>(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_SFIXED64:
+ st = CountPackedFixed<int64>(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_SINT32:
+ st = CountPackedVarint(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_SINT64:
+ st = CountPackedVarint(buf, buf_size);
+ break;
+ // default: intentionally omitted in order to enable static checking.
+ }
+ if (!st.ok()) {
+ return st;
+ }
+
+ if (!field.is_repeated && *count_ptr_ > 1) {
+ *count_ptr_ = 1;
+ }
+ return Status::OK();
+ }
+
+ private:
+ // Skips a length-delimited value.
+ static bool SkipBytes(CodedInputStream* input) {
+ uint32 length;
+ if (!input->ReadVarint32(&length)) {
+ return false;
+ }
+ return input->Skip(length);
+ }
+
+ // Counts the number of packed varints in an array.
+ // The end of a varint is signaled by a value < 0x80,
+ // so counting them requires parsing the bytestream.
+ // It is the caller's responsibility to ensure that len > 0.
+ Status CountPackedVarint(const uint8* buf, size_t len) {
+ const uint8* bound = buf + len;
+ int count;
+
+ // The last byte in a valid encoded varint is guaranteed to have
+ // the high bit unset. We rely on this property to prevent
+ // ReadVarint64FromArray from going out of bounds, so validate
+ // the end of the buf before scanning anything.
+ if (bound[-1] & 0x80) {
+ return errors::DataLoss("Corrupt packed varint");
+ }
+
+ // Now we can trust ReadVarint64FromArray to stay in bounds.
+ for (count = 0; buf < bound; ++count) {
+ uint64 temp;
+ bool ok;
+ buf = internal::ReadVarint64FromArray(buf, &ok, &temp);
+ if (!ok) {
+ return errors::DataLoss("Corrupt packed varint");
+ }
+ }
+
+ *count_ptr_ += count;
+ return Status::OK();
+ }
+
+ // Counts the number of fixed-size values in a packed field.
+ // This can be done without actually parsing anything.
+ template <typename T>
+ Status CountPackedFixed(const uint8* unused_buf, size_t len) {
+ int count = len / sizeof(T);
+ if (count * sizeof(T) != len) {
+ return errors::DataLoss(
+ "Illegal data length for packed fixed-size type: ", len);
+ }
+ *count_ptr_ += len / sizeof(T);
+ return Status::OK();
+ }
+
+ // Skips a single value in the input stream.
+ // Dispatches to the appropriately typed field skipper based on the
+ // schema type tag.
+ // This is not as permissive as just handling the wire type.
+ static bool SkipValue(CodedInputStream* input, const FieldInfo& field) {
+ uint32 tmp32;
+ protobuf_uint64 tmp64;
+ switch (field.type) {
+ case WireFormatLite::TYPE_DOUBLE:
+ return input->ReadLittleEndian64(&tmp64);
+ case WireFormatLite::TYPE_FLOAT:
+ return input->ReadLittleEndian32(&tmp32);
+ case WireFormatLite::TYPE_INT64:
+ return input->ReadVarint64(&tmp64);
+ case WireFormatLite::TYPE_UINT64:
+ return input->ReadVarint64(&tmp64);
+ case WireFormatLite::TYPE_INT32:
+ return input->ReadVarint32(&tmp32);
+ case WireFormatLite::TYPE_FIXED64:
+ return input->ReadLittleEndian64(&tmp64);
+ case WireFormatLite::TYPE_FIXED32:
+ return input->ReadLittleEndian32(&tmp32);
+ case WireFormatLite::TYPE_BOOL:
+ return input->ReadVarint32(&tmp32);
+ case WireFormatLite::TYPE_STRING:
+ return SkipBytes(input);
+ case WireFormatLite::TYPE_GROUP:
+ return WireFormatLite::SkipField(
+ input, WireFormatLite::MakeTag(
+ field.number, WireFormatLite::WIRETYPE_START_GROUP));
+ case WireFormatLite::TYPE_MESSAGE:
+ return SkipBytes(input);
+ case WireFormatLite::TYPE_BYTES:
+ return SkipBytes(input);
+ case WireFormatLite::TYPE_UINT32:
+ return input->ReadVarint32(&tmp32);
+ case WireFormatLite::TYPE_ENUM:
+ return input->ReadVarint32(&tmp32);
+ case WireFormatLite::TYPE_SFIXED32:
+ return input->ReadLittleEndian32(&tmp32);
+ case WireFormatLite::TYPE_SFIXED64:
+ return input->ReadLittleEndian64(&tmp64);
+ case WireFormatLite::TYPE_SINT32:
+ return input->ReadVarint32(&tmp32);
+ case WireFormatLite::TYPE_SINT64:
+ return input->ReadVarint64(&tmp64);
+ // default: intentionally omitted in order to enable static checking.
+ }
+ }
+
+ int32* count_ptr_ = nullptr;
+};
+
+// A DenseCollector accumulates values from a proto into a tensor.
+//
+// There is an instance of DenseCollector for each field of each
+// proto. The DenseCollector deserializes the value from the wire
+// directly into the preallocated output Tensor.
+//
+// This class is named DenseCollector because in the future there should
+// be a SparseCollector that accumulates field data into sparse tensors if
+// the user requests it.
+class DenseCollector {
+ public:
+ // Default constructor allows the collector to be a vector element.
+ DenseCollector() = default;
+
+ // A DenseCollector applies to one field of a serialized message.
+ DenseCollector(uint8* datap, DataType dtype, int max_repeat_count)
+ : datap_(datap), dtype_(dtype), max_repeat_count_(max_repeat_count) {}
+
+ // Reads a value from the input stream and stores it.
+ //
+ // Always inlining gave a ~50% speedup on microbenchmarks at one point.
+ // TODO(nix): try removing it to see if that still holds.
+ // TODO(jsimsa): ABSL_ATTRIBUTE_ALWAYS_INLINE
+ Status ReadValue(CodedInputStream* input, const FieldInfo& field) {
+ // For required and optional fields, we overwrite values[0] with
+ // the latest one in the wire stream.
+ // See https://developers.google.com/protocol-buffers/docs/encoding#optional
+ // Only for repeated fields do we advance the next_repeat_index_ past 1.
+ // TODO(nix): to handle oneof we must also zero out any previous values
+ // seen on the wire.
+ int32 index = 0;
+ if (field.is_repeated) {
+ index = next_repeat_index_;
+ }
+ next_repeat_index_ = index + 1;
+
+ return internal::ReadValue(input, field.type, field.number, dtype_, index,
+ datap_);
+ }
+
+ // Reads and stores a length-delimited list of values.
+ Status ReadPackedValues(CodedInputStream* input, const FieldInfo& field,
+ const size_t buf_size) {
+ const void* buf;
+ int unused_max_buf_size;
+ input->GetDirectBufferPointerInline(&buf, &unused_max_buf_size);
+ // This is safe because the underlying storage for the CodedInputStream is
+ // owned by the input tensor. If it were a Cord or file-backed stream this
+ // pointer would go stale after the bytes were skipped.
+ if (!input->Skip(buf_size)) {
+ return errors::DataLoss(
+ "ReadPackedValues: Skipping packed field failed. Field tag: ",
+ field.number);
+ }
+
+ // Setting stride=0 causes new values to overwrite old ones for
+ // non-repeated fields.
+ const int stride = field.is_repeated ? 1 : 0;
+
+ if (next_repeat_index_ >= max_repeat_count_) {
+ return errors::DataLoss(
+ "ReadPackedValues: Tried to write more entries than allowed. "
+ "Field tag: ",
+ field.number, ", Max entries allowed: ", max_repeat_count_);
+ } else {
+ return internal::ReadPackedFromArray(buf, buf_size, field.type,
+ field.number, dtype_, stride,
+ &next_repeat_index_, datap_);
+ }
+ }
+
+ // Fills in any missing values in the output array with defaults.
+ // Dispatches to the appropriately typed field default based on the
+ // runtime type tag.
+ Status FillWithDefaults() {
+ switch (dtype_) {
+ case DataType::DT_FLOAT:
+ return FillDefault<float>();
+ case DataType::DT_DOUBLE:
+ return FillDefault<double>();
+ case DataType::DT_INT32:
+ return FillDefault<int32>();
+ case DataType::DT_UINT8:
+ return FillDefault<uint8>();
+ case DataType::DT_INT8:
+ return FillDefault<int8>();
+ case DataType::DT_STRING:
+ return FillDefault<string>();
+ case DataType::DT_INT64:
+ return FillDefault<int64>();
+ case DataType::DT_BOOL:
+ return FillDefault<bool>();
+ default:
+ // There are many tensorflow dtypes not handled here, but they
+ // should not come up unless type casting is added to the Op.
+ // Chaining with tf.cast() should do the right thing until then.
+ return errors::DataLoss(
+ "Failed filling defaults in unknown tf::DataType");
+ }
+ }
+
+ private:
+ // Fills empty values in the dense representation with a
+ // default value. This uses next_repeat_index_ which counts the number
+ // of parsed values for the field.
+ template <class T>
+ Status FillDefault() {
+ for (int i = next_repeat_index_; i < max_repeat_count_; i++) {
+ reinterpret_cast<T*>(datap_)[i] = T();
+ }
+ return Status::OK();
+ }
+
+ int32 next_repeat_index_ = 0;
+
+ // This is a pointer to data_[message_index_].
+ // There is no bounds checking at this level: we computed the max
+ // repeat size for each field in CountCollector and use the same
+ // code to traverse it here, so we are guaranteed not to be called
+ // for more items than we have allocated space.
+ void* const datap_ = nullptr;
+
+ const DataType dtype_ = DataType::DT_INVALID;
+ const int max_repeat_count_ = 0;
+};
+
+class DecodeProtoOp : public OpKernel {
+ public:
+ explicit DecodeProtoOp(OpKernelConstruction* context) : OpKernel(context) {
+ string descriptor_source;
+ OP_REQUIRES_OK(context,
+ context->GetAttr("descriptor_source", &descriptor_source));
+
+ // We always get back a desc_pool, but we may not own it. If we own it,
+ // owned_desc_pool_ will be filled in.
+ DescriptorPool const* desc_pool;
+ OP_REQUIRES_OK(context, GetDescriptorPool(context->env(), descriptor_source,
+ &desc_pool, &owned_desc_pool_));
+
+ string message_type;
+ OP_REQUIRES_OK(context, context->GetAttr("message_type", &message_type));
+
+ const Descriptor* message_desc =
+ desc_pool->FindMessageTypeByName(message_type);
+ OP_REQUIRES(context, message_desc != nullptr,
+ errors::InvalidArgument("No descriptor found for message type ",
+ message_type));
+
+ std::vector<string> field_names;
+ OP_REQUIRES_OK(context, context->GetAttr("field_names", &field_names));
+ std::vector<DataType> output_types;
+ OP_REQUIRES_OK(context, context->GetAttr("output_types", &output_types));
+ OP_REQUIRES(
+ context, field_names.size() == output_types.size(),
+ errors::InvalidArgument("field_names and output_types attributes must "
+ "have the same length"));
+
+ // Gather the field descriptors and check that requested output types match.
+
+ int field_index = 0;
+ std::vector<const FieldDescriptor*> field_descs;
+ for (const string& name : field_names) {
+ auto fd = message_desc->FindFieldByName(name);
+ OP_REQUIRES(context, fd != nullptr,
+ errors::InvalidArgument("Unknown field: ", name,
+ " in message type ", message_type));
+ OP_REQUIRES(context,
+ CheckOutputType(fd->type(), output_types[field_index]),
+ // Many TensorFlow types don't have corresponding proto types
+ // and the user will get an error if they are requested. It
+ // would be nice to allow conversions here, but tf.cast
+ // already exists so we don't duplicate the functionality.
+ // Known unhandled types:
+ // DT_INT16 DT_COMPLEX64 DT_QINT8 DT_QUINT8 DT_QINT32
+ // DT_BFLOAT16 DT_QINT16 DT_QUINT16 DT_UINT16
+ errors::InvalidArgument("Unexpected output type for ",
+ fd->full_name(), ": ", fd->cpp_type(),
+ " to ", output_types[field_index]));
+
+ field_index++;
+ field_descs.push_back(fd);
+ }
+
+ // Internally we want the field_descs sorted by their number on the wire.
+ // But the output tensors are allocated in the order given by the caller.
+ // Build a mapping i->j, where field_descs[i] corresponds to outputs[j].
+ std::vector<int> output_indices;
+ output_indices.reserve(field_names.size());
+ for (int i = 0; i < field_names.size(); i++) {
+ output_indices.push_back(i);
+ }
+ std::sort(output_indices.begin(), output_indices.end(),
+ [field_descs](int a, int b) {
+ return field_descs[a]->number() < field_descs[b]->number();
+ });
+
+ // Now store the fields in sorted order.
+ for (int i = 0; i < field_names.size(); i++) {
+ fields_.push_back(MakeUnique<FieldInfo>(field_descs[output_indices[i]],
+ output_indices[i]));
+ }
+
+ message_prototype_ = message_factory_.GetPrototype(message_desc);
+ OP_REQUIRES(context, message_prototype_ != nullptr,
+ errors::InvalidArgument("Couldn't get prototype message: ",
+ message_desc->full_name()));
+ string format;
+ OP_REQUIRES_OK(context, context->GetAttr("message_format", &format));
+ OP_REQUIRES(
+ context, format == "binary" || format == "text",
+ errors::InvalidArgument("format must be one of binary or text"));
+ is_binary_ = format == "binary";
+
+ // Enable the initial protobuf sanitizer, which is much
+ // more expensive than the decoder.
+ // TODO(nix): Remove this once the fast decoder
+ // has passed security review.
+ OP_REQUIRES_OK(context, context->GetAttr("sanitize", &sanitize_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& buf_tensor = ctx->input(0);
+ int message_count = buf_tensor.NumElements();
+ OP_REQUIRES(ctx, message_count >= 1,
+ errors::InvalidArgument(
+ "Bufs argument must contain at least one value"));
+
+ int field_count = fields_.size();
+
+ // Save the argument shape for later, then flatten the input
+ // Tensor since we are working componentwise. We will restore
+ // the same shape in the returned Tensor.
+ const TensorShape& shape_prefix = buf_tensor.shape();
+
+ TensorShape sizes_shape = shape_prefix;
+ sizes_shape.AddDim(field_count);
+ Tensor* sizes_tensor = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, sizes_shape, &sizes_tensor));
+
+ // This is used to allocate binary bufs if used. It serves only
+ // to define memory ownership.
+ std::vector<string> tmp_binary_bufs(message_count);
+
+ // These are the actual buffers to use, which may be in tmp_binary_bufs
+ // or may be pointers into the buf_tensor. Either way they are not owned
+ // here.
+ std::vector<const string*> bufs;
+
+ if (is_binary_ && !sanitize_) {
+ // Fast path.
+ for (int mi = 0; mi < message_count; ++mi) {
+ const string* buf = &buf_tensor.flat<string>()(mi);
+ bufs.push_back(buf);
+ }
+ } else {
+ // We will have to allocate a copy, either to convert from text to
+ // binary or to sanitize a binary proto.
+ for (int mi = 0; mi < message_count; ++mi) {
+ ReserializeMessage(ctx, buf_tensor.flat<string>()(mi),
+ &tmp_binary_bufs[mi]);
+ if (!ctx->status().ok()) {
+ return;
+ }
+ bufs.push_back(&tmp_binary_bufs[mi]);
+ }
+ }
+
+ // Walk through all the strings in the input tensor, counting
+ // the number of fields in each.
+ // We can't allocate our actual output Tensor until we know the
+ // maximum repeat count, so we do a first pass through the serialized
+ // proto just counting fields.
+ // We always allocate at least one value so that optional fields
+ // are populated with default values - this avoids a TF
+ // conditional when handling the output data.
+ // The caller can distinguish between real data and defaults
+ // using the repeat count matrix that is returned by decode_proto.
+ std::vector<int32> max_sizes(field_count, 1);
+ for (int mi = 0; mi < message_count; ++mi) {
+ CountFields(ctx, mi, *bufs[mi], sizes_tensor, &max_sizes);
+ if (!ctx->status().ok()) {
+ return;
+ }
+ }
+
+ // Allocate the output tensors now that we've seen the max size.
+ // TODO(nix): Use allocate_output_or_forward_input for the largest
+ // output tensor. This can avoid one large allocation by re-using
+ // the memory of the input tensor.
+ std::vector<Tensor*> outputs(field_count);
+ for (int fi = 0; fi < field_count; ++fi) {
+ TensorShape flat_shape = {static_cast<int64>(message_count),
+ max_sizes[fi]};
+ TensorShape out_shape = shape_prefix;
+ out_shape.AddDim(max_sizes[fi]);
+
+ // Surprisingly we don't specify the types from the output_types
+ // attribute: that is done for us based on the Op declaration:
+ // REGISTER_OP(...)
+ // .Attr("output_types: list(type) >= 0")
+ // .Output("values: output_types")
+ OP_REQUIRES_OK(ctx,
+ // ctx->allocate_output(output_indices_[fi] + 1,
+ ctx->allocate_output(fields_[fi]->output_index + 1,
+ out_shape, &outputs[fi]));
+ }
+
+ // Make the second pass through the serialized proto, decoding
+ // into preallocated tensors.
+ AccumulateFields(ctx, bufs, outputs);
+ }
+
+ private:
+ // Copy a serialized message to binary, e.g. to handle text proto inputs.
+ void ReserializeMessage(OpKernelContext* ctx, const string& buf,
+ string* binary_buf) {
+ // Handle text protos by translating them to binary.
+ std::unique_ptr<Message> message(message_prototype_->New());
+ OP_REQUIRES(ctx, message, errors::DataLoss("Initializing message failed"));
+
+ if (is_binary_) {
+ // If we get here we are sanitizing the input protobuf by parsing
+ // and reserializing it with a trusted (but very slow) library.
+ OP_REQUIRES(ctx, message->ParseFromString(buf),
+ errors::DataLoss("Unable to parse binary protobuf"));
+ } else {
+ OP_REQUIRES(ctx, TextFormat::ParseFromString(buf, message.get()),
+ errors::DataLoss("Unable to parse text protobuf"));
+ }
+
+ OP_REQUIRES(ctx, message->SerializeToString(binary_buf),
+ errors::DataLoss("Unable to reserialize text proto as binary"));
+ }
+
+ // Count the number of occurrences of each requested field in a message batch.
+ void CountFields(OpKernelContext* ctx, int message_index, const string& buf,
+ Tensor* sizes_tensor, std::vector<int32>* max_sizes) {
+ int field_count = fields_.size();
+
+ CodedInputStream input(reinterpret_cast<const uint8*>(buf.c_str()),
+ buf.size());
+
+ std::vector<int32> field_sizes(field_count, 0);
+ std::vector<CountCollector> counters;
+ counters.reserve(field_count);
+ for (int i = 0; i < field_count; i++) {
+ counters.emplace_back(&field_sizes[i]);
+ }
+
+ Status st = Collect(&input, &counters);
+ if (st.ok() && !input.ConsumedEntireMessage()) {
+ st = errors::DataLoss("CountFields: Failed to consume entire buffer");
+ }
+ if (kFailOnDecodeError) {
+ OP_REQUIRES_OK(ctx, st); // NOLINT
+ }
+ if (!st.ok()) {
+ // This code suppresses the corrupt proto, treating it as empty
+ // to avoid crashing the process.
+ LOG(WARNING) << "Proto counting error for message type " << message_type_
+ << ": " << st;
+
+ for (int fi = 0; fi < field_count; fi++) {
+ field_sizes[fi] = 0;
+ }
+ // Finished decoding this message.
+ return;
+ }
+
+ // Update the size tensor and max repeat size for each field.
+ auto sizes = sizes_tensor->flat_inner_dims<int32>();
+ for (int fi = 0; fi < field_count; fi++) {
+ int32 size = field_sizes[fi];
+ sizes(message_index, fields_[fi]->output_index) = size;
+ if ((*max_sizes)[fi] < size) {
+ (*max_sizes)[fi] = size;
+ }
+ }
+ }
+
+ // Parse fields from a serialized message into preallocated tensors.
+ void AccumulateFields(OpKernelContext* ctx,
+ const std::vector<const string*>& bufs,
+ std::vector<Tensor*> outputs) {
+ struct TensorInfo {
+ explicit TensorInfo(Tensor* tensor) {
+ // Note that we can decode only max_repeat_count values before overflow.
+ // No other bounds checking is done for repeated fields. For
+ // optional fields there is a check to make sure that only the last
+ // value on the wire appears in the output tensor.
+ dtype = tensor->dtype();
+ last_dim_size = tensor->dim_size(tensor->dims() - 1);
+
+ if (dtype != DT_STRING) {
+ const int element_size = DataTypeSize(dtype);
+ CHECK_GT(element_size, 0);
+ stride = last_dim_size * element_size;
+
+ const int64 flatshape[1] = {tensor->NumElements() * element_size};
+ data = tensor->bit_casted_shaped<uint8, 1>(flatshape).data();
+ } else {
+ // DataTypeSize() returns 0 for string types.
+ stride = last_dim_size * sizeof(string);
+ data = reinterpret_cast<uint8*>(tensor->flat<string>().data());
+ }
+ }
+
+ DataType dtype;
+ int last_dim_size;
+ int stride;
+ uint8* data;
+ };
+
+ int field_count = fields_.size();
+
+ std::vector<TensorInfo> tensors;
+ tensors.reserve(field_count);
+ for (int fi = 0; fi < field_count; fi++) {
+ tensors.emplace_back(outputs[fi]);
+ }
+
+ for (int message_index = 0; message_index < bufs.size(); ++message_index) {
+ const string& buf = *bufs[message_index];
+
+ std::vector<DenseCollector> collectors;
+ collectors.reserve(field_count);
+ for (const TensorInfo& info : tensors) {
+ collectors.emplace_back(info.data + message_index * info.stride,
+ info.dtype, info.last_dim_size);
+ }
+
+ // Fill in output tensors from the wire.
+ CodedInputStream input(reinterpret_cast<const uint8*>(buf.c_str()),
+ buf.size());
+ Status st = Collect(&input, &collectors);
+ if (st.ok() && !input.ConsumedEntireMessage()) {
+ st = errors::DataLoss(
+ "AccumulateFields: Failed to consume entire buffer");
+ }
+ if (kFailOnDecodeError) {
+ OP_REQUIRES_OK(ctx, st); // NOLINT
+ }
+ if (!st.ok()) {
+ // This code suppresses the corrupt proto, treating it as empty
+ // to avoid crashing training.
+ LOG(WARNING) << "Proto counting error for message type "
+ << message_type_ << ": " << st;
+ }
+
+ // Fill the remainder of the dense outputs with default values.
+ for (auto& collector : collectors) {
+ OP_REQUIRES_OK(ctx, collector.FillWithDefaults());
+ }
+ }
+ }
+
+ // Look up the FieldDescriptor for a particular field number.
+ bool LookupField(int field_number, int* field_index) {
+ // Look up the FieldDescriptor using linear search.
+ // TODO(nix): this could be sped up with binary search, but we are
+ // already way off the fastpath at this point. If you see a hotspot
+ // here, somebody is sending you very inefficient protos.
+ for (int fi = fields_.size() - 1; fi >= 0; fi--) {
+ if (field_number == fields_[fi]->number) {
+ *field_index = fi;
+ return true;
+ }
+ }
+ return false;
+ }
+
+ // Traverses a serialized protobuf, dispatching values to the collectors.
+ template <class CollectorClass>
+ Status Collect(CodedInputStream* input,
+ std::vector<CollectorClass>* collectors) {
+ int last_good_field_index = -1;
+ bool fields_disordered = false;
+ int prev_field_number = -1;
+ int field_number = -1;
+ int last_good_field_number = -1;
+ int next_good_field_number = fields_[0]->number;
+
+ // The 'tag' variable should always be treated as tainted.
+ for (uint32 tag = input->ReadTag();
+ tag != 0 && WireFormatLite::GetTagWireType(tag) !=
+ WireFormatLite::WIRETYPE_END_GROUP;
+ tag = input->ReadTag(), prev_field_number = field_number) {
+ field_number = WireFormatLite::GetTagFieldNumber(tag);
+ const FieldInfo* field = nullptr;
+
+ // This takes advantage of the sorted field numbers in most serialized
+ // protos: it tries the next expected field first rather than doing
+ // a lookup by field number.
+ // TODO(nix): haberman@ suggests a hybrid approach with a lookup table
+ // for small field numbers and a hash table for larger ones. This would
+ // be a simpler approach that should offer comparable speed in most
+ // cases.
+ if (field_number == last_good_field_number) {
+ field = fields_[last_good_field_index].get();
+ } else {
+ if (field_number < prev_field_number) {
+ fields_disordered = true;
+ }
+
+ // If fields are out of order, fall back to slow lookup.
+ if (fields_disordered) {
+ int field_index;
+ if (LookupField(field_number, &field_index)) {
+ field = fields_[field_index].get();
+ last_good_field_index = field_index;
+ }
+ } else {
+ // If we see a field that is past the next field we want,
+ // it was empty. Look for the one after that.
+ // Repeat until we run out of fields that we care about.
+ while (field_number >= next_good_field_number) {
+ if (field_number == next_good_field_number) {
+ last_good_field_number = field_number;
+ field = fields_[last_good_field_index + 1].get();
+ }
+
+ // Start looking for the field after the current one.
+ ++last_good_field_index;
+ if (last_good_field_index < fields_.size() - 1) {
+ next_good_field_number =
+ fields_[last_good_field_index + 1]->number;
+ } else {
+ // Saw something past the last field we care about.
+ // Continue parsing the message just in case there
+ // are disordered fields later, but any remaining
+ // ordered fields will have no effect.
+ next_good_field_number = INT_MAX;
+ }
+ }
+ }
+ }
+
+ if (!field) {
+ // Unknown and unrequested fields are skipped.
+ if (!WireFormatLite::SkipField(input, tag)) {
+ return errors::DataLoss("Failed skipping unrequested field");
+ }
+ continue;
+ }
+
+ Status st = CollectField(*field, WireFormatLite::GetTagWireType(tag),
+ input, &(*collectors)[last_good_field_index]);
+ if (!st.ok()) {
+ return st;
+ }
+ }
+ return Status::OK();
+ }
+
+ // Collects values for a single field.
+ template <class CollectorClass>
+ Status CollectField(const FieldInfo& field,
+ WireFormatLite::WireType wire_type,
+ CodedInputStream* input, CollectorClass* collector) {
+ // The wire format library defines the same constants used in
+ // descriptor.proto. This static_cast is safe because they
+ // are guaranteed to stay in sync.
+ // We need the field type from the FieldDescriptor here
+ // because the wire format doesn't tell us anything about
+ // what happens inside a packed repeated field: there is
+ // enough information in the wire format to skip the
+ // whole field but not enough to know how to parse what's
+ // inside. For that we go to the schema.
+ WireFormatLite::WireType schema_wire_type =
+ WireFormatLite::WireTypeForFieldType(field.type);
+
+ // Handle packed repeated fields. SkipField would skip the
+ // whole length-delimited blob without letting us count the
+ // values, so we have to scan them ourselves.
+ if (wire_type == WireFormatLite::WIRETYPE_LENGTH_DELIMITED &&
+ schema_wire_type != WireFormatLite::WIRETYPE_LENGTH_DELIMITED) {
+ // Handle packed repeated primitives.
+ int length;
+ if (!input->ReadVarintSizeAsInt(&length)) {
+ return errors::DataLoss("CollectField: Failed reading packed size");
+ }
+ Status st = collector->ReadPackedValues(input, field, length);
+ if (!st.ok()) {
+ return st;
+ }
+ return Status::OK();
+ }
+
+ // Read ordinary values, including strings, bytes, and messages.
+ if (wire_type != schema_wire_type) {
+ if (!WireFormatLite::SkipField(
+ input, WireFormatLite::MakeTag(field.number, wire_type))) {
+ return errors::DataLoss(
+ "CollectField: Failed skipping malformed field");
+ }
+ return Status::OK();
+ }
+ return collector->ReadValue(input, field);
+ }
+
+ string message_type_;
+ // Note that fields are sorted by increasing field number,
+ // which is not in general the order given by the user-specified
+ // field_names and output_types Op attributes.
+ std::vector<std::unique_ptr<const FieldInfo>> fields_;
+
+ // Owned_desc_pool_ is null when using descriptor_source=local.
+ std::unique_ptr<DescriptorPool> owned_desc_pool_;
+ DynamicMessageFactory message_factory_;
+ const Message* message_prototype_;
+
+ // True if decoding binary format, false if decoding text format.
+ bool is_binary_;
+
+ // True if the protos should be sanitized before parsing.
+ // Enables the initial protobuf sanitizer, which is much
+ // more expensive than the decoder. The flag defaults to true
+ // but can be set to false for trusted sources.
+ // TODO(nix): flip the default to false when the fast decoder
+ // has passed security review.
+ bool sanitize_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(DecodeProtoOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("DecodeProtoV2").Device(DEVICE_CPU),
+ DecodeProtoOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/dense_update_functor.cc b/tensorflow/core/kernels/dense_update_functor.cc
index a878fe9a97..3ed3794e01 100644
--- a/tensorflow/core/kernels/dense_update_functor.cc
+++ b/tensorflow/core/kernels/dense_update_functor.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
@@ -70,4 +71,59 @@ struct DenseUpdate<CPUDevice, string, ASSIGN> {
} // namespace functor
+#define CPU_DENSE_COPY(T) \
+ case DataTypeToEnum<T>::value: { \
+ functor::DenseUpdate<CPUDevice, T, ASSIGN> copy_functor_; \
+ copy_functor_(context->eigen_device<CPUDevice>(), tensor->flat<T>(), \
+ from.flat<T>()); \
+ break; \
+ }
+
+#define INSTANTIATE_GET_VARIANT_COPY_FN(DEVICE, TYPE_CALLER, TYPE_DENSE_COPY) \
+ template <> \
+ Status VariantCopyFn<DEVICE>(OpKernelContext * context, const Tensor& from, \
+ Tensor* to) { \
+ PersistentTensor tmp; \
+ Tensor* tensor; \
+ AllocatorAttributes attr; \
+ attr.set_gpu_compatible(true); \
+ attr.set_nic_compatible(true); \
+ TF_RETURN_IF_ERROR(context->allocate_persistent( \
+ from.dtype(), from.shape(), &tmp, &tensor, attr)); \
+ switch (from.dtype()) { \
+ TYPE_CALLER(TYPE_DENSE_COPY); \
+ default: \
+ return errors::InvalidArgument( \
+ "VariantCopyFn: Could not perform a deep copy of variant " \
+ "element of type: ", \
+ DataTypeString(from.dtype()), \
+ " using device: ", context->device()->name()); \
+ } \
+ *to = *tensor; \
+ return Status::OK(); \
+ }
+
+INSTANTIATE_GET_VARIANT_COPY_FN(CPUDevice, TF_CALL_ALL_TYPES, CPU_DENSE_COPY);
+
+#if GOOGLE_CUDA
+#define GPU_DENSE_COPY(T) \
+ case DataTypeToEnum<T>::value: { \
+ functor::DenseUpdate<GPUDevice, T, ASSIGN> copy_functor_; \
+ copy_functor_(context->eigen_device<GPUDevice>(), tensor->flat<T>(), \
+ from.flat<T>()); \
+ break; \
+ }
+#define TF_CALL_GPU_AND_ADDITIONAL_TYPES(T) \
+ TF_CALL_GPU_ALL_TYPES(T); \
+ TF_CALL_int32(T); \
+ TF_CALL_int64(T);
+INSTANTIATE_GET_VARIANT_COPY_FN(GPUDevice, TF_CALL_GPU_AND_ADDITIONAL_TYPES,
+ GPU_DENSE_COPY);
+#undef TF_CALL_GPU_AND_ADDITIONAL_TYPES
+#undef GPU_DENSE_COPY
+#endif // GOOGLE_CUDA
+
+#undef CPU_DENSE_COPY
+#undef INSTANTIATE_GET_VARIANT_COPY_FN
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/dense_update_functor.h b/tensorflow/core/kernels/dense_update_functor.h
index 4aefe26c54..240c13261e 100644
--- a/tensorflow/core/kernels/dense_update_functor.h
+++ b/tensorflow/core/kernels/dense_update_functor.h
@@ -19,11 +19,14 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
@@ -89,6 +92,17 @@ struct DenseUpdate<SYCLDevice, T, ASSIGN> {
#endif // TENSORFLOW_USE_SYCL
} // end namespace functor
+
+template <typename Device>
+Status VariantCopyFn(OpKernelContext* context, const Tensor& from, Tensor* to);
+
+template <>
+Status VariantCopyFn<CPUDevice>(OpKernelContext* context, const Tensor& from,
+ Tensor* to);
+template <>
+Status VariantCopyFn<GPUDevice>(OpKernelContext* context, const Tensor& from,
+ Tensor* to);
+
} // end namespace tensorflow
#endif // TENSORFLOW_KERNELS_DENSE_UPDATE_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/encode_proto_op.cc b/tensorflow/core/kernels/encode_proto_op.cc
new file mode 100644
index 0000000000..3b02ae52a2
--- /dev/null
+++ b/tensorflow/core/kernels/encode_proto_op.cc
@@ -0,0 +1,591 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// EncodeProto is a TensorFlow Op which serializes tensors into
+// arbitrary protobufs.
+//
+// See the docstring in ../ops/encode_proto_op.cc for usage of the op.
+//
+// This implementation writes the serialized format using a handful of
+// calls from the WireFormatLite API.
+
+#include <memory>
+#include <vector>
+
+#include "third_party/eigen3/Eigen/Core"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/util/proto/descriptors.h"
+
+namespace tensorflow {
+namespace {
+
+using ::tensorflow::protobuf::Descriptor;
+using ::tensorflow::protobuf::DescriptorPool;
+using ::tensorflow::protobuf::FieldDescriptor;
+using ::tensorflow::protobuf::internal::WireFormatLite;
+using ::tensorflow::protobuf::io::CodedOutputStream;
+using ::tensorflow::protobuf::io::StringOutputStream;
+
+// Computes the total serialized size for a packed repeated field.
+// For fixed-size types this can just multiply, but for variable-sized
+// types it has to iterate through the values in the tensor.
+template <WireFormatLite::FieldType FieldType, typename TensorT>
+size_t TotalPackedSize(const Tensor& input, int message_index, int size);
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_DOUBLE, double>(const Tensor& input,
+ int message_index,
+ int size) {
+ return size * WireFormatLite::kDoubleSize;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_FLOAT, double>(const Tensor& input,
+ int message_index,
+ int size) {
+ return size * WireFormatLite::kFloatSize;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_FLOAT, float>(const Tensor& input,
+ int message_index,
+ int size) {
+ return size * WireFormatLite::kFloatSize;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_INT64, int64>(const Tensor& input,
+ int message_index,
+ int size) {
+ size_t data_size = 0;
+ auto input_t = input.flat_inner_dims<int64>();
+ for (int64 i = 0; i < size; i++) {
+ data_size += WireFormatLite::Int64Size(
+ input_t(static_cast<int64>(message_index), i));
+ }
+ return data_size;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_UINT64, int64>(const Tensor& input,
+ int message_index,
+ int size) {
+ size_t data_size = 0;
+ auto input_t = input.flat_inner_dims<int64>();
+ for (int64 i = 0; i < size; i++) {
+ data_size += WireFormatLite::UInt64Size(
+ input_t(static_cast<int64>(message_index), i));
+ }
+ return data_size;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_INT32, int32>(const Tensor& input,
+ int message_index,
+ int size) {
+ size_t data_size = 0;
+ auto input_t = input.flat_inner_dims<int32>();
+ for (int64 i = 0; i < size; i++) {
+ data_size += WireFormatLite::Int32Size(
+ input_t(static_cast<int64>(message_index), i));
+ }
+ return data_size;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_FIXED64, int64>(const Tensor& input,
+ int message_index,
+ int size) {
+ return size * WireFormatLite::kFixed64Size;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, int64>(const Tensor& input,
+ int message_index,
+ int size) {
+ return size * WireFormatLite::kFixed32Size;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, int32>(const Tensor& input,
+ int message_index,
+ int size) {
+ return size * WireFormatLite::kFixed32Size;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_BOOL, bool>(const Tensor& input,
+ int message_index,
+ int size) {
+ return size * WireFormatLite::kBoolSize;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, int64>(const Tensor& input,
+ int message_index,
+ int size) {
+ size_t data_size = 0;
+ auto input_t = input.flat_inner_dims<int64>();
+ for (int64 i = 0; i < size; i++) {
+ data_size += WireFormatLite::UInt32Size(
+ input_t(static_cast<int64>(message_index), i));
+ }
+ return data_size;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, int32>(const Tensor& input,
+ int message_index,
+ int size) {
+ size_t data_size = 0;
+ auto input_t = input.flat_inner_dims<int32>();
+ for (int64 i = 0; i < size; i++) {
+ data_size += WireFormatLite::UInt32Size(
+ input_t(static_cast<int64>(message_index), i));
+ }
+ return data_size;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_ENUM, int32>(const Tensor& input,
+ int message_index,
+ int size) {
+ size_t data_size = 0;
+ auto input_t = input.flat_inner_dims<int32>();
+ for (int64 i = 0; i < size; i++) {
+ data_size +=
+ WireFormatLite::EnumSize(input_t(static_cast<int64>(message_index), i));
+ }
+ return data_size;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED32, int32>(
+ const Tensor& input, int message_index, int size) {
+ return size * WireFormatLite::kSFixed32Size;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED64, int64>(
+ const Tensor& input, int message_index, int size) {
+ return size * WireFormatLite::kSFixed64Size;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_SINT32, int32>(const Tensor& input,
+ int message_index,
+ int size) {
+ size_t data_size = 0;
+ auto input_t = input.flat_inner_dims<int32>();
+ for (int64 i = 0; i < size; i++) {
+ data_size += WireFormatLite::SInt32Size(
+ input_t(static_cast<int64>(message_index), i));
+ }
+ return data_size;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_SINT64, int64>(const Tensor& input,
+ int message_index,
+ int size) {
+ size_t data_size = 0;
+ auto input_t = input.flat_inner_dims<int64>();
+ for (int64 i = 0; i < size; i++) {
+ data_size += WireFormatLite::SInt64Size(
+ input_t(static_cast<int64>(message_index), i));
+ }
+ return data_size;
+}
+
+// Writes a possibly repeated primitive field.
+// TensorFlow does not have unsigned types, so we decode them to signed and
+// encode them back to unsigned.
+template <typename TensorT, typename ProtoT,
+ WireFormatLite::FieldType FieldType,
+ void Writer(ProtoT, CodedOutputStream*)>
+void WriteField(const FieldDescriptor& field_desc, const Tensor& input,
+ int message_index, int size, CodedOutputStream* output) {
+ auto wire_type = WireFormatLite::WireTypeForFieldType(
+ WireFormatLite::FieldType(field_desc.type()));
+
+ auto input_t = input.flat_inner_dims<TensorT>();
+ if (field_desc.options().packed()) {
+ // Write the tag for the packed field.
+ WireFormatLite::WriteTag(field_desc.number(),
+ WireFormatLite::WIRETYPE_LENGTH_DELIMITED, output);
+
+ // Write the total packed length.
+ size_t data_size =
+ TotalPackedSize<FieldType, TensorT>(input, message_index, size);
+ output->WriteVarint32(data_size);
+
+ // Write individual values.
+ for (int64 i = 0; i < size; i++) {
+ // Note implicit cast from signed to unsigned.
+ const ProtoT& value = input_t(static_cast<int64>(message_index), i);
+ Writer(value, output);
+ }
+ } else {
+ for (int64 i = 0; i < size; i++) {
+ WireFormatLite::WriteTag(field_desc.number(), wire_type, output);
+
+ // Note implicit cast from signed to unsigned.
+ const ProtoT& value = input_t(static_cast<int64>(message_index), i);
+ Writer(value, output);
+ }
+ }
+}
+
+// Writes a possibly repeated string, bytes, or message field.
+template <typename T, void Writer(int, const T&, CodedOutputStream*)>
+void WriteVarLenField(const FieldDescriptor& field_desc, const Tensor& input,
+ int message_index, int size, CodedOutputStream* output) {
+ auto input_t = input.flat_inner_dims<T>();
+ for (int64 i = 0; i < size; i++) {
+ const T& value = input_t(static_cast<int64>(message_index), i);
+ // TODO(nix): there doesn't seem to be an inlined version of
+ // WireFormatLite::WriteString or its relatives, which might allow a
+ // small speedup.
+ Writer(field_desc.number(), value, output);
+ }
+}
+
+// Writes a group field.
+// Groups are treated like submessages, but tag-delimited
+// instead of length-delimited. WireFormatLite handles this
+// differently so we code it ourselves.
+void WriteGroup(const FieldDescriptor& field_desc, const Tensor& input,
+ int message_index, int size, CodedOutputStream* output) {
+ auto input_t = input.flat_inner_dims<string>();
+ for (int64 i = 0; i < size; i++) {
+ const string& value = input_t(static_cast<int64>(message_index), i);
+ WireFormatLite::WriteTag(field_desc.number(),
+ WireFormatLite::WIRETYPE_START_GROUP, output);
+ // Note the use of WriteRaw instead of WriteString to skip the length.
+ output->WriteRaw(value.data(), value.size());
+ WireFormatLite::WriteTag(field_desc.number(),
+ WireFormatLite::WIRETYPE_END_GROUP, output);
+ }
+}
+
+// Writes a (possibly repeated) field into an output stream.
+// It is the caller's responsibility to ensure that the type of
+// the input tensor is compatible with the type of the proto
+// field descriptor, and that (message_index, size-1) is within
+// bounds.
+void WriteField(const FieldDescriptor& field_desc, const Tensor& input,
+ int message_index, int size, CodedOutputStream* output) {
+ DataType tf_type = input.dtype();
+
+ switch (field_desc.type()) {
+ case WireFormatLite::TYPE_DOUBLE:
+ return WriteField<double, double, WireFormatLite::TYPE_DOUBLE,
+ WireFormatLite::WriteDoubleNoTag>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_FLOAT:
+ switch (tf_type) {
+ case DataType::DT_FLOAT:
+ return WriteField<float, float, WireFormatLite::TYPE_FLOAT,
+ WireFormatLite::WriteFloatNoTag>(
+ field_desc, input, message_index, size, output);
+ case DataType::DT_DOUBLE:
+ return WriteField<double, float, WireFormatLite::TYPE_FLOAT,
+ WireFormatLite::WriteFloatNoTag>(
+ field_desc, input, message_index, size, output);
+ default:
+ return;
+ }
+ case WireFormatLite::TYPE_INT64:
+ return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_INT64,
+ WireFormatLite::WriteInt64NoTag>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_UINT64:
+ return WriteField<int64, protobuf_uint64, WireFormatLite::TYPE_UINT64,
+ WireFormatLite::WriteUInt64NoTag>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_INT32:
+ return WriteField<int32, int32, WireFormatLite::TYPE_INT32,
+ WireFormatLite::WriteInt32NoTag>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_FIXED64:
+ return WriteField<int64, protobuf_uint64, WireFormatLite::TYPE_FIXED64,
+ WireFormatLite::WriteFixed64NoTag>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_FIXED32:
+ switch (tf_type) {
+ case DataType::DT_INT64:
+ return WriteField<int64, uint32, WireFormatLite::TYPE_FIXED32,
+ WireFormatLite::WriteFixed32NoTag>(
+ field_desc, input, message_index, size, output);
+ case DataType::DT_INT32:
+ return WriteField<int32, uint32, WireFormatLite::TYPE_FIXED32,
+ WireFormatLite::WriteFixed32NoTag>(
+ field_desc, input, message_index, size, output);
+ default:
+ return;
+ }
+ case WireFormatLite::TYPE_BOOL:
+ return WriteField<bool, bool, WireFormatLite::TYPE_BOOL,
+ WireFormatLite::WriteBoolNoTag>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_STRING:
+ return WriteVarLenField<string, WireFormatLite::WriteString>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_GROUP:
+ return WriteGroup(field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_MESSAGE:
+ return WriteVarLenField<string, WireFormatLite::WriteBytes>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_BYTES:
+ return WriteVarLenField<string, WireFormatLite::WriteBytes>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_UINT32:
+ switch (tf_type) {
+ case DataType::DT_INT64:
+ return WriteField<int64, uint32, WireFormatLite::TYPE_UINT32,
+ WireFormatLite::WriteUInt32NoTag>(
+ field_desc, input, message_index, size, output);
+ case DataType::DT_INT32:
+ return WriteField<int32, uint32, WireFormatLite::TYPE_UINT32,
+ WireFormatLite::WriteUInt32NoTag>(
+ field_desc, input, message_index, size, output);
+ default:
+ return;
+ }
+ case WireFormatLite::TYPE_ENUM:
+ return WriteField<int32, int32, WireFormatLite::TYPE_ENUM,
+ WireFormatLite::WriteEnumNoTag>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_SFIXED32:
+ return WriteField<int32, int32, WireFormatLite::TYPE_SFIXED32,
+ WireFormatLite::WriteSFixed32NoTag>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_SFIXED64:
+ return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_SFIXED64,
+ WireFormatLite::WriteSFixed64NoTag>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_SINT32:
+ return WriteField<int32, int32, WireFormatLite::TYPE_SINT32,
+ WireFormatLite::WriteSInt32NoTag>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_SINT64:
+ return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_SINT64,
+ WireFormatLite::WriteSInt64NoTag>(
+ field_desc, input, message_index, size, output);
+ // default: intentionally omitted in order to enable static checking.
+ }
+}
+
+// Checks that a Protobuf field is compatible with a TensorFlow datatype.
+// This is separated from WriteField to lift it out of the inner loop.
+bool IsCompatibleType(const FieldDescriptor& field_desc, DataType tf_type) {
+ switch (field_desc.type()) {
+ case WireFormatLite::TYPE_DOUBLE:
+ return tf_type == DataType::DT_DOUBLE;
+ case WireFormatLite::TYPE_FLOAT:
+ return tf_type == DataType::DT_FLOAT || tf_type == DataType::DT_DOUBLE;
+ case WireFormatLite::TYPE_INT64:
+ case WireFormatLite::TYPE_SFIXED64:
+ case WireFormatLite::TYPE_SINT64:
+ return tf_type == DataType::DT_INT64;
+ case WireFormatLite::TYPE_UINT64:
+ return tf_type == DataType::DT_INT64;
+ case WireFormatLite::TYPE_INT32:
+ case WireFormatLite::TYPE_ENUM:
+ case WireFormatLite::TYPE_SFIXED32:
+ case WireFormatLite::TYPE_SINT32:
+ return tf_type == DataType::DT_INT32;
+ case WireFormatLite::TYPE_FIXED64:
+ return tf_type == DataType::DT_INT64;
+ case WireFormatLite::TYPE_FIXED32:
+ case WireFormatLite::TYPE_UINT32:
+ return tf_type == DataType::DT_INT64 || tf_type == DataType::DT_INT32;
+ case WireFormatLite::TYPE_BOOL:
+ return tf_type == DataType::DT_BOOL;
+ case WireFormatLite::TYPE_STRING:
+ case WireFormatLite::TYPE_GROUP:
+ case WireFormatLite::TYPE_MESSAGE:
+ case WireFormatLite::TYPE_BYTES:
+ return tf_type == DataType::DT_STRING;
+ // default: intentionally omitted in order to enable static checking.
+ }
+ return false;
+}
+
+class EncodeProtoOp : public OpKernel {
+ public:
+ explicit EncodeProtoOp(OpKernelConstruction* context) : OpKernel(context) {
+ string descriptor_source;
+ OP_REQUIRES_OK(context,
+ context->GetAttr("descriptor_source", &descriptor_source));
+ // We always get back a desc_pool, but we may not own it. If we own it,
+ // owned_desc_pool_ will be filled in.
+ DescriptorPool const* desc_pool;
+ OP_REQUIRES_OK(context, GetDescriptorPool(context->env(), descriptor_source,
+ &desc_pool, &owned_desc_pool_));
+
+ string message_type;
+ OP_REQUIRES_OK(context, context->GetAttr("message_type", &message_type));
+ const Descriptor* message_desc =
+ desc_pool->FindMessageTypeByName(message_type);
+ OP_REQUIRES(context, message_desc != nullptr,
+ errors::InvalidArgument("No descriptor found for message type ",
+ message_type));
+
+ OP_REQUIRES_OK(context, context->GetAttr("field_names", &field_names_));
+
+ // Gather the field descriptors for the given field_names.
+ field_descs_.resize(field_names_.size());
+ for (int i = 0; i < field_names_.size(); i++) {
+ const string& name = field_names_[i];
+ auto field_desc = message_desc->FindFieldByName(name);
+ OP_REQUIRES(context, field_desc != nullptr,
+ errors::InvalidArgument("Unknown field: ", name,
+ " in message type ", message_type));
+
+ field_descs_[i] = field_desc;
+ }
+
+ // Build a list of indices into field_descs sorted by increasing
+ // field_number. This will be used to output fields in sorted order,
+ // which is strongly encouraged when serializing protobufs.
+ sorted_field_index_.resize(field_names_.size());
+ // Start with the fields sorted by current index.
+ for (int i = 0; i < field_names_.size(); i++) sorted_field_index_[i] = i;
+ // Then sort the field indices by their proto field number.
+ std::sort(sorted_field_index_.begin(), sorted_field_index_.end(),
+ [this](int a, int b) -> bool {
+ return field_descs_[a]->number() < field_descs_[b]->number();
+ });
+ }
+
+ void Compute(OpKernelContext* cx) override {
+ const Tensor* sizes_tensor;
+ OP_REQUIRES_OK(cx, cx->input("sizes", &sizes_tensor));
+
+ OpInputList values;
+ OP_REQUIRES_OK(cx, cx->input_list("values", &values));
+
+ OP_REQUIRES(cx, field_descs_.size() == values.size(),
+ errors::InvalidArgument(
+ "Length of inputs list must match field_names"));
+
+ // Check the arguments for consistency.
+ TensorShape common_prefix;
+ int message_count;
+ for (int i = 0; i < field_descs_.size(); i++) {
+ const Tensor& v = values[i];
+
+ // The type of each value tensor must match the corresponding field.
+ OP_REQUIRES(cx, IsCompatibleType(*field_descs_[i], v.dtype()),
+ errors::InvalidArgument(
+ "Incompatible type for field " + field_names_[i] +
+ ". Saw dtype: ",
+ DataTypeString(v.dtype()),
+ " but field type is: ", field_descs_[i]->type_name()));
+
+ // All value tensors must have the same shape prefix (i.e. batch size).
+ TensorShape shape_prefix = v.shape();
+ shape_prefix.RemoveDim(shape_prefix.dims() - 1);
+
+ // Do some initialization on the first input value. The rest will
+ // have to match this one.
+ if (i == 0) {
+ OP_REQUIRES(cx, v.dims() >= 1,
+ errors::InvalidArgument(
+ "Expected value to be at least a vector, saw shape: ",
+ v.shape().DebugString()));
+ common_prefix = shape_prefix;
+ message_count = common_prefix.num_elements();
+ } else {
+ OP_REQUIRES(cx, shape_prefix == common_prefix,
+ errors::InvalidArgument(
+ "Values must match up to the last dimension"));
+ }
+ }
+
+ TensorShape expected_sizes_shape = common_prefix;
+ expected_sizes_shape.AddDim(field_descs_.size());
+
+ OP_REQUIRES(cx, sizes_tensor->shape() == expected_sizes_shape,
+ errors::InvalidArgument(
+ "sizes should be batch_size + [len(field_names)]. Saw: ",
+ sizes_tensor->shape().DebugString(),
+ " but expected: ", expected_sizes_shape.DebugString()));
+
+ auto sizes = sizes_tensor->flat_inner_dims<int32>();
+
+ for (int i = 0; i < field_descs_.size(); ++i) {
+ const Tensor& v = values[i];
+ int max_size = v.dim_size(v.dims() - 1);
+
+ // The last dimension of a value tensor must be greater than the
+ // corresponding
+ // size in the sizes tensor.
+ for (int message_index = 0; message_index < message_count;
+ message_index++) {
+ OP_REQUIRES(
+ cx, sizes(message_index, i) <= max_size,
+ errors::InvalidArgument(
+ "Size to write must not be larger than value tensor; but saw: ",
+ sizes(message_index, i), " > ", max_size, " at message ",
+ message_index, " field ", i));
+ }
+ }
+
+ // This pointer is owned by the context.
+ Tensor* output_tensor;
+ OP_REQUIRES_OK(cx, cx->allocate_output(0, common_prefix, &output_tensor));
+
+ auto bufs = output_tensor->flat<string>();
+ for (int message_index = 0; message_index < message_count;
+ message_index++) {
+ // TODO(nix): possibly optimize allocation here by calling
+ // bufs(message_index).reserve(DEFAULT_BUF_SIZE);
+ StringOutputStream output_string(&bufs(message_index));
+ CodedOutputStream out(&output_string);
+ // Write fields in ascending field_number order.
+ for (int i : sorted_field_index_) {
+ auto& field_desc = *field_descs_[i];
+ const Tensor& v = values[i];
+ int size = sizes(message_index, i);
+ if (!size) continue;
+ WriteField(field_desc, v, message_index, size, &out);
+ }
+ }
+ }
+
+ private:
+ std::vector<string> field_names_;
+ std::vector<const FieldDescriptor*> field_descs_;
+
+ // Owned_desc_pool_ is null when using descriptor_source=local.
+ std::unique_ptr<DescriptorPool> owned_desc_pool_;
+
+ // Contains indices into field_names_, sorted by field number since
+ // that's the order of writing.
+ std::vector<int> sorted_field_index_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(EncodeProtoOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("EncodeProto").Device(DEVICE_CPU), EncodeProtoOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/fill_functor.cu.cc b/tensorflow/core/kernels/fill_functor.cu.cc
index 3487606778..050c95cf40 100644
--- a/tensorflow/core/kernels/fill_functor.cu.cc
+++ b/tensorflow/core/kernels/fill_functor.cu.cc
@@ -76,7 +76,7 @@ struct FillFunctor<GPUDevice, T> {
};
#define DEFINE_FILL_GPU(T) template struct FillFunctor<GPUDevice, T>;
-TF_CALL_REAL_NUMBER_TYPES(DEFINE_FILL_GPU);
+TF_CALL_NUMBER_TYPES(DEFINE_FILL_GPU);
TF_CALL_bool(DEFINE_FILL_GPU);
#undef DEFINE_FILL_GPU
diff --git a/tensorflow/core/kernels/gather_functor.h b/tensorflow/core/kernels/gather_functor.h
index 16ccb03b85..2c6e8bf3bc 100644
--- a/tensorflow/core/kernels/gather_functor.h
+++ b/tensorflow/core/kernels/gather_functor.h
@@ -28,6 +28,7 @@ limitations under the License.
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
namespace functor {
@@ -50,7 +51,7 @@ SliceIndex HandleCopies(OpKernelContext* ctx,
}
// Compute slice_bytes here so that static knowledge is available
const size_t slice_bytes = slice_elems * sizeof(T);
- auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
+ auto* worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
mutex mu;
// Store the value of invalidate index for printing error information, it's a
// shared variable.
@@ -162,6 +163,16 @@ struct GatherFunctor<CPUDevice, T, Index> {
}
};
+template <typename Index>
+struct GatherFunctor<GPUDevice, Variant, Index> {
+ int64 operator()(OpKernelContext* ctx,
+ typename TTypes<Variant, 3>::ConstTensor params,
+ typename TTypes<Index>::ConstFlat indices,
+ typename TTypes<Variant, 3>::Tensor out) {
+ return GatherFunctorCPU<Variant, Index>()(ctx, params, indices, out);
+ }
+};
+
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/inplace_ops.cc b/tensorflow/core/kernels/inplace_ops.cc
index a71d047ed1..ef6ce0546b 100644
--- a/tensorflow/core/kernels/inplace_ops.cc
+++ b/tensorflow/core/kernels/inplace_ops.cc
@@ -213,13 +213,13 @@ REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate")
typedef Eigen::GpuDevice GPUDevice;
-#define REGISTER_EMPTY(type) \
+#define REGISTER_PARALLEL_CONCAT_START(type) \
REGISTER_KERNEL_BUILDER(Name("_ParallelConcatStart") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("dtype"), \
ParallelConcatStart<GPUDevice, type>);
-TF_CALL_GPU_NUMBER_TYPES(REGISTER_EMPTY)
-#undef REGISTER_EMPTY
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_PARALLEL_CONCAT_START)
+#undef REGISTER_PARALLEL_CONCAT_START
#define REGISTER_PARALLEL_CONCAT(type) \
REGISTER_KERNEL_BUILDER( \
@@ -248,5 +248,295 @@ REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate")
ParallelConcatUpdate<CPUDevice>);
#endif
+class InplaceOpBase : public OpKernel {
+ public:
+ explicit InplaceOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ auto x = ctx->input(0);
+ auto i = ctx->input(1);
+ auto v = ctx->input(2);
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(i.shape()),
+ errors::InvalidArgument("i must be a vector. ",
+ i.shape().DebugString()));
+ OP_REQUIRES(ctx, x.dims() == v.dims(),
+ errors::InvalidArgument(
+ "x and v shape doesn't match (ranks differ): ",
+ x.shape().DebugString(), " vs. ", v.shape().DebugString()));
+ for (int i = 1; i < x.dims(); ++i) {
+ OP_REQUIRES(
+ ctx, x.dim_size(i) == v.dim_size(i),
+ errors::InvalidArgument("x and v shape doesn't match at index ", i,
+ " : ", x.shape().DebugString(), " vs. ",
+ v.shape().DebugString()));
+ }
+ OP_REQUIRES(ctx, i.dim_size(0) == v.dim_size(0),
+ errors::InvalidArgument(
+ "i and x shape doesn't match at index 0: ",
+ i.shape().DebugString(), " vs. ", v.shape().DebugString()));
+
+ Tensor y = x; // This creates an alias intentionally.
+ OP_REQUIRES_OK(ctx, DoCompute(ctx, i, v, &y));
+ ctx->set_output(0, y);
+ }
+
+ protected:
+ virtual Status DoCompute(OpKernelContext* ctx, const Tensor& i,
+ const Tensor& v, Tensor* y) = 0;
+};
+
+} // end namespace
+
+namespace functor {
+
+template <typename T>
+void DoInplaceOp(const CPUDevice& d, InplaceOpType op, const Tensor& i,
+ const Tensor& v, Tensor* y) {
+ auto Ti = i.flat<int32>();
+ auto Tv = v.flat_outer_dims<T>();
+ auto Ty = y->flat_outer_dims<T>();
+ auto nrows = Ty.dimension(0);
+ for (int64 j = 0; j < Ti.size(); ++j) {
+ auto r = (Ti(j) % nrows + nrows) % nrows; // Guard index range.
+ switch (op) {
+ case I_UPDATE:
+ Ty.template chip<0>(r).device(d) = Tv.template chip<0>(j);
+ break;
+ case I_ADD:
+ Ty.template chip<0>(r).device(d) += Tv.template chip<0>(j);
+ break;
+ case I_SUB:
+ Ty.template chip<0>(r).device(d) -= Tv.template chip<0>(j);
+ break;
+ }
+ }
+}
+
+// String type only supports inplace update.
+void DoInplaceStringUpdateOp(const CPUDevice& d, const Tensor& i,
+ const Tensor& v, Tensor* y) {
+ auto Ti = i.flat<int32>();
+ auto Tv = v.flat_outer_dims<string>();
+ auto Ty = y->flat_outer_dims<string>();
+ auto nrows = Ty.dimension(0);
+ for (int64 j = 0; j < Ti.size(); ++j) {
+ auto r = (Ti(j) % nrows + nrows) % nrows; // Guard index range.
+ Ty.template chip<0>(r).device(d) = Tv.template chip<0>(j);
+ }
+}
+
+template <>
+Status DoInplace(const CPUDevice& device, InplaceOpType op, const Tensor& i,
+ const Tensor& v, Tensor* y) {
+ CHECK_EQ(v.dtype(), y->dtype());
+ if (op == I_UPDATE) {
+ if (v.dtype() == DT_STRING) {
+ DoInplaceStringUpdateOp(device, i, v, y);
+ return Status::OK();
+ } else if (v.dtype() == DT_BOOL) {
+ DoInplaceOp<bool>(device, op, i, v, y);
+ return Status::OK();
+ }
+ }
+ switch (v.dtype()) {
+#define CASE(type) \
+ case DataTypeToEnum<type>::value: \
+ DoInplaceOp<type>(device, op, i, v, y); \
+ break;
+ TF_CALL_NUMBER_TYPES(CASE);
+#undef CASE
+ default:
+ return errors::InvalidArgument("Unsupported data type: ", v.dtype());
+ }
+ return Status::OK();
+}
+
+} // end namespace functor
+
+namespace {
+template <typename Device, functor::InplaceOpType op>
+class InplaceOp : public InplaceOpBase {
+ public:
+ explicit InplaceOp(OpKernelConstruction* ctx) : InplaceOpBase(ctx) {}
+
+ protected:
+ Status DoCompute(OpKernelContext* ctx, const Tensor& i, const Tensor& v,
+ Tensor* y) override {
+ const auto& d = ctx->eigen_device<Device>();
+ return ::tensorflow::functor::DoInplace(d, op, i, v, y);
+ }
+};
+
+class CopyOpBase : public OpKernel {
+ public:
+ explicit CopyOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ auto x = ctx->input(0);
+ Tensor* y;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, x.shape(), &y));
+ OP_REQUIRES_OK(ctx, DoCompute(ctx, x, y));
+ }
+
+ protected:
+ virtual Status DoCompute(OpKernelContext* ctx, const Tensor& x,
+ Tensor* y) = 0;
+};
+
+template <typename Device>
+class CopyOp : public CopyOpBase {
+ public:
+ explicit CopyOp(OpKernelConstruction* ctx) : CopyOpBase(ctx) {}
+
+ protected:
+ Status DoCompute(OpKernelContext* ctx, const Tensor& x, Tensor* y) override {
+ const auto& d = ctx->eigen_device<Device>();
+ return ::tensorflow::functor::DoCopy(d, x, y);
+ }
+};
+
+} // end namespace
+
+namespace functor {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+template <>
+Status DoCopy(const CPUDevice& device, const Tensor& x, Tensor* y) {
+ CHECK_EQ(x.dtype(), y->dtype());
+ switch (x.dtype()) {
+#define CASE(type) \
+ case DataTypeToEnum<type>::value: \
+ y->flat<type>().device(device) = x.flat<type>(); \
+ break;
+
+ TF_CALL_NUMBER_TYPES(CASE);
+ TF_CALL_bool(CASE);
+#undef CASE
+ default:
+ return errors::InvalidArgument("Unsupported data type: ", x.dtype());
+ }
+ return Status::OK();
+}
+
+} // end namespace functor
+
+namespace {
+template <typename Device, typename T>
+class EmptyOp : public OpKernel {
+ public:
+ explicit EmptyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("init", &init_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& shape = ctx->input(0);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsVector(shape.shape()),
+ errors::InvalidArgument("shape must be a vector of int32, got shape ",
+ shape.shape().DebugString()));
+ auto dims = shape.flat<int32>();
+ TensorShape out_shape;
+ OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(
+ reinterpret_cast<const int32*>(dims.data()),
+ dims.size(), &out_shape));
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
+
+ if (init_) {
+ functor::SetZeroFunctor<Device, T>()(ctx->eigen_device<Device>(),
+ out->flat<T>());
+ }
+ }
+
+ private:
+ bool init_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("InplaceUpdate").Device(DEVICE_CPU),
+ InplaceOp<CPUDevice, functor::I_UPDATE>);
+REGISTER_KERNEL_BUILDER(Name("InplaceAdd").Device(DEVICE_CPU),
+ InplaceOp<CPUDevice, functor::I_ADD>);
+REGISTER_KERNEL_BUILDER(Name("InplaceSub").Device(DEVICE_CPU),
+ InplaceOp<CPUDevice, functor::I_SUB>);
+REGISTER_KERNEL_BUILDER(Name("DeepCopy").Device(DEVICE_CPU), CopyOp<CPUDevice>);
+
+#define REGISTER_EMPTY(type, dev) \
+ REGISTER_KERNEL_BUILDER(Name("Empty") \
+ .Device(DEVICE_##dev) \
+ .HostMemory("shape") \
+ .TypeConstraint<type>("dtype"), \
+ EmptyOp<dev##Device, type>)
+
+REGISTER_EMPTY(float, CPU)
+REGISTER_EMPTY(double, CPU)
+REGISTER_EMPTY(Eigen::half, CPU)
+REGISTER_EMPTY(string, CPU)
+REGISTER_EMPTY(int32, CPU)
+REGISTER_EMPTY(int64, CPU)
+REGISTER_EMPTY(bool, CPU)
+
+#if GOOGLE_CUDA
+
+typedef Eigen::GpuDevice GPUDevice;
+
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("InplaceUpdate").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
+ InplaceOp<GPUDevice, functor::I_UPDATE>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("InplaceAdd").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
+ InplaceOp<GPUDevice, functor::I_ADD>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("InplaceSub").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
+ InplaceOp<GPUDevice, functor::I_SUB>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("DeepCopy").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
+ CopyOp<GPUDevice>);
+
+REGISTER(float);
+REGISTER(double);
+REGISTER(Eigen::half);
+REGISTER(int64);
+
+REGISTER_KERNEL_BUILDER(Name("InplaceUpdate")
+ .Device(DEVICE_GPU)
+ .HostMemory("x")
+ .HostMemory("i")
+ .HostMemory("v")
+ .HostMemory("y")
+ .TypeConstraint<int32>("T"),
+ InplaceOp<CPUDevice, functor::I_UPDATE>);
+REGISTER_KERNEL_BUILDER(Name("InplaceAdd")
+ .Device(DEVICE_GPU)
+ .HostMemory("x")
+ .HostMemory("i")
+ .HostMemory("v")
+ .HostMemory("y")
+ .TypeConstraint<int32>("T"),
+ InplaceOp<CPUDevice, functor::I_ADD>);
+REGISTER_KERNEL_BUILDER(Name("InplaceSub")
+ .Device(DEVICE_GPU)
+ .HostMemory("x")
+ .HostMemory("i")
+ .HostMemory("v")
+ .HostMemory("y")
+ .TypeConstraint<int32>("T"),
+ InplaceOp<CPUDevice, functor::I_SUB>);
+
+REGISTER_KERNEL_BUILDER(Name("DeepCopy")
+ .Device(DEVICE_GPU)
+ .HostMemory("x")
+ .HostMemory("y")
+ .TypeConstraint<int32>("T"),
+ CopyOp<CPUDevice>);
+REGISTER_EMPTY(float, GPU);
+REGISTER_EMPTY(double, GPU);
+REGISTER_EMPTY(Eigen::half, GPU);
+REGISTER_EMPTY(int64, GPU);
+
+#endif // GOOGLE_CUDA
+
} // end namespace
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/inplace_ops_functor.h b/tensorflow/core/kernels/inplace_ops_functor.h
index 53529f5165..b806787e91 100644
--- a/tensorflow/core/kernels/inplace_ops_functor.h
+++ b/tensorflow/core/kernels/inplace_ops_functor.h
@@ -26,6 +26,23 @@ template <typename Device>
Status DoParallelConcat(const Device& device, const Tensor& value, int32 loc,
Tensor* output);
+// Inplace update/add/sub values in 'y'. It computes
+// y[i, :] = v if op is I_UPDATE
+// y[i, :] += v if op is I_ADD
+// y[i, :] -= v if op is I_SUB
+// Returns an error if the operation fails.
+enum InplaceOpType {
+ I_UPDATE, // x = y
+ I_ADD, // x += y
+ I_SUB, // x -= y
+};
+template <typename Device>
+Status DoInplace(const Device& device, InplaceOpType op, const Tensor& i,
+ const Tensor& v, Tensor* y);
+// Copies x into y.
+template <typename Device>
+Status DoCopy(const Device& device, const Tensor& x, Tensor* y);
+
} // end namespace functor
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc b/tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc
index 8467360435..f1616b1ea8 100644
--- a/tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc
+++ b/tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc
@@ -77,6 +77,103 @@ Status DoParallelConcat(const Device& d, const Tensor& value, int32 loc,
return Status::OK();
}
+template <typename T, InplaceOpType op>
+__global__ void DoInplaceOpKernel(int nthreads, const int64 rows,
+ const int64 cols, const int64 n, const T* src,
+ const int32* rowids, T* dst) {
+ CUDA_1D_KERNEL_LOOP(idx, nthreads) {
+ int64 r = idx / cols;
+ int64 c = idx % cols;
+ r = (rowids[r] % rows + rows) % rows; // Guard index range.
+ T* p = dst + r * cols + c;
+ const T* q = src + idx;
+ switch (op) {
+ case I_UPDATE:
+ *p = ldg(q);
+ break;
+ case I_ADD:
+ *p += ldg(q);
+ break;
+ case I_SUB:
+ *p -= ldg(q);
+ break;
+ }
+ }
+}
+
+template <typename T>
+void DoInplaceOp(const Device& d, InplaceOpType op, const Tensor& i,
+ const Tensor& v, Tensor* y) {
+ const int64 nelem = v.NumElements();
+ CudaLaunchConfig cfg = GetCudaLaunchConfig(nelem, d);
+ auto Ty = y->flat_outer_dims<T>();
+ const int64 nrows = Ty.dimension(0);
+ const int64 ncols = Ty.dimension(1);
+ const int64 n = i.NumElements();
+ const T* src = v.flat<T>().data();
+ // TODO(sjhwang): Check that first dimension fits in int32 range.
+ const int32* rowids = i.flat<int32>().data();
+ T* dst = y->flat<T>().data();
+ switch (op) {
+ case I_UPDATE:
+ DoInplaceOpKernel<T, I_UPDATE>
+ <<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>>(
+ cfg.virtual_thread_count, nrows, ncols, n, src, rowids, dst);
+ break;
+ case I_ADD:
+ DoInplaceOpKernel<T, I_ADD>
+ <<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>>(
+ cfg.virtual_thread_count, nrows, ncols, n, src, rowids, dst);
+ break;
+ case I_SUB:
+ DoInplaceOpKernel<T, I_SUB>
+ <<<cfg.block_count, cfg.thread_per_block, 0, d.stream()>>>(
+ cfg.virtual_thread_count, nrows, ncols, n, src, rowids, dst);
+ break;
+ }
+}
+
+template <>
+Status DoInplace(const Device& d, InplaceOpType op, const Tensor& i,
+ const Tensor& v, Tensor* y) {
+ CHECK_EQ(v.dtype(), y->dtype());
+ switch (v.dtype()) {
+#define CASE(type) \
+ case DataTypeToEnum<type>::value: \
+ DoInplaceOp<type>(d, op, i, v, y); \
+ break;
+
+ CASE(float)
+ CASE(double)
+ CASE(Eigen::half)
+ CASE(int64)
+#undef CASE
+ default:
+ return errors::InvalidArgument("Unsupported data type: ", v.dtype());
+ }
+ return Status::OK();
+}
+
+template <>
+Status DoCopy(const Device& d, const Tensor& x, Tensor* y) {
+ CHECK_EQ(x.dtype(), y->dtype());
+ switch (x.dtype()) {
+#define CASE(type) \
+ case DataTypeToEnum<type>::value: \
+ y->flat<type>().device(d) = x.flat<type>(); \
+ break;
+
+ CASE(float)
+ CASE(double)
+ CASE(Eigen::half)
+ CASE(int64)
+#undef CASE
+ default:
+ return errors::InvalidArgument("Unsupported dtype: ", x.dtype());
+ }
+ return Status::OK();
+}
+
} // end namespace functor
} // namespace tensorflow
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/lookup_util.cc b/tensorflow/core/kernels/lookup_util.cc
index 27031d9216..77386a16e0 100644
--- a/tensorflow/core/kernels/lookup_util.cc
+++ b/tensorflow/core/kernels/lookup_util.cc
@@ -101,9 +101,10 @@ class TextFileLineIterator
string line;
status_ = input_buffer_->ReadLine(&line);
if (!status_.ok()) {
- if (errors::IsOutOfRange(status_) && next_id_ != total_size()) {
+ if (errors::IsOutOfRange(status_) && vocab_size_ != -1 &&
+ next_id_ != vocab_size_) {
status_ = errors::InvalidArgument("Invalid vocab_size in ", filename_,
- ": expected ", total_size(),
+ ": expected ", vocab_size_,
" but got ", next_id_);
}
valid_ = false;
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index f49a05c70a..72504200cc 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -280,64 +280,6 @@ class AssignVariableOp : public OpKernel {
};
template <typename Device>
-Status VariantCopyFn(OpKernelContext* context, const Tensor& from, Tensor* to);
-
-#define CPU_DENSE_COPY(T) \
- case DataTypeToEnum<T>::value: { \
- functor::DenseUpdate<CPUDevice, T, ASSIGN> copy_functor_; \
- copy_functor_(context->eigen_device<CPUDevice>(), tensor->flat<T>(), \
- from.flat<T>()); \
- break; \
- }
-
-#define INSTANTIATE_GET_VARIANT_COPY_FN(Device, TYPE_CALLER, TYPE_DENSE_COPY) \
- template <> \
- Status VariantCopyFn<Device>(OpKernelContext * context, const Tensor& from, \
- Tensor* to) { \
- PersistentTensor tmp; \
- Tensor* tensor; \
- AllocatorAttributes attr; \
- attr.set_gpu_compatible(true); \
- attr.set_nic_compatible(true); \
- TF_RETURN_IF_ERROR(context->allocate_persistent( \
- from.dtype(), from.shape(), &tmp, &tensor, attr)); \
- switch (from.dtype()) { \
- TYPE_CALLER(TYPE_DENSE_COPY); \
- default: \
- return errors::InvalidArgument( \
- "VariantCopyFn: Could not perform a deep copy of variant " \
- "element of type: ", \
- DataTypeString(from.dtype()), \
- " using device: ", context->device()->name()); \
- } \
- *to = *tensor; \
- return Status::OK(); \
- }
-
-INSTANTIATE_GET_VARIANT_COPY_FN(CPUDevice, TF_CALL_ALL_TYPES, CPU_DENSE_COPY);
-
-#if GOOGLE_CUDA
-#define GPU_DENSE_COPY(T) \
- case DataTypeToEnum<T>::value: { \
- functor::DenseUpdate<GPUDevice, T, ASSIGN> copy_functor_; \
- copy_functor_(context->eigen_device<GPUDevice>(), tensor->flat<T>(), \
- from.flat<T>()); \
- break; \
- }
-#define TF_CALL_GPU_AND_ADDITIONAL_TYPES(T) \
- TF_CALL_GPU_ALL_TYPES(T); \
- TF_CALL_int32(T); \
- TF_CALL_int64(T);
-INSTANTIATE_GET_VARIANT_COPY_FN(GPUDevice, TF_CALL_GPU_AND_ADDITIONAL_TYPES,
- GPU_DENSE_COPY);
-#undef TF_CALL_GPU_AND_ADDITIONAL_TYPES
-#undef GPU_DENSE_COPY
-#endif // GOOGLE_CUDA
-
-#undef CPU_DENSE_COPY
-#undef INSTANTIATE_GET_VARIANT_COPY_FN
-
-template <typename Device>
class AssignVariableOp<Device, Variant> : public OpKernel {
public:
explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
@@ -370,9 +312,16 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
// Copying is unnecessary if we are the last user of the value
// tensor, we can just adopt the input tensor's buffer instead.
// Note that Variant objects themselves always reside on host.
+ //
+ // We nevertheless want to signal to the runtime that the tensor
+ // should reside in memory of the associated device, as Variant
+ // tensors may be marked as sitting on either CPU or GPU. This
+ // helps to elide one or more copies.
std::unique_ptr<Tensor> input_alias = context->forward_input(
1, OpKernelContext::Params::kNoReservation /*output_index*/, DT_VARIANT,
- value.shape(), HOST_MEMORY, attr);
+ value.shape(),
+ std::is_same<Device, CPUDevice>::value ? HOST_MEMORY : DEVICE_MEMORY,
+ attr);
mutex_lock ml(*variable->mu());
variable->is_initialized = true;
@@ -396,12 +345,8 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
const auto elements_in = value.flat<Variant>();
auto elements_out = variable->tensor()->flat<Variant>();
- auto copy_fn = std::bind(&VariantCopyFn<Device>, context,
- std::placeholders::_1, std::placeholders::_2);
for (int64 i = 0; i < elements_in.size(); ++i) {
- OP_REQUIRES_OK(context, VariantDeviceCopy(
- VariantDeviceCopyDirection::DEVICE_TO_DEVICE,
- elements_in(i), &elements_out(i), copy_fn));
+ elements_out(i) = elements_in(i);
}
}
@@ -560,7 +505,14 @@ class ResourceGatherOp : public OpKernel {
}
Tensor* out = nullptr;
- OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
+ Tensor tmp;
+ if (params.dtype() == DT_VARIANT) {
+ tmp = Tensor(DT_VARIANT, result_shape);
+ c->set_output(0, tmp);
+ out = &tmp;
+ } else {
+ OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
+ }
if (N > 0) {
const int64 gather_dim_size = params.dim_size(0);
int64 inner_size = 1;
@@ -607,6 +559,23 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_GPU);
+// Variant objects themselves sit on CPU, even if they contain data
+// pointing to a device.
+REGISTER_KERNEL_BUILDER(Name("ResourceGather")
+ .Device(DEVICE_GPU)
+ .HostMemory("resource")
+ .HostMemory("indices")
+ .TypeConstraint<Variant>("dtype")
+ .TypeConstraint<int32>("Tindices"),
+ ResourceGatherOp<GPUDevice, Variant, int32>)
+REGISTER_KERNEL_BUILDER(Name("ResourceGather")
+ .Device(DEVICE_GPU)
+ .HostMemory("resource")
+ .HostMemory("indices")
+ .TypeConstraint<Variant>("dtype")
+ .TypeConstraint<int64>("Tindices"),
+ ResourceGatherOp<GPUDevice, Variant, int64>)
+
#endif // GOOGLE_CUDA
#undef REGISTER_GATHER_CPU
@@ -721,6 +690,8 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate",
scatter_op::UpdateOp::ASSIGN);
+REGISTER_SCATTER_KERNEL(Variant, CPU, "ResourceScatterUpdate",
+ scatter_op::UpdateOp::ASSIGN);
// Registers GPU kernels.
#if GOOGLE_CUDA
@@ -733,6 +704,23 @@ REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate",
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_GPU);
+REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
+ .Device(DEVICE_GPU)
+ .HostMemory("resource")
+ .HostMemory("indices")
+ .TypeConstraint<Variant>("dtype")
+ .TypeConstraint<int32>("Tindices"),
+ ResourceScatterUpdateOp<GPUDevice, Variant, int32,
+ scatter_op::UpdateOp::ASSIGN>)
+REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
+ .Device(DEVICE_GPU)
+ .HostMemory("resource")
+ .HostMemory("indices")
+ .TypeConstraint<Variant>("dtype")
+ .TypeConstraint<int64>("Tindices"),
+ ResourceScatterUpdateOp<GPUDevice, Variant, int64,
+ scatter_op::UpdateOp::ASSIGN>)
+
#endif // GOOGLE_CUDA
#undef REGISTER_SCATTER_ARITHMETIC
diff --git a/tensorflow/core/kernels/rpc_op.cc b/tensorflow/core/kernels/rpc_op.cc
new file mode 100644
index 0000000000..2447ef5040
--- /dev/null
+++ b/tensorflow/core/kernels/rpc_op.cc
@@ -0,0 +1,129 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// RpcOp is a TensorFlow op that sends and receives arbitrary messages.
+//
+// See docs in ../ops/rpc_op.cc.
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "third_party/eigen3/Eigen/Core"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/rpc/call_container.h"
+#include "tensorflow/core/util/rpc/rpc_factory.h"
+#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
+
+namespace tensorflow {
+
+class RpcOp : public AsyncOpKernel {
+ public:
+ explicit RpcOp(OpKernelConstruction* context) : AsyncOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("protocol", &protocol_));
+ OP_REQUIRES(context, !protocol_.empty(),
+ errors::InvalidArgument("protocol must be non-empty."));
+ bool fail_fast;
+ OP_REQUIRES_OK(context, context->GetAttr("fail_fast", &fail_fast));
+ int64 timeout_in_ms;
+ OP_REQUIRES_OK(context, context->GetAttr("timeout_in_ms", &timeout_in_ms));
+
+ RPCFactoryRegistry::RPCFactoryFn* rpc_factory_fn =
+ RPCFactoryRegistry::Global()->Get(protocol_);
+ OP_REQUIRES(context, rpc_factory_fn != nullptr,
+ errors::InvalidArgument("The protocol ", protocol_,
+ " was not recognized."));
+
+ rpc_factory_.reset((*rpc_factory_fn)(context, fail_fast, timeout_in_ms));
+ }
+
+ ~RpcOp() override {}
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ const Tensor& address_t = ctx->input(0);
+ const Tensor& method_t = ctx->input(1);
+ const Tensor& request_t = ctx->input(2);
+
+ OP_REQUIRES_ASYNC(
+ ctx, address_t.dims() == 0 || address_t.dims() == 1,
+ errors::InvalidArgument("address must be a scalar or vector."), done);
+ OP_REQUIRES_ASYNC(
+ ctx, method_t.dims() == 0 || method_t.dims() == 1,
+ errors::InvalidArgument("method must be a scalar or vector."), done);
+ OP_REQUIRES_ASYNC(
+ ctx, request_t.dims() == 0 || request_t.dims() == 1,
+ errors::InvalidArgument("request must be a scalar or vector."), done);
+
+ TensorShape output_shape({});
+ for (const Tensor& t : {address_t, method_t, request_t}) {
+ if (t.dims() == 1) {
+ OP_REQUIRES_ASYNC(
+ ctx,
+ output_shape.dims() == 0 ||
+ output_shape.dim_size(0) == t.dim_size(0),
+ errors::InvalidArgument(
+ "Input vector shapes don't match: ", output_shape.DebugString(),
+ " vs. ", t.shape().DebugString()),
+ done);
+ output_shape = t.shape();
+ }
+ }
+
+ Tensor* response_t;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->allocate_output(0, output_shape, &response_t), done);
+
+ const bool try_rpc = (ctx->num_outputs() > 1);
+
+ Tensor* status_code_t = nullptr;
+ Tensor* status_message_t = nullptr;
+ if (try_rpc) {
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->allocate_output(1, output_shape, &status_code_t), done);
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->allocate_output(2, output_shape, &status_message_t), done);
+ }
+
+ if (request_t.NumElements() == 0) {
+ // Special case, we finished early!
+ done();
+ return;
+ }
+
+ int64 num_elements = output_shape.num_elements();
+
+ rpc_factory_->Call(ctx, num_elements, address_t, method_t, request_t,
+ try_rpc, response_t, status_code_t, status_message_t,
+ std::move(done));
+ }
+
+ private:
+ string protocol_;
+ std::unique_ptr<RPCFactory> rpc_factory_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RpcOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("Rpc").Device(DEVICE_CPU), RpcOp);
+REGISTER_KERNEL_BUILDER(Name("TryRpc").Device(DEVICE_CPU), RpcOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/scatter_functor.h b/tensorflow/core/kernels/scatter_functor.h
index 52666645bf..ebaa2bd9c6 100644
--- a/tensorflow/core/kernels/scatter_functor.h
+++ b/tensorflow/core/kernels/scatter_functor.h
@@ -20,8 +20,11 @@ limitations under the License.
#include "third_party/eigen3/Eigen/Core"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@@ -203,9 +206,9 @@ struct ScatterFunctorBase {
const Index N = static_cast<Index>(indices.size());
const Index limit = static_cast<Index>(params.dimension(0));
for (Index i = 0; i < N; i++) {
- // Grab the index and check its validity. An earlier version of the
- // code checked it and then grabbed it from memory a second time, which
- // was a security risk since it could have changed in between.
+ // Grab the index and check its validity. Do this carefully,
+ // to avoid checking the value and grabbing it again from
+ // memory a second time (a security risk since it may change in between).
const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, limit)) return i;
// Copy last Ndim-1 dimensions of updates[i] to params[index]
@@ -216,6 +219,42 @@ struct ScatterFunctorBase {
}
};
+template <typename Device, typename Index>
+struct ScatterFunctorVariantAssignBase {
+ Index operator()(OpKernelContext* c, const Device& d,
+ typename TTypes<Variant>::Matrix params,
+ typename TTypes<Variant>::ConstMatrix updates,
+ typename TTypes<Index>::ConstFlat indices) {
+ // indices and params sizes were validated in DoCompute().
+ const Index N = static_cast<Index>(indices.size());
+ const Index limit = static_cast<Index>(params.dimension(0));
+ const Index cols = static_cast<Index>(params.dimension(1));
+ DCHECK_EQ(N, updates.dimension(0));
+ DCHECK_EQ(cols, updates.dimension(1));
+ for (Index i = 0; i < N; i++) {
+ // Grab the index and check its validity. Do this carefully,
+ // to avoid checking the value and grabbing it again from
+ // memory a second time (a security risk since it may change in between).
+ const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
+ if (!FastBoundsCheck(index, limit)) return i;
+ // Copy last Ndim-1 dimensions of updates[i] to params[index]
+ for (int j = 0; j < cols; ++j) {
+ const Variant& to_scatter = updates(i, j);
+ params(index, j) = to_scatter;
+ }
+ }
+ return -1;
+ }
+};
+
+template <typename Index>
+struct ScatterFunctor<CPUDevice, Variant, Index, scatter_op::UpdateOp::ASSIGN>
+ : ScatterFunctorVariantAssignBase<CPUDevice, Index> {};
+
+template <typename Index>
+struct ScatterFunctor<GPUDevice, Variant, Index, scatter_op::UpdateOp::ASSIGN>
+ : ScatterFunctorVariantAssignBase<GPUDevice, Index> {};
+
#ifdef TENSORFLOW_USE_SYCL
template <typename T, typename Index, scatter_op::UpdateOp op>
struct ScatterFunctorBase<SYCLDevice, T, Index, op> {
@@ -227,9 +266,9 @@ struct ScatterFunctorBase<SYCLDevice, T, Index, op> {
const Index N = static_cast<Index>(indices.size());
const Index limit = static_cast<Index>(params.dimension(0));
for (Index i = 0; i < N; i++) {
- // Grab the index and check its validity. An earlier version of the
- // code checked it and then grabbed it from memory a second time, which
- // was a security risk since it could have changed in between.
+ // Grab the index and check its validity. Do this carefully,
+ // to avoid checking the value and grabbing it again from
+ // memory a second time (a security risk since it may change in between).
const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, limit)) return i;
// Copy last Ndim-1 dimensions of updates[i] to params[index]
@@ -252,9 +291,10 @@ struct ScatterFunctorBase<CPUDevice, T, Index, scatter_op::UpdateOp::ASSIGN> {
const Index limit = static_cast<Index>(params.dimension(0));
if (!std::is_same<T, string>::value) {
for (Index i = 0; i < N; i++) {
- // Grab the index and check its validity. An earlier version of the
- // code checked it and then grabbed it from memory a second time, which
- // was a security risk since it could have changed in between.
+ // Grab the index and check its validity. Do this carefully,
+ // to avoid checking the value and grabbing it again from
+ // memory a second time (a security risk since it may change in
+ // between).
const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, limit)) return i;
memmove(params.data() + index * params.dimension(1),
@@ -263,9 +303,10 @@ struct ScatterFunctorBase<CPUDevice, T, Index, scatter_op::UpdateOp::ASSIGN> {
}
} else {
for (Index i = 0; i < N; i++) {
- // Grab the index and check its validity. An earlier version of the
- // code checked it and then grabbed it from memory a second time, which
- // was a security risk since it could have changed in between.
+ // Grab the index and check its validity. Do this carefully,
+ // to avoid checking the value and grabbing it again from
+ // memory a second time (a security risk since it may change in
+ // between).
const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, limit)) return i;
// Copy last Ndim-1 dimensions of updates[i] to params[index]
@@ -321,9 +362,9 @@ struct ScatterScalarFunctorBase {
const Index N = static_cast<Index>(indices.size());
const Index limit = static_cast<Index>(params.dimension(0));
for (Index i = 0; i < N; i++) {
- // Grab the index and check its validity. An earlier version of the
- // code checked it and then grabbed it from memory a second time, which
- // was a security risk since it could have changed in between.
+ // Grab the index and check its validity. Do this carefully,
+ // to avoid checking the value and grabbing it again from
+ // memory a second time (a security risk since it may change in between).
const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, limit)) return i;
// Broadcast update to params[index]
@@ -334,6 +375,41 @@ struct ScatterScalarFunctorBase {
}
};
+template <typename Device, typename Index>
+struct ScatterScalarFunctorVariantAssignBase {
+ Index operator()(OpKernelContext* c, const Device& d,
+ typename TTypes<Variant>::Matrix params,
+ const typename TTypes<Variant>::ConstScalar update,
+ typename TTypes<Index>::ConstFlat indices) {
+ // indices and params sizes were validated in DoCompute().
+ const Index N = static_cast<Index>(indices.size());
+ const Index limit = static_cast<Index>(params.dimension(0));
+ const Index cols = static_cast<Index>(params.dimension(1));
+ const Variant& to_scatter = update();
+ for (Index i = 0; i < N; i++) {
+ // Grab the index and check its validity. Do this carefully,
+ // to avoid checking the value and grabbing it again from
+ // memory a second time (a security risk since it may change in between).
+ const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
+ if (!FastBoundsCheck(index, limit)) return i;
+ // Broadcast update to params[index]
+ for (Index j = 0; j < cols; ++j) {
+ params(index, j) = to_scatter;
+ }
+ }
+ return -1;
+ }
+};
+
+template <typename Index>
+struct ScatterScalarFunctor<CPUDevice, Variant, Index,
+ scatter_op::UpdateOp::ASSIGN>
+ : ScatterScalarFunctorVariantAssignBase<CPUDevice, Index> {};
+template <typename Index>
+struct ScatterScalarFunctor<GPUDevice, Variant, Index,
+ scatter_op::UpdateOp::ASSIGN>
+ : ScatterScalarFunctorVariantAssignBase<GPUDevice, Index> {};
+
#ifdef TENSORFLOW_USE_SYCL
template <typename T, typename Index, scatter_op::UpdateOp op>
struct ScatterScalarFunctorBase<SYCLDevice, T, Index, op> {
@@ -345,9 +421,9 @@ struct ScatterScalarFunctorBase<SYCLDevice, T, Index, op> {
const Index N = static_cast<Index>(indices.size());
const Index limit = static_cast<Index>(params.dimension(0));
for (Index i = 0; i < N; i++) {
- // Grab the index and check its validity. An earlier version of the
- // code checked it and then grabbed it from memory a second time, which
- // was a security risk since it could have changed in between.
+ // Grab the index and check its validity. Do this carefully,
+ // to avoid checking the value and grabbing it again from
+ // memory a second time (a security risk since it may change in between).
const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, limit)) return i;
// Broadcast update to params[index]
@@ -370,9 +446,9 @@ struct ScatterScalarFunctorBase<CPUDevice, T, Index,
const Index N = static_cast<Index>(indices.size());
const Index limit = static_cast<Index>(params.dimension(0));
for (Index i = 0; i < N; i++) {
- // Grab the index and check its validity. An earlier version of the
- // code checked it and then grabbed it from memory a second time, which
- // was a security risk since it could have changed in between.
+ // Grab the index and check its validity. Do this carefully,
+ // to avoid checking the value and grabbing it again from
+ // memory a second time (a security risk since it may change in between).
const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, limit)) return i;
// Broadcast update to params[index]
diff --git a/tensorflow/core/kernels/training_op_helpers.h b/tensorflow/core/kernels/training_op_helpers.h
index f6e2a5ae25..857daae177 100644
--- a/tensorflow/core/kernels/training_op_helpers.h
+++ b/tensorflow/core/kernels/training_op_helpers.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_KERNELS_TRAINING_OP_HELPERS_H_
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/kernels/variable_ops.h"
@@ -40,14 +41,27 @@ Status PrepareToUpdateVariable(OpKernelContext* ctx, Tensor* tensor) {
// updating.
PersistentTensor unused;
Tensor* tmp;
- AllocatorAttributes attr;
- attr.set_gpu_compatible(true);
- attr.set_nic_compatible(true);
- TF_RETURN_IF_ERROR(ctx->allocate_persistent(
- tensor->dtype(), tensor->shape(), &unused, &tmp, attr));
- functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
- copy_functor(ctx->eigen_device<Device>(), tmp->flat<T>(),
- const_cast<const Tensor*>(tensor)->flat<T>());
+ if (std::is_same<T, Variant>::value) {
+ AllocatorAttributes attr;
+ attr.set_on_host(true);
+ TF_RETURN_IF_ERROR(ctx->allocate_persistent(
+ tensor->dtype(), tensor->shape(), &unused, &tmp, attr));
+
+ const auto elements_in = tensor->flat<Variant>();
+ auto elements_out = tmp->flat<Variant>();
+ for (int64 i = 0; i < elements_in.size(); ++i) {
+ elements_out(i) = elements_in(i);
+ }
+ } else {
+ AllocatorAttributes attr;
+ attr.set_gpu_compatible(true);
+ attr.set_nic_compatible(true);
+ TF_RETURN_IF_ERROR(ctx->allocate_persistent(
+ tensor->dtype(), tensor->shape(), &unused, &tmp, attr));
+ functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
+ copy_functor(ctx->eigen_device<Device>(), tmp->flat<T>(),
+ const_cast<const Tensor*>(tensor)->flat<T>());
+ }
*tensor = *tmp;
}
return Status::OK();
diff --git a/tensorflow/core/lib/gtl/flatmap_test.cc b/tensorflow/core/lib/gtl/flatmap_test.cc
index bb65e5357a..0901eba926 100644
--- a/tensorflow/core/lib/gtl/flatmap_test.cc
+++ b/tensorflow/core/lib/gtl/flatmap_test.cc
@@ -321,7 +321,7 @@ TEST(FlatMap, Copy) {
NumMap copy2;
copy2 = src;
EXPECT_EQ(Contents(src), Contents(copy2));
- copy2 = copy2; // Self-assignment
+ copy2 = *&copy2; // Self-assignment, avoiding -Wself-assign.
EXPECT_EQ(Contents(src), Contents(copy2));
}
}
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 62ce70eb6b..2a8b9f9bee 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -27,6 +27,7 @@ namespace tensorflow {
using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
+using shape_inference::UnchangedShape;
namespace {
@@ -341,6 +342,50 @@ REGISTER_OP("Pack")
return Status::OK();
});
+REGISTER_OP("DeepCopy")
+ .Input("x: T")
+ .Output("y: T")
+ .Attr("T: type")
+ .SetIsStateful()
+ .SetShapeFn(UnchangedShape);
+
+REGISTER_OP("InplaceUpdate")
+ .Input("x: T")
+ .Input("i: int32")
+ .Input("v: T")
+ .Output("y: T")
+ .Attr("T: type")
+ .SetShapeFn(UnchangedShape);
+
+REGISTER_OP("InplaceAdd")
+ .Input("x: T")
+ .Input("i: int32")
+ .Input("v: T")
+ .Output("y: T")
+ .Attr("T: type")
+ .SetShapeFn(UnchangedShape);
+
+REGISTER_OP("InplaceSub")
+ .Input("x: T")
+ .Input("i: int32")
+ .Input("v: T")
+ .Output("y: T")
+ .Attr("T: type")
+ .SetShapeFn(UnchangedShape);
+
+REGISTER_OP("Empty")
+ .Input("shape: int32")
+ .Output("output: dtype")
+ .Attr("dtype: type")
+ .Attr("init: bool = false")
+ .SetIsStateful()
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle out;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
+ c->set_output(0, out);
+ return Status::OK();
+ });
+
// --------------------------------------------------------------------------
REGISTER_OP("Unpack")
.Input("value: T")
@@ -622,7 +667,7 @@ REGISTER_OP("OnesLike")
.Input("x: T")
.Output("y: T")
.Attr(
- "T: {bfloat16, float, double, int8, uint8, int16, uint16, int32, "
+ "T: {bfloat16, half, float, double, int8, uint8, int16, uint16, int32, "
"int64, complex64, complex128, bool}")
.SetShapeFn(shape_inference::UnchangedShape);
@@ -630,7 +675,9 @@ REGISTER_OP("OnesLike")
REGISTER_OP("Diag")
.Input("diagonal: T")
.Output("output: T")
- .Attr("T: {bfloat16, float, double, int32, int64, complex64, complex128}")
+ .Attr(
+ "T: {bfloat16, half, float, double, int32, int64, complex64, "
+ "complex128}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle in = c->input(0);
TF_RETURN_IF_ERROR(c->WithRankAtLeast(in, 1, &in));
@@ -645,7 +692,9 @@ REGISTER_OP("Diag")
REGISTER_OP("DiagPart")
.Input("input: T")
.Output("diagonal: T")
- .Attr("T: {bfloat16, float, double, int32, int64, complex64, complex128}")
+ .Attr(
+ "T: {bfloat16, half, float, double, int32, int64, complex64, "
+ "complex128}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle in = c->input(0);
if (!c->RankKnown(in)) {
@@ -789,7 +838,7 @@ REGISTER_OP("ReverseV2")
.Output("output: T")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr(
- "T: {uint8, int8, uint16, int16, int32, int64, bool, half, bfloat16, "
+ "T: {uint8, int8, uint16, int16, int32, int64, bool, bfloat16, half, "
"float, double, complex64, complex128, string}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input = c->input(0);
@@ -1165,7 +1214,7 @@ REGISTER_OP("PreventGradient")
REGISTER_OP("CheckNumerics")
.Input("tensor: T")
.Output("output: T")
- .Attr("T: {half, bfloat16, float, double}")
+ .Attr("T: {bfloat16, half, float, double}")
.Attr("message: string")
.SetShapeFn(shape_inference::UnchangedShape);
@@ -2450,13 +2499,12 @@ REGISTER_OP("Bitcast")
.Output("output: type")
// All supported dtypes are listed here to include qint16 and quint16.
.Attr(
- "T: {bfloat16, float, double, int64, int32, uint8, uint16, int8, int16,"
- " complex64, complex128, qint8, quint8, qint16, quint16, qint32,"
- " half}")
+ "T: {bfloat16, half, float, double, int64, int32, uint8, uint16, int8, "
+ "int16, complex64, complex128, qint8, quint8, qint16, quint16, qint32}")
.Attr(
- "type: {bfloat16, float, double, int64, int32, uint8, uint16, int8, "
- "int16, complex64, complex128, qint8, quint8, qint16, quint16, qint32,"
- " half}")
+ "type: {bfloat16, half, float, double, int64, int32, uint8, uint16, "
+ "int8, int16, complex64, complex128, qint8, quint8, qint16, quint16, "
+ "qint32}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input = c->input(0);
if (!c->RankKnown(input)) {
@@ -2552,7 +2600,7 @@ REGISTER_OP("QuantizeAndDequantize")
.Attr("input_min: float = 0")
.Attr("input_max: float = 0")
.Output("output: T")
- .Attr("T: {bfloat16, float, double}")
+ .Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape)
.Deprecated(22, "Replaced by QuantizeAndDequantizeV2");
@@ -2565,7 +2613,7 @@ REGISTER_OP("QuantizeAndDequantizeV2")
.Attr("num_bits: int = 8")
.Attr("range_given: bool = false")
.Output("output: T")
- .Attr("T: {bfloat16, float, double}")
+ .Attr("T: {bfloat16, half, float, double}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
@@ -2582,7 +2630,7 @@ REGISTER_OP("QuantizeAndDequantizeV3")
.Attr("signed_input: bool = true")
.Attr("range_given: bool = true")
.Output("output: T")
- .Attr("T: {bfloat16, float, double}")
+ .Attr("T: {bfloat16, half, float, double}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
diff --git a/tensorflow/core/ops/collective_ops.cc b/tensorflow/core/ops/collective_ops.cc
new file mode 100644
index 0000000000..d6157a69df
--- /dev/null
+++ b/tensorflow/core/ops/collective_ops.cc
@@ -0,0 +1,55 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("CollectiveReduce")
+ .Input("input: T")
+ .Output("data: T")
+ .Attr("T: {float, float16, float64, int32, int64}")
+ .Attr("group_size: int")
+ .Attr("group_key: int")
+ .Attr("instance_key: int")
+ .Attr("merge_op: {'Min', 'Max', 'Mul', 'Add'}")
+ .Attr("final_op: {'Id', 'Div'}")
+ .Attr("subdiv_offsets: list(int)")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::UnchangedShape);
+
+REGISTER_OP("CollectiveBcastSend")
+ .Input("input: T")
+ .Output("data: T")
+ .Attr("T: {float, float16, float64, int32, int64}")
+ .Attr("group_size: int")
+ .Attr("group_key: int")
+ .Attr("instance_key: int")
+ .Attr("shape: shape")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::ExplicitShape);
+
+REGISTER_OP("CollectiveBcastRecv")
+ .Output("data: T")
+ .Attr("T: {float, float16, float64, int32, int64}")
+ .Attr("group_size: int")
+ .Attr("group_key: int")
+ .Attr("instance_key: int")
+ .Attr("shape: shape")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::ExplicitShape);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 10b24c2d34..026bfa89cf 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -65,6 +65,31 @@ op {
}
}
op {
+ name: "Abs"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "AccumulateNV2"
input_arg {
name: "inputs"
@@ -608,6 +633,33 @@ op {
}
}
op {
+ name: "Acos"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "Acosh"
input_arg {
name: "x"
@@ -657,6 +709,31 @@ op {
}
}
op {
+ name: "Acosh"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "Add"
input_arg {
name: "x"
@@ -726,6 +803,41 @@ op {
}
}
op {
+ name: "Add"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_STRING
+ }
+ }
+ }
+}
+op {
name: "AddManySparseToTensorsMap"
input_arg {
name: "sparse_indices"
@@ -1095,6 +1207,42 @@ op {
is_commutative: true
}
op {
+ name: "AddV2"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ is_aggregate: true
+ is_commutative: true
+}
+op {
name: "AdjustContrast"
input_arg {
name: "images"
@@ -6167,6 +6315,33 @@ op {
}
}
op {
+ name: "Asin"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "Asinh"
input_arg {
name: "x"
@@ -6216,6 +6391,31 @@ op {
}
}
op {
+ name: "Asinh"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "Assert"
input_arg {
name: "condition"
@@ -6762,6 +6962,33 @@ op {
}
}
op {
+ name: "Atan"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "Atan2"
input_arg {
name: "y"
@@ -6813,6 +7040,33 @@ op {
}
}
op {
+ name: "Atan2"
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "Atanh"
input_arg {
name: "x"
@@ -6862,6 +7116,31 @@ op {
}
}
op {
+ name: "Atanh"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "AudioSpectrogram"
input_arg {
name: "input"
@@ -8329,6 +8608,50 @@ op {
}
}
op {
+ name: "BatchMatMul"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ attr {
+ name: "adj_x"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "adj_y"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+}
+op {
name: "BatchMatrixBandPart"
input_arg {
name: "input"
@@ -10155,6 +10478,67 @@ op {
}
}
op {
+ name: "Bitcast"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "type"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT64
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_UINT16
+ type: DT_INT8
+ type: DT_INT16
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT16
+ type: DT_QUINT16
+ type: DT_QINT32
+ }
+ }
+ }
+ attr {
+ name: "type"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT64
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_UINT16
+ type: DT_INT8
+ type: DT_INT16
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT16
+ type: DT_QUINT16
+ type: DT_QINT32
+ }
+ }
+ }
+}
+op {
name: "BitwiseAnd"
input_arg {
name: "x"
@@ -11082,6 +11466,29 @@ op {
}
}
op {
+ name: "Ceil"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "CheckNumerics"
input_arg {
name: "tensor"
@@ -11135,6 +11542,33 @@ op {
}
}
op {
+ name: "CheckNumerics"
+ input_arg {
+ name: "tensor"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "message"
+ type: "string"
+ }
+}
+op {
name: "Cholesky"
input_arg {
name: "input"
@@ -11212,6 +11646,147 @@ op {
is_stateful: true
}
op {
+ name: "CollectiveBcastRecv"
+ output_arg {
+ name: "data"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_HALF
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "group_size"
+ type: "int"
+ }
+ attr {
+ name: "group_key"
+ type: "int"
+ }
+ attr {
+ name: "instance_key"
+ type: "int"
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ }
+ is_stateful: true
+}
+op {
+ name: "CollectiveBcastSend"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "data"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_HALF
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "group_size"
+ type: "int"
+ }
+ attr {
+ name: "group_key"
+ type: "int"
+ }
+ attr {
+ name: "instance_key"
+ type: "int"
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ }
+ is_stateful: true
+}
+op {
+ name: "CollectiveReduce"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "data"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_HALF
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "group_size"
+ type: "int"
+ }
+ attr {
+ name: "group_key"
+ type: "int"
+ }
+ attr {
+ name: "instance_key"
+ type: "int"
+ }
+ attr {
+ name: "merge_op"
+ type: "string"
+ allowed_values {
+ list {
+ s: "Min"
+ s: "Max"
+ s: "Mul"
+ s: "Add"
+ }
+ }
+ }
+ attr {
+ name: "final_op"
+ type: "string"
+ allowed_values {
+ list {
+ s: "Id"
+ s: "Div"
+ }
+ }
+ }
+ attr {
+ name: "subdiv_offsets"
+ type: "list(int)"
+ }
+ is_stateful: true
+}
+op {
name: "CompareAndBitpack"
input_arg {
name: "input"
@@ -13319,6 +13894,31 @@ op {
}
}
op {
+ name: "Cos"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "Cosh"
input_arg {
name: "x"
@@ -13368,6 +13968,31 @@ op {
}
}
op {
+ name: "Cosh"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "CountUpTo"
input_arg {
name: "ref"
@@ -15907,6 +16532,55 @@ op {
}
}
op {
+ name: "DecodeProtoV2"
+ input_arg {
+ name: "bytes"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "sizes"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "values"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "message_type"
+ type: "string"
+ }
+ attr {
+ name: "field_names"
+ type: "list(string)"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "descriptor_source"
+ type: "string"
+ default_value {
+ s: "local://"
+ }
+ }
+ attr {
+ name: "message_format"
+ type: "string"
+ default_value {
+ s: "binary"
+ }
+ }
+ attr {
+ name: "sanitize"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+}
+op {
name: "DecodeRaw"
input_arg {
name: "bytes"
@@ -16005,6 +16679,22 @@ op {
}
}
op {
+ name: "DeepCopy"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ is_stateful: true
+}
+op {
name: "DeleteSessionTensor"
input_arg {
name: "handle"
@@ -17055,6 +17745,33 @@ op {
}
}
op {
+ name: "Diag"
+ input_arg {
+ name: "diagonal"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "DiagPart"
input_arg {
name: "input"
@@ -17106,6 +17823,33 @@ op {
}
}
op {
+ name: "DiagPart"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "diagonal"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "Digamma"
input_arg {
name: "x"
@@ -17151,6 +17895,29 @@ op {
}
}
op {
+ name: "Digamma"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "Dilation2D"
input_arg {
name: "input"
@@ -17924,6 +18691,41 @@ op {
}
}
op {
+ name: "Div"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "DrawBoundingBoxes"
input_arg {
name: "images"
@@ -18219,6 +19021,29 @@ op {
}
}
op {
+ name: "Empty"
+ input_arg {
+ name: "shape"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ }
+ attr {
+ name: "init"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ is_stateful: true
+}
+op {
name: "EmptyTensorList"
input_arg {
name: "element_shape"
@@ -18380,6 +19205,42 @@ op {
}
}
op {
+ name: "EncodeProto"
+ input_arg {
+ name: "sizes"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "values"
+ type_list_attr: "Tinput_types"
+ }
+ output_arg {
+ name: "bytes"
+ type: DT_STRING
+ }
+ attr {
+ name: "field_names"
+ type: "list(string)"
+ }
+ attr {
+ name: "message_type"
+ type: "string"
+ }
+ attr {
+ name: "descriptor_source"
+ type: "string"
+ default_value {
+ s: "local://"
+ }
+ }
+ attr {
+ name: "Tinput_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "EncodeWav"
input_arg {
name: "audio"
@@ -18525,6 +19386,46 @@ op {
is_commutative: true
}
op {
+ name: "Equal"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type: DT_BOOL
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_QUINT8
+ type: DT_QINT8
+ type: DT_QINT32
+ type: DT_STRING
+ type: DT_BOOL
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ is_commutative: true
+}
+op {
name: "Erf"
input_arg {
name: "x"
@@ -18570,6 +19471,29 @@ op {
}
}
op {
+ name: "Erf"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "Erfc"
input_arg {
name: "x"
@@ -18615,6 +19539,29 @@ op {
}
}
op {
+ name: "Erfc"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "Exit"
input_arg {
name: "data"
@@ -18679,6 +19626,31 @@ op {
}
}
op {
+ name: "Exp"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "ExpandDims"
input_arg {
name: "input"
@@ -18760,6 +19732,31 @@ op {
}
}
op {
+ name: "Expm1"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "ExtractGlimpse"
input_arg {
name: "input"
@@ -20540,6 +21537,29 @@ op {
}
}
op {
+ name: "Floor"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "FloorDiv"
input_arg {
name: "x"
@@ -20609,6 +21629,41 @@ op {
}
}
op {
+ name: "FloorDiv"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "FloorMod"
input_arg {
name: "x"
@@ -20664,6 +21719,35 @@ op {
}
}
op {
+ name: "FloorMod"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "FlushSummaryWriter"
input_arg {
name: "writer"
@@ -23249,6 +24333,75 @@ op {
is_stateful: true
}
op {
+ name: "InplaceAdd"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "i"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "v"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+}
+op {
+ name: "InplaceSub"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "i"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "v"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+}
+op {
+ name: "InplaceUpdate"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "i"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "v"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+}
+op {
name: "InterleaveDataset"
input_arg {
name: "input_dataset"
@@ -23481,6 +24634,33 @@ op {
}
}
op {
+ name: "Inv"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "InvGrad"
input_arg {
name: "x"
@@ -23665,6 +24845,35 @@ op {
}
}
op {
+ name: "InvGrad"
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "dy"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "Invert"
input_arg {
name: "x"
@@ -23798,6 +25007,29 @@ op {
}
}
op {
+ name: "IsFinite"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type: DT_BOOL
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "IsInf"
input_arg {
name: "x"
@@ -23843,6 +25075,29 @@ op {
}
}
op {
+ name: "IsInf"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type: DT_BOOL
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "IsNan"
input_arg {
name: "x"
@@ -23888,6 +25143,29 @@ op {
}
}
op {
+ name: "IsNan"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type: DT_BOOL
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "IsVariableInitialized"
input_arg {
name: "ref"
@@ -24850,6 +26128,29 @@ op {
}
}
op {
+ name: "Lgamma"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "LinSpace"
input_arg {
name: "start"
@@ -25066,6 +26367,31 @@ op {
}
}
op {
+ name: "Log"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "Log1p"
input_arg {
name: "x"
@@ -25115,6 +26441,31 @@ op {
}
}
op {
+ name: "Log1p"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "LogMatrixDeterminant"
input_arg {
name: "input"
@@ -26130,6 +27481,50 @@ op {
}
}
op {
+ name: "MatMul"
+ input_arg {
+ name: "a"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "b"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "product"
+ type_attr: "T"
+ }
+ attr {
+ name: "transpose_a"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "transpose_b"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "MatchingFiles"
input_arg {
name: "pattern"
@@ -30008,6 +31403,36 @@ op {
is_commutative: true
}
op {
+ name: "Maximum"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ is_commutative: true
+}
+op {
name: "Mean"
input_arg {
name: "input"
@@ -30663,6 +32088,36 @@ op {
is_commutative: true
}
op {
+ name: "Minimum"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ is_commutative: true
+}
+op {
name: "MirrorPad"
input_arg {
name: "input"
@@ -30802,6 +32257,36 @@ op {
}
}
op {
+ name: "Mod"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_HALF
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "Mul"
input_arg {
name: "x"
@@ -30873,6 +32358,42 @@ op {
is_commutative: true
}
op {
+ name: "Mul"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ is_commutative: true
+}
+op {
name: "Multinomial"
input_arg {
name: "logits"
@@ -31475,6 +32996,33 @@ op {
}
}
op {
+ name: "Neg"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "NegTrain"
input_arg {
name: "w_in"
@@ -31658,6 +33206,46 @@ op {
is_commutative: true
}
op {
+ name: "NotEqual"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type: DT_BOOL
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_QUINT8
+ type: DT_QINT8
+ type: DT_QINT32
+ type: DT_STRING
+ type: DT_BOOL
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ is_commutative: true
+}
+op {
name: "NthElement"
input_arg {
name: "input"
@@ -31925,6 +33513,38 @@ op {
}
}
op {
+ name: "OnesLike"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT8
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_UINT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_BOOL
+ }
+ }
+ }
+}
+op {
name: "OrderedMapClear"
attr {
name: "capacity"
@@ -33435,6 +35055,37 @@ op {
}
}
op {
+ name: "Pow"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_HALF
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "PrefetchDataset"
input_arg {
name: "input_dataset"
@@ -34289,6 +35940,117 @@ op {
}
}
op {
+ name: "QuantizeAndDequantize"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "signed_input"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ attr {
+ name: "num_bits"
+ type: "int"
+ default_value {
+ i: 8
+ }
+ }
+ attr {
+ name: "range_given"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "input_min"
+ type: "float"
+ default_value {
+ f: 0
+ }
+ }
+ attr {
+ name: "input_max"
+ type: "float"
+ default_value {
+ f: 0
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ deprecation {
+ version: 22
+ }
+}
+op {
+ name: "QuantizeAndDequantizeV2"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "input_min"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "input_max"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "signed_input"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ attr {
+ name: "num_bits"
+ type: "int"
+ default_value {
+ i: 8
+ }
+ }
+ attr {
+ name: "range_given"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "QuantizeAndDequantizeV2"
input_arg {
name: "input"
@@ -34332,6 +36094,7 @@ op {
type: "type"
allowed_values {
list {
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -34383,6 +36146,7 @@ op {
allowed_values {
list {
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -34485,6 +36249,55 @@ op {
}
}
op {
+ name: "QuantizeAndDequantizeV3"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "input_min"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "input_max"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "num_bits"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "signed_input"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ attr {
+ name: "range_given"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "QuantizeDownAndShrinkRange"
input_arg {
name: "input"
@@ -38531,6 +40344,41 @@ op {
}
}
op {
+ name: "RealDiv"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "Reciprocal"
input_arg {
name: "x"
@@ -38584,6 +40432,33 @@ op {
}
}
op {
+ name: "Reciprocal"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "ReciprocalGrad"
input_arg {
name: "x"
@@ -38669,6 +40544,35 @@ op {
}
}
op {
+ name: "ReciprocalGrad"
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "dy"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "RecordInput"
output_arg {
name: "records"
@@ -48267,6 +50171,56 @@ op {
}
}
op {
+ name: "ReverseV2"
+ input_arg {
+ name: "tensor"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "axis"
+ type_attr: "Tidx"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "Tidx"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_BOOL
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_STRING
+ }
+ }
+ }
+}
+op {
name: "RightShift"
input_arg {
name: "x"
@@ -48342,6 +50296,29 @@ op {
}
}
op {
+ name: "Rint"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "Roll"
input_arg {
name: "input"
@@ -48438,6 +50415,74 @@ op {
}
}
op {
+ name: "Round"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
+ name: "Rpc"
+ input_arg {
+ name: "address"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "method"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "request"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "response"
+ type: DT_STRING
+ }
+ attr {
+ name: "protocol"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "fail_fast"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ attr {
+ name: "timeout_in_ms"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
+ is_stateful: true
+}
+op {
name: "Rsqrt"
input_arg {
name: "x"
@@ -48487,6 +50532,31 @@ op {
}
}
op {
+ name: "Rsqrt"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "RsqrtGrad"
input_arg {
name: "x"
@@ -48572,6 +50642,35 @@ op {
}
}
op {
+ name: "RsqrtGrad"
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "dy"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "SampleDistortedBoundingBox"
input_arg {
name: "image_size"
@@ -52768,6 +54867,31 @@ op {
}
}
op {
+ name: "Sigmoid"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "SigmoidGrad"
input_arg {
name: "x"
@@ -52853,6 +54977,35 @@ op {
}
}
op {
+ name: "SigmoidGrad"
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "dy"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "Sign"
input_arg {
name: "x"
@@ -52906,6 +55059,33 @@ op {
}
}
op {
+ name: "Sign"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "Sin"
input_arg {
name: "x"
@@ -52955,6 +55135,31 @@ op {
}
}
op {
+ name: "Sin"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "Sinh"
input_arg {
name: "x"
@@ -53004,6 +55209,31 @@ op {
}
}
op {
+ name: "Sinh"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "Size"
input_arg {
name: "input"
@@ -62162,6 +64392,31 @@ op {
}
}
op {
+ name: "Sqrt"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "SqrtGrad"
input_arg {
name: "x"
@@ -62247,6 +64502,35 @@ op {
}
}
op {
+ name: "SqrtGrad"
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "dy"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "Square"
input_arg {
name: "x"
@@ -62300,6 +64584,33 @@ op {
}
}
op {
+ name: "Square"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "SquaredDifference"
input_arg {
name: "x"
@@ -62363,6 +64674,38 @@ op {
is_commutative: true
}
op {
+ name: "SquaredDifference"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ is_commutative: true
+}
+op {
name: "Squeeze"
input_arg {
name: "input"
@@ -63580,6 +65923,41 @@ op {
}
}
op {
+ name: "Sub"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "Substr"
input_arg {
name: "input"
@@ -64220,6 +66598,33 @@ op {
}
}
op {
+ name: "Tan"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "Tanh"
input_arg {
name: "x"
@@ -64269,6 +66674,31 @@ op {
}
}
op {
+ name: "Tanh"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "TanhGrad"
input_arg {
name: "x"
@@ -64354,6 +66784,35 @@ op {
}
}
op {
+ name: "TanhGrad"
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "dy"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "TemporaryVariable"
output_arg {
name: "ref"
@@ -66631,6 +69090,41 @@ op {
}
}
op {
+ name: "TruncateDiv"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "TruncateMod"
input_arg {
name: "x"
@@ -66686,6 +69180,35 @@ op {
}
}
op {
+ name: "TruncateMod"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "TruncatedNormal"
input_arg {
name: "shape"
@@ -66781,6 +69304,55 @@ op {
is_stateful: true
}
op {
+ name: "TryRpc"
+ input_arg {
+ name: "address"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "method"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "request"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "response"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "status_code"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "status_message"
+ type: DT_STRING
+ }
+ attr {
+ name: "protocol"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "fail_fast"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ attr {
+ name: "timeout_in_ms"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
+ is_stateful: true
+}
+op {
name: "Unbatch"
input_arg {
name: "batched_tensor"
diff --git a/tensorflow/core/ops/decode_proto_ops.cc b/tensorflow/core/ops/decode_proto_ops.cc
new file mode 100644
index 0000000000..3f6fb2f582
--- /dev/null
+++ b/tensorflow/core/ops/decode_proto_ops.cc
@@ -0,0 +1,67 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+using tensorflow::shape_inference::InferenceContext;
+using tensorflow::shape_inference::ShapeHandle;
+
+REGISTER_OP("DecodeProtoV2")
+ .Input("bytes: string")
+ .Attr("message_type: string")
+ .Attr("field_names: list(string)")
+ .Attr("output_types: list(type) >= 0")
+ .Attr("descriptor_source: string = 'local://'")
+ .Attr("message_format: string = 'binary'")
+ .Attr("sanitize: bool = false")
+ .Output("sizes: int32")
+ .Output("values: output_types")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle input = c->input(0);
+
+ std::vector<tensorflow::DataType> output_types;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_types", &output_types));
+
+ ShapeHandle sizes;
+ TF_RETURN_IF_ERROR(
+ c->Concatenate(input, c->Vector(output_types.size()), &sizes));
+ c->set_output(0, sizes);
+
+ // TODO(nix): to do the best possible job of shape inference, we
+ // should examine the proto descriptors here in order to set shape
+ // indices to 1 instead of unknown for optional or required fields.
+ // Any general-purpose code will have to handle the unknown case,
+ // but there might be XLA code that could be sped up with the additional
+ // knowledge.
+ for (int i = 0; i < output_types.size(); ++i) {
+ ShapeHandle values;
+ TF_RETURN_IF_ERROR(
+ c->Concatenate(input, c->Vector(c->UnknownDim()), &values));
+ c->set_output(i + 1, values);
+ }
+
+ return Status::OK();
+ });
+
+// TODO(nix): Consider adding an additional input argument that truncates
+// repeated fields to a maximum count. For now this could be done by passing
+// the output through tf.slice.
+
+// TODO(nix): define missing value behavior.
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/encode_proto_ops.cc b/tensorflow/core/ops/encode_proto_ops.cc
new file mode 100644
index 0000000000..f5ec3056e3
--- /dev/null
+++ b/tensorflow/core/ops/encode_proto_ops.cc
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+using tensorflow::shape_inference::InferenceContext;
+using tensorflow::shape_inference::ShapeHandle;
+
+REGISTER_OP("EncodeProto")
+ .Input("sizes: int32")
+ .Input("values: Tinput_types")
+ .Attr("field_names: list(string)")
+ .Attr("message_type: string")
+ .Attr("descriptor_source: string = 'local://'")
+ .Attr("Tinput_types: list(type)")
+ .Output("bytes: string")
+ .SetShapeFn([](InferenceContext* c) {
+ int first_field_index = 1;
+ int num_fields = c->num_inputs() - 1;
+
+ ShapeHandle output;
+ for (int i = num_fields - 1; i >= 0; --i) {
+ ShapeHandle input = c->input(first_field_index + i);
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 2, &input));
+ ShapeHandle inner;
+ TF_RETURN_IF_ERROR(c->Subshape(input, 0, -1, &inner));
+ TF_RETURN_IF_ERROR(c->Merge(inner, output, &output));
+ }
+
+ c->set_output(0, output);
+ return Status::OK();
+ });
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/list_ops.cc b/tensorflow/core/ops/list_ops.cc
index cad617638f..c151055ee6 100644
--- a/tensorflow/core/ops/list_ops.cc
+++ b/tensorflow/core/ops/list_ops.cc
@@ -30,7 +30,8 @@ REGISTER_OP("EmptyTensorList")
DataType t;
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t));
shape_inference::ShapeHandle s;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(
+ c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(0, &s));
c->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{{s, t}});
return Status::OK();
@@ -193,6 +194,7 @@ REGISTER_OP("TensorListReserve")
.Attr("element_dtype: type")
.Attr("shape_type: {int32, int64}")
.SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->Scalar());
shape_inference::ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
DataType t;
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 4548e59fbf..8f8443a46c 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -65,7 +65,7 @@ REGISTER_OP("BatchMatMul")
.Input("x: T")
.Input("y: T")
.Output("output: T")
- .Attr("T: {half, bfloat16, float, double, int32, complex64, complex128}")
+ .Attr("T: {bfloat16, half, float, double, int32, complex64, complex128}")
.Attr("adj_x: bool = false")
.Attr("adj_y: bool = false")
.SetShapeFn([](InferenceContext* c) {
@@ -133,7 +133,7 @@ _HostCast requires its input and produces its output in host memory.
REGISTER_OP("Abs")
.Input("x: T")
.Output("y: T")
- .Attr("T: {half, bfloat16, float, double, int32, int64}")
+ .Attr("T: {bfloat16, half, float, double, int32, int64}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("ComplexAbs")
@@ -148,27 +148,27 @@ REGISTER_OP("ComplexAbs")
Input("x: T") \
.Output("y: T") \
.Attr( \
- "T: {half, bfloat16, float, double, int32, int64, complex64, " \
+ "T: {bfloat16, half, float, double, int32, int64, complex64, " \
"complex128}") \
.SetShapeFn(shape_inference::UnchangedShape)
#define UNARY_REAL() \
Input("x: T") \
.Output("y: T") \
- .Attr("T: {half, bfloat16, float, double}") \
+ .Attr("T: {bfloat16, half, float, double}") \
.SetShapeFn(shape_inference::UnchangedShape)
#define UNARY_COMPLEX() \
Input("x: T") \
.Output("y: T") \
- .Attr("T: {half, bfloat16, float, double, complex64, complex128}") \
+ .Attr("T: {bfloat16, half, float, double, complex64, complex128}") \
.SetShapeFn(shape_inference::UnchangedShape)
#define UNARY_GRADIENT_COMPLEX() \
Input("y: T") \
.Input("dy: T") \
.Output("z: T") \
- .Attr("T: {half, bfloat16, float, double, complex64, complex128}") \
+ .Attr("T: {bfloat16, half, float, double, complex64, complex128}") \
.SetShapeFn(shape_inference::UnchangedShape)
REGISTER_OP("Neg").UNARY();
@@ -246,57 +246,57 @@ REGISTER_OP("Atan").UNARY();
REGISTER_OP("IsNan")
.Input("x: T")
.Output("y: bool")
- .Attr("T: {half, bfloat16, float, double}")
+ .Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("IsInf")
.Input("x: T")
.Output("y: bool")
- .Attr("T: {half, bfloat16, float, double}")
+ .Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("IsFinite")
.Input("x: T")
.Output("y: bool")
- .Attr("T: {half, bfloat16, float, double}")
+ .Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("Sign")
.Input("x: T")
.Output("y: T")
.Attr(
- "T: {half, bfloat16, float, double, int32, int64, complex64, "
+ "T: {bfloat16, half, float, double, int32, int64, complex64, "
"complex128}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("Floor")
.Input("x: T")
.Output("y: T")
- .Attr("T: {half, bfloat16, float, double}")
+ .Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("Ceil")
.Input("x: T")
.Output("y: T")
- .Attr("T: {half, bfloat16, float, double}")
+ .Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("Rint")
.Input("x: T")
.Output("y: T")
- .Attr("T: {bfloat16, float, double}")
+ .Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
// Declares cwise binary operations signature: 't, 't -> 't.
#define BINARY_MORE() \
Input("x: T").Input("y: T").Output("z: T").Attr( \
- "T: {half, bfloat16, float, double, uint8, int8, uint16, int16, int32, " \
+ "T: {bfloat16, half, float, double, uint8, int8, uint16, int16, int32, " \
"int64, complex64, complex128}")
#define BINARY_FEWER() \
Input("x: T").Input("y: T").Output("z: T").Attr( \
- "T: {half, bfloat16, float, double, int32, int64, complex64, " \
+ "T: {bfloat16, half, float, double, int32, int64, complex64, " \
"complex128}")
REGISTER_OP("Add")
@@ -304,7 +304,7 @@ REGISTER_OP("Add")
.Input("y: T")
.Output("z: T")
.Attr(
- "T: {half, bfloat16, float, double, uint8, int8, int16, int32, int64, "
+ "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
"complex64, complex128, string}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
@@ -315,7 +315,7 @@ REGISTER_OP("AddV2")
.Input("y: T")
.Output("z: T")
.Attr(
- "T: {half, bfloat16, float, double, uint8, int8, int16, int32, int64, "
+ "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
"complex64, complex128}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.SetIsAggregate()
@@ -412,7 +412,7 @@ REGISTER_OP("Maximum")
.Input("x: T")
.Input("y: T")
.Output("z: T")
- .Attr("T: {half, bfloat16, float, double, int32, int64}")
+ .Attr("T: {bfloat16, half, float, double, int32, int64}")
.SetIsCommutative()
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
@@ -437,7 +437,7 @@ REGISTER_OP("Minimum")
.Input("x: T")
.Input("y: T")
.Output("z: T")
- .Attr("T: {half, bfloat16, float, double, int32, int64}")
+ .Attr("T: {bfloat16, half, float, double, int32, int64}")
.SetIsCommutative()
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
@@ -445,21 +445,21 @@ REGISTER_OP("Mod")
.Input("x: T")
.Input("y: T")
.Output("z: T")
- .Attr("T: {int32, int64, bfloat16, float, double}")
+ .Attr("T: {int32, int64, float16, half, bfloat16, float, double}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
REGISTER_OP("FloorMod")
.Input("x: T")
.Input("y: T")
.Output("z: T")
- .Attr("T: {int32, int64, bfloat16, float, double}")
+ .Attr("T: {int32, int64, bfloat16, half, float, double}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
REGISTER_OP("TruncateMod")
.Input("x: T")
.Input("y: T")
.Output("z: T")
- .Attr("T: {int32, int64, bfloat16, float, double}")
+ .Attr("T: {int32, int64, bfloat16, half, float, double}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
REGISTER_OP("Pow")
@@ -467,7 +467,7 @@ REGISTER_OP("Pow")
.Input("y: T")
.Output("z: T")
.Attr(
- "T: {half, bfloat16, float, double, int32, int64, complex64, "
+ "T: {bfloat16, float, half, double, int32, int64, complex64, "
"complex128}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
@@ -503,7 +503,7 @@ REGISTER_OP("Atan2")
.Input("y: T")
.Input("x: T")
.Output("z: T")
- .Attr("T: {bfloat16, float, double}")
+ .Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
REGISTER_OP("Betainc")
@@ -574,7 +574,7 @@ REGISTER_OP("GreaterEqual").COMPARISON();
.Output("z: bool") \
.SetIsCommutative() \
.Attr( \
- "T: {half, bfloat16, float, double, uint8, int8, int16, int32, " \
+ "T: {bfloat16, half, float, double, uint8, int8, int16, int32, " \
"int64, complex64, quint8, qint8, qint32, string, bool, " \
"complex128}") \
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
@@ -713,7 +713,7 @@ REGISTER_OP("MatMul")
.Output("product: T")
.Attr("transpose_a: bool = false")
.Attr("transpose_b: bool = false")
- .Attr("T: {half, bfloat16, float, double, int32, complex64, complex128}")
+ .Attr("T: {bfloat16, half, float, double, int32, complex64, complex128}")
.SetShapeFn(shape_inference::MatMulShape);
REGISTER_OP("SparseMatMul")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 5764976aee..b61a3b0e64 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -30,8 +30,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
@@ -210,8 +210,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
@@ -237,8 +237,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
@@ -266,8 +266,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_UINT8
@@ -423,8 +423,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_UINT8
@@ -1932,8 +1932,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
@@ -1959,8 +1959,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
@@ -2191,8 +2191,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
@@ -2223,6 +2223,7 @@ op {
allowed_values {
list {
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -2244,8 +2245,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
@@ -3004,8 +3005,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
@@ -3854,6 +3855,7 @@ op {
allowed_values {
list {
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT64
@@ -3869,7 +3871,6 @@ op {
type: DT_QINT16
type: DT_QUINT16
type: DT_QINT32
- type: DT_HALF
}
}
}
@@ -3879,6 +3880,7 @@ op {
allowed_values {
list {
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT64
@@ -3894,7 +3896,6 @@ op {
type: DT_QINT16
type: DT_QUINT16
type: DT_QINT32
- type: DT_HALF
}
}
}
@@ -4637,8 +4638,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -4660,8 +4661,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -4729,6 +4730,147 @@ op {
is_stateful: true
}
op {
+ name: "CollectiveBcastRecv"
+ output_arg {
+ name: "data"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_HALF
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "group_size"
+ type: "int"
+ }
+ attr {
+ name: "group_key"
+ type: "int"
+ }
+ attr {
+ name: "instance_key"
+ type: "int"
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ }
+ is_stateful: true
+}
+op {
+ name: "CollectiveBcastSend"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "data"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_HALF
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "group_size"
+ type: "int"
+ }
+ attr {
+ name: "group_key"
+ type: "int"
+ }
+ attr {
+ name: "instance_key"
+ type: "int"
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ }
+ is_stateful: true
+}
+op {
+ name: "CollectiveReduce"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "data"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_HALF
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "group_size"
+ type: "int"
+ }
+ attr {
+ name: "group_key"
+ type: "int"
+ }
+ attr {
+ name: "instance_key"
+ type: "int"
+ }
+ attr {
+ name: "merge_op"
+ type: "string"
+ allowed_values {
+ list {
+ s: "Min"
+ s: "Max"
+ s: "Mul"
+ s: "Add"
+ }
+ }
+ }
+ attr {
+ name: "final_op"
+ type: "string"
+ allowed_values {
+ list {
+ s: "Id"
+ s: "Div"
+ }
+ }
+ }
+ attr {
+ name: "subdiv_offsets"
+ type: "list(int)"
+ }
+ is_stateful: true
+}
+op {
name: "CompareAndBitpack"
input_arg {
name: "input"
@@ -5759,8 +5901,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
@@ -5784,8 +5926,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
@@ -7345,6 +7487,55 @@ op {
}
}
op {
+ name: "DecodeProtoV2"
+ input_arg {
+ name: "bytes"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "sizes"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "values"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "message_type"
+ type: "string"
+ }
+ attr {
+ name: "field_names"
+ type: "list(string)"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "descriptor_source"
+ type: "string"
+ default_value {
+ s: "local://"
+ }
+ }
+ attr {
+ name: "message_format"
+ type: "string"
+ default_value {
+ s: "binary"
+ }
+ }
+ attr {
+ name: "sanitize"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+}
+op {
name: "DecodeRaw"
input_arg {
name: "bytes"
@@ -7409,6 +7600,22 @@ op {
}
}
op {
+ name: "DeepCopy"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ is_stateful: true
+}
+op {
name: "DeleteSessionTensor"
input_arg {
name: "handle"
@@ -7960,6 +8167,7 @@ op {
allowed_values {
list {
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
@@ -7986,6 +8194,7 @@ op {
allowed_values {
list {
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
@@ -8011,8 +8220,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -8217,8 +8426,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_UINT8
@@ -8433,6 +8642,29 @@ op {
}
}
op {
+ name: "Empty"
+ input_arg {
+ name: "shape"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ }
+ attr {
+ name: "init"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ is_stateful: true
+}
+op {
name: "EmptyTensorList"
input_arg {
name: "element_shape"
@@ -8594,6 +8826,42 @@ op {
}
}
op {
+ name: "EncodeProto"
+ input_arg {
+ name: "sizes"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "values"
+ type_list_attr: "Tinput_types"
+ }
+ output_arg {
+ name: "bytes"
+ type: DT_STRING
+ }
+ attr {
+ name: "field_names"
+ type: "list(string)"
+ }
+ attr {
+ name: "message_type"
+ type: "string"
+ }
+ attr {
+ name: "descriptor_source"
+ type: "string"
+ default_value {
+ s: "local://"
+ }
+ }
+ attr {
+ name: "Tinput_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "EncodeWav"
input_arg {
name: "audio"
@@ -8678,8 +8946,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_UINT8
@@ -8714,8 +8982,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -8737,8 +9005,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -8775,8 +9043,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
@@ -8832,8 +9100,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
@@ -9700,8 +9968,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -9727,8 +9995,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_UINT8
@@ -9765,6 +10033,7 @@ op {
type: DT_INT32
type: DT_INT64
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -11622,6 +11891,75 @@ op {
is_stateful: true
}
op {
+ name: "InplaceAdd"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "i"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "v"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+}
+op {
+ name: "InplaceSub"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "i"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "v"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+}
+op {
+ name: "InplaceUpdate"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "i"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "v"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+}
+op {
name: "InterleaveDataset"
input_arg {
name: "input_dataset"
@@ -11680,8 +12018,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
@@ -11711,8 +12049,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
@@ -11799,8 +12137,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -11822,8 +12160,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -11845,8 +12183,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -12360,8 +12698,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -12508,8 +12846,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
@@ -12533,8 +12871,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
@@ -13390,8 +13728,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
@@ -14625,8 +14963,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
@@ -14881,8 +15219,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
@@ -14997,6 +15335,8 @@ op {
list {
type: DT_INT32
type: DT_INT64
+ type: DT_HALF
+ type: DT_HALF
type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
@@ -15023,8 +15363,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_UINT8
@@ -15445,8 +15785,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
@@ -15581,8 +15921,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_UINT8
@@ -15746,6 +16086,7 @@ op {
allowed_values {
list {
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT8
@@ -17024,9 +17365,9 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
type: DT_FLOAT
+ type: DT_HALF
type: DT_DOUBLE
type: DT_INT32
type: DT_INT64
@@ -17456,6 +17797,7 @@ op {
allowed_values {
list {
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -17511,6 +17853,7 @@ op {
allowed_values {
list {
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -17559,6 +17902,7 @@ op {
allowed_values {
list {
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -20106,8 +20450,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_UINT8
@@ -20137,8 +20481,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
@@ -20168,8 +20512,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
@@ -23457,8 +23801,8 @@ op {
type: DT_INT32
type: DT_INT64
type: DT_BOOL
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
@@ -23516,6 +23860,7 @@ op {
allowed_values {
list {
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -23580,8 +23925,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
@@ -23593,6 +23938,47 @@ op {
}
}
op {
+ name: "Rpc"
+ input_arg {
+ name: "address"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "method"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "request"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "response"
+ type: DT_STRING
+ }
+ attr {
+ name: "protocol"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "fail_fast"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ attr {
+ name: "timeout_in_ms"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
+ is_stateful: true
+}
+op {
name: "Rsqrt"
input_arg {
name: "x"
@@ -23607,8 +23993,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
@@ -23636,8 +24022,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
@@ -25487,8 +25873,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
@@ -25516,8 +25902,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
@@ -25541,8 +25927,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
@@ -25568,8 +25954,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
@@ -25593,8 +25979,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
@@ -28988,8 +29374,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
@@ -29017,8 +29403,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
@@ -29042,8 +29428,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
@@ -29073,8 +29459,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
@@ -30022,8 +30408,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_UINT8
@@ -30407,8 +30793,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
@@ -30434,8 +30820,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
@@ -30463,8 +30849,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_COMPLEX64
@@ -32085,8 +32471,8 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_UINT8
@@ -32123,6 +32509,7 @@ op {
type: DT_INT32
type: DT_INT64
type: DT_BFLOAT16
+ type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -32178,6 +32565,55 @@ op {
is_stateful: true
}
op {
+ name: "TryRpc"
+ input_arg {
+ name: "address"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "method"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "request"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "response"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "status_code"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "status_message"
+ type: DT_STRING
+ }
+ attr {
+ name: "protocol"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "fail_fast"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ attr {
+ name: "timeout_in_ms"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
+ is_stateful: true
+}
+op {
name: "Unbatch"
input_arg {
name: "batched_tensor"
diff --git a/tensorflow/core/ops/rpc_ops.cc b/tensorflow/core/ops/rpc_ops.cc
new file mode 100644
index 0000000000..72fda5e6eb
--- /dev/null
+++ b/tensorflow/core/ops/rpc_ops.cc
@@ -0,0 +1,81 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+using tensorflow::shape_inference::DimensionHandle;
+using tensorflow::shape_inference::InferenceContext;
+using tensorflow::shape_inference::ShapeHandle;
+
+Status RpcShapeOp(InferenceContext* c, bool try_rpc) {
+ ShapeHandle address;
+ ShapeHandle method;
+ ShapeHandle request;
+ ShapeHandle output;
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &address));
+ if (c->Rank(address) == 1) {
+ TF_RETURN_IF_ERROR(c->Merge(output, address, &output));
+ }
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &method));
+ if (c->Rank(method) == 1) {
+ TF_RETURN_IF_ERROR(c->Merge(output, method, &output));
+ }
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &request));
+ if (c->Rank(request) == 1) {
+ TF_RETURN_IF_ERROR(c->Merge(output, request, &output));
+ }
+ if (!c->RankKnown(output)) {
+ output = request;
+ }
+ c->set_output(0, output); // response
+ if (try_rpc) {
+ c->set_output(1, output); // status_code
+ c->set_output(2, output); // status_message
+ }
+ return Status::OK();
+}
+
+REGISTER_OP("Rpc")
+ .Input("address: string")
+ .Input("method: string")
+ .Input("request: string")
+ .Attr("protocol: string = ''")
+ .Attr("fail_fast: bool = true")
+ .Attr("timeout_in_ms: int = 0")
+ .Output("response: string")
+ .SetIsStateful()
+ .SetShapeFn([](InferenceContext* c) {
+ return RpcShapeOp(c, /*try_rpc=*/false);
+ });
+
+REGISTER_OP("TryRpc")
+ .Input("address: string")
+ .Input("method: string")
+ .Input("request: string")
+ .Attr("protocol: string = ''")
+ .Attr("fail_fast: bool = true")
+ .Attr("timeout_in_ms: int = 0")
+ .Output("response: string")
+ .Output("status_code: int32")
+ .Output("status_message: string")
+ .SetIsStateful()
+ .SetShapeFn([](InferenceContext* c) {
+ return RpcShapeOp(c, /*try_rpc=*/true);
+ });
+
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/default/build_config/BUILD b/tensorflow/core/platform/default/build_config/BUILD
index 447056eb4b..44a89c3a96 100644
--- a/tensorflow/core/platform/default/build_config/BUILD
+++ b/tensorflow/core/platform/default/build_config/BUILD
@@ -114,6 +114,12 @@ cc_library(
)
cc_library(
+ name = "base",
+ srcs = [],
+ copts = tf_copts(),
+)
+
+cc_library(
name = "platformlib",
copts = tf_copts(),
deps = [
@@ -166,6 +172,13 @@ cc_library(
)
cc_library(
+ name = "test_lite_main",
+ testonly = 1,
+ linkstatic = 1,
+ deps = [],
+)
+
+cc_library(
name = "test_main",
testonly = 1,
linkstatic = 1,
diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc
index ee423699b2..6da679dc75 100644
--- a/tensorflow/core/platform/s3/s3_file_system.cc
+++ b/tensorflow/core/platform/s3/s3_file_system.cc
@@ -156,7 +156,7 @@ Status ParseS3Path(const string& fname, bool empty_object_ok, string* bucket,
return errors::InvalidArgument("S3 path doesn't contain a bucket name: ",
fname);
}
- objectp.Consume("/");
+ str_util::ConsumePrefix(&objectp, "/");
*object = objectp.ToString();
if (!empty_object_ok && object->empty()) {
return errors::InvalidArgument("S3 path doesn't contain an object name: ",
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index a3557e4721..c1a0075b64 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -409,6 +409,17 @@ message RunMetadata {
repeated GraphDef partition_graphs = 3;
}
+// Defines a connection between two tensors in a `GraphDef`.
+message TensorConnection {
+ // A tensor name. The value of this tensor will be substituted for
+ // the tensor named in `to_tensor`.
+ string from_tensor = 1;
+
+ // A tensor name. The value of this tensor will be bound to the
+ // value of the tensor named in `from_tensor`.
+ string to_tensor = 2;
+}
+
// Defines a subgraph in another `GraphDef` as a set of feed points and nodes
// to be fetched or executed.
//
@@ -429,5 +440,10 @@ message CallableOptions {
// Options that will be applied to each run.
RunOptions run_options = 4;
- // Next: 5
+ // Tensors to be connected in the callable. Each TensorConnection denotes
+ // a pair of tensors in the graph, between which an edge will be created
+ // in the callable.
+ repeated TensorConnection tensor_connection = 5;
+
+ // Next: 6
}
diff --git a/tensorflow/core/protobuf/master.proto b/tensorflow/core/protobuf/master.proto
index 0437cb1b83..96c91536f7 100644
--- a/tensorflow/core/protobuf/master.proto
+++ b/tensorflow/core/protobuf/master.proto
@@ -23,6 +23,7 @@ option java_package = "org.tensorflow.distruntime";
import "tensorflow/core/framework/device_attributes.proto";
import "tensorflow/core/framework/graph.proto";
+import "tensorflow/core/framework/tensor.proto";
import "tensorflow/core/lib/core/error_codes.proto";
import "tensorflow/core/protobuf/config.proto";
import "tensorflow/core/protobuf/named_tensor.proto";
@@ -264,3 +265,70 @@ message ListDevicesResponse {
repeated DeviceAttributes local_device = 1;
repeated DeviceAttributes remote_device = 2;
}
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// MakeCallable method request/response protos.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message MakeCallableRequest {
+ // REQUIRED: session_handle must be returned by a CreateSession call
+ // to the same master service.
+ string session_handle = 1;
+
+ // Options that define the behavior of the created callable.
+ CallableOptions options = 2;
+}
+
+message MakeCallableResponse {
+ // A handle to the created callable.
+ int64 handle = 1;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// RunCallable method request/response protos.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message RunCallableRequest {
+ // REQUIRED: session_handle must be returned by a CreateSession call
+ // to the same master service.
+ string session_handle = 1;
+ // REQUIRED: handle must be returned by a MakeCallable call to the same
+ // master service.
+ int64 handle = 2;
+
+ // Values of the tensors passed as arguments to the callable, in the order
+ // defined in the CallableOptions.feed field passed to MakeCallable.
+ repeated TensorProto feed = 3;
+}
+
+message RunCallableResponse {
+ // Values of the tensors returned by the callable, in the order defined in the
+ // CallableOptions.fetch field passed to MakeCallable.
+ repeated TensorProto fetch = 1;
+
+ // Returned metadata if requested in the options.
+ RunMetadata metadata = 2;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// ReleaseCallable method request/response protos.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message ReleaseCallableRequest {
+ // REQUIRED: session_handle must be returned by a CreateSession call
+ // to the same master service.
+ string session_handle = 1;
+
+ // REQUIRED: handle must be returned by a MakeCallable call to the same
+ // master service.
+ int64 handle = 2;
+}
+
+message ReleaseCallableResponse {
+}
diff --git a/tensorflow/core/protobuf/master_service.proto b/tensorflow/core/protobuf/master_service.proto
index 771c80562a..1170611f37 100644
--- a/tensorflow/core/protobuf/master_service.proto
+++ b/tensorflow/core/protobuf/master_service.proto
@@ -107,4 +107,13 @@ service MasterService {
// will no longer affect fresh ones via the resources in containers listed in
// the ResetRequest. See ResetRequest for more details.
rpc Reset(ResetRequest) returns (ResetResponse);
+
+ // Registers a callable for execution with RunCallable.
+ rpc MakeCallable(MakeCallableRequest) returns (MakeCallableResponse);
+
+ // Executes a callable registered with MakeCallable.
+ rpc RunCallable(RunCallableRequest) returns (RunCallableResponse);
+
+ // Frees resources associated with a callable registered with MakeCallable.
+ rpc ReleaseCallable(ReleaseCallableRequest) returns (ReleaseCallableResponse);
}
diff --git a/tensorflow/core/util/proto/BUILD b/tensorflow/core/util/proto/BUILD
new file mode 100644
index 0000000000..ade14ed162
--- /dev/null
+++ b/tensorflow/core/util/proto/BUILD
@@ -0,0 +1,62 @@
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+
+cc_library(
+ name = "decode",
+ hdrs = ["decode.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "descriptors",
+ srcs = ["descriptors.cc"],
+ hdrs = ["descriptors.h"],
+ deps = [
+ ":descriptor_pool_registry",
+ ":local_descriptor_pool_registration",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "descriptor_pool_registry",
+ srcs = ["descriptor_pool_registry.cc"],
+ hdrs = ["descriptor_pool_registry.h"],
+ deps = [
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_cc_test(
+ name = "descriptor_pool_registry_test",
+ srcs = ["descriptor_pool_registry_test.cc"],
+ deps = [
+ ":descriptor_pool_registry",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+# Depending on this target adds support for using the special
+# value "local://" (or "") for descriptor source, in which case
+# descriptors linked into the code will be searched.
+cc_library(
+ name = "local_descriptor_pool_registration",
+ srcs = ["local_descriptor_pool_registration.cc"],
+ deps = [
+ ":descriptor_pool_registry",
+ "//tensorflow/core:lib",
+ ],
+ alwayslink = 1,
+)
diff --git a/tensorflow/core/util/proto/decode.h b/tensorflow/core/util/proto/decode.h
new file mode 100644
index 0000000000..74634a356a
--- /dev/null
+++ b/tensorflow/core/util/proto/decode.h
@@ -0,0 +1,592 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Inline functions for parsing the protocol buffers wire format.
+//
+// These functions have been optimized at the expense of safety.
+// They are broken out into a separate file for readability but are
+// not intended for use by clients other than the decode_proto op.
+//
+// The calling code in the decode_proto op does some fairly
+// complicated things to ensure that this code is called
+// safely. Changes to this code should be thoroughly fuzz tested.
+
+#ifndef TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
+#define TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace internal {
+
+using tensorflow::protobuf::internal::WireFormatLite;
+using tensorflow::protobuf::io::CodedInputStream;
+using tensorflow::protobuf::io::CodedOutputStream;
+using tensorflow::protobuf::io::StringOutputStream;
+
+// Converts an uint64 to an int64 without loss of information.
+// Unsigned values greater than INT64_MAX are represented as
+// negative numbers by wrapping (same as twos-complement bit equivalence).
+inline int64 WrapUnsignedAsSigned64(uint64 unsigned_value) {
+ // For a detailed explanation of why this works to wrap unsigned ints, see
+ // http://stackoverflow.com/questions/13150449/efficient-unsigned-to-signed-cast-avoiding-implementation-defined-behavior
+ // Both if tests should be optimized out.
+ if (unsigned_value <= INT64_MAX) {
+ return static_cast<int64>(unsigned_value);
+ }
+ // The C++ spec allows an architecture where this test is required.
+ if (unsigned_value >= INT64_MIN) {
+ return static_cast<int64>(unsigned_value - INT64_MIN) + INT64_MIN;
+ }
+ return 0; // This should never occur.
+}
+
+// Converts an uint32 to an int32 without loss of information.
+// Unsigned values greater than INT_MAX are represented as
+// negative numbers by wrapping (same as twos-complement bit equivalence).
+inline int32 WrapUnsignedAsSigned32(uint32 unsigned_value) {
+ // For a detailed explanation of why this works to wrap unsigned ints, see
+ // http://stackoverflow.com/questions/13150449/efficient-unsigned-to-signed-cast-avoiding-implementation-defined-behavior
+ // Both if tests should be optimized out.
+ if (unsigned_value <= INT_MAX) {
+ return static_cast<int32>(unsigned_value);
+ }
+ // The C++ spec allows an architecture where this test is required.
+ if (unsigned_value >= INT_MIN) {
+ return static_cast<int32>(unsigned_value - INT_MIN) + INT_MIN;
+ }
+ return 0; // This should never occur.
+}
+
+// Reads a single varint32 from a byte array.
+// It is the caller's responsibility to ensure that there is enough
+// space in the buffer.
+// The ok value will be set to false if the buffer does not contain
+// a valid varint.
+inline const uint8* ReadVarint64FromArray(const uint8* buffer, bool* ok,
+ uint64* value);
+
+// Reads a single varint32 from a byte array.
+// It is the caller's responsibility to ensure that there is enough
+// space in the buffer.
+// The ok value will be set to false if the buffer does not contain
+// a valid varint.
+// This is slightly less efficient than the private version in
+// coded_stream.cc but we duplicate less code by calling
+// the 64 bit version instead of copying the code.
+inline const uint8* ReadVarint32FromArray(const uint8* buffer, bool* ok,
+ uint32* value) {
+ uint64 tmp;
+ const uint8* buf = ReadVarint64FromArray(buffer, ok, &tmp);
+ *value = tmp & 0xffffffff;
+ return buf;
+}
+
+// Reads a single proto field value from a byte array into an array.
+// The array is part of a Tensor that was allocated by the caller
+// with type TensorType, while DeclaredType is the proto field type.
+template <class TensorType, enum WireFormatLite::FieldType DeclaredType>
+const uint8* ReadFromArray(const uint8* buf, TensorType* value);
+
+template <>
+inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_INT32>(
+ const uint8* buf, int32* value) {
+ uint32 temp;
+ bool unused_ok; // The Counting pass would have failed if this were corrupt.
+ buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
+ *value = static_cast<int32>(temp);
+ return buf;
+}
+
+template <>
+inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_INT64>(
+ const uint8* buf, int64* value) {
+ uint64 temp;
+ bool unused_ok; // The Counting pass would have failed if this were corrupt.
+ buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
+ *value = WrapUnsignedAsSigned64(temp);
+ return buf;
+}
+
+template <>
+inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_UINT32>(
+ const uint8* buf, int64* value) {
+ uint32 temp;
+ bool unused_ok; // The Counting pass would have failed if this were corrupt.
+ buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
+ *value = temp;
+ return buf;
+}
+
+template <>
+inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_UINT32>(
+ const uint8* buf, int32* value) {
+ uint32 temp;
+ bool unused_ok; // The Counting pass would have failed if this were corrupt.
+ buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
+ *value = WrapUnsignedAsSigned32(temp);
+ return buf;
+}
+
+template <>
+inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_UINT64>(
+ const uint8* buf, int64* value) {
+ uint64 temp;
+ bool unused_ok; // The Counting pass would have failed if this were corrupt.
+ buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
+ *value = static_cast<int64>(temp);
+ return buf;
+}
+
+template <>
+inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_SINT32>(
+ const uint8* buf, int32* value) {
+ uint32 temp;
+ bool unused_ok; // The Counting pass would have failed if this were corrupt.
+ buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
+ *value = WireFormatLite::ZigZagDecode32(temp);
+ return buf;
+}
+
+template <>
+inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SINT64>(
+ const uint8* buf, int64* value) {
+ uint64 temp;
+ bool unused_ok; // The Counting pass would have failed if this were corrupt.
+ buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
+ *value = WireFormatLite::ZigZagDecode64(temp);
+ return buf;
+}
+
+template <>
+inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_FIXED32>(
+ const uint8* buf, int64* value) {
+ uint32 temp;
+ buf = WireFormatLite::ReadPrimitiveFromArray<uint32,
+ WireFormatLite::TYPE_FIXED32>(
+ buf, &temp);
+ *value = temp;
+ return buf;
+}
+
+template <>
+inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_FIXED32>(
+ const uint8* buf, int32* value) {
+ uint32 temp;
+ buf = WireFormatLite::ReadPrimitiveFromArray<uint32,
+ WireFormatLite::TYPE_FIXED32>(
+ buf, &temp);
+ *value = WrapUnsignedAsSigned32(temp);
+ return buf;
+}
+
+template <>
+inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_FIXED64>(
+ const uint8* buf, int64* value) {
+ protobuf_uint64 temp;
+ buf = WireFormatLite::ReadPrimitiveFromArray<protobuf_uint64,
+ WireFormatLite::TYPE_FIXED64>(
+ buf, &temp);
+ *value = WrapUnsignedAsSigned64(temp);
+ return buf;
+}
+
+template <>
+inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_SFIXED32>(
+ const uint8* buf, int32* value) {
+ return WireFormatLite::ReadPrimitiveFromArray<int32,
+ WireFormatLite::TYPE_SFIXED32>(
+ buf, value);
+}
+
+template <>
+inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SFIXED64>(
+ const uint8* buf, int64* value) {
+ protobuf_int64 temp;
+ buf = WireFormatLite::ReadPrimitiveFromArray<protobuf_int64,
+ WireFormatLite::TYPE_SFIXED64>(
+ buf, &temp);
+ *value = temp;
+ return buf;
+}
+
+template <>
+inline const uint8* ReadFromArray<float, WireFormatLite::TYPE_FLOAT>(
+ const uint8* buf, float* value) {
+ return WireFormatLite::ReadPrimitiveFromArray<float,
+ WireFormatLite::TYPE_FLOAT>(
+ buf, value);
+}
+
+template <>
+inline const uint8* ReadFromArray<double, WireFormatLite::TYPE_DOUBLE>(
+ const uint8* buf, double* value) {
+ return WireFormatLite::ReadPrimitiveFromArray<double,
+ WireFormatLite::TYPE_DOUBLE>(
+ buf, value);
+}
+
+template <>
+inline const uint8* ReadFromArray<bool, WireFormatLite::TYPE_BOOL>(
+ const uint8* buf, bool* value) {
+ uint64 temp;
+ bool unused_ok; // The Counting pass would have failed if this were corrupt.
+ buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
+ *value = temp != 0;
+ return buf;
+}
+
+template <>
+inline const uint8* ReadFromArray<int, WireFormatLite::TYPE_ENUM>(
+ const uint8* buf, int* value) {
+ uint32 temp;
+ bool unused_ok; // The Counting pass would have failed if this were corrupt.
+ buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
+ *value = static_cast<int>(temp);
+ return buf;
+}
+
+// Reads packed values from an array.
+// Stride is set to 1 for repeated fields, and 0 for non-repeated fields
+// (where any value overwrites previous values).
+template <class TensorType, enum WireFormatLite::FieldType DeclaredType>
+inline int ReadPackedPrimitives(const void* bufp, const size_t len,
+ const int index, const int stride,
+ void* datap) {
+ const uint8* buf = reinterpret_cast<const uint8*>(bufp);
+ const uint8* bound = buf + len;
+ TensorType* data = reinterpret_cast<TensorType*>(datap) + index;
+ int count;
+
+ // This could overrun the bound by stride-1. This is defended
+ // against in the caller, where it ensures that the input buffer
+ // contains complete values.
+ for (count = 0; buf < bound; count += stride) {
+ buf = ReadFromArray<TensorType, DeclaredType>(buf, data + count);
+ }
+ return count;
+}
+
+// Reads a primitive value field from a serialized proto.
+// The value is parsed from the serialized format, then static_cast
+// to the desired type for TensorFlow and stored.
+template <class ValueType, class TensorType,
+ enum WireFormatLite::FieldType DeclaredType>
+inline Status ReadPrimitive(CodedInputStream* input, int index, void* data) {
+ ValueType v;
+ if (!WireFormatLite::ReadPrimitive<ValueType, DeclaredType>(input, &v)) {
+ return errors::DataLoss("Failed reading primitive");
+ }
+
+ reinterpret_cast<TensorType*>(data)[index] = v;
+ return Status::OK();
+}
+
+// Reads a string, submessage, or other variable-length field from a
+// serialized proto.
+// May read all or part of a repeated field.
+inline Status ReadBytes(CodedInputStream* input, int index, void* datap) {
+ string* data = reinterpret_cast<string*>(datap) + index;
+ if (!WireFormatLite::ReadBytes(input, data)) {
+ return errors::DataLoss("Failed reading bytes");
+ }
+ return Status::OK();
+}
+
+// Reads a tag-delimited field (TYPE_GROUP) from a serialized proto,
+// as a bytestring.
+inline Status ReadGroupBytes(CodedInputStream* input, int field_number,
+ int index, void* datap) {
+ // WireFormatLite::SkipField has an option to emit the
+ // skipped bytes to an output stream. We could do better by implementing our
+ // own scanner but this is simpler for now.
+ // TODO(nix): there is a faster way to grab TYPE_GROUP bytes by relying
+ // on input->IsFlat() == true and using input->GetDirectBufferPointer()
+ // with input->CurrentPosition().
+ string* data = reinterpret_cast<string*>(datap) + index;
+ StringOutputStream string_stream(data);
+ CodedOutputStream out(&string_stream);
+ if (!WireFormatLite::SkipField(
+ input,
+ WireFormatLite::MakeTag(field_number,
+ WireFormatLite::WIRETYPE_START_GROUP),
+ &out)) {
+ return errors::DataLoss("Failed reading group");
+ }
+ return Status::OK();
+}
+
+// Reads a single field value from a CodedInputStream into a tensor.
+inline Status ReadValue(CodedInputStream* input,
+ WireFormatLite::FieldType field_type, int field_number,
+ DataType dtype, int index, void* datap) {
+ // Dispatch to the appropriately typed field reader based on the
+ // schema type.
+ switch (field_type) {
+ case WireFormatLite::TYPE_DOUBLE:
+ return ReadPrimitive<double, double, WireFormatLite::TYPE_DOUBLE>(
+ input, index, datap);
+ case WireFormatLite::TYPE_FLOAT:
+ if (dtype == DataType::DT_FLOAT) {
+ return ReadPrimitive<float, float, WireFormatLite::TYPE_FLOAT>(
+ input, index, datap);
+ }
+ if (dtype == DataType::DT_DOUBLE) {
+ return ReadPrimitive<float, double, WireFormatLite::TYPE_FLOAT>(
+ input, index, datap);
+ }
+ // Any case that reaches this point should have triggered an error
+ // already.
+ return errors::DataLoss("Failed reading TYPE_FLOAT");
+ case WireFormatLite::TYPE_INT64:
+ return ReadPrimitive<protobuf_int64, int64, WireFormatLite::TYPE_INT64>(
+ input, index, datap);
+ case WireFormatLite::TYPE_UINT64:
+ return ReadPrimitive<protobuf_uint64, int64, WireFormatLite::TYPE_UINT64>(
+ input, index, datap);
+ case WireFormatLite::TYPE_INT32:
+ return ReadPrimitive<int32, int32, WireFormatLite::TYPE_INT32>(
+ input, index, datap);
+ case WireFormatLite::TYPE_FIXED64:
+ return ReadPrimitive<protobuf_uint64, int64,
+ WireFormatLite::TYPE_FIXED64>(input, index, datap);
+ case WireFormatLite::TYPE_FIXED32:
+ if (dtype == DataType::DT_INT64) {
+ return ReadPrimitive<uint32, int64, WireFormatLite::TYPE_FIXED32>(
+ input, index, datap);
+ }
+ if (dtype == DataType::DT_INT32) {
+ return ReadPrimitive<uint32, int32, WireFormatLite::TYPE_FIXED32>(
+ input, index, datap);
+ }
+ // Any case that reaches this point should have triggered an error
+ // already.
+ return errors::DataLoss("Failed reading TYPE_FIXED32");
+ case WireFormatLite::TYPE_BOOL:
+ return ReadPrimitive<bool, bool, WireFormatLite::TYPE_BOOL>(input, index,
+ datap);
+ case WireFormatLite::TYPE_STRING:
+ return ReadBytes(input, index, datap);
+ case WireFormatLite::TYPE_GROUP:
+ return ReadGroupBytes(input, field_number, index, datap);
+ case WireFormatLite::TYPE_MESSAGE:
+ return ReadBytes(input, index, datap);
+ case WireFormatLite::TYPE_BYTES:
+ return ReadBytes(input, index, datap);
+ case WireFormatLite::TYPE_UINT32:
+ if (dtype == DataType::DT_INT64) {
+ return ReadPrimitive<uint32, int64, WireFormatLite::TYPE_UINT32>(
+ input, index, datap);
+ }
+ if (dtype == DataType::DT_INT32) {
+ return ReadPrimitive<uint32, int32, WireFormatLite::TYPE_UINT32>(
+ input, index, datap);
+ }
+ // Any case that reaches this point should have triggered an error
+ // already.
+ return errors::DataLoss("Failed reading TYPE_UINT32");
+ case WireFormatLite::TYPE_ENUM:
+ return ReadPrimitive<int32, int32, WireFormatLite::TYPE_ENUM>(
+ input, index, datap);
+ case WireFormatLite::TYPE_SFIXED32:
+ return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SFIXED32>(
+ input, index, datap);
+ case WireFormatLite::TYPE_SFIXED64:
+ return ReadPrimitive<protobuf_int64, int64,
+ WireFormatLite::TYPE_SFIXED64>(input, index, datap);
+ case WireFormatLite::TYPE_SINT32:
+ return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SINT32>(
+ input, index, datap);
+ case WireFormatLite::TYPE_SINT64:
+ return ReadPrimitive<protobuf_int64, int64, WireFormatLite::TYPE_SINT64>(
+ input, index, datap);
+ // default: intentionally omitted in order to enable static checking.
+ }
+ // Unreachable.
+ return errors::DataLoss("Failed reading unknown wire type");
+}
+
+// Reads and stores a length-delimited list of values.
+inline Status ReadPackedFromArray(const void* buf, size_t buf_size,
+ const WireFormatLite::FieldType field_type,
+ const int field_number, const DataType dtype,
+ const int stride, int* index, void* data) {
+ // Dispatch to the appropriately typed field reader based on the
+ // schema type.
+ switch (field_type) {
+ case WireFormatLite::TYPE_DOUBLE:
+ *index += ReadPackedPrimitives<double, WireFormatLite::TYPE_DOUBLE>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case WireFormatLite::TYPE_FLOAT:
+ *index += ReadPackedPrimitives<float, WireFormatLite::TYPE_FLOAT>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case WireFormatLite::TYPE_INT64:
+ *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_INT64>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case WireFormatLite::TYPE_UINT64:
+ *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_UINT64>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case WireFormatLite::TYPE_INT32:
+ *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_INT32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case WireFormatLite::TYPE_FIXED64:
+ *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_FIXED64>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case WireFormatLite::TYPE_FIXED32:
+ if (dtype == DataType::DT_INT64) {
+ *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_FIXED32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ }
+ if (dtype == DataType::DT_INT32) {
+ *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_FIXED32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ }
+ // Any case that reaches this point should have triggered an error
+ // already.
+ return errors::DataLoss("Failed reading TYPE_FIXED32");
+ case WireFormatLite::TYPE_BOOL:
+ *index += ReadPackedPrimitives<bool, WireFormatLite::TYPE_BOOL>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case WireFormatLite::TYPE_STRING:
+ case WireFormatLite::TYPE_GROUP:
+ case WireFormatLite::TYPE_MESSAGE:
+ case WireFormatLite::TYPE_BYTES:
+ return errors::DataLoss("Non-primitive type encountered as packed");
+ case WireFormatLite::TYPE_UINT32:
+ if (dtype == DataType::DT_INT64) {
+ *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_UINT32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ }
+ if (dtype == DataType::DT_INT32) {
+ *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_UINT32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ }
+ // Any case that reaches this point should have triggered an error
+ // already.
+ return errors::DataLoss("Failed reading TYPE_UINT32");
+ case WireFormatLite::TYPE_ENUM:
+ *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_ENUM>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case WireFormatLite::TYPE_SFIXED32:
+ *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SFIXED32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+
+ case WireFormatLite::TYPE_SFIXED64:
+ *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SFIXED64>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+
+ case WireFormatLite::TYPE_SINT32:
+ *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SINT32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+
+ case WireFormatLite::TYPE_SINT64:
+ *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SINT64>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ // default: intentionally omitted in order to enable static checking.
+ }
+ // Unreachable.
+ return errors::DataLoss("Failed reading unknown wire type");
+}
+
+// Reads a varint from the given buffer, write it to *value, and return the
+// new buffer pointer.
+// This was copied from coded_stream.cc where it is private.
+// Important: This routine may read as much as kMaxVarintBytes from
+// the buffer. It is the caller's responsibility to make sure that there is
+// enough space in the buffer.
+inline const uint8* ReadVarint64FromArray(const uint8* buffer, bool* ok,
+ uint64* value) {
+ const uint8* ptr = buffer;
+ uint32 b;
+
+ // Splitting into 32-bit pieces gives better performance on 32-bit
+ // processors.
+ uint32 part0 = 0, part1 = 0, part2 = 0;
+
+ b = *(ptr++);
+ part0 = b;
+ if (!(b & 0x80)) goto done;
+ part0 -= 0x80;
+ b = *(ptr++);
+ part0 += b << 7;
+ if (!(b & 0x80)) goto done;
+ part0 -= 0x80 << 7;
+ b = *(ptr++);
+ part0 += b << 14;
+ if (!(b & 0x80)) goto done;
+ part0 -= 0x80 << 14;
+ b = *(ptr++);
+ part0 += b << 21;
+ if (!(b & 0x80)) goto done;
+ part0 -= 0x80 << 21;
+ b = *(ptr++);
+ part1 = b;
+ if (!(b & 0x80)) goto done;
+ part1 -= 0x80;
+ b = *(ptr++);
+ part1 += b << 7;
+ if (!(b & 0x80)) goto done;
+ part1 -= 0x80 << 7;
+ b = *(ptr++);
+ part1 += b << 14;
+ if (!(b & 0x80)) goto done;
+ part1 -= 0x80 << 14;
+ b = *(ptr++);
+ part1 += b << 21;
+ if (!(b & 0x80)) goto done;
+ part1 -= 0x80 << 21;
+ b = *(ptr++);
+ part2 = b;
+ if (!(b & 0x80)) goto done;
+ part2 -= 0x80;
+ b = *(ptr++);
+ part2 += b << 7;
+ if (!(b & 0x80)) goto done;
+ // "part2 -= 0x80 << 7" is irrelevant because (0x80 << 7) << 56 is 0.
+
+ // We have overrun the maximum size of a varint (10 bytes). Assume
+ // the data is corrupt.
+ *ok = false;
+ return ptr;
+
+done:
+ *ok = true;
+ *value = (static_cast<uint64>(part0)) | (static_cast<uint64>(part1) << 28) |
+ (static_cast<uint64>(part2) << 56);
+ return ptr;
+}
+
+} // namespace internal
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
diff --git a/tensorflow/core/util/proto/descriptor_pool_registry.cc b/tensorflow/core/util/proto/descriptor_pool_registry.cc
new file mode 100644
index 0000000000..5f0423f76b
--- /dev/null
+++ b/tensorflow/core/util/proto/descriptor_pool_registry.cc
@@ -0,0 +1,45 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+
+#include "tensorflow/core/platform/logging.h"
+
+#include "tensorflow/core/util/proto/descriptor_pool_registry.h"
+
+namespace tensorflow {
+
+DescriptorPoolRegistry* DescriptorPoolRegistry::Global() {
+ static DescriptorPoolRegistry* registry = new DescriptorPoolRegistry;
+ return registry;
+}
+
+DescriptorPoolRegistry::DescriptorPoolFn* DescriptorPoolRegistry::Get(
+ const string& source) {
+ auto found = fns_.find(source);
+ if (found == fns_.end()) return nullptr;
+ return &found->second;
+}
+
+void DescriptorPoolRegistry::Register(
+ const string& source,
+ const DescriptorPoolRegistry::DescriptorPoolFn& pool_fn) {
+ auto existing = Get(source);
+ CHECK_EQ(existing, nullptr)
+ << "descriptor pool for source: " << source << " already registered";
+ fns_.insert(std::pair<const string&, DescriptorPoolFn>(source, pool_fn));
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/proto/descriptor_pool_registry.h b/tensorflow/core/util/proto/descriptor_pool_registry.h
new file mode 100644
index 0000000000..66c20e9e41
--- /dev/null
+++ b/tensorflow/core/util/proto/descriptor_pool_registry.h
@@ -0,0 +1,76 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTOR_POOL_REGISTRY_H_
+#define TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTOR_POOL_REGISTRY_H_
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
+
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+class DescriptorPoolRegistry {
+ public:
+ typedef std::function<Status(
+ tensorflow::protobuf::DescriptorPool const** desc_pool,
+ std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool)>
+ DescriptorPoolFn;
+
+ // Returns a pointer to a global DescriptorPoolRegistry object.
+ static DescriptorPoolRegistry* Global();
+
+ // Returns a pointer to a descriptor pool function for the given source.
+ DescriptorPoolFn* Get(const string& source);
+
+ // Registers a descriptor pool factory.
+ void Register(const string& source, const DescriptorPoolFn& pool_fn);
+
+ private:
+ std::map<string, DescriptorPoolFn> fns_;
+};
+
+namespace descriptor_pool_registration {
+
+class DescriptorPoolRegistration {
+ public:
+ DescriptorPoolRegistration(
+ const string& source,
+ const DescriptorPoolRegistry::DescriptorPoolFn& pool_fn) {
+ DescriptorPoolRegistry::Global()->Register(source, pool_fn);
+ }
+};
+
+} // namespace descriptor_pool_registration
+
+#define REGISTER_DESCRIPTOR_POOL(source, pool_fn) \
+ REGISTER_DESCRIPTOR_POOL_UNIQ_HELPER(__COUNTER__, source, pool_fn)
+
+#define REGISTER_DESCRIPTOR_POOL_UNIQ_HELPER(ctr, source, pool_fn) \
+ REGISTER_DESCRIPTOR_POOL_UNIQ(ctr, source, pool_fn)
+
+#define REGISTER_DESCRIPTOR_POOL_UNIQ(ctr, source, pool_fn) \
+ static descriptor_pool_registration::DescriptorPoolRegistration \
+ descriptor_pool_registration_fn_##ctr(source, pool_fn)
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTOR_POOL_REGISTRY_H_
diff --git a/tensorflow/core/util/proto/descriptor_pool_registry_test.cc b/tensorflow/core/util/proto/descriptor_pool_registry_test.cc
new file mode 100644
index 0000000000..a6899998ab
--- /dev/null
+++ b/tensorflow/core/util/proto/descriptor_pool_registry_test.cc
@@ -0,0 +1,43 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/util/proto/descriptor_pool_registry.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+struct Value {
+ static Status Function(
+ tensorflow::protobuf::DescriptorPool const** desc_pool,
+ std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) {
+ return Status::OK();
+ }
+};
+
+REGISTER_DESCRIPTOR_POOL("TEST POOL 1", Value::Function);
+REGISTER_DESCRIPTOR_POOL("TEST POOL 2", Value::Function);
+} // namespace
+
+TEST(DescriptorPoolRegistryTest, TestBasic) {
+ EXPECT_EQ(DescriptorPoolRegistry::Global()->Get("NON-EXISTENT"), nullptr);
+ auto pool1 = DescriptorPoolRegistry::Global()->Get("TEST POOL 1");
+ EXPECT_NE(pool1, nullptr);
+ auto pool2 = DescriptorPoolRegistry::Global()->Get("TEST POOL 2");
+ EXPECT_NE(pool2, nullptr);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/proto/descriptors.cc b/tensorflow/core/util/proto/descriptors.cc
new file mode 100644
index 0000000000..271c85efd8
--- /dev/null
+++ b/tensorflow/core/util/proto/descriptors.cc
@@ -0,0 +1,85 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/reader_op_kernel.h"
+#include "tensorflow/core/util/proto/descriptor_pool_registry.h"
+
+#include "tensorflow/core/util/proto/descriptors.h"
+
+namespace tensorflow {
+namespace {
+
+// Build a `DescriptorPool` from the named file or URI. The file or URI
+// must be available to the current TensorFlow environment.
+//
+// The file must contiain a serialized `FileDescriptorSet`. See
+// `GetDescriptorPool()` for more information.
+Status GetDescriptorPoolFromFile(
+ tensorflow::Env* env, const string& filename,
+ std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) {
+ Status st = env->FileExists(filename);
+ if (!st.ok()) {
+ return st;
+ }
+
+ // Read and parse the FileDescriptorSet.
+ tensorflow::protobuf::FileDescriptorSet descs;
+ std::unique_ptr<tensorflow::ReadOnlyMemoryRegion> buf;
+ st = env->NewReadOnlyMemoryRegionFromFile(filename, &buf);
+ if (!st.ok()) {
+ return st;
+ }
+ if (!descs.ParseFromArray(buf->data(), buf->length())) {
+ return errors::InvalidArgument(
+ "descriptor_source contains invalid FileDescriptorSet: ", filename);
+ }
+
+ // Build a DescriptorPool from the FileDescriptorSet.
+ owned_desc_pool->reset(new tensorflow::protobuf::DescriptorPool());
+ for (const auto& filedesc : descs.file()) {
+ if ((*owned_desc_pool)->BuildFile(filedesc) == nullptr) {
+ return errors::InvalidArgument(
+ "Problem loading FileDescriptorProto (missing dependencies?): ",
+ filename);
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Status GetDescriptorPool(
+ tensorflow::Env* env, string const& descriptor_source,
+ tensorflow::protobuf::DescriptorPool const** desc_pool,
+ std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) {
+ // Attempt to lookup the pool in the registry.
+ auto pool_fn = DescriptorPoolRegistry::Global()->Get(descriptor_source);
+ if (pool_fn != nullptr) {
+ return (*pool_fn)(desc_pool, owned_desc_pool);
+ }
+
+ // If there is no pool function registered for the given source, let the
+ // runtime find the file or URL.
+ Status status =
+ GetDescriptorPoolFromFile(env, descriptor_source, owned_desc_pool);
+ if (status.ok()) {
+ *desc_pool = owned_desc_pool->get();
+ }
+ *desc_pool = owned_desc_pool->get();
+ return status;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/proto/descriptors.h b/tensorflow/core/util/proto/descriptors.h
new file mode 100644
index 0000000000..92ee8997ab
--- /dev/null
+++ b/tensorflow/core/util/proto/descriptors.h
@@ -0,0 +1,42 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTORS_H_
+#define TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTORS_H_
+
+#include <memory>
+#include <string>
+
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+class Env;
+class Status;
+
+// Get a `DescriptorPool` object from the named `descriptor_source`.
+// `descriptor_source` may be a path to a file accessible to TensorFlow, in
+// which case it is parsed as a `FileDescriptorSet` and used to build the
+// `DescriptorPool`.
+//
+// `owned_desc_pool` will be filled in with the same pointer as `desc_pool` if
+// the caller should take ownership.
+extern tensorflow::Status GetDescriptorPool(
+ tensorflow::Env* env, string const& descriptor_source,
+ tensorflow::protobuf::DescriptorPool const** desc_pool,
+ std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTORS_H_
diff --git a/tensorflow/core/util/proto/local_descriptor_pool_registration.cc b/tensorflow/core/util/proto/local_descriptor_pool_registration.cc
new file mode 100644
index 0000000000..48fe0102d0
--- /dev/null
+++ b/tensorflow/core/util/proto/local_descriptor_pool_registration.cc
@@ -0,0 +1,39 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/util/proto/descriptor_pool_registry.h"
+
+namespace tensorflow {
+namespace {
+
+struct LocalDescriptorPool {
+ static Status Function(
+ tensorflow::protobuf::DescriptorPool const** desc_pool,
+ std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) {
+ *desc_pool = ::tensorflow::protobuf::DescriptorPool::generated_pool();
+ if (*desc_pool == nullptr) {
+ return errors::InvalidArgument("Problem loading protobuf generated_pool");
+ }
+ return Status::OK();
+ }
+};
+
+REGISTER_DESCRIPTOR_POOL("", LocalDescriptorPool::Function);
+REGISTER_DESCRIPTOR_POOL("local://", LocalDescriptorPool::Function);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/util/reporter.cc b/tensorflow/core/util/reporter.cc
index ee38f81f3e..a595c9509e 100644
--- a/tensorflow/core/util/reporter.cc
+++ b/tensorflow/core/util/reporter.cc
@@ -47,6 +47,18 @@ Status TestReporter::Benchmark(int64 iters, double cpu_time, double wall_time,
return Status::OK();
}
+Status TestReporter::SetProperty(const string& name, const string& value) {
+ if (closed_) return Status::OK();
+ (*benchmark_entry_.mutable_extras())[name].set_string_value(value);
+ return Status::OK();
+}
+
+Status TestReporter::SetProperty(const string& name, double value) {
+ if (closed_) return Status::OK();
+ (*benchmark_entry_.mutable_extras())[name].set_double_value(value);
+ return Status::OK();
+}
+
Status TestReporter::Initialize() {
if (fname_.empty()) {
return Status::OK();
diff --git a/tensorflow/core/util/reporter.h b/tensorflow/core/util/reporter.h
index bcae12204e..e551e2e4f5 100644
--- a/tensorflow/core/util/reporter.h
+++ b/tensorflow/core/util/reporter.h
@@ -34,11 +34,13 @@ namespace tensorflow {
//
// If this environment variable is not defined, no logging is performed.
//
-// The intended use is via the following 4 lines:
+// The intended use is via the following lines:
//
// TestReporter reporter(test_name);
// TF_CHECK_OK(reporter.Initialize()));
// TF_CHECK_OK(reporter.Benchmark(iters, cpu_time, wall_time, throughput));
+// TF_CHECK_OK(reporter.SetProperty("some_string_property", "some_value");
+// TF_CHECK_OK(reporter.SetProperty("some_double_property", double_value);
// TF_CHECK_OK(reporter.Close());
//
// For example, if the environment variable
@@ -75,6 +77,12 @@ class TestReporter {
Status Benchmark(int64 iters, double cpu_time, double wall_time,
double throughput);
+ // Set property on Benchmark to the given value.
+ Status SetProperty(const string& name, double value);
+
+ // Set property on Benchmark to the given value.
+ Status SetProperty(const string& name, const string& value);
+
// TODO(b/32704451): Don't just ignore the ::tensorflow::Status object!
~TestReporter() { Close().IgnoreError(); } // Autoclose in destructor.
diff --git a/tensorflow/core/util/reporter_test.cc b/tensorflow/core/util/reporter_test.cc
index 90ea09876e..0972b86ea5 100644
--- a/tensorflow/core/util/reporter_test.cc
+++ b/tensorflow/core/util/reporter_test.cc
@@ -115,5 +115,28 @@ TEST(TestReporter, Benchmark) {
EXPECT_EQ(benchmark_entry.throughput(), 3.0);
}
+TEST(TestReporter, SetProperties) {
+ string fname =
+ strings::StrCat(testing::TmpDir(), "/test_reporter_benchmarks_");
+ TestReporter test_reporter(fname, "b2/3/4");
+ TF_EXPECT_OK(test_reporter.Initialize());
+ TF_EXPECT_OK(test_reporter.SetProperty("string_prop", "abc"));
+ TF_EXPECT_OK(test_reporter.SetProperty("double_prop", 4.0));
+
+ TF_EXPECT_OK(test_reporter.Close());
+ string expected_fname = strings::StrCat(fname, "b2__3__4");
+ string read;
+ TF_EXPECT_OK(ReadFileToString(Env::Default(), expected_fname, &read));
+
+ BenchmarkEntries benchmark_entries;
+ ASSERT_TRUE(benchmark_entries.ParseFromString(read));
+ ASSERT_EQ(1, benchmark_entries.entry_size());
+ const BenchmarkEntry& benchmark_entry = benchmark_entries.entry(0);
+ const auto& extras = benchmark_entry.extras();
+ ASSERT_EQ(2, extras.size());
+ EXPECT_EQ("abc", extras.at("string_prop").string_value());
+ EXPECT_EQ(4.0, extras.at("double_prop").double_value());
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/util/rpc/BUILD b/tensorflow/core/util/rpc/BUILD
new file mode 100644
index 0000000000..f0f161ecc0
--- /dev/null
+++ b/tensorflow/core/util/rpc/BUILD
@@ -0,0 +1,48 @@
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+
+cc_library(
+ name = "call_container",
+ hdrs = ["call_container.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+cc_library(
+ name = "rpc_factory",
+ srcs = ["rpc_factory.cc"],
+ hdrs = ["rpc_factory.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "rpc_factory_registry",
+ srcs = ["rpc_factory_registry.cc"],
+ hdrs = ["rpc_factory_registry.h"],
+ deps = [
+ ":rpc_factory",
+ "//tensorflow/core:framework",
+ ],
+)
+
+tf_cc_test(
+ name = "rpc_factory_registry_test",
+ srcs = ["rpc_factory_registry_test.cc"],
+ deps = [
+ ":rpc_factory_registry",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
diff --git a/tensorflow/core/util/rpc/call_container.h b/tensorflow/core/util/rpc/call_container.h
new file mode 100644
index 0000000000..7f36056797
--- /dev/null
+++ b/tensorflow/core/util/rpc/call_container.h
@@ -0,0 +1,90 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
+#define TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
+
+#include <list>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/util/reffed_status_callback.h"
+
+namespace tensorflow {
+
+template <typename Call>
+class CallContainer {
+ public:
+ explicit CallContainer(OpKernelContext* ctx, int num_calls, bool fail_fast,
+ bool try_rpc, AsyncOpKernel::DoneCallback done,
+ CancellationToken token)
+ : ctx_(ctx),
+ done_(std::move(done)),
+ token_(token),
+ fail_fast_(fail_fast),
+ try_rpc_(try_rpc) {
+ CHECK_GT(num_calls, 0);
+
+ // This will run when all RPCs are finished.
+ reffed_status_callback_ = new ReffedStatusCallback([this](const Status& s) {
+ ctx_->cancellation_manager()->DeregisterCallback(token_);
+ ctx_->SetStatus(s);
+ done_();
+ delete this;
+ });
+
+ // Subtract reference count from the initial creation.
+ core::ScopedUnref unref(reffed_status_callback_);
+
+ for (int i = 0; i < num_calls; ++i) {
+ // Increase the reference on the callback for each new RPC.
+ reffed_status_callback_->Ref();
+ }
+ }
+
+ std::list<Call>* calls() { return &calls_; }
+
+ void StartCancel() {
+ // Once this loop is done, can no longer assume anything is valid
+ // because "delete this" may have been immediately called.
+ // Nothing should run after this loop.
+ for (auto& call : calls_) {
+ call.StartCancel();
+ }
+ }
+
+ void Done(const Status& s, int index) {
+ if (!try_rpc_) {
+ reffed_status_callback_->UpdateStatus(s);
+ }
+ reffed_status_callback_->Unref();
+ }
+
+ private:
+ OpKernelContext* ctx_;
+ std::list<Call> calls_;
+ const AsyncOpKernel::DoneCallback done_;
+ const CancellationToken token_;
+ const bool fail_fast_;
+ const bool try_rpc_;
+
+ // Performs its own reference counting.
+ ReffedStatusCallback* reffed_status_callback_;
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
diff --git a/tensorflow/core/util/rpc/rpc_factory.cc b/tensorflow/core/util/rpc/rpc_factory.cc
new file mode 100644
index 0000000000..8530f02b6e
--- /dev/null
+++ b/tensorflow/core/util/rpc/rpc_factory.cc
@@ -0,0 +1,53 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/lib/strings/numbers.h"
+
+#include "tensorflow/core/util/rpc/rpc_factory.h"
+
+namespace tensorflow {
+
+template <>
+bool GetEnvVar(const char* key, const string& default_value, string* value) {
+ const char* env_value = std::getenv(key);
+ if (!env_value || env_value[0] == '\0') {
+ *value = default_value;
+ } else {
+ *value = env_value;
+ }
+ return true;
+}
+
+template <>
+bool GetEnvVar(const char* key, const int64& default_value, int64* value) {
+ const char* env_value = std::getenv(key);
+ if (!env_value || env_value[0] == '\0') {
+ *value = default_value;
+ return true;
+ }
+ return strings::safe_strto64(env_value, value);
+}
+
+template <>
+bool GetEnvVar(const char* key, const uint64& default_value, uint64* value) {
+ const char* env_value = std::getenv(key);
+ if (!env_value || env_value[0] == '\0') {
+ *value = default_value;
+ return true;
+ }
+ return strings::safe_strtou64(env_value, value);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/rpc/rpc_factory.h b/tensorflow/core/util/rpc/rpc_factory.h
new file mode 100644
index 0000000000..9bf078c0f4
--- /dev/null
+++ b/tensorflow/core/util/rpc/rpc_factory.h
@@ -0,0 +1,70 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_
+#define TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+
+// Return the environment variable `key`. If the variable is not set,
+// use the default value. If it is set but could not be parsed,
+// return `false`. Otherwise set `value` and return `true`.
+template <typename T>
+bool GetEnvVar(const char* key, const T& default_value, T* value);
+
+class RPCFactory {
+ public:
+ RPCFactory() {}
+ virtual ~RPCFactory() {}
+
+ // Start a Call() to methods `method_t` at addresses `address_t` with
+ // request strings from `request_t`. Any of these may be scalar
+ // Tensors, in which case the operands are broadcasted.
+ // Upon completion of all requests, `response_t` will be populated.
+ //
+ // If `try_rpc` is `true`, then `status_message_t` and
+ // `status_code_t` will be populated as well.
+ //
+ // If `try_rpc` is `false`, then `status_message_t` and
+ // `status_code_t` are ignored (and may be nullptr). Instead, the
+ // status of any failed call will be propagated to the op.
+ //
+ // REQUIRES:
+ // - `response_t` is not null, and is a string Tensor with the same shape as
+ // `request_t`.
+ //
+ // If `try_rpc` is `true`:
+ // - `status_code_t` and `status_message_t` are not null.
+ // - `status_code_t` is an int32 Tensor with the same shape as
+ // `request_t`.
+ // - `status_message_t` is a string Tensor with the same shape as
+ // `request_t`.
+ virtual void Call(OpKernelContext* ctx, int64 num_elements,
+ const Tensor& address_t, const Tensor& method_t,
+ const Tensor& request_t, const bool try_rpc,
+ Tensor* response_t, Tensor* status_code_t,
+ Tensor* status_message_t,
+ AsyncOpKernel::DoneCallback done) = 0;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(RPCFactory);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_
diff --git a/tensorflow/core/util/rpc/rpc_factory_registry.cc b/tensorflow/core/util/rpc/rpc_factory_registry.cc
new file mode 100644
index 0000000000..a148b5c04d
--- /dev/null
+++ b/tensorflow/core/util/rpc/rpc_factory_registry.cc
@@ -0,0 +1,44 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+
+#include "tensorflow/core/util/rpc/rpc_factory.h"
+
+#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
+
+namespace tensorflow {
+
+RPCFactoryRegistry* RPCFactoryRegistry::Global() {
+ static RPCFactoryRegistry* registry = new RPCFactoryRegistry;
+ return registry;
+}
+
+RPCFactoryRegistry::RPCFactoryFn* RPCFactoryRegistry::Get(
+ const string& protocol) {
+ auto found = fns_.find(protocol);
+ if (found == fns_.end()) return nullptr;
+ return &found->second;
+}
+
+void RPCFactoryRegistry::Register(const string& protocol,
+ const RPCFactoryFn& factory_fn) {
+ auto existing = Get(protocol);
+ CHECK_EQ(existing, nullptr)
+ << "RPC factory for protocol: " << protocol << " already registered";
+ fns_.insert(std::pair<const string&, RPCFactoryFn>(protocol, factory_fn));
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/rpc/rpc_factory_registry.h b/tensorflow/core/util/rpc/rpc_factory_registry.h
new file mode 100644
index 0000000000..2635a4012e
--- /dev/null
+++ b/tensorflow/core/util/rpc/rpc_factory_registry.h
@@ -0,0 +1,72 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_
+#define TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_
+
+#include <map>
+#include <string>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/util/rpc/rpc_factory.h"
+
+namespace tensorflow {
+
+class RPCFactoryRegistry {
+ public:
+ typedef std::function<RPCFactory*(OpKernelConstruction* ctx, bool fail_fast,
+ int64 timeout_in_ms)>
+ RPCFactoryFn;
+
+ // Returns a pointer to a global RPCFactoryRegistry object.
+ static RPCFactoryRegistry* Global();
+
+ // Returns a pointer to an function that creates an RPC factory for the given
+ // protocol.
+ RPCFactoryFn* Get(const string& protocol);
+
+ // Registers a function that creates and RPC factory for the given protocol.
+ // The function should transfer the ownership of the factory to its caller.
+ void Register(const string& protocol, const RPCFactoryFn& factory_fn);
+
+ private:
+ std::map<string, RPCFactoryFn> fns_;
+};
+
+namespace rpc_factory_registration {
+
+class RPCFactoryRegistration {
+ public:
+ RPCFactoryRegistration(const string& protocol,
+ const RPCFactoryRegistry::RPCFactoryFn& factory_fn) {
+ RPCFactoryRegistry::Global()->Register(protocol, factory_fn);
+ }
+};
+
+} // namespace rpc_factory_registration
+
+#define REGISTER_RPC_FACTORY(protocol, factory_fn) \
+ REGISTER_RPC_FACTORY_UNIQ_HELPER(__COUNTER__, protocol, factory_fn)
+
+#define REGISTER_RPC_FACTORY_UNIQ_HELPER(ctr, protocol, factory_fn) \
+ REGISTER_RPC_FACTORY_UNIQ(ctr, protocol, factory_fn)
+
+#define REGISTER_RPC_FACTORY_UNIQ(ctr, protocol, factory_fn) \
+ static rpc_factory_registration::RPCFactoryRegistration \
+ rpc_factory_registration_fn_##ctr(protocol, factory_fn)
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_
diff --git a/tensorflow/core/util/rpc/rpc_factory_registry_test.cc b/tensorflow/core/util/rpc/rpc_factory_registry_test.cc
new file mode 100644
index 0000000000..cfd0f95016
--- /dev/null
+++ b/tensorflow/core/util/rpc/rpc_factory_registry_test.cc
@@ -0,0 +1,41 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+struct Value {
+ static RPCFactory* Function(OpKernelConstruction* ctx, bool fail_fast,
+ int64 timeout_in_ms) {
+ return nullptr;
+ }
+};
+
+REGISTER_RPC_FACTORY("TEST FACTORY 1", Value::Function);
+REGISTER_RPC_FACTORY("TEST FACTORY 2", Value::Function);
+} // namespace
+
+TEST(RPCFactoryRegistryTest, TestBasic) {
+ EXPECT_EQ(RPCFactoryRegistry::Global()->Get("NON-EXISTENT"), nullptr);
+ auto factory1 = RPCFactoryRegistry::Global()->Get("TEST FACTORY 1");
+ EXPECT_NE(factory1, nullptr);
+ auto factory2 = RPCFactoryRegistry::Global()->Get("TEST FACTORY 2");
+ EXPECT_NE(factory2, nullptr);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/docs_src/extend/new_data_formats.md b/tensorflow/docs_src/extend/new_data_formats.md
index 10e717c280..2c33a6b6f7 100644
--- a/tensorflow/docs_src/extend/new_data_formats.md
+++ b/tensorflow/docs_src/extend/new_data_formats.md
@@ -1,4 +1,4 @@
-# Custom Data Readers
+# Reading custom file and record formats
PREREQUISITES:
@@ -9,187 +9,273 @@ PREREQUISITES:
We divide the task of supporting a file format into two pieces:
-* File formats: We use a *Reader* Op to read a *record* (which can be any
- string) from a file.
-* Record formats: We use decoder or parsing Ops to turn a string record
+* File formats: We use a reader `tf.data.Dataset` to read raw *records* (which
+ are typically represented by scalar string tensors, but can have more
+ structure) from a file.
+* Record formats: We use decoder or parsing ops to turn a string record
into tensors usable by TensorFlow.
For example, to read a
[CSV file](https://en.wikipedia.org/wiki/Comma-separated_values), we use
-@{tf.TextLineReader$a Reader for text files}
-followed by
-@{tf.decode_csv$an Op that parses CSV data from a line of text}.
+@{tf.data.TextLineDataset$a dataset for reading text files line-by-line}
+and then @{tf.data.Dataset.map$map} an
+@{tf.decode_csv$op} that parses CSV data from each line of text in the dataset.
[TOC]
-## Writing a Reader for a file format
+## Writing a `Dataset` for a file format
-A `Reader` is something that reads records from a file. There are some examples
-of Reader Ops already built into TensorFlow:
+A @{tf.data.Dataset} represents a sequence of *elements*, which can be the
+individual records in a file. There are several examples of "reader" datasets
+that are already built into TensorFlow:
-* @{tf.TFRecordReader}
- ([source in `kernels/tf_record_reader_op.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/tf_record_reader_op.cc))
-* @{tf.FixedLengthRecordReader}
- ([source in `kernels/fixed_length_record_reader_op.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/fixed_length_record_reader_op.cc))
-* @{tf.TextLineReader}
- ([source in `kernels/text_line_reader_op.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/text_line_reader_op.cc))
+* @{tf.data.TFRecordDataset}
+ ([source in `kernels/data/reader_dataset_ops.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/data/reader_dataset_ops.cc))
+* @{tf.data.FixedLengthRecordDataset}
+ ([source in `kernels/data/reader_dataset_ops.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/data/reader_dataset_ops.cc))
+* @{tf.data.TextLineDataset}
+ ([source in `kernels/data/reader_dataset_ops.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/data/reader_dataset_ops.cc))
-You can see these all expose the same interface, the only differences
-are in their constructors. The most important method is `read`.
-It takes a queue argument, which is where it gets filenames to
-read from whenever it needs one (e.g. when the `read` op first runs, or
-the previous `read` reads the last record from a file). It produces
-two scalar tensors: a string key and a string value.
+Each of these implementations comprises three related classes:
-To create a new reader called `SomeReader`, you will need to:
+* A `tensorflow::DatasetOpKernel` subclass (e.g. `TextLineDatasetOp`), which
+ tells TensorFlow how to construct a dataset object from the inputs to and
+ attrs of an op, in its `MakeDataset()` method.
-1. In C++, define a subclass of
- [`tensorflow::ReaderBase`](https://www.tensorflow.org/code/tensorflow/core/framework/reader_base.h)
- called `SomeReader`.
-2. In C++, register a new reader op and kernel with the name `"SomeReader"`.
-3. In Python, define a subclass of @{tf.ReaderBase} called `SomeReader`.
+* A `tensorflow::GraphDatasetBase` subclass (e.g. `TextLineDatasetOp::Dataset`),
+ which represents the *immutable* definition of the dataset itself, and tells
+ TensorFlow how to construct an iterator object over that dataset, in its
+ `MakeIterator()` method.
-You can put all the C++ code in a file in
-`tensorflow/core/user_ops/some_reader_op.cc`. The code to read a file will live
-in a descendant of the C++ `ReaderBase` class, which is defined in
-[`tensorflow/core/kernels/reader_base.h`](https://www.tensorflow.org/code/tensorflow/core/framework/reader_base.h).
-You will need to implement the following methods:
+* A `tensorflow::DatasetIterator<Dataset>` subclass (e.g.
+ `TextLineDatasetOp::Dataset::Iterator`), which represents the *mutable* state
+ of an iterator over a particular dataset, and tells TensorFlow how to get the
+ next element from the iterator, in its `GetNextInternal()` method.
-* `OnWorkStartedLocked`: open the next file
-* `ReadLocked`: read a record or report EOF/error
-* `OnWorkFinishedLocked`: close the current file, and
-* `ResetLocked`: get a clean slate after, e.g., an error
+The most important method is the `GetNextInternal()` method, since it defines
+how to actually read records from the file and represent them as one or more
+`Tensor` objects.
-These methods have names ending in "Locked" since `ReaderBase` makes sure
-to acquire a mutex before calling any of these methods, so you generally don't
-have to worry about thread safety (though that only protects the members of the
-class, not global state).
+To create a new reader dataset called (for example) `MyReaderDataset`, you will
+need to:
-For `OnWorkStartedLocked`, the name of the file to open is the value returned by
-the `current_work()` method. `ReadLocked` has this signature:
+1. In C++, define subclasses of `tensorflow::DatasetOpKernel`,
+ `tensorflow::GraphDatasetBase`, and `tensorflow::DatasetIterator<Dataset>`
+ that implement the reading logic.
+2. In C++, register a new reader op and kernel with the name
+ `"MyReaderDataset"`.
+3. In Python, define a subclass of @{tf.data.Dataset} called `MyReaderDataset`.
-```c++
-Status ReadLocked(string* key, string* value, bool* produced, bool* at_end)
-```
-
-If `ReadLocked` successfully reads a record from the file, it should fill in:
-
-* `*key`: with an identifier for the record, that a human could use to find
- this record again. You can include the filename from `current_work()`,
- and append a record number or whatever.
-* `*value`: with the contents of the record.
-* `*produced`: set to `true`.
-
-If you hit the end of a file (EOF), set `*at_end` to `true`. In either case,
-return `Status::OK()`. If there is an error, simply return it using one of the
-helper functions from
-[`tensorflow/core/lib/core/errors.h`](https://www.tensorflow.org/code/tensorflow/core/lib/core/errors.h)
-without modifying any arguments.
-
-Next you will create the actual Reader op. It will help if you are familiar
-with @{$adding_an_op$the adding an op how-to}. The main steps
-are:
-
-* Registering the op.
-* Define and register an `OpKernel`.
-
-To register the op, you will use a `REGISTER_OP` call defined in
-[`tensorflow/core/framework/op.h`](https://www.tensorflow.org/code/tensorflow/core/framework/op.h).
-Reader ops never take any input and always have a single output with type
-`resource`. They should have string `container` and `shared_name` attrs.
-You may optionally define additional attrs
-for configuration or include documentation in a `Doc`. For examples, see
-[`tensorflow/core/ops/io_ops.cc`](https://www.tensorflow.org/code/tensorflow/core/ops/io_ops.cc),
-e.g.:
+You can put all the C++ code in a single file, such as
+`my_reader_dataset_op.cc`. It will help if you are
+familiar with @{$adding_an_op$the adding an op how-to}. The following skeleton
+can be used as a starting point for your implementation:
```c++
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
-REGISTER_OP("TextLineReader")
- .Output("reader_handle: resource")
- .Attr("skip_header_lines: int = 0")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .SetIsStateful()
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-A Reader that outputs the lines of a file delimited by '\n'.
-)doc");
-```
-
-To define an `OpKernel`, Readers can use the shortcut of descending from
-`ReaderOpKernel`, defined in
-[`tensorflow/core/framework/reader_op_kernel.h`](https://www.tensorflow.org/code/tensorflow/core/framework/reader_op_kernel.h),
-and implement a constructor that calls `SetReaderFactory`. After defining
-your class, you will need to register it using `REGISTER_KERNEL_BUILDER(...)`.
-An example with no attrs:
+namespace tensorflow {
+namespace {
-```c++
-#include "tensorflow/core/framework/reader_op_kernel.h"
-
-class TFRecordReaderOp : public ReaderOpKernel {
+class MyReaderDatasetOp : public DatasetOpKernel {
public:
- explicit TFRecordReaderOp(OpKernelConstruction* context)
- : ReaderOpKernel(context) {
- Env* env = context->env();
- SetReaderFactory([this, env]() { return new TFRecordReader(name(), env); });
- }
-};
-REGISTER_KERNEL_BUILDER(Name("TFRecordReader").Device(DEVICE_CPU),
- TFRecordReaderOp);
-```
+ MyReaderDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
+ // Parse and validate any attrs that define the dataset using
+ // `ctx->GetAttr()`, and store them in member variables.
+ }
-An example with attrs:
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ // Parse and validate any input tensors 0that define the dataset using
+ // `ctx->input()` or the utility function
+ // `ParseScalarArgument<T>(ctx, &arg)`.
-```c++
-#include "tensorflow/core/framework/reader_op_kernel.h"
-
-class TextLineReaderOp : public ReaderOpKernel {
- public:
- explicit TextLineReaderOp(OpKernelConstruction* context)
- : ReaderOpKernel(context) {
- int skip_header_lines = -1;
- OP_REQUIRES_OK(context,
- context->GetAttr("skip_header_lines", &skip_header_lines));
- OP_REQUIRES(context, skip_header_lines >= 0,
- errors::InvalidArgument("skip_header_lines must be >= 0 not ",
- skip_header_lines));
- Env* env = context->env();
- SetReaderFactory([this, skip_header_lines, env]() {
- return new TextLineReader(name(), skip_header_lines, env);
- });
+ // Create the dataset object, passing any (already-validated) arguments from
+ // attrs or input tensors.
+ *output = new Dataset(ctx);
}
-};
-REGISTER_KERNEL_BUILDER(Name("TextLineReader").Device(DEVICE_CPU),
- TextLineReaderOp);
-```
-
-The last step is to add the Python wrapper. You can either do this by
-@{$adding_an_op#build_the_op_library$compiling a dynamic library}
-or, if you are building TensorFlow from source, adding to `user_ops.py`.
-For the latter, you will import `tensorflow.python.ops.io_ops` in
-[`tensorflow/python/user_ops/user_ops.py`](https://www.tensorflow.org/code/tensorflow/python/user_ops/user_ops.py)
-and add a descendant of [`io_ops.ReaderBase`](https://www.tensorflow.org/code/tensorflow/python/ops/io_ops.py).
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx) : GraphDatasetBase(ctx) {}
+
+ std::unique_ptr<IteratorBase> MakeIterator(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::MyReader")}));
+ }
+
+ // Record structure: Each record is represented by a scalar string tensor.
+ //
+ // Dataset elements can have a fixed number of components of different
+ // types and shapes; replace the following two methods to customize this
+ // aspect of the dataset.
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
+ return *dtypes;
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}});
+ return *shapes;
+ }
+
+ string DebugString() override { return "MyReaderDatasetOp::Dataset"; }
+
+ protected:
+ // Optional: Implementation of `GraphDef` serialization for this dataset.
+ //
+ // Implement this method if you want to be able to save and restore
+ // instances of this dataset (and any iterators over it).
+ Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ // Construct nodes to represent any of the input tensors from this
+ // object's member variables using `b->AddScalar()` and `b->AddVector()`.
+ std::vector<Node*> input_tensors;
+ TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params), i_(0) {}
+
+ // Implementation of the reading logic.
+ //
+ // The example implementation in this file yields the string "MyReader!"
+ // ten times. In general there are three cases:
+ //
+ // 1. If an element is successfully read, store it as one or more tensors
+ // in `*out_tensors`, set `*end_of_sequence = false` and return
+ // `Status::OK()`.
+ // 2. If the end of input is reached, set `*end_of_sequence = true` and
+ // return `Status::OK()`.
+ // 3. If an error occurs, return an error status using one of the helper
+ // functions from "tensorflow/core/lib/core/errors.h".
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ // NOTE: `GetNextInternal()` may be called concurrently, so it is
+ // recommended that you protect the iterator state with a mutex.
+ mutex_lock l(mu_);
+ if (i_ < 10) {
+ // Create a scalar string tensor and add it to the output.
+ Tensor record_tensor(ctx->allocator({}), DT_STRING, {});
+ record_tensor.scalar<string>()() = "MyReader!";
+ out_tensors->emplace_back(std::move(record_tensor));
+ ++i_;
+ *end_of_sequence = false;
+ } else {
+ *end_of_sequence = true;
+ }
+ return Status::OK();
+ }
+
+ protected:
+ // Optional: Implementation of iterator state serialization for this
+ // iterator.
+ //
+ // Implement these two methods if you want to be able to save and restore
+ // instances of this iterator.
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
+ return Status::OK();
+ }
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ int64 i_ GUARDED_BY(mu_);
+ };
+ };
+};
-```python
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import common_shapes
-from tensorflow.python.ops import io_ops
+// Register the op definition for MyReaderDataset.
+//
+// Dataset ops always have a single output, of type `variant`, which represents
+// the constructed `Dataset` object.
+//
+// Add any attrs and input tensors that define the dataset here.
+REGISTER_OP("MyReaderDataset")
+ .Output("handle: variant")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::ScalarShape);
-class SomeReader(io_ops.ReaderBase):
+// Register the kernel implementation for MyReaderDataset.
+REGISTER_KERNEL_BUILDER(Name("MyReaderDataset").Device(DEVICE_CPU),
+ MyReaderDatasetOp);
- def __init__(self, name=None):
- rr = gen_user_ops.some_reader(name=name)
- super(SomeReader, self).__init__(rr)
+} // namespace
+} // namespace tensorflow
+```
+The last step is to build the C++ code and add a Python wrapper. The easiest way
+to do this is by @{$adding_an_op#build_the_op_library$compiling a dynamic
+library} (e.g. called `"my_reader_dataset_op.so"`), and adding a Python class
+that subclasses @{tf.data.Dataset} to wrap it. An example Python program is
+given here:
-ops.NotDifferentiable("SomeReader")
+```python
+import tensorflow as tf
+
+# Assumes the file is in the current working directory.
+my_reader_dataset_module = tf.load_op_library("./my_reader_dataset_op.so")
+
+class MyReaderDataset(tf.data.Dataset):
+
+ def __init__(self):
+ super(MyReaderDataset, self).__init__()
+ # Create any input attrs or tensors as members of this class.
+
+ def _as_variant_tensor(self):
+ # Actually construct the graph node for the dataset op.
+ #
+ # This method will be invoked when you create an iterator on this dataset
+ # or a dataset derived from it.
+ return my_reader_dataset_module.my_reader_dataset()
+
+ # The following properties define the structure of each element: a scalar
+ # `tf.string` tensor. Change these properties to match the `output_dtypes()`
+ # and `output_shapes()` methods of `MyReaderDataset::Dataset` if you modify
+ # the structure of each element.
+ @property
+ def output_types(self):
+ return tf.string
+
+ @property
+ def output_shapes(self):
+ return tf.TensorShape([])
+
+ @property
+ def output_classes(self):
+ return tf.Tensor
+
+if __name__ == "__main__":
+ # Create a MyReaderDataset and print its elements.
+ with tf.Session() as sess:
+ iterator = MyReaderDataset().make_one_shot_iterator()
+ next_element = iterator.get_next()
+ try:
+ while True:
+ print(sess.run(next_element)) # Prints "MyReader!" ten times.
+ except tf.errors.OutOfRangeError:
+ pass
```
-You can see some examples in
-[`tensorflow/python/ops/io_ops.py`](https://www.tensorflow.org/code/tensorflow/python/ops/io_ops.py).
+You can see some examples of `Dataset` wrapper classes in
+[`tensorflow/python/data/ops/dataset_ops.py`](https://www.tensorflow.org/code/tensorflow/python/data/ops/dataset_ops.py).
## Writing an Op for a record format
@@ -201,9 +287,7 @@ track down where the bad data came from.
Examples of Ops useful for decoding records:
-* @{tf.parse_single_example}
- (and
- @{tf.parse_example})
+* @{tf.parse_single_example} (and @{tf.parse_example})
* @{tf.decode_csv}
* @{tf.decode_raw}
@@ -211,11 +295,6 @@ Note that it can be useful to use multiple Ops to decode a particular record
format. For example, you may have an image saved as a string in
[a `tf.train.Example` protocol buffer](https://www.tensorflow.org/code/tensorflow/core/example/example.proto).
Depending on the format of that image, you might take the corresponding output
-from a
-@{tf.parse_single_example}
-op and call @{tf.image.decode_jpeg},
-@{tf.image.decode_png}, or
-@{tf.decode_raw}. It is common to
-take the output of `tf.decode_raw` and use
-@{tf.slice} and
-@{tf.reshape} to extract pieces.
+from a @{tf.parse_single_example} op and call @{tf.image.decode_jpeg},
+@{tf.image.decode_png}, or @{tf.decode_raw}. It is common to take the output
+of `tf.decode_raw` and use @{tf.slice} and @{tf.reshape} to extract pieces.
diff --git a/tensorflow/docs_src/programmers_guide/eager.md b/tensorflow/docs_src/programmers_guide/eager.md
index dc5b403428..595e6be4af 100644
--- a/tensorflow/docs_src/programmers_guide/eager.md
+++ b/tensorflow/docs_src/programmers_guide/eager.md
@@ -102,11 +102,11 @@ print(a.numpy())
# [3 4]]
```
-The `tfe` module contains symbols available to both eager and graph execution
+The `tf.contrib.eager` module contains symbols available to both eager and graph execution
environments and is useful for writing code to [work with graphs](#work_with_graphs):
```py
-import tensorflow.contrib.eager as tfe
+tfe = tf.contrib.eager
```
## Dynamic control flow
@@ -213,25 +213,25 @@ their objects.
[Automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation)
is useful for implementing machine learning algorithms such as
[backpropagation](https://en.wikipedia.org/wiki/Backpropagation) for training
-neural networks. During eager execution, use `tfe.GradientTape` to trace
+neural networks. During eager execution, use `tf.GradientTape` to trace
operations for computing gradients later.
-`tfe.GradientTape` is an opt-in feature to provide maximal performance when
+`tf.GradientTape` is an opt-in feature to provide maximal performance when
not tracing. Since different operations can occur during each call, all
forward-pass operations get recorded to a "tape". To compute the gradient, play
-the tape backwards and then discard. A particular `tfe.GradientTape` can only
+the tape backwards and then discard. A particular `tf.GradientTape` can only
compute one gradient; subsequent calls throw a runtime error.
```py
w = tfe.Variable([[1.0]])
-with tfe.GradientTape() as tape:
+with tf.GradientTape() as tape:
loss = w * w
grad = tape.gradient(loss, [w])
print(grad) # => [tf.Tensor([[ 2.]], shape=(1, 1), dtype=float32)]
```
-Here's an example of `tfe.GradientTape` that records forward-pass operations
+Here's an example of `tf.GradientTape` that records forward-pass operations
to train a simple model:
```py
@@ -251,8 +251,8 @@ def loss(weights, biases):
# Return the derivative of loss with respect to weight and bias
def grad(weights, biases):
- with tfe.GradientTape() as tape:
- loss_value = loss(weights, biases)
+ with tf.GradientTape() as tape:
+ loss_value = loss(weights, biases)
return tape.gradient(loss_value, [weights, biases])
train_steps = 200
@@ -292,7 +292,7 @@ Final loss: 0.974
W = 3.01582956314, B = 2.1191945076
```
-Replay the `tfe.GradientTape` to compute the gradients and apply them in a
+Replay the `tf.GradientTape` to compute the gradients and apply them in a
training loop. This is demonstrated in an excerpt from the
[mnist_eager.py](https://github.com/tensorflow/models/blob/master/official/mnist/mnist_eager.py)
example:
@@ -301,9 +301,9 @@ example:
dataset = tf.data.Dataset.from_tensor_slices((data.train.images,
data.train.labels))
...
-for (batch, (images, labels)) in enumerate(tfe.Iterator(dataset)):
+for (batch, (images, labels)) in enumerate(dataset):
...
- with tfe.GradientTape() as tape:
+ with tf.GradientTape() as tape:
logits = model(images, training=True)
loss_value = loss(logits, labels)
...
@@ -353,17 +353,17 @@ def loss(model, x, y):
return tf.losses.sparse_softmax_cross_entropy(labels=y, logits=prediction)
def grad(model, inputs, targets):
- with tfe.GradientTape() as tape:
+ with tf.GradientTape() as tape:
loss_value = loss(model, inputs, targets)
return tape.gradient(loss_value, model.variables)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
-x, y = tfe.Iterator(dataset_train).next()
+x, y = iter(dataset_train).next()
print("Initial loss: {:.3f}".format(loss(model, x, y)))
# Training loop
-for (i, (x, y)) in enumerate(tfe.Iterator(dataset_train)):
+for (i, (x, y)) in enumerate(dataset_train):
# Calculate derivatives of the input function with respect to its parameters.
grads = grad(model, x, y)
# Apply the gradient to the model
@@ -398,7 +398,7 @@ And for faster training, move the computation to a GPU:
```py
with tf.device("/gpu:0"):
- for (i, (x, y)) in enumerate(tfe.Iterator(dataset_train)):
+ for (i, (x, y)) in enumerate(dataset_train):
# minimize() is equivalent to the grad() and apply_gradients() calls.
optimizer.minimize(lambda: loss(model, x, y),
global_step=tf.train.get_or_create_global_step())
@@ -411,7 +411,7 @@ training to make automatic differentiation easier. The parameters of a model can
be encapsulated in classes as variables.
Better encapsulate model parameters by using `tfe.Variable` with
-`tfe.GradientTape`. For example, the automatic differentiation example above
+`tf.GradientTape`. For example, the automatic differentiation example above
can be rewritten:
```py
@@ -435,7 +435,7 @@ def loss(model, inputs, targets):
return tf.reduce_mean(tf.square(error))
def grad(model, inputs, targets):
- with tfe.GradientTape() as tape:
+ with tf.GradientTape() as tape:
loss_value = loss(model, inputs, targets)
return tape.gradient(loss_value, [model.W, model.B])
@@ -585,14 +585,14 @@ for _ in range(iterations):
### Dynamic models
-`tfe.GradientTape` can also be used in dynamic models. This example for a
+`tf.GradientTape` can also be used in dynamic models. This example for a
[backtracking line search](https://wikipedia.org/wiki/Backtracking_line_search)
algorithm looks like normal NumPy code, except there are gradients and is
differentiable, despite the complex control flow:
```py
def line_search_step(fn, init_x, rate=1.0):
- with tfe.GradientTape() as tape:
+ with tf.GradientTape() as tape:
# Variables are automatically recorded, but manually watch a tensor
tape.watch(init_x)
value = fn(init_x)
@@ -608,7 +608,7 @@ def line_search_step(fn, init_x, rate=1.0):
### Additional functions to compute gradients
-`tfe.GradientTape` is a powerful interface for computing gradients, but there
+`tf.GradientTape` is a powerful interface for computing gradients, but there
is another [Autograd](https://github.com/HIPS/autograd)-style API available for
automatic differentiation. These functions are useful if writing math code with
only tensors and gradient functions, and without `tfe.Variables`:
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 0fd2177df7..3d261c9d0a 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -1845,6 +1845,286 @@ func ReverseSequence(scope *Scope, input tf.Output, seq_lengths tf.Output, seq_d
return op.Output(0)
}
+// UniqueWithCountsAttr is an optional argument to UniqueWithCounts.
+type UniqueWithCountsAttr func(optionalAttr)
+
+// UniqueWithCountsOutIdx sets the optional out_idx attribute to value.
+// If not specified, defaults to DT_INT32
+func UniqueWithCountsOutIdx(value tf.DataType) UniqueWithCountsAttr {
+ return func(m optionalAttr) {
+ m["out_idx"] = value
+ }
+}
+
+// Finds unique elements in a 1-D tensor.
+//
+// This operation returns a tensor `y` containing all of the unique elements of `x`
+// sorted in the same order that they occur in `x`. This operation also returns a
+// tensor `idx` the same size as `x` that contains the index of each value of `x`
+// in the unique output `y`. Finally, it returns a third tensor `count` that
+// contains the count of each element of `y` in `x`. In other words:
+//
+// `y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]`
+//
+// For example:
+//
+// ```
+// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8]
+// y, idx, count = unique_with_counts(x)
+// y ==> [1, 2, 4, 7, 8]
+// idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4]
+// count ==> [2, 1, 3, 1, 2]
+// ```
+//
+// Arguments:
+// x: 1-D.
+//
+// Returns 1-D.1-D.1-D.
+func UniqueWithCounts(scope *Scope, x tf.Output, optional ...UniqueWithCountsAttr) (y tf.Output, idx tf.Output, count tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "UniqueWithCounts",
+ Input: []tf.Input{
+ x,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// UniqueV2Attr is an optional argument to UniqueV2.
+type UniqueV2Attr func(optionalAttr)
+
+// UniqueV2OutIdx sets the optional out_idx attribute to value.
+// If not specified, defaults to DT_INT32
+func UniqueV2OutIdx(value tf.DataType) UniqueV2Attr {
+ return func(m optionalAttr) {
+ m["out_idx"] = value
+ }
+}
+
+// Finds unique elements in a 1-D tensor.
+//
+// This operation returns a tensor `y` containing all of the unique elements of `x`
+// sorted in the same order that they occur in `x`. This operation also returns a
+// tensor `idx` the same size as `x` that contains the index of each value of `x`
+// in the unique output `y`. In other words:
+//
+// `y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]`
+//
+// For example:
+//
+// ```
+// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8]
+// y, idx = unique(x)
+// y ==> [1, 2, 4, 7, 8]
+// idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4]
+// ```
+//
+// Arguments:
+// x: A `Tensor`.
+// axis: A `Tensor` of type `int64` (default: 0). The axis of the Tensor to
+// find the unique elements.
+//
+// Returns A `Tensor`. Unique elements along the `axis` of `Tensor` x.A 1-D Tensor. Has the same type as x that contains the index of each
+// value of x in the output y.
+func UniqueV2(scope *Scope, x tf.Output, axis tf.Output, optional ...UniqueV2Attr) (y tf.Output, idx tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "UniqueV2",
+ Input: []tf.Input{
+ x, axis,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
+}
+
+// UniqueAttr is an optional argument to Unique.
+type UniqueAttr func(optionalAttr)
+
+// UniqueOutIdx sets the optional out_idx attribute to value.
+// If not specified, defaults to DT_INT32
+func UniqueOutIdx(value tf.DataType) UniqueAttr {
+ return func(m optionalAttr) {
+ m["out_idx"] = value
+ }
+}
+
+// Finds unique elements in a 1-D tensor.
+//
+// This operation returns a tensor `y` containing all of the unique elements of `x`
+// sorted in the same order that they occur in `x`. This operation also returns a
+// tensor `idx` the same size as `x` that contains the index of each value of `x`
+// in the unique output `y`. In other words:
+//
+// `y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]`
+//
+// For example:
+//
+// ```
+// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8]
+// y, idx = unique(x)
+// y ==> [1, 2, 4, 7, 8]
+// idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4]
+// ```
+//
+// Arguments:
+// x: 1-D.
+//
+// Returns 1-D.1-D.
+func Unique(scope *Scope, x tf.Output, optional ...UniqueAttr) (y tf.Output, idx tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Unique",
+ Input: []tf.Input{
+ x,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
+}
+
+// Shuffle dimensions of x according to a permutation and conjugate the result.
+//
+// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy:
+// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]`
+// `y[i,j,k,...,s,t,u] == conj(x[perm[i], perm[j], perm[k],...,perm[s], perm[t], perm[u]])`
+func ConjugateTranspose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "ConjugateTranspose",
+ Input: []tf.Input{
+ x, perm,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Reshapes a tensor.
+//
+// Given `tensor`, this operation returns a tensor that has the same values
+// as `tensor` with shape `shape`.
+//
+// If one component of `shape` is the special value -1, the size of that dimension
+// is computed so that the total size remains constant. In particular, a `shape`
+// of `[-1]` flattens into 1-D. At most one component of `shape` can be -1.
+//
+// If `shape` is 1-D or higher, then the operation returns a tensor with shape
+// `shape` filled with the values of `tensor`. In this case, the number of elements
+// implied by `shape` must be the same as the number of elements in `tensor`.
+//
+// For example:
+//
+// ```
+// # tensor 't' is [1, 2, 3, 4, 5, 6, 7, 8, 9]
+// # tensor 't' has shape [9]
+// reshape(t, [3, 3]) ==> [[1, 2, 3],
+// [4, 5, 6],
+// [7, 8, 9]]
+//
+// # tensor 't' is [[[1, 1], [2, 2]],
+// # [[3, 3], [4, 4]]]
+// # tensor 't' has shape [2, 2, 2]
+// reshape(t, [2, 4]) ==> [[1, 1, 2, 2],
+// [3, 3, 4, 4]]
+//
+// # tensor 't' is [[[1, 1, 1],
+// # [2, 2, 2]],
+// # [[3, 3, 3],
+// # [4, 4, 4]],
+// # [[5, 5, 5],
+// # [6, 6, 6]]]
+// # tensor 't' has shape [3, 2, 3]
+// # pass '[-1]' to flatten 't'
+// reshape(t, [-1]) ==> [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6]
+//
+// # -1 can also be used to infer the shape
+//
+// # -1 is inferred to be 9:
+// reshape(t, [2, -1]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3],
+// [4, 4, 4, 5, 5, 5, 6, 6, 6]]
+// # -1 is inferred to be 2:
+// reshape(t, [-1, 9]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3],
+// [4, 4, 4, 5, 5, 5, 6, 6, 6]]
+// # -1 is inferred to be 3:
+// reshape(t, [ 2, -1, 3]) ==> [[[1, 1, 1],
+// [2, 2, 2],
+// [3, 3, 3]],
+// [[4, 4, 4],
+// [5, 5, 5],
+// [6, 6, 6]]]
+//
+// # tensor 't' is [7]
+// # shape `[]` reshapes to a scalar
+// reshape(t, []) ==> 7
+// ```
+//
+// Arguments:
+//
+// shape: Defines the shape of the output tensor.
+func Reshape(scope *Scope, tensor tf.Output, shape tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Reshape",
+ Input: []tf.Input{
+ tensor, shape,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Checks a tensor for NaN and Inf values.
+//
+// When run, reports an `InvalidArgument` error if `tensor` has any values
+// that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is.
+//
+// Arguments:
+//
+// message: Prefix of the error message.
+func CheckNumerics(scope *Scope, tensor tf.Output, message string) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"message": message}
+ opspec := tf.OpSpec{
+ Type: "CheckNumerics",
+ Input: []tf.Input{
+ tensor,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Returns the complex conjugate of a complex number.
//
// Given a tensor `input` of complex numbers, this operation returns a tensor of
@@ -2098,6 +2378,68 @@ func SparseSegmentSumWithNumSegments(scope *Scope, data tf.Output, indices tf.Ou
return op.Output(0)
}
+// PreventGradientAttr is an optional argument to PreventGradient.
+type PreventGradientAttr func(optionalAttr)
+
+// PreventGradientMessage sets the optional message attribute to value.
+//
+// value: Will be printed in the error when anyone tries to differentiate
+// this operation.
+// If not specified, defaults to ""
+func PreventGradientMessage(value string) PreventGradientAttr {
+ return func(m optionalAttr) {
+ m["message"] = value
+ }
+}
+
+// An identity op that triggers an error if a gradient is requested.
+//
+// When executed in a graph, this op outputs its input tensor as-is.
+//
+// When building ops to compute gradients, the TensorFlow gradient system
+// will return an error when trying to lookup the gradient of this op,
+// because no gradient must ever be registered for this function. This
+// op exists to prevent subtle bugs from silently returning unimplemented
+// gradients in some corner cases.
+//
+// Arguments:
+// input: any tensor.
+//
+// Returns the same input tensor.
+func PreventGradient(scope *Scope, input tf.Output, optional ...PreventGradientAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "PreventGradient",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes asin of x element-wise.
+func Asin(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Asin",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// SparseToDenseAttr is an optional argument to SparseToDense.
type SparseToDenseAttr func(optionalAttr)
@@ -2671,120 +3013,6 @@ func Igammac(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) {
return op.Output(0)
}
-// LogUniformCandidateSamplerAttr is an optional argument to LogUniformCandidateSampler.
-type LogUniformCandidateSamplerAttr func(optionalAttr)
-
-// LogUniformCandidateSamplerSeed sets the optional seed attribute to value.
-//
-// value: If either seed or seed2 are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func LogUniformCandidateSamplerSeed(value int64) LogUniformCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// LogUniformCandidateSamplerSeed2 sets the optional seed2 attribute to value.
-//
-// value: An second seed to avoid seed collision.
-// If not specified, defaults to 0
-func LogUniformCandidateSamplerSeed2(value int64) LogUniformCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Generates labels for candidate sampling with a log-uniform distribution.
-//
-// See explanations of candidate sampling and the data formats at
-// go/candidate-sampling.
-//
-// For each batch, this op picks a single set of sampled candidate labels.
-//
-// The advantages of sampling candidates per-batch are simplicity and the
-// possibility of efficient dense matrix multiplication. The disadvantage is that
-// the sampled candidates must be chosen independently of the context and of the
-// true labels.
-//
-// Arguments:
-// true_classes: A batch_size * num_true matrix, in which each row contains the
-// IDs of the num_true target_classes in the corresponding original label.
-// num_true: Number of true labels per context.
-// num_sampled: Number of candidates to randomly sample.
-// unique: If unique is true, we sample with rejection, so that all sampled
-// candidates in a batch are unique. This requires some approximation to
-// estimate the post-rejection sampling probabilities.
-// range_max: The sampler will sample integers from the interval [0, range_max).
-//
-// Returns A vector of length num_sampled, in which each element is
-// the ID of a sampled candidate.A batch_size * num_true matrix, representing
-// the number of times each candidate is expected to occur in a batch
-// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
-// candidate representing the number of times the candidate is expected
-// to occur in a batch of sampled candidates. If unique=true, then this is a
-// probability.
-func LogUniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LogUniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "LogUniformCandidateSampler",
- Input: []tf.Input{
- true_classes,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
-// Returns (x - y)(x - y) element-wise.
-//
-// *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting
-// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-func SquaredDifference(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SquaredDifference",
- Input: []tf.Input{
- x, y,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Forwards the input to the output.
-//
-// This operator represents the loop termination condition used by the
-// "pivot" switches of a loop.
-//
-// Arguments:
-// input: A boolean scalar, representing the branch predicate of the Switch op.
-//
-// Returns The same tensor as `input`.
-func LoopCond(scope *Scope, input tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "LoopCond",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// ApproximateEqualAttr is an optional argument to ApproximateEqual.
type ApproximateEqualAttr func(optionalAttr)
@@ -3257,6 +3485,69 @@ func Digamma(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
+// Shuffle dimensions of x according to a permutation.
+//
+// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy:
+// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]`
+func Transpose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Transpose",
+ Input: []tf.Input{
+ x, perm,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// MinAttr is an optional argument to Min.
+type MinAttr func(optionalAttr)
+
+// MinKeepDims sets the optional keep_dims attribute to value.
+//
+// value: If true, retain reduced dimensions with length 1.
+// If not specified, defaults to false
+func MinKeepDims(value bool) MinAttr {
+ return func(m optionalAttr) {
+ m["keep_dims"] = value
+ }
+}
+
+// Computes the minimum of elements across dimensions of a tensor.
+//
+// Reduces `input` along the dimensions given in `axis`. Unless
+// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
+// `axis`. If `keep_dims` is true, the reduced dimensions are
+// retained with length 1.
+//
+// Arguments:
+// input: The tensor to reduce.
+// axis: The dimensions to reduce. Must be in the range
+// `[-rank(input), rank(input))`.
+//
+// Returns The reduced tensor.
+func Min(scope *Scope, input tf.Output, axis tf.Output, optional ...MinAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Min",
+ Input: []tf.Input{
+ input, axis,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Conv2DBackpropFilterAttr is an optional argument to Conv2DBackpropFilter.
type Conv2DBackpropFilterAttr func(optionalAttr)
@@ -4419,6 +4710,66 @@ func MatrixDiag(scope *Scope, diagonal tf.Output) (output tf.Output) {
return op.Output(0)
}
+// Computes the inverse permutation of a tensor.
+//
+// This operation computes the inverse of an index permutation. It takes a 1-D
+// integer tensor `x`, which represents the indices of a zero-based array, and
+// swaps each value with its index position. In other words, for an output tensor
+// `y` and an input tensor `x`, this operation computes the following:
+//
+// `y[x[i]] = i for i in [0, 1, ..., len(x) - 1]`
+//
+// The values must include 0. There can be no duplicate values or negative values.
+//
+// For example:
+//
+// ```
+// # tensor `x` is [3, 4, 0, 2, 1]
+// invert_permutation(x) ==> [2, 4, 3, 0, 1]
+// ```
+//
+// Arguments:
+// x: 1-D.
+//
+// Returns 1-D.
+func InvertPermutation(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "InvertPermutation",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes log softmax activations.
+//
+// For each batch `i` and class `j` we have
+//
+// logsoftmax[i, j] = logits[i, j] - log(sum(exp(logits[i])))
+//
+// Arguments:
+// logits: 2-D with shape `[batch_size, num_classes]`.
+//
+// Returns Same shape as `logits`.
+func LogSoftmax(scope *Scope, logits tf.Output) (logsoftmax tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "LogSoftmax",
+ Input: []tf.Input{
+ logits,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Returns the truth value of (x <= y) element-wise.
//
// *NOTE*: `LessEqual` supports broadcasting. More about broadcasting
@@ -5657,66 +6008,6 @@ func Reverse(scope *Scope, tensor tf.Output, dims tf.Output) (output tf.Output)
return op.Output(0)
}
-// Computes log softmax activations.
-//
-// For each batch `i` and class `j` we have
-//
-// logsoftmax[i, j] = logits[i, j] - log(sum(exp(logits[i])))
-//
-// Arguments:
-// logits: 2-D with shape `[batch_size, num_classes]`.
-//
-// Returns Same shape as `logits`.
-func LogSoftmax(scope *Scope, logits tf.Output) (logsoftmax tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "LogSoftmax",
- Input: []tf.Input{
- logits,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Computes the inverse permutation of a tensor.
-//
-// This operation computes the inverse of an index permutation. It takes a 1-D
-// integer tensor `x`, which represents the indices of a zero-based array, and
-// swaps each value with its index position. In other words, for an output tensor
-// `y` and an input tensor `x`, this operation computes the following:
-//
-// `y[x[i]] = i for i in [0, 1, ..., len(x) - 1]`
-//
-// The values must include 0. There can be no duplicate values or negative values.
-//
-// For example:
-//
-// ```
-// # tensor `x` is [3, 4, 0, 2, 1]
-// invert_permutation(x) ==> [2, 4, 3, 0, 1]
-// ```
-//
-// Arguments:
-// x: 1-D.
-//
-// Returns 1-D.
-func InvertPermutation(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "InvertPermutation",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// BiasAddGradAttr is an optional argument to BiasAddGrad.
type BiasAddGradAttr func(optionalAttr)
@@ -5919,6 +6210,363 @@ func Acos(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
+// QuantizeAndDequantizeAttr is an optional argument to QuantizeAndDequantize.
+type QuantizeAndDequantizeAttr func(optionalAttr)
+
+// QuantizeAndDequantizeSignedInput sets the optional signed_input attribute to value.
+// If not specified, defaults to true
+func QuantizeAndDequantizeSignedInput(value bool) QuantizeAndDequantizeAttr {
+ return func(m optionalAttr) {
+ m["signed_input"] = value
+ }
+}
+
+// QuantizeAndDequantizeNumBits sets the optional num_bits attribute to value.
+// If not specified, defaults to 8
+func QuantizeAndDequantizeNumBits(value int64) QuantizeAndDequantizeAttr {
+ return func(m optionalAttr) {
+ m["num_bits"] = value
+ }
+}
+
+// QuantizeAndDequantizeRangeGiven sets the optional range_given attribute to value.
+// If not specified, defaults to false
+func QuantizeAndDequantizeRangeGiven(value bool) QuantizeAndDequantizeAttr {
+ return func(m optionalAttr) {
+ m["range_given"] = value
+ }
+}
+
+// QuantizeAndDequantizeInputMin sets the optional input_min attribute to value.
+// If not specified, defaults to 0
+func QuantizeAndDequantizeInputMin(value float32) QuantizeAndDequantizeAttr {
+ return func(m optionalAttr) {
+ m["input_min"] = value
+ }
+}
+
+// QuantizeAndDequantizeInputMax sets the optional input_max attribute to value.
+// If not specified, defaults to 0
+func QuantizeAndDequantizeInputMax(value float32) QuantizeAndDequantizeAttr {
+ return func(m optionalAttr) {
+ m["input_max"] = value
+ }
+}
+
+// Use QuantizeAndDequantizeV2 instead.
+//
+// DEPRECATED at GraphDef version 22: Replaced by QuantizeAndDequantizeV2
+func QuantizeAndDequantize(scope *Scope, input tf.Output, optional ...QuantizeAndDequantizeAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "QuantizeAndDequantize",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Returns locations of nonzero / true values in a tensor.
+//
+// This operation returns the coordinates of true elements in `condition`. The
+// coordinates are returned in a 2-D tensor where the first dimension (rows)
+// represents the number of true elements, and the second dimension (columns)
+// represents the coordinates of the true elements. Keep in mind, the shape of
+// the output tensor can vary depending on how many true values there are in
+// `condition`. Indices are output in row-major order.
+//
+// For example:
+//
+// ```
+// # 'input' tensor is [[True, False]
+// # [True, False]]
+// # 'input' has two true values, so output has two coordinates.
+// # 'input' has rank of 2, so coordinates have two indices.
+// where(input) ==> [[0, 0],
+// [1, 0]]
+//
+// # `condition` tensor is [[[True, False]
+// # [True, False]]
+// # [[False, True]
+// # [False, True]]
+// # [[False, False]
+// # [False, True]]]
+// # 'input' has 5 true values, so output has 5 coordinates.
+// # 'input' has rank of 3, so coordinates have three indices.
+// where(input) ==> [[0, 0, 0],
+// [0, 1, 0],
+// [1, 0, 1],
+// [1, 1, 1],
+// [2, 1, 1]]
+//
+// # `condition` tensor is [[[1.5, 0.0]
+// # [-0.5, 0.0]]
+// # [[0.0, 0.25]
+// # [0.0, 0.75]]
+// # [[0.0, 0.0]
+// # [0.0, 0.01]]]
+// # 'input' has 5 nonzero values, so output has 5 coordinates.
+// # 'input' has rank of 3, so coordinates have three indices.
+// where(input) ==> [[0, 0, 0],
+// [0, 1, 0],
+// [1, 0, 1],
+// [1, 1, 1],
+// [2, 1, 1]]
+//
+// # `condition` tensor is [[[1.5 + 0.0j, 0.0 + 0.0j]
+// # [0.0 + 0.5j, 0.0 + 0.0j]]
+// # [[0.0 + 0.0j, 0.25 + 1.5j]
+// # [0.0 + 0.0j, 0.75 + 0.0j]]
+// # [[0.0 + 0.0j, 0.0 + 0.0j]
+// # [0.0 + 0.0j, 0.01 + 0.0j]]]
+// # 'input' has 5 nonzero magnitude values, so output has 5 coordinates.
+// # 'input' has rank of 3, so coordinates have three indices.
+// where(input) ==> [[0, 0, 0],
+// [0, 1, 0],
+// [1, 0, 1],
+// [1, 1, 1],
+// [2, 1, 1]]
+// ```
+func Where(scope *Scope, condition tf.Output) (index tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Where",
+ Input: []tf.Input{
+ condition,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// QueueDequeueV2Attr is an optional argument to QueueDequeueV2.
+type QueueDequeueV2Attr func(optionalAttr)
+
+// QueueDequeueV2TimeoutMs sets the optional timeout_ms attribute to value.
+//
+// value: If the queue is empty, this operation will block for up to
+// timeout_ms milliseconds.
+// Note: This option is not supported yet.
+// If not specified, defaults to -1
+func QueueDequeueV2TimeoutMs(value int64) QueueDequeueV2Attr {
+ return func(m optionalAttr) {
+ m["timeout_ms"] = value
+ }
+}
+
+// Dequeues a tuple of one or more tensors from the given queue.
+//
+// This operation has k outputs, where k is the number of components
+// in the tuples stored in the given queue, and output i is the ith
+// component of the dequeued tuple.
+//
+// N.B. If the queue is empty, this operation will block until an element
+// has been dequeued (or 'timeout_ms' elapses, if specified).
+//
+// Arguments:
+// handle: The handle to a queue.
+// component_types: The type of each component in a tuple.
+//
+// Returns One or more tensors that were dequeued as a tuple.
+func QueueDequeueV2(scope *Scope, handle tf.Output, component_types []tf.DataType, optional ...QueueDequeueV2Attr) (components []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"component_types": component_types}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "QueueDequeueV2",
+ Input: []tf.Input{
+ handle,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if components, idx, err = makeOutputList(op, idx, "components"); err != nil {
+ scope.UpdateErr("QueueDequeueV2", err)
+ return
+ }
+ return components
+}
+
+// Computes the Gauss error function of `x` element-wise.
+func Erf(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Erf",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Returns element-wise largest integer not greater than x.
+func Floor(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Floor",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// OneHotAttr is an optional argument to OneHot.
+type OneHotAttr func(optionalAttr)
+
+// OneHotAxis sets the optional axis attribute to value.
+//
+// value: The axis to fill (default: -1, a new inner-most axis).
+// If not specified, defaults to -1
+func OneHotAxis(value int64) OneHotAttr {
+ return func(m optionalAttr) {
+ m["axis"] = value
+ }
+}
+
+// Returns a one-hot tensor.
+//
+// The locations represented by indices in `indices` take value `on_value`,
+// while all other locations take value `off_value`.
+//
+// If the input `indices` is rank `N`, the output will have rank `N+1`,
+// The new axis is created at dimension `axis` (default: the new axis is
+// appended at the end).
+//
+// If `indices` is a scalar the output shape will be a vector of length `depth`.
+//
+// If `indices` is a vector of length `features`, the output shape will be:
+// ```
+// features x depth if axis == -1
+// depth x features if axis == 0
+// ```
+//
+// If `indices` is a matrix (batch) with shape `[batch, features]`,
+// the output shape will be:
+// ```
+// batch x features x depth if axis == -1
+// batch x depth x features if axis == 1
+// depth x batch x features if axis == 0
+// ```
+//
+//
+// Examples
+// =========
+//
+// Suppose that
+//
+// ```
+// indices = [0, 2, -1, 1]
+// depth = 3
+// on_value = 5.0
+// off_value = 0.0
+// axis = -1
+// ```
+//
+// Then output is `[4 x 3]`:
+//
+// ```output =
+// [5.0 0.0 0.0] // one_hot(0)
+// [0.0 0.0 5.0] // one_hot(2)
+// [0.0 0.0 0.0] // one_hot(-1)
+// [0.0 5.0 0.0] // one_hot(1)
+// ```
+//
+// Suppose that
+//
+// ```
+// indices = [0, 2, -1, 1]
+// depth = 3
+// on_value = 0.0
+// off_value = 3.0
+// axis = 0
+// ```
+//
+// Then output is `[3 x 4]`:
+//
+// ```output =
+// [0.0 3.0 3.0 3.0]
+// [3.0 3.0 3.0 0.0]
+// [3.0 3.0 3.0 3.0]
+// [3.0 0.0 3.0 3.0]
+// // ^ one_hot(0)
+// // ^ one_hot(2)
+// // ^ one_hot(-1)
+// // ^ one_hot(1)
+// ```
+// Suppose that
+//
+// ```
+// indices = [[0, 2], [1, -1]]
+// depth = 3
+// on_value = 1.0
+// off_value = 0.0
+// axis = -1
+// ```
+//
+// Then output is `[2 x 2 x 3]`:
+//
+// ```output =
+// [
+// [1.0, 0.0, 0.0] // one_hot(0)
+// [0.0, 0.0, 1.0] // one_hot(2)
+// ][
+// [0.0, 1.0, 0.0] // one_hot(1)
+// [0.0, 0.0, 0.0] // one_hot(-1)
+// ]```
+//
+// Arguments:
+// indices: A tensor of indices.
+// depth: A scalar defining the depth of the one hot dimension.
+// on_value: A scalar defining the value to fill in output when `indices[j] = i`.
+// off_value: A scalar defining the value to fill in output when `indices[j] != i`.
+//
+// Returns The one-hot tensor.
+func OneHot(scope *Scope, indices tf.Output, depth tf.Output, on_value tf.Output, off_value tf.Output, optional ...OneHotAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "OneHot",
+ Input: []tf.Input{
+ indices, depth, on_value, off_value,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Real-valued fast Fourier transform.
//
// Computes the 1-dimensional discrete Fourier transform of a real-valued signal
@@ -6541,6 +7189,34 @@ func DataFormatVecPermute(scope *Scope, x tf.Output, optional ...DataFormatVecPe
return op.Output(0)
}
+// Reads the value of a variable.
+//
+// The tensor returned by this operation is immutable.
+//
+// The value returned by this operation is guaranteed to be influenced by all the
+// writes on which this operation depends directly or indirectly, and to not be
+// influenced by any of the writes which depend directly or indirectly on this
+// operation.
+//
+// Arguments:
+// resource: handle to the resource in which to store the variable.
+// dtype: the dtype of the value.
+func ReadVariableOp(scope *Scope, resource tf.Output, dtype tf.DataType) (value tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"dtype": dtype}
+ opspec := tf.OpSpec{
+ Type: "ReadVariableOp",
+ Input: []tf.Input{
+ resource,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes tan of x element-wise.
func Tan(scope *Scope, x tf.Output) (y tf.Output) {
if scope.Err() != nil {
@@ -6843,60 +7519,6 @@ func Complex(scope *Scope, real tf.Output, imag tf.Output, optional ...ComplexAt
return op.Output(0)
}
-// UniqueWithCountsAttr is an optional argument to UniqueWithCounts.
-type UniqueWithCountsAttr func(optionalAttr)
-
-// UniqueWithCountsOutIdx sets the optional out_idx attribute to value.
-// If not specified, defaults to DT_INT32
-func UniqueWithCountsOutIdx(value tf.DataType) UniqueWithCountsAttr {
- return func(m optionalAttr) {
- m["out_idx"] = value
- }
-}
-
-// Finds unique elements in a 1-D tensor.
-//
-// This operation returns a tensor `y` containing all of the unique elements of `x`
-// sorted in the same order that they occur in `x`. This operation also returns a
-// tensor `idx` the same size as `x` that contains the index of each value of `x`
-// in the unique output `y`. Finally, it returns a third tensor `count` that
-// contains the count of each element of `y` in `x`. In other words:
-//
-// `y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]`
-//
-// For example:
-//
-// ```
-// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8]
-// y, idx, count = unique_with_counts(x)
-// y ==> [1, 2, 4, 7, 8]
-// idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4]
-// count ==> [2, 1, 3, 1, 2]
-// ```
-//
-// Arguments:
-// x: 1-D.
-//
-// Returns 1-D.1-D.1-D.
-func UniqueWithCounts(scope *Scope, x tf.Output, optional ...UniqueWithCountsAttr) (y tf.Output, idx tf.Output, count tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "UniqueWithCounts",
- Input: []tf.Input{
- x,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// StatelessRandomNormalAttr is an optional argument to StatelessRandomNormal.
type StatelessRandomNormalAttr func(optionalAttr)
@@ -7475,85 +8097,6 @@ func ComputeAccidentalHits(scope *Scope, true_classes tf.Output, sampled_candida
return op.Output(0), op.Output(1), op.Output(2)
}
-// CumsumAttr is an optional argument to Cumsum.
-type CumsumAttr func(optionalAttr)
-
-// CumsumExclusive sets the optional exclusive attribute to value.
-//
-// value: If `True`, perform exclusive cumsum.
-// If not specified, defaults to false
-func CumsumExclusive(value bool) CumsumAttr {
- return func(m optionalAttr) {
- m["exclusive"] = value
- }
-}
-
-// CumsumReverse sets the optional reverse attribute to value.
-//
-// value: A `bool` (default: False).
-// If not specified, defaults to false
-func CumsumReverse(value bool) CumsumAttr {
- return func(m optionalAttr) {
- m["reverse"] = value
- }
-}
-
-// Compute the cumulative sum of the tensor `x` along `axis`.
-//
-// By default, this op performs an inclusive cumsum, which means that the first
-// element of the input is identical to the first element of the output:
-//
-// ```python
-// tf.cumsum([a, b, c]) # => [a, a + b, a + b + c]
-// ```
-//
-// By setting the `exclusive` kwarg to `True`, an exclusive cumsum is
-// performed instead:
-//
-// ```python
-// tf.cumsum([a, b, c], exclusive=True) # => [0, a, a + b]
-// ```
-//
-// By setting the `reverse` kwarg to `True`, the cumsum is performed in the
-// opposite direction:
-//
-// ```python
-// tf.cumsum([a, b, c], reverse=True) # => [a + b + c, b + c, c]
-// ```
-//
-// This is more efficient than using separate `tf.reverse` ops.
-//
-// The `reverse` and `exclusive` kwargs can also be combined:
-//
-// ```python
-// tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0]
-// ```
-//
-// Arguments:
-// x: A `Tensor`. Must be one of the following types: `float32`, `float64`,
-// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`,
-// `complex128`, `qint8`, `quint8`, `qint32`, `half`.
-// axis: A `Tensor` of type `int32` (default: 0). Must be in the range
-// `[-rank(x), rank(x))`.
-func Cumsum(scope *Scope, x tf.Output, axis tf.Output, optional ...CumsumAttr) (out tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Cumsum",
- Input: []tf.Input{
- x, axis,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// QuantizedRelu6Attr is an optional argument to QuantizedRelu6.
type QuantizedRelu6Attr func(optionalAttr)
@@ -8142,85 +8685,6 @@ func ResourceApplyPowerSign(scope *Scope, var_ tf.Output, m tf.Output, lr tf.Out
return scope.AddOperation(opspec)
}
-// CumprodAttr is an optional argument to Cumprod.
-type CumprodAttr func(optionalAttr)
-
-// CumprodExclusive sets the optional exclusive attribute to value.
-//
-// value: If `True`, perform exclusive cumprod.
-// If not specified, defaults to false
-func CumprodExclusive(value bool) CumprodAttr {
- return func(m optionalAttr) {
- m["exclusive"] = value
- }
-}
-
-// CumprodReverse sets the optional reverse attribute to value.
-//
-// value: A `bool` (default: False).
-// If not specified, defaults to false
-func CumprodReverse(value bool) CumprodAttr {
- return func(m optionalAttr) {
- m["reverse"] = value
- }
-}
-
-// Compute the cumulative product of the tensor `x` along `axis`.
-//
-// By default, this op performs an inclusive cumprod, which means that the first
-// element of the input is identical to the first element of the output:
-//
-// ```python
-// tf.cumprod([a, b, c]) # => [a, a * b, a * b * c]
-// ```
-//
-// By setting the `exclusive` kwarg to `True`, an exclusive cumprod is
-// performed instead:
-//
-// ```python
-// tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b]
-// ```
-//
-// By setting the `reverse` kwarg to `True`, the cumprod is performed in the
-// opposite direction:
-//
-// ```python
-// tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c]
-// ```
-//
-// This is more efficient than using separate `tf.reverse` ops.
-//
-// The `reverse` and `exclusive` kwargs can also be combined:
-//
-// ```python
-// tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1]
-// ```
-//
-// Arguments:
-// x: A `Tensor`. Must be one of the following types: `float32`, `float64`,
-// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`,
-// `complex128`, `qint8`, `quint8`, `qint32`, `half`.
-// axis: A `Tensor` of type `int32` (default: 0). Must be in the range
-// `[-rank(x), rank(x))`.
-func Cumprod(scope *Scope, x tf.Output, axis tf.Output, optional ...CumprodAttr) (out tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Cumprod",
- Input: []tf.Input{
- x, axis,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Computes the mean along segments of a tensor.
//
// Read @{$math_ops#segmentation$the section on segmentation} for an explanation of
@@ -9607,24 +10071,6 @@ func StringToHashBucketFast(scope *Scope, input tf.Output, num_buckets int64) (o
return op.Output(0)
}
-// Returns the max of x and y (i.e. x > y ? x : y) element-wise.
-//
-// *NOTE*: `Maximum` supports broadcasting. More about broadcasting
-// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-func Maximum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Maximum",
- Input: []tf.Input{
- x, y,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// TensorArrayGatherV3Attr is an optional argument to TensorArrayGatherV3.
type TensorArrayGatherV3Attr func(optionalAttr)
@@ -9914,194 +10360,6 @@ func ResourceSparseApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.O
return scope.AddOperation(opspec)
}
-// Returns element-wise largest integer not greater than x.
-func Floor(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Floor",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Computes the Gauss error function of `x` element-wise.
-func Erf(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Erf",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// OneHotAttr is an optional argument to OneHot.
-type OneHotAttr func(optionalAttr)
-
-// OneHotAxis sets the optional axis attribute to value.
-//
-// value: The axis to fill (default: -1, a new inner-most axis).
-// If not specified, defaults to -1
-func OneHotAxis(value int64) OneHotAttr {
- return func(m optionalAttr) {
- m["axis"] = value
- }
-}
-
-// Returns a one-hot tensor.
-//
-// The locations represented by indices in `indices` take value `on_value`,
-// while all other locations take value `off_value`.
-//
-// If the input `indices` is rank `N`, the output will have rank `N+1`,
-// The new axis is created at dimension `axis` (default: the new axis is
-// appended at the end).
-//
-// If `indices` is a scalar the output shape will be a vector of length `depth`.
-//
-// If `indices` is a vector of length `features`, the output shape will be:
-// ```
-// features x depth if axis == -1
-// depth x features if axis == 0
-// ```
-//
-// If `indices` is a matrix (batch) with shape `[batch, features]`,
-// the output shape will be:
-// ```
-// batch x features x depth if axis == -1
-// batch x depth x features if axis == 1
-// depth x batch x features if axis == 0
-// ```
-//
-//
-// Examples
-// =========
-//
-// Suppose that
-//
-// ```
-// indices = [0, 2, -1, 1]
-// depth = 3
-// on_value = 5.0
-// off_value = 0.0
-// axis = -1
-// ```
-//
-// Then output is `[4 x 3]`:
-//
-// ```output =
-// [5.0 0.0 0.0] // one_hot(0)
-// [0.0 0.0 5.0] // one_hot(2)
-// [0.0 0.0 0.0] // one_hot(-1)
-// [0.0 5.0 0.0] // one_hot(1)
-// ```
-//
-// Suppose that
-//
-// ```
-// indices = [0, 2, -1, 1]
-// depth = 3
-// on_value = 0.0
-// off_value = 3.0
-// axis = 0
-// ```
-//
-// Then output is `[3 x 4]`:
-//
-// ```output =
-// [0.0 3.0 3.0 3.0]
-// [3.0 3.0 3.0 0.0]
-// [3.0 3.0 3.0 3.0]
-// [3.0 0.0 3.0 3.0]
-// // ^ one_hot(0)
-// // ^ one_hot(2)
-// // ^ one_hot(-1)
-// // ^ one_hot(1)
-// ```
-// Suppose that
-//
-// ```
-// indices = [[0, 2], [1, -1]]
-// depth = 3
-// on_value = 1.0
-// off_value = 0.0
-// axis = -1
-// ```
-//
-// Then output is `[2 x 2 x 3]`:
-//
-// ```output =
-// [
-// [1.0, 0.0, 0.0] // one_hot(0)
-// [0.0, 0.0, 1.0] // one_hot(2)
-// ][
-// [0.0, 1.0, 0.0] // one_hot(1)
-// [0.0, 0.0, 0.0] // one_hot(-1)
-// ]```
-//
-// Arguments:
-// indices: A tensor of indices.
-// depth: A scalar defining the depth of the one hot dimension.
-// on_value: A scalar defining the value to fill in output when `indices[j] = i`.
-// off_value: A scalar defining the value to fill in output when `indices[j] != i`.
-//
-// Returns The one-hot tensor.
-func OneHot(scope *Scope, indices tf.Output, depth tf.Output, on_value tf.Output, off_value tf.Output, optional ...OneHotAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "OneHot",
- Input: []tf.Input{
- indices, depth, on_value, off_value,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Reads the value of a variable.
-//
-// The tensor returned by this operation is immutable.
-//
-// The value returned by this operation is guaranteed to be influenced by all the
-// writes on which this operation depends directly or indirectly, and to not be
-// influenced by any of the writes which depend directly or indirectly on this
-// operation.
-//
-// Arguments:
-// resource: handle to the resource in which to store the variable.
-// dtype: the dtype of the value.
-func ReadVariableOp(scope *Scope, resource tf.Output, dtype tf.DataType) (value tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"dtype": dtype}
- opspec := tf.OpSpec{
- Type: "ReadVariableOp",
- Input: []tf.Input{
- resource,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// MaxPool3DGradAttr is an optional argument to MaxPool3DGrad.
type MaxPool3DGradAttr func(optionalAttr)
@@ -11406,6 +11664,97 @@ func Sub(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
return op.Output(0)
}
+// LogUniformCandidateSamplerAttr is an optional argument to LogUniformCandidateSampler.
+type LogUniformCandidateSamplerAttr func(optionalAttr)
+
+// LogUniformCandidateSamplerSeed sets the optional seed attribute to value.
+//
+// value: If either seed or seed2 are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func LogUniformCandidateSamplerSeed(value int64) LogUniformCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// LogUniformCandidateSamplerSeed2 sets the optional seed2 attribute to value.
+//
+// value: An second seed to avoid seed collision.
+// If not specified, defaults to 0
+func LogUniformCandidateSamplerSeed2(value int64) LogUniformCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Generates labels for candidate sampling with a log-uniform distribution.
+//
+// See explanations of candidate sampling and the data formats at
+// go/candidate-sampling.
+//
+// For each batch, this op picks a single set of sampled candidate labels.
+//
+// The advantages of sampling candidates per-batch are simplicity and the
+// possibility of efficient dense matrix multiplication. The disadvantage is that
+// the sampled candidates must be chosen independently of the context and of the
+// true labels.
+//
+// Arguments:
+// true_classes: A batch_size * num_true matrix, in which each row contains the
+// IDs of the num_true target_classes in the corresponding original label.
+// num_true: Number of true labels per context.
+// num_sampled: Number of candidates to randomly sample.
+// unique: If unique is true, we sample with rejection, so that all sampled
+// candidates in a batch are unique. This requires some approximation to
+// estimate the post-rejection sampling probabilities.
+// range_max: The sampler will sample integers from the interval [0, range_max).
+//
+// Returns A vector of length num_sampled, in which each element is
+// the ID of a sampled candidate.A batch_size * num_true matrix, representing
+// the number of times each candidate is expected to occur in a batch
+// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
+// candidate representing the number of times the candidate is expected
+// to occur in a batch of sampled candidates. If unique=true, then this is a
+// probability.
+func LogUniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LogUniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "LogUniformCandidateSampler",
+ Input: []tf.Input{
+ true_classes,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// Returns the max of x and y (i.e. x > y ? x : y) element-wise.
+//
+// *NOTE*: `Maximum` supports broadcasting. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func Maximum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Maximum",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes softmax cross entropy cost and gradients to backpropagate.
//
// Inputs are the logits, not probabilities.
@@ -12768,69 +13117,6 @@ func ReadFile(scope *Scope, filename tf.Output) (contents tf.Output) {
return op.Output(0)
}
-// MinAttr is an optional argument to Min.
-type MinAttr func(optionalAttr)
-
-// MinKeepDims sets the optional keep_dims attribute to value.
-//
-// value: If true, retain reduced dimensions with length 1.
-// If not specified, defaults to false
-func MinKeepDims(value bool) MinAttr {
- return func(m optionalAttr) {
- m["keep_dims"] = value
- }
-}
-
-// Computes the minimum of elements across dimensions of a tensor.
-//
-// Reduces `input` along the dimensions given in `axis`. Unless
-// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
-// `axis`. If `keep_dims` is true, the reduced dimensions are
-// retained with length 1.
-//
-// Arguments:
-// input: The tensor to reduce.
-// axis: The dimensions to reduce. Must be in the range
-// `[-rank(input), rank(input))`.
-//
-// Returns The reduced tensor.
-func Min(scope *Scope, input tf.Output, axis tf.Output, optional ...MinAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Min",
- Input: []tf.Input{
- input, axis,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Shuffle dimensions of x according to a permutation.
-//
-// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy:
-// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]`
-func Transpose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Transpose",
- Input: []tf.Input{
- x, perm,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Computes sigmoid of `x` element-wise.
//
// Specifically, `y = 1 / (1 + exp(-x))`.
@@ -16533,30 +16819,6 @@ func SparseMatMul(scope *Scope, a tf.Output, b tf.Output, optional ...SparseMatM
return op.Output(0)
}
-// Computes the power of one value to another.
-//
-// Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for
-// corresponding elements in `x` and `y`. For example:
-//
-// ```
-// # tensor 'x' is [[2, 2]], [3, 3]]
-// # tensor 'y' is [[8, 16], [2, 3]]
-// tf.pow(x, y) ==> [[256, 65536], [9, 27]]
-// ```
-func Pow(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Pow",
- Input: []tf.Input{
- x, y,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// ShapeAttr is an optional argument to Shape.
type ShapeAttr func(optionalAttr)
@@ -16597,6 +16859,30 @@ func Shape(scope *Scope, input tf.Output, optional ...ShapeAttr) (output tf.Outp
return op.Output(0)
}
+// Computes the power of one value to another.
+//
+// Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for
+// corresponding elements in `x` and `y`. For example:
+//
+// ```
+// # tensor 'x' is [[2, 2]], [3, 3]]
+// # tensor 'y' is [[8, 16], [2, 3]]
+// tf.pow(x, y) ==> [[256, 65536], [9, 27]]
+// ```
+func Pow(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Pow",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes fingerprints of the input strings.
//
// Arguments:
@@ -16871,6 +17157,70 @@ func DepthwiseConv2dNativeBackpropInput(scope *Scope, input_sizes tf.Output, fil
return op.Output(0)
}
+// Stops gradient computation.
+//
+// When executed in a graph, this op outputs its input tensor as-is.
+//
+// When building ops to compute gradients, this op prevents the contribution of
+// its inputs to be taken into account. Normally, the gradient generator adds ops
+// to a graph to compute the derivatives of a specified 'loss' by recursively
+// finding out inputs that contributed to its computation. If you insert this op
+// in the graph it inputs are masked from the gradient generator. They are not
+// taken into account for computing gradients.
+//
+// This is useful any time you want to compute a value with TensorFlow but need
+// to pretend that the value was a constant. Some examples include:
+//
+// * The *EM* algorithm where the *M-step* should not involve backpropagation
+// through the output of the *E-step*.
+// * Contrastive divergence training of Boltzmann machines where, when
+// differentiating the energy function, the training must not backpropagate
+// through the graph that generated the samples from the model.
+// * Adversarial training, where no backprop should happen through the adversarial
+// example generation process.
+func StopGradient(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "StopGradient",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Eagerly executes a python function to compute func(input)->output. The
+//
+// semantics of the input, output, and attributes are the same as those for
+// PyFunc.
+func EagerPyFunc(scope *Scope, input []tf.Output, token string, Tout []tf.DataType) (output []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"token": token, "Tout": Tout}
+ opspec := tf.OpSpec{
+ Type: "EagerPyFunc",
+ Input: []tf.Input{
+ tf.OutputList(input),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if output, idx, err = makeOutputList(op, idx, "output"); err != nil {
+ scope.UpdateErr("EagerPyFunc", err)
+ return
+ }
+ return output
+}
+
// Adds sparse updates to the variable referenced by `resource`.
//
// This operation computes
@@ -16951,6 +17301,47 @@ func InTopK(scope *Scope, predictions tf.Output, targets tf.Output, k int64) (pr
return op.Output(0)
}
+// Returns (x - y)(x - y) element-wise.
+//
+// *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func SquaredDifference(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SquaredDifference",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Forwards the input to the output.
+//
+// This operator represents the loop termination condition used by the
+// "pivot" switches of a loop.
+//
+// Arguments:
+// input: A boolean scalar, representing the branch predicate of the Switch op.
+//
+// Returns The same tensor as `input`.
+func LoopCond(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "LoopCond",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the gradient for the inverse of `x` wrt its input.
//
// Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy`
@@ -17124,203 +17515,6 @@ func RandomGamma(scope *Scope, shape tf.Output, alpha tf.Output, optional ...Ran
return op.Output(0)
}
-// QuantizeAndDequantizeAttr is an optional argument to QuantizeAndDequantize.
-type QuantizeAndDequantizeAttr func(optionalAttr)
-
-// QuantizeAndDequantizeSignedInput sets the optional signed_input attribute to value.
-// If not specified, defaults to true
-func QuantizeAndDequantizeSignedInput(value bool) QuantizeAndDequantizeAttr {
- return func(m optionalAttr) {
- m["signed_input"] = value
- }
-}
-
-// QuantizeAndDequantizeNumBits sets the optional num_bits attribute to value.
-// If not specified, defaults to 8
-func QuantizeAndDequantizeNumBits(value int64) QuantizeAndDequantizeAttr {
- return func(m optionalAttr) {
- m["num_bits"] = value
- }
-}
-
-// QuantizeAndDequantizeRangeGiven sets the optional range_given attribute to value.
-// If not specified, defaults to false
-func QuantizeAndDequantizeRangeGiven(value bool) QuantizeAndDequantizeAttr {
- return func(m optionalAttr) {
- m["range_given"] = value
- }
-}
-
-// QuantizeAndDequantizeInputMin sets the optional input_min attribute to value.
-// If not specified, defaults to 0
-func QuantizeAndDequantizeInputMin(value float32) QuantizeAndDequantizeAttr {
- return func(m optionalAttr) {
- m["input_min"] = value
- }
-}
-
-// QuantizeAndDequantizeInputMax sets the optional input_max attribute to value.
-// If not specified, defaults to 0
-func QuantizeAndDequantizeInputMax(value float32) QuantizeAndDequantizeAttr {
- return func(m optionalAttr) {
- m["input_max"] = value
- }
-}
-
-// Use QuantizeAndDequantizeV2 instead.
-//
-// DEPRECATED at GraphDef version 22: Replaced by QuantizeAndDequantizeV2
-func QuantizeAndDequantize(scope *Scope, input tf.Output, optional ...QuantizeAndDequantizeAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "QuantizeAndDequantize",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Returns locations of nonzero / true values in a tensor.
-//
-// This operation returns the coordinates of true elements in `condition`. The
-// coordinates are returned in a 2-D tensor where the first dimension (rows)
-// represents the number of true elements, and the second dimension (columns)
-// represents the coordinates of the true elements. Keep in mind, the shape of
-// the output tensor can vary depending on how many true values there are in
-// `condition`. Indices are output in row-major order.
-//
-// For example:
-//
-// ```
-// # 'input' tensor is [[True, False]
-// # [True, False]]
-// # 'input' has two true values, so output has two coordinates.
-// # 'input' has rank of 2, so coordinates have two indices.
-// where(input) ==> [[0, 0],
-// [1, 0]]
-//
-// # `condition` tensor is [[[True, False]
-// # [True, False]]
-// # [[False, True]
-// # [False, True]]
-// # [[False, False]
-// # [False, True]]]
-// # 'input' has 5 true values, so output has 5 coordinates.
-// # 'input' has rank of 3, so coordinates have three indices.
-// where(input) ==> [[0, 0, 0],
-// [0, 1, 0],
-// [1, 0, 1],
-// [1, 1, 1],
-// [2, 1, 1]]
-//
-// # `condition` tensor is [[[1.5, 0.0]
-// # [-0.5, 0.0]]
-// # [[0.0, 0.25]
-// # [0.0, 0.75]]
-// # [[0.0, 0.0]
-// # [0.0, 0.01]]]
-// # 'input' has 5 nonzero values, so output has 5 coordinates.
-// # 'input' has rank of 3, so coordinates have three indices.
-// where(input) ==> [[0, 0, 0],
-// [0, 1, 0],
-// [1, 0, 1],
-// [1, 1, 1],
-// [2, 1, 1]]
-//
-// # `condition` tensor is [[[1.5 + 0.0j, 0.0 + 0.0j]
-// # [0.0 + 0.5j, 0.0 + 0.0j]]
-// # [[0.0 + 0.0j, 0.25 + 1.5j]
-// # [0.0 + 0.0j, 0.75 + 0.0j]]
-// # [[0.0 + 0.0j, 0.0 + 0.0j]
-// # [0.0 + 0.0j, 0.01 + 0.0j]]]
-// # 'input' has 5 nonzero magnitude values, so output has 5 coordinates.
-// # 'input' has rank of 3, so coordinates have three indices.
-// where(input) ==> [[0, 0, 0],
-// [0, 1, 0],
-// [1, 0, 1],
-// [1, 1, 1],
-// [2, 1, 1]]
-// ```
-func Where(scope *Scope, condition tf.Output) (index tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Where",
- Input: []tf.Input{
- condition,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// QueueDequeueV2Attr is an optional argument to QueueDequeueV2.
-type QueueDequeueV2Attr func(optionalAttr)
-
-// QueueDequeueV2TimeoutMs sets the optional timeout_ms attribute to value.
-//
-// value: If the queue is empty, this operation will block for up to
-// timeout_ms milliseconds.
-// Note: This option is not supported yet.
-// If not specified, defaults to -1
-func QueueDequeueV2TimeoutMs(value int64) QueueDequeueV2Attr {
- return func(m optionalAttr) {
- m["timeout_ms"] = value
- }
-}
-
-// Dequeues a tuple of one or more tensors from the given queue.
-//
-// This operation has k outputs, where k is the number of components
-// in the tuples stored in the given queue, and output i is the ith
-// component of the dequeued tuple.
-//
-// N.B. If the queue is empty, this operation will block until an element
-// has been dequeued (or 'timeout_ms' elapses, if specified).
-//
-// Arguments:
-// handle: The handle to a queue.
-// component_types: The type of each component in a tuple.
-//
-// Returns One or more tensors that were dequeued as a tuple.
-func QueueDequeueV2(scope *Scope, handle tf.Output, component_types []tf.DataType, optional ...QueueDequeueV2Attr) (components []tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"component_types": component_types}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "QueueDequeueV2",
- Input: []tf.Input{
- handle,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- if scope.Err() != nil {
- return
- }
- var idx int
- var err error
- if components, idx, err = makeOutputList(op, idx, "components"); err != nil {
- scope.UpdateErr("QueueDequeueV2", err)
- return
- }
- return components
-}
-
// RandomUniformIntAttr is an optional argument to RandomUniformInt.
type RandomUniformIntAttr func(optionalAttr)
@@ -17816,6 +18010,164 @@ func MatrixBandPart(scope *Scope, input tf.Output, num_lower tf.Output, num_uppe
return op.Output(0)
}
+// CumsumAttr is an optional argument to Cumsum.
+type CumsumAttr func(optionalAttr)
+
+// CumsumExclusive sets the optional exclusive attribute to value.
+//
+// value: If `True`, perform exclusive cumsum.
+// If not specified, defaults to false
+func CumsumExclusive(value bool) CumsumAttr {
+ return func(m optionalAttr) {
+ m["exclusive"] = value
+ }
+}
+
+// CumsumReverse sets the optional reverse attribute to value.
+//
+// value: A `bool` (default: False).
+// If not specified, defaults to false
+func CumsumReverse(value bool) CumsumAttr {
+ return func(m optionalAttr) {
+ m["reverse"] = value
+ }
+}
+
+// Compute the cumulative sum of the tensor `x` along `axis`.
+//
+// By default, this op performs an inclusive cumsum, which means that the first
+// element of the input is identical to the first element of the output:
+//
+// ```python
+// tf.cumsum([a, b, c]) # => [a, a + b, a + b + c]
+// ```
+//
+// By setting the `exclusive` kwarg to `True`, an exclusive cumsum is
+// performed instead:
+//
+// ```python
+// tf.cumsum([a, b, c], exclusive=True) # => [0, a, a + b]
+// ```
+//
+// By setting the `reverse` kwarg to `True`, the cumsum is performed in the
+// opposite direction:
+//
+// ```python
+// tf.cumsum([a, b, c], reverse=True) # => [a + b + c, b + c, c]
+// ```
+//
+// This is more efficient than using separate `tf.reverse` ops.
+//
+// The `reverse` and `exclusive` kwargs can also be combined:
+//
+// ```python
+// tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0]
+// ```
+//
+// Arguments:
+// x: A `Tensor`. Must be one of the following types: `float32`, `float64`,
+// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`,
+// `complex128`, `qint8`, `quint8`, `qint32`, `half`.
+// axis: A `Tensor` of type `int32` (default: 0). Must be in the range
+// `[-rank(x), rank(x))`.
+func Cumsum(scope *Scope, x tf.Output, axis tf.Output, optional ...CumsumAttr) (out tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Cumsum",
+ Input: []tf.Input{
+ x, axis,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// CumprodAttr is an optional argument to Cumprod.
+type CumprodAttr func(optionalAttr)
+
+// CumprodExclusive sets the optional exclusive attribute to value.
+//
+// value: If `True`, perform exclusive cumprod.
+// If not specified, defaults to false
+func CumprodExclusive(value bool) CumprodAttr {
+ return func(m optionalAttr) {
+ m["exclusive"] = value
+ }
+}
+
+// CumprodReverse sets the optional reverse attribute to value.
+//
+// value: A `bool` (default: False).
+// If not specified, defaults to false
+func CumprodReverse(value bool) CumprodAttr {
+ return func(m optionalAttr) {
+ m["reverse"] = value
+ }
+}
+
+// Compute the cumulative product of the tensor `x` along `axis`.
+//
+// By default, this op performs an inclusive cumprod, which means that the first
+// element of the input is identical to the first element of the output:
+//
+// ```python
+// tf.cumprod([a, b, c]) # => [a, a * b, a * b * c]
+// ```
+//
+// By setting the `exclusive` kwarg to `True`, an exclusive cumprod is
+// performed instead:
+//
+// ```python
+// tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b]
+// ```
+//
+// By setting the `reverse` kwarg to `True`, the cumprod is performed in the
+// opposite direction:
+//
+// ```python
+// tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c]
+// ```
+//
+// This is more efficient than using separate `tf.reverse` ops.
+//
+// The `reverse` and `exclusive` kwargs can also be combined:
+//
+// ```python
+// tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1]
+// ```
+//
+// Arguments:
+// x: A `Tensor`. Must be one of the following types: `float32`, `float64`,
+// `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`,
+// `complex128`, `qint8`, `quint8`, `qint32`, `half`.
+// axis: A `Tensor` of type `int32` (default: 0). Must be in the range
+// `[-rank(x), rank(x))`.
+func Cumprod(scope *Scope, x tf.Output, axis tf.Output, optional ...CumprodAttr) (out tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Cumprod",
+ Input: []tf.Input{
+ x, axis,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// QuantizedMatMulAttr is an optional argument to QuantizedMatMul.
type QuantizedMatMulAttr func(optionalAttr)
@@ -21902,80 +22254,64 @@ func NonMaxSuppressionV2(scope *Scope, boxes tf.Output, scores tf.Output, max_ou
return op.Output(0)
}
-// Reshapes a tensor.
-//
-// Given `tensor`, this operation returns a tensor that has the same values
-// as `tensor` with shape `shape`.
+// Creates a TensorArray for storing the gradients of values in the given handle.
//
-// If one component of `shape` is the special value -1, the size of that dimension
-// is computed so that the total size remains constant. In particular, a `shape`
-// of `[-1]` flattens into 1-D. At most one component of `shape` can be -1.
+// If the given TensorArray gradient already exists, returns a reference to it.
//
-// If `shape` is 1-D or higher, then the operation returns a tensor with shape
-// `shape` filled with the values of `tensor`. In this case, the number of elements
-// implied by `shape` must be the same as the number of elements in `tensor`.
+// Locks the size of the original TensorArray by disabling its dynamic size flag.
//
-// For example:
+// **A note about the input flow_in:**
//
-// ```
-// # tensor 't' is [1, 2, 3, 4, 5, 6, 7, 8, 9]
-// # tensor 't' has shape [9]
-// reshape(t, [3, 3]) ==> [[1, 2, 3],
-// [4, 5, 6],
-// [7, 8, 9]]
+// The handle flow_in forces the execution of the gradient lookup to occur
+// only after certain other operations have occurred. For example, when
+// the forward TensorArray is dynamically sized, writes to this TensorArray
+// may resize the object. The gradient TensorArray is statically sized based
+// on the size of the forward TensorArray when this operation executes.
+// Furthermore, the size of the forward TensorArray is frozen by this call.
+// As a result, the flow is used to ensure that the call to generate the gradient
+// TensorArray only happens after all writes are executed.
//
-// # tensor 't' is [[[1, 1], [2, 2]],
-// # [[3, 3], [4, 4]]]
-// # tensor 't' has shape [2, 2, 2]
-// reshape(t, [2, 4]) ==> [[1, 1, 2, 2],
-// [3, 3, 4, 4]]
+// In the case of dynamically sized TensorArrays, gradient computation should
+// only be performed on read operations that have themselves been chained via
+// flow to occur only after all writes have executed. That way the final size
+// of the forward TensorArray is known when this operation is called.
//
-// # tensor 't' is [[[1, 1, 1],
-// # [2, 2, 2]],
-// # [[3, 3, 3],
-// # [4, 4, 4]],
-// # [[5, 5, 5],
-// # [6, 6, 6]]]
-// # tensor 't' has shape [3, 2, 3]
-// # pass '[-1]' to flatten 't'
-// reshape(t, [-1]) ==> [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6]
+// **A note about the source attribute:**
//
-// # -1 can also be used to infer the shape
+// TensorArray gradient calls use an accumulator TensorArray object. If
+// multiple gradients are calculated and run in the same session, the multiple
+// gradient nodes may accidentally flow through the same accumulator TensorArray.
+// This double counts and generally breaks the TensorArray gradient flow.
//
-// # -1 is inferred to be 9:
-// reshape(t, [2, -1]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3],
-// [4, 4, 4, 5, 5, 5, 6, 6, 6]]
-// # -1 is inferred to be 2:
-// reshape(t, [-1, 9]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3],
-// [4, 4, 4, 5, 5, 5, 6, 6, 6]]
-// # -1 is inferred to be 3:
-// reshape(t, [ 2, -1, 3]) ==> [[[1, 1, 1],
-// [2, 2, 2],
-// [3, 3, 3]],
-// [[4, 4, 4],
-// [5, 5, 5],
-// [6, 6, 6]]]
+// The solution is to identify which gradient call this particular
+// TensorArray gradient is being called in. This is performed by identifying
+// a unique string (e.g. "gradients", "gradients_1", ...) from the input
+// gradient Tensor's name. This string is used as a suffix when creating
+// the TensorArray gradient object here (the attribute `source`).
//
-// # tensor 't' is [7]
-// # shape `[]` reshapes to a scalar
-// reshape(t, []) ==> 7
-// ```
+// The attribute `source` is added as a suffix to the forward TensorArray's
+// name when performing the creation / lookup, so that each separate gradient
+// calculation gets its own TensorArray accumulator.
//
// Arguments:
-//
-// shape: Defines the shape of the output tensor.
-func Reshape(scope *Scope, tensor tf.Output, shape tf.Output) (output tf.Output) {
+// handle: The handle to the forward TensorArray.
+// flow_in: A float scalar that enforces proper chaining of operations.
+// source: The gradient source string, used to decide which gradient TensorArray
+// to return.
+func TensorArrayGradV3(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output, flow_out tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{"source": source}
opspec := tf.OpSpec{
- Type: "Reshape",
+ Type: "TensorArrayGradV3",
Input: []tf.Input{
- tensor, shape,
+ handle, flow_in,
},
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0)
+ return op.Output(0), op.Output(1)
}
// Creates a dataset that splits a SparseTensor into elements row-wise.
@@ -24260,66 +24596,6 @@ func DecodeCompressed(scope *Scope, bytes tf.Output, optional ...DecodeCompresse
return op.Output(0)
}
-// Creates a TensorArray for storing the gradients of values in the given handle.
-//
-// If the given TensorArray gradient already exists, returns a reference to it.
-//
-// Locks the size of the original TensorArray by disabling its dynamic size flag.
-//
-// **A note about the input flow_in:**
-//
-// The handle flow_in forces the execution of the gradient lookup to occur
-// only after certain other operations have occurred. For example, when
-// the forward TensorArray is dynamically sized, writes to this TensorArray
-// may resize the object. The gradient TensorArray is statically sized based
-// on the size of the forward TensorArray when this operation executes.
-// Furthermore, the size of the forward TensorArray is frozen by this call.
-// As a result, the flow is used to ensure that the call to generate the gradient
-// TensorArray only happens after all writes are executed.
-//
-// In the case of dynamically sized TensorArrays, gradient computation should
-// only be performed on read operations that have themselves been chained via
-// flow to occur only after all writes have executed. That way the final size
-// of the forward TensorArray is known when this operation is called.
-//
-// **A note about the source attribute:**
-//
-// TensorArray gradient calls use an accumulator TensorArray object. If
-// multiple gradients are calculated and run in the same session, the multiple
-// gradient nodes may accidentally flow through the same accumulator TensorArray.
-// This double counts and generally breaks the TensorArray gradient flow.
-//
-// The solution is to identify which gradient call this particular
-// TensorArray gradient is being called in. This is performed by identifying
-// a unique string (e.g. "gradients", "gradients_1", ...) from the input
-// gradient Tensor's name. This string is used as a suffix when creating
-// the TensorArray gradient object here (the attribute `source`).
-//
-// The attribute `source` is added as a suffix to the forward TensorArray's
-// name when performing the creation / lookup, so that each separate gradient
-// calculation gets its own TensorArray accumulator.
-//
-// Arguments:
-// handle: The handle to the forward TensorArray.
-// flow_in: A float scalar that enforces proper chaining of operations.
-// source: The gradient source string, used to decide which gradient TensorArray
-// to return.
-func TensorArrayGradV3(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output, flow_out tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"source": source}
- opspec := tf.OpSpec{
- Type: "TensorArrayGradV3",
- Input: []tf.Input{
- handle, flow_in,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
-}
-
// Compare values of `input` to `threshold` and pack resulting bits into a `uint8`.
//
// Each comparison returns a boolean `true` (if `input_value > threshold`)
@@ -26991,58 +27267,6 @@ func DecodeWav(scope *Scope, contents tf.Output, optional ...DecodeWavAttr) (aud
return op.Output(0), op.Output(1)
}
-// UniqueAttr is an optional argument to Unique.
-type UniqueAttr func(optionalAttr)
-
-// UniqueOutIdx sets the optional out_idx attribute to value.
-// If not specified, defaults to DT_INT32
-func UniqueOutIdx(value tf.DataType) UniqueAttr {
- return func(m optionalAttr) {
- m["out_idx"] = value
- }
-}
-
-// Finds unique elements in a 1-D tensor.
-//
-// This operation returns a tensor `y` containing all of the unique elements of `x`
-// sorted in the same order that they occur in `x`. This operation also returns a
-// tensor `idx` the same size as `x` that contains the index of each value of `x`
-// in the unique output `y`. In other words:
-//
-// `y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]`
-//
-// For example:
-//
-// ```
-// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8]
-// y, idx = unique(x)
-// y ==> [1, 2, 4, 7, 8]
-// idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4]
-// ```
-//
-// Arguments:
-// x: 1-D.
-//
-// Returns 1-D.1-D.
-func Unique(scope *Scope, x tf.Output, optional ...UniqueAttr) (y tf.Output, idx tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Unique",
- Input: []tf.Input{
- x,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
-}
-
// Concatenates a list of `N` tensors along the first dimension.
//
// The input tensors are all required to have size 1 in the first dimension.
@@ -27663,227 +27887,3 @@ func GatherNd(scope *Scope, params tf.Output, indices tf.Output) (output tf.Outp
op := scope.AddOperation(opspec)
return op.Output(0)
}
-
-// Eagerly executes a python function to compute func(input)->output. The
-//
-// semantics of the input, output, and attributes are the same as those for
-// PyFunc.
-func EagerPyFunc(scope *Scope, input []tf.Output, token string, Tout []tf.DataType) (output []tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"token": token, "Tout": Tout}
- opspec := tf.OpSpec{
- Type: "EagerPyFunc",
- Input: []tf.Input{
- tf.OutputList(input),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- if scope.Err() != nil {
- return
- }
- var idx int
- var err error
- if output, idx, err = makeOutputList(op, idx, "output"); err != nil {
- scope.UpdateErr("EagerPyFunc", err)
- return
- }
- return output
-}
-
-// Stops gradient computation.
-//
-// When executed in a graph, this op outputs its input tensor as-is.
-//
-// When building ops to compute gradients, this op prevents the contribution of
-// its inputs to be taken into account. Normally, the gradient generator adds ops
-// to a graph to compute the derivatives of a specified 'loss' by recursively
-// finding out inputs that contributed to its computation. If you insert this op
-// in the graph it inputs are masked from the gradient generator. They are not
-// taken into account for computing gradients.
-//
-// This is useful any time you want to compute a value with TensorFlow but need
-// to pretend that the value was a constant. Some examples include:
-//
-// * The *EM* algorithm where the *M-step* should not involve backpropagation
-// through the output of the *E-step*.
-// * Contrastive divergence training of Boltzmann machines where, when
-// differentiating the energy function, the training must not backpropagate
-// through the graph that generated the samples from the model.
-// * Adversarial training, where no backprop should happen through the adversarial
-// example generation process.
-func StopGradient(scope *Scope, input tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "StopGradient",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Computes asin of x element-wise.
-func Asin(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Asin",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// PreventGradientAttr is an optional argument to PreventGradient.
-type PreventGradientAttr func(optionalAttr)
-
-// PreventGradientMessage sets the optional message attribute to value.
-//
-// value: Will be printed in the error when anyone tries to differentiate
-// this operation.
-// If not specified, defaults to ""
-func PreventGradientMessage(value string) PreventGradientAttr {
- return func(m optionalAttr) {
- m["message"] = value
- }
-}
-
-// An identity op that triggers an error if a gradient is requested.
-//
-// When executed in a graph, this op outputs its input tensor as-is.
-//
-// When building ops to compute gradients, the TensorFlow gradient system
-// will return an error when trying to lookup the gradient of this op,
-// because no gradient must ever be registered for this function. This
-// op exists to prevent subtle bugs from silently returning unimplemented
-// gradients in some corner cases.
-//
-// Arguments:
-// input: any tensor.
-//
-// Returns the same input tensor.
-func PreventGradient(scope *Scope, input tf.Output, optional ...PreventGradientAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "PreventGradient",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Checks a tensor for NaN and Inf values.
-//
-// When run, reports an `InvalidArgument` error if `tensor` has any values
-// that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is.
-//
-// Arguments:
-//
-// message: Prefix of the error message.
-func CheckNumerics(scope *Scope, tensor tf.Output, message string) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"message": message}
- opspec := tf.OpSpec{
- Type: "CheckNumerics",
- Input: []tf.Input{
- tensor,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Shuffle dimensions of x according to a permutation and conjugate the result.
-//
-// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy:
-// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]`
-// `y[i,j,k,...,s,t,u] == conj(x[perm[i], perm[j], perm[k],...,perm[s], perm[t], perm[u]])`
-func ConjugateTranspose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "ConjugateTranspose",
- Input: []tf.Input{
- x, perm,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// UniqueV2Attr is an optional argument to UniqueV2.
-type UniqueV2Attr func(optionalAttr)
-
-// UniqueV2OutIdx sets the optional out_idx attribute to value.
-// If not specified, defaults to DT_INT32
-func UniqueV2OutIdx(value tf.DataType) UniqueV2Attr {
- return func(m optionalAttr) {
- m["out_idx"] = value
- }
-}
-
-// Finds unique elements in a 1-D tensor.
-//
-// This operation returns a tensor `y` containing all of the unique elements of `x`
-// sorted in the same order that they occur in `x`. This operation also returns a
-// tensor `idx` the same size as `x` that contains the index of each value of `x`
-// in the unique output `y`. In other words:
-//
-// `y[idx[i]] = x[i] for i in [0, 1,...,rank(x) - 1]`
-//
-// For example:
-//
-// ```
-// # tensor 'x' is [1, 1, 2, 4, 4, 4, 7, 8, 8]
-// y, idx = unique(x)
-// y ==> [1, 2, 4, 7, 8]
-// idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4]
-// ```
-//
-// Arguments:
-// x: A `Tensor`.
-// axis: A `Tensor` of type `int64` (default: 0). The axis of the Tensor to
-// find the unique elements.
-//
-// Returns A `Tensor`. Unique elements along the `axis` of `Tensor` x.A 1-D Tensor. Has the same type as x that contains the index of each
-// value of x in the output y.
-func UniqueV2(scope *Scope, x tf.Output, axis tf.Output, optional ...UniqueV2Attr) (y tf.Output, idx tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "UniqueV2",
- Input: []tf.Input{
- x, axis,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
-}
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 753be82425..87b5c596ac 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1617,7 +1617,10 @@ py_library(
py_library(
name = "array_ops",
- srcs = ["ops/array_ops.py"],
+ srcs = [
+ "ops/array_ops.py",
+ "ops/inplace_ops.py",
+ ],
srcs_version = "PY2AND3",
deps = [
":array_ops_gen",
@@ -3368,6 +3371,7 @@ tf_py_wrap_cc(
"//tensorflow/c:python_api",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/eager:c_api",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_rpc_factory_registration",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
"//tensorflow/core/grappler:grappler_item",
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 8729e085a3..c28de3d054 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -121,7 +121,7 @@ class Dataset(object):
An `Iterator` over the elements of this dataset.
Raises:
- RuntimeError: If eager execution is enabled.
+ RuntimeError: If eager execution is not enabled.
"""
if context.executing_eagerly():
return iterator_ops.EagerIterator(self)
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 61859d6be3..5168ad3b18 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -223,6 +223,16 @@ class HelperContext(object):
else:
return val
+ def EnterGradientColocation(self, op, gradient_uid):
+ """Start building a gradient colocated with an op."""
+ if self._outer_context:
+ self._outer_context.EnterGradientColocation(op, gradient_uid)
+
+ def ExitGradientColocation(self, op, gradient_uid):
+ """Start building a gradient colocated with an op."""
+ if self._outer_context:
+ self._outer_context.ExitGradientColocation(op, gradient_uid)
+
def __enter__(self):
# pylint: disable=protected-access
self._g = ops.get_default_graph()
diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py
index bb033d3495..189b81aeea 100644
--- a/tensorflow/python/estimator/canned/head.py
+++ b/tensorflow/python/estimator/canned/head.py
@@ -57,8 +57,8 @@ _PREDICT_SERVING_KEY = 'predict'
# A LossSpec contains
# * a scalar `Tensor` representing reduced weighted training loss
-# * a scalar `Tensor` representing the unreduced unweighted loss
-# * a scalar `Tensor` representing the example weights
+# * a `Tensor` representing the unreduced unweighted loss
+# * a `Tensor` representing the example weights
# * possibly processed labels (e.g. vocabulary lookup, shape manipulation, etc)
LossSpec = collections.namedtuple(
'LossSpec', ['training_loss', 'unreduced_loss', 'weights',
@@ -163,8 +163,8 @@ class _Head(object):
Returns:
A LossSpec that contains
* the scalar `Tensor` representing reduced weighted training loss
- * the scalar `Tensor` representing the unreduced unweighted loss
- * the scalar `Tensor` representing the example weights
+ * the `Tensor` representing the unreduced unweighted loss
+ * the `Tensor` representing the example weights
* possibly processed labels (e.g. vocabulary lookup, shape manipulation,
etc.)
diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py
index 0edae92fd4..a31c424263 100644
--- a/tensorflow/python/framework/dtypes.py
+++ b/tensorflow/python/framework/dtypes.py
@@ -345,7 +345,7 @@ tf_export("uint16").export_constant(__name__, "uint16")
uint32 = DType(types_pb2.DT_UINT32)
tf_export("uint32").export_constant(__name__, "uint32")
uint64 = DType(types_pb2.DT_UINT64)
-tf_export("uint64").export_constant(__name__, "uint32")
+tf_export("uint64").export_constant(__name__, "uint64")
int16 = DType(types_pb2.DT_INT16)
tf_export("int16").export_constant(__name__, "int16")
int8 = DType(types_pb2.DT_INT8)
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 2574fa57a4..662cda2a7d 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import collections
import copy
+import functools
import linecache
import os
import re
@@ -4180,6 +4181,19 @@ class Graph(object):
return self._name_stack
@tf_contextlib.contextmanager
+ def _colocate_with_for_gradient(self, op, gradient_uid,
+ ignore_existing=False):
+ with self.colocate_with(op, ignore_existing):
+ if gradient_uid is not None and self._control_flow_context is not None:
+ try:
+ self._control_flow_context.EnterGradientColocation(op, gradient_uid)
+ yield
+ finally:
+ self._control_flow_context.ExitGradientColocation(op, gradient_uid)
+ else:
+ yield
+
+ @tf_contextlib.contextmanager
def colocate_with(self, op, ignore_existing=False):
"""Returns a context manager that specifies an op to colocate with.
@@ -4958,8 +4972,7 @@ def container(container_name):
return get_default_graph().container(container_name)
-@tf_export("colocate_with")
-def colocate_with(op, ignore_existing=False):
+def _colocate_with_for_gradient(op, gradient_uid, ignore_existing=False):
if context.executing_eagerly():
if op is not None:
return device(op.device)
@@ -4973,7 +4986,13 @@ def colocate_with(op, ignore_existing=False):
else:
raise ValueError("Encountered an Eager-defined Tensor during graph "
"construction, but a function was not being built.")
- return default_graph.colocate_with(op, ignore_existing)
+ return default_graph._colocate_with_for_gradient(
+ op, gradient_uid=gradient_uid, ignore_existing=ignore_existing)
+
+
+@tf_export("colocate_with")
+def colocate_with(op, ignore_existing=False):
+ return _colocate_with_for_gradient(op, None, ignore_existing=ignore_existing)
@tf_export("control_dependencies")
@@ -5226,14 +5245,35 @@ class _DefaultGraphStack(_DefaultStack): # pylint: disable=protected-access
@tf_contextlib.contextmanager
def get_controller(self, default):
try:
- context.context().context_switches.push(default.building_function,
- default.as_default)
+ if context.executing_eagerly():
+ # A Graph alone on the context stack would keep init_scope-wrapped
+ # operations graph building when entered (assuming init_scope is called
+ # in a graph building context). Instead, we push a context which first
+ # enables eager execution and then re-enters the Graph.
+ context.context().context_switches.push(
+ default.building_function,
+ functools.partial(
+ _enter_context_and_graph,
+ context.eager_mode,
+ default.as_default))
+ else:
+ # This Graph is being used from a graph building context. A lack of
+ # context switch implies that the context is graph building.
+ context.context().context_switches.push(default.building_function,
+ default.as_default)
with super(_DefaultGraphStack, self).get_controller(default) as g:
yield g
finally:
context.context().context_switches.pop()
+@tf_contextlib.contextmanager
+def _enter_context_and_graph(context_fn, graph_fn):
+ """Combines two context managers."""
+ with context_fn(), graph_fn():
+ yield
+
+
_default_graph_stack = _DefaultGraphStack()
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 58bead91ed..c9c1a3d66b 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -2305,6 +2305,13 @@ class InitScopeTest(test_util.TensorFlowTestCase):
self.assertEqual(ops.get_name_scope(), "inner")
self.assertEqual(ops.get_name_scope(), "")
+ def testEagerGraphContextsExecuteEagerly(self):
+ with context.eager_mode():
+ with ops.Graph().as_default():
+ with context.graph_mode():
+ with ops.init_scope():
+ self.assertTrue(context.executing_eagerly())
+
def testPreservesNameScopeInEagerExecution(self):
with context.eager_mode():
def foo():
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index 64b0fa6c00..8cf24206ed 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -822,17 +822,32 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name
all-or-nothing.
Args:
- tensor: The rank-1 Tensor to be evaluated.
+ tensor: The rank-0 or rank-1 Tensor to be evaluated.
Returns:
A `TensorShape` based on the constant value of the given `tensor`.
+
+ Raises:
+ ValueError: If the shape is rank-0 and is not statically known to be -1.
"""
if isinstance(tensor, ops.EagerTensor):
return tensor_shape.as_shape(
[dim if dim != -1 else None for dim in tensor.numpy()])
+ if tensor.get_shape().ndims == 0:
+ value = constant_value(tensor)
+ if value is None:
+ raise ValueError(
+ "Received a scalar with unknown value as shape; require a statically "
+ "known scalar with value '-1' to describe an unknown shape.")
+ if value != -1:
+ raise ValueError(
+ "Received a scalar value '%s' as shape; require a statically known "
+ "scalar with value '-1' to describe an unknown shape." % value)
+ return tensor_shape.unknown_shape()
+
shape = tensor.get_shape().with_rank(1)
- if tensor.get_shape() == [0]:
+ if shape == [0]:
return tensor_shape.scalar()
elif tensor.op.type == "Shape":
return tensor.op.inputs[0].get_shape()
diff --git a/tensorflow/python/framework/versions.py b/tensorflow/python/framework/versions.py
index d08b4bf48a..472ccbcac7 100644
--- a/tensorflow/python/framework/versions.py
+++ b/tensorflow/python/framework/versions.py
@@ -31,13 +31,17 @@ __monolithic_build__ = pywrap_tensorflow.__monolithic_build__
VERSION = __version__
tf_export("VERSION", "__version__").export_constant(__name__, "VERSION")
GIT_VERSION = __git_version__
-tf_export("GIT_VERSION").export_constant(__name__, "GIT_VERSION")
+tf_export("GIT_VERSION", "__git_version__").export_constant(
+ __name__, "GIT_VERSION")
COMPILER_VERSION = __compiler_version__
-tf_export("COMPILER_VERSION").export_constant(__name__, "COMPILER_VERSION")
+tf_export("COMPILER_VERSION", "__compiler_version__").export_constant(
+ __name__, "COMPILER_VERSION")
CXX11_ABI_FLAG = __cxx11_abi_flag__
-tf_export("CXX11_ABI_FLAG").export_constant(__name__, "CXX11_ABI_FLAG")
+tf_export("CXX11_ABI_FLAG", "__cxx11_abi_flag__").export_constant(
+ __name__, "CXX11_ABI_FLAG")
MONOLITHIC_BUILD = __monolithic_build__
-tf_export("MONOLITHIC_BUILD").export_constant(__name__, "MONOLITHIC_BUILD")
+tf_export("MONOLITHIC_BUILD", "__monolithic_build__").export_constant(
+ __name__, "MONOLITHIC_BUILD")
GRAPH_DEF_VERSION = pywrap_tensorflow.GRAPH_DEF_VERSION
tf_export("GRAPH_DEF_VERSION").export_constant(__name__, "GRAPH_DEF_VERSION")
diff --git a/tensorflow/python/grappler/cluster.i b/tensorflow/python/grappler/cluster.i
index 067c8213d4..6816e20407 100644
--- a/tensorflow/python/grappler/cluster.i
+++ b/tensorflow/python/grappler/cluster.i
@@ -320,7 +320,8 @@ static PyObject* TF_MeasureCosts(
tensorflow::OpPerformanceList op_performance_data;
tensorflow::StepStats step_stats;
- tensorflow::grappler::MeasuringCostEstimator cost_measure(cluster.get(), 10, 0);
+ const int num_measurements = cluster->type() == "virtual" ? 1 : 10;
+ tensorflow::grappler::MeasuringCostEstimator cost_measure(cluster.get(), num_measurements, 0);
tensorflow::grappler::Costs costs;
tensorflow::Status status = _GetOpPerformanceDataAndRunTime(
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 57f5097639..f6e1d0eec3 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -609,6 +609,7 @@ py_test(
srcs = ["_impl/keras/utils/data_utils_test.py"],
srcs_version = "PY2AND3",
tags = [
+ "no_oss",
"no_windows",
"noasan", # times out
"notsan",
diff --git a/tensorflow/python/keras/_impl/keras/applications/resnet50.py b/tensorflow/python/keras/_impl/keras/applications/resnet50.py
index f8c6aff4f2..c3a92bea89 100644
--- a/tensorflow/python/keras/_impl/keras/applications/resnet50.py
+++ b/tensorflow/python/keras/_impl/keras/applications/resnet50.py
@@ -237,9 +237,8 @@ def ResNet50(include_top=True,
else:
bn_axis = 1
- x = ZeroPadding2D(padding=(3, 3), name='conv1_pad')(img_input)
x = Conv2D(
- 64, (7, 7), strides=(2, 2), padding='valid', name='conv1')(x)
+ 64, (7, 7), strides=(2, 2), padding='same', name='conv1')(img_input)
x = BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
x = Activation('relu')(x)
x = MaxPooling2D((3, 3), strides=(2, 2))(x)
diff --git a/tensorflow/python/keras/_impl/keras/estimator.py b/tensorflow/python/keras/_impl/keras/estimator.py
index 5d370ebbb5..8043242b70 100644
--- a/tensorflow/python/keras/_impl/keras/estimator.py
+++ b/tensorflow/python/keras/_impl/keras/estimator.py
@@ -26,6 +26,7 @@ from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import export as export_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config as run_config_lib
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
@@ -465,11 +466,21 @@ def model_to_estimator(keras_model=None,
estimator = estimator_lib.Estimator(
keras_model_fn, model_dir=model_dir, config=config)
+ old_session = K._SESSION
# Pass the config into keras backend's default session.
sess = session.Session(config=estimator._session_config)
K.set_session(sess)
+ try:
+ keras_weights = keras_model.get_weights()
+ except errors.FailedPreconditionError as e:
+ if old_session is None:
+ raise e
+ logging.warning(
+ 'The Keras backend session has already been '
+ 'set. The _session_config passed to model_to_estimator is not used.')
+ K.set_session(old_session)
+ keras_weights = keras_model.get_weights()
- keras_weights = keras_model.get_weights()
if keras_model._is_graph_network:
# TODO(yifeif): move checkpoint initialization to scaffold.init_fn
_save_first_checkpoint(keras_model,
diff --git a/tensorflow/python/keras/_impl/keras/estimator_test.py b/tensorflow/python/keras/_impl/keras/estimator_test.py
index e076dc25b1..27b7ec7dd4 100644
--- a/tensorflow/python/keras/_impl/keras/estimator_test.py
+++ b/tensorflow/python/keras/_impl/keras/estimator_test.py
@@ -512,6 +512,26 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
._config.gpu_options.per_process_gpu_memory_fraction,
gpu_options.per_process_gpu_memory_fraction)
+ def test_pretrained_weights(self):
+ keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model()
+ keras_model.compile(
+ loss='categorical_crossentropy',
+ optimizer=rmsprop.RMSPropOptimizer(1e-3),
+ metrics=['mse', keras.metrics.categorical_accuracy])
+
+ keras_model.train_on_batch(
+ np.random.random((10,) + _INPUT_SIZE), np.random.random((10,
+ _NUM_CLASS)))
+ weights = keras_model.get_weights()
+ keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model()
+ keras_model.set_weights(weights)
+ keras_model.compile(
+ loss='categorical_crossentropy',
+ optimizer=rmsprop.RMSPropOptimizer(1e-3),
+ metrics=['mse', keras.metrics.categorical_accuracy])
+ keras.estimator.model_to_estimator(
+ keras_model=keras_model, config=self._config)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 6c34ea1816..3033b48977 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1191,6 +1191,22 @@ cuda_py_test(
)
cuda_py_test(
+ name = "inplace_ops_test",
+ size = "small",
+ srcs = ["inplace_ops_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ ],
+ shard_count = 10,
+)
+
+cuda_py_test(
name = "batch_matmul_op_test",
size = "small",
srcs = ["batch_matmul_op_test.py"],
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py
index 18796f7095..749313b00d 100644
--- a/tensorflow/python/kernel_tests/constant_op_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_test.py
@@ -653,12 +653,12 @@ class FillTest(test.TestCase):
self._compareAll([2, 3], np_ans[0][0], np_ans)
def testFillComplex64(self):
- np_ans = np.array([[0.15] * 3] * 2).astype(np.complex64)
- self._compare([2, 3], np_ans[0][0], np_ans, use_gpu=False)
+ np_ans = np.array([[0.15 + 0.3j] * 3] * 2).astype(np.complex64)
+ self._compareAll([2, 3], np_ans[0][0], np_ans)
def testFillComplex128(self):
- np_ans = np.array([[0.15] * 3] * 2).astype(np.complex128)
- self._compare([2, 3], np_ans[0][0], np_ans, use_gpu=False)
+ np_ans = np.array([[0.15 + 0.3j] * 3] * 2).astype(np.complex128)
+ self._compareAll([2, 3], np_ans[0][0], np_ans)
def testFillString(self):
np_ans = np.array([[b"yolo"] * 3] * 2)
diff --git a/tensorflow/python/kernel_tests/inplace_ops_test.py b/tensorflow/python/kernel_tests/inplace_ops_test.py
new file mode 100644
index 0000000000..0f95e13187
--- /dev/null
+++ b/tensorflow/python/kernel_tests/inplace_ops_test.py
@@ -0,0 +1,198 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for inplace_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import inplace_ops
+from tensorflow.python.platform import test as test_lib
+
+
+class InplaceOpsTest(test_util.TensorFlowTestCase):
+
+ def testBasicUpdate(self):
+ for dtype in [dtypes.float32, dtypes.int32, dtypes.int64]:
+ with self.test_session(use_gpu=True):
+ x = array_ops.ones([7, 3], dtype)
+ y = np.ones([7, 3], dtype.as_numpy_dtype)
+ self.assertAllClose(x.eval(), y)
+ x = inplace_ops.inplace_update(x, [3], array_ops.ones([1, 3], dtype))
+ y[3, :] = 1
+ self.assertAllClose(x.eval(), y)
+ x = inplace_ops.inplace_update(x, [-1],
+ array_ops.ones([1, 3], dtype) * 2)
+ y[-1, :] = 2
+ self.assertAllClose(x.eval(), y)
+ x = inplace_ops.inplace_update(x, 5, array_ops.ones([3], dtype) * 7)
+ y[5, :] = 7
+ self.assertAllClose(x.eval(), y)
+
+ def testBasicUpdateBool(self):
+ with self.test_session(use_gpu=True):
+ x = array_ops.ones([7, 3], dtypes.bool)
+ y = np.ones([7, 3], dtypes.bool.as_numpy_dtype)
+ self.assertAllClose(x.eval(), y)
+ x = inplace_ops.inplace_update(x, [3], array_ops.ones([1, 3],
+ dtypes.bool))
+ y[3, :] = True
+ self.assertAllClose(x.eval(), y)
+ x = inplace_ops.inplace_update(x, [-1],
+ array_ops.zeros([1, 3], dtypes.bool))
+ y[-1, :] = False
+ self.assertAllClose(x.eval(), y)
+ x = inplace_ops.inplace_update(x, 5, array_ops.zeros([3], dtypes.bool))
+ y[5, :] = False
+ self.assertAllClose(x.eval(), y)
+
+ def testBasicAdd(self):
+ for dtype in [dtypes.float32, dtypes.int32, dtypes.int64]:
+ with self.test_session(use_gpu=True):
+ x = array_ops.ones([7, 3], dtype)
+ y = np.ones([7, 3], dtype.as_numpy_dtype)
+ self.assertAllClose(x.eval(), y)
+ x = array_ops.inplace_add(x, [3], array_ops.ones([1, 3], dtype))
+ y[3, :] += 1
+ self.assertAllClose(x.eval(), y)
+ x = inplace_ops.inplace_add(x, [-1], array_ops.ones([1, 3], dtype) * 2)
+ y[-1, :] += 2
+ self.assertAllClose(x.eval(), y)
+ x = inplace_ops.inplace_add(x, 5, array_ops.ones([3], dtype) * 7)
+ y[5, :] += 7
+ self.assertAllClose(x.eval(), y)
+ x = inplace_ops.inplace_add(x, None, array_ops.ones([7, 3], dtype) * 99)
+ y[:, :] += 99
+ self.assertAllClose(x.eval(), y)
+
+ def testBasicSub(self):
+ for dtype in [dtypes.float32, dtypes.int32, dtypes.int64]:
+ with self.test_session(use_gpu=True):
+ x = array_ops.ones([7, 3], dtype)
+ y = np.ones([7, 3], dtype.as_numpy_dtype)
+ self.assertAllClose(x.eval(), y)
+ x = inplace_ops.inplace_sub(x, [3], array_ops.ones([1, 3], dtype))
+ y[3, :] -= 1
+ self.assertAllClose(x.eval(), y)
+ x = inplace_ops.inplace_sub(x, [-1], array_ops.ones([1, 3], dtype) * 2)
+ y[-1, :] -= 2
+ self.assertAllClose(x.eval(), y)
+ x = inplace_ops.inplace_sub(x, 5, array_ops.ones([3], dtype) * 7)
+ y[5, :] -= 7
+ self.assertAllClose(x.eval(), y)
+ x = inplace_ops.inplace_sub(x, None, array_ops.ones([7, 3], dtype) * 99)
+ y[:, :] -= 99
+ self.assertAllClose(x.eval(), y)
+
+ def testRandom(self):
+ with self.test_session(use_gpu=True):
+ d0, d1, d2 = 100, 3, 5
+ x = array_ops.zeros([d0, d1, d2])
+ y = np.zeros([d0, d1, d2])
+ for _ in xrange(20):
+ idx = np.random.choice(d0, d0 // 10, replace=False)
+ val = np.random.randint(10, size=(d0 // 10, d1, d2))
+ op = np.random.randint(3)
+ if op == 0:
+ x = inplace_ops.inplace_update(x, idx, val)
+ y[idx, :] = val
+ elif op == 1:
+ x = inplace_ops.inplace_add(x, idx, val)
+ y[idx, :] += val
+ elif op == 2:
+ x = inplace_ops.inplace_sub(x, idx, val)
+ y[idx, :] -= val
+ self.assertAllClose(x.eval(), y)
+
+ def testRandom1D(self):
+ with self.test_session(use_gpu=True):
+ d0 = 100
+ x = array_ops.zeros([d0])
+ y = np.zeros([d0])
+ for _ in xrange(20):
+ idx = np.random.choice(d0, d0 // 10, replace=False)
+ val = np.random.randint(10, size=(d0 // 10))
+ op = np.random.randint(3)
+ if op == 0:
+ x = inplace_ops.inplace_update(x, idx, val)
+ y[idx] = val
+ elif op == 1:
+ x = inplace_ops.inplace_add(x, idx, val)
+ y[idx] += val
+ elif op == 2:
+ x = inplace_ops.inplace_sub(x, idx, val)
+ y[idx] -= val
+ self.assertAllClose(x.eval(), y)
+
+ def testAlias(self):
+ with self.test_session(use_gpu=True) as sess:
+ x = array_ops.ones([2, 3])
+ y = inplace_ops.alias_inplace_add(x, [0], [[1, 2, 3]])
+ with ops.control_dependencies([y]):
+ z = array_ops.identity(x)
+ _, vy, vz = sess.run([x, y, z])
+ self.assertAllClose(vy, vz)
+
+ def testError(self):
+ with self.test_session():
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "must be a vector"):
+ _ = inplace_ops.inplace_update([[1.]], [[0]], [[10]]).eval()
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "x and v shape doesn't match"):
+ _ = inplace_ops.inplace_update([[1.]], [0], [10]).eval()
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "i and x shape doesn't match"):
+ _ = inplace_ops.inplace_update([[1.]], [0, 1], [[10]]).eval()
+
+ def testEmpty(self):
+ for dtype in [
+ dtypes.float32, dtypes.float64, dtypes.int32, dtypes.int64, dtypes.bool
+ ]:
+ with self.test_session(use_gpu=True):
+ test_shapes = [(), (1,), (2, 3), (0, 2), (2, 3, 5), (2, 0, 5)]
+ for shape in test_shapes:
+ val = inplace_ops.empty(shape, dtype).eval()
+ self.assertEqual(val.shape, shape)
+ self.assertEqual(val.dtype, dtype.as_numpy_dtype)
+ val = inplace_ops.empty(shape, dtype, init=True).eval()
+ self.assertEqual(val.shape, shape)
+ self.assertEqual(val.dtype, dtype.as_numpy_dtype)
+ self.assertAllEqual(val, np.zeros(shape, dtype.as_numpy_dtype))
+ val = inplace_ops.empty_like(array_ops.zeros(shape, dtype)).eval()
+ self.assertEqual(val.shape, shape)
+ self.assertEqual(val.dtype, dtype.as_numpy_dtype)
+ val = inplace_ops.empty_like(
+ array_ops.zeros(shape, dtype), init=True).eval()
+ self.assertEqual(val.shape, shape)
+ self.assertEqual(val.dtype, dtype.as_numpy_dtype)
+ self.assertAllEqual(val, np.zeros(shape, dtype.as_numpy_dtype))
+
+ val = inplace_ops.empty((1, 2), dtypes.string, init=True).eval()
+ self.assertEqual(val.tolist(), [[b"", b""]])
+
+ val = inplace_ops.empty((1, 2), dtypes.string, init=False).eval()
+ self.assertEqual(val.tolist(), [[b"", b""]])
+
+
+if __name__ == "__main__":
+ test_lib.main()
diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py
index dbbed39c72..6173a1def3 100644
--- a/tensorflow/python/kernel_tests/list_ops_test.py
+++ b/tensorflow/python/kernel_tests/list_ops_test.py
@@ -33,6 +33,8 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
@@ -43,71 +45,83 @@ def scalar_shape():
class ListOpsTest(test_util.TensorFlowTestCase):
+ @test_util.run_in_graph_and_eager_modes()
def testPushPop(self):
l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
element_shape=scalar_shape())
l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
- self.assertAllEqual(e, 1.0)
+ self.assertAllEqual(self.evaluate(e), 1.0)
+ @test_util.run_in_graph_and_eager_modes()
def testPushPopGPU(self):
if not context.num_gpus():
return
with context.device("gpu:0"):
self.testPushPop()
+ @test_util.run_in_graph_and_eager_modes()
def testStack(self):
l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
element_shape=scalar_shape())
l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
- self.assertAllEqual(t, [1.0, 2.0])
+ self.assertAllEqual(self.evaluate(t), [1.0, 2.0])
+ @test_util.run_in_graph_and_eager_modes()
def testStackGPU(self):
if not context.num_gpus():
return
with context.device("gpu:0"):
self.testStack()
+ @test_util.run_in_graph_and_eager_modes()
def testTensorListFromTensor(self):
t = constant_op.constant([1.0, 2.0])
l = list_ops.tensor_list_from_tensor(t, element_shape=scalar_shape())
l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
- self.assertAllEqual(e, 2.0)
+ self.assertAllEqual(self.evaluate(e), 2.0)
l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
- self.assertAllEqual(e, 1.0)
- self.assertAllEqual(list_ops.tensor_list_length(l), 0)
+ self.assertAllEqual(self.evaluate(e), 1.0)
+ self.assertAllEqual(self.evaluate(list_ops.tensor_list_length(l)), 0)
+ @test_util.run_in_graph_and_eager_modes()
def testFromTensorGPU(self):
if not context.num_gpus():
return
with context.device("gpu:0"):
self.testTensorListFromTensor()
+ @test_util.run_in_graph_and_eager_modes()
def testGetSetItem(self):
t = constant_op.constant([1.0, 2.0])
l = list_ops.tensor_list_from_tensor(t, element_shape=scalar_shape())
e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
- self.assertAllEqual(e0, 1.0)
+ self.assertAllEqual(self.evaluate(e0), 1.0)
l = list_ops.tensor_list_set_item(l, 0, 3.0)
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
- self.assertAllEqual(t, [3.0, 2.0])
+ self.assertAllEqual(self.evaluate(t), [3.0, 2.0])
+ @test_util.run_in_graph_and_eager_modes()
def testGetSetGPU(self):
if not context.num_gpus():
return
with context.device("gpu:0"):
self.testGetSetItem()
+ @test_util.run_in_graph_and_eager_modes()
def testUnknownShape(self):
- l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
- element_shape=-1)
+ l = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32, element_shape=-1)
l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
l = list_ops.tensor_list_push_back(l, constant_op.constant([1.0, 2.0]))
- _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
- self.assertAllEqual(e, [1.0, 2.0])
+ l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
+ self.assertAllEqual(self.evaluate(e), [1.0, 2.0])
+ l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
+ self.assertAllEqual(self.evaluate(e), 1.0)
+ @test_util.run_in_graph_and_eager_modes()
def testCPUGPUCopy(self):
if not context.num_gpus():
return
@@ -116,15 +130,16 @@ class ListOpsTest(test_util.TensorFlowTestCase):
with context.device("gpu:0"):
l_gpu = array_ops.identity(l)
self.assertAllEqual(
- list_ops.tensor_list_pop_back(
- l_gpu, element_dtype=dtypes.float32)[1],
- 2.0)
+ self.evaluate(
+ list_ops.tensor_list_pop_back(
+ l_gpu, element_dtype=dtypes.float32)[1]), 2.0)
l_cpu = array_ops.identity(l_gpu)
self.assertAllEqual(
- list_ops.tensor_list_pop_back(
- l_cpu, element_dtype=dtypes.float32)[1],
- 2.0)
+ self.evaluate(
+ list_ops.tensor_list_pop_back(
+ l_cpu, element_dtype=dtypes.float32)[1]), 2.0)
+ @test_util.run_in_graph_and_eager_modes()
def testGraphStack(self):
with context.graph_mode(), self.test_session():
tl = list_ops.empty_tensor_list(
@@ -132,9 +147,11 @@ class ListOpsTest(test_util.TensorFlowTestCase):
element_dtype=dtypes.int32)
tl = list_ops.tensor_list_push_back(tl, [1])
self.assertAllEqual(
- list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32).eval(),
+ self.evaluate(
+ list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32)),
[[1]])
+ @test_util.run_in_graph_and_eager_modes()
def testGraphStackInLoop(self):
with context.graph_mode(), self.test_session():
t1 = list_ops.empty_tensor_list(
@@ -149,9 +166,10 @@ class ListOpsTest(test_util.TensorFlowTestCase):
i, t1 = control_flow_ops.while_loop(lambda i, t1: math_ops.less(i, 4),
body, [i, t1])
- s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.int32).eval()
- self.assertAllEqual(s1, [0, 1, 2, 3])
+ s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.int32)
+ self.assertAllEqual(self.evaluate(s1), [0, 1, 2, 3])
+ @test_util.run_in_graph_and_eager_modes()
def testGraphStackSwitchDtype(self):
with context.graph_mode(), self.test_session():
list_ = list_ops.empty_tensor_list(
@@ -169,11 +187,11 @@ class ListOpsTest(test_util.TensorFlowTestCase):
for _ in range(2):
list_, m = body(list_, m)
- s1 = list_ops.tensor_list_stack(
- list_, element_dtype=dtypes.float32).eval()
+ s1 = list_ops.tensor_list_stack(list_, element_dtype=dtypes.float32)
np_s1 = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32)
- self.assertAllEqual(s1, np_s1)
+ self.assertAllEqual(self.evaluate(s1), np_s1)
+ @test_util.run_in_graph_and_eager_modes()
def testGraphStackInLoopSwitchDtype(self):
with context.graph_mode(), self.test_session():
t1 = list_ops.empty_tensor_list(
@@ -193,10 +211,11 @@ class ListOpsTest(test_util.TensorFlowTestCase):
i, m, t1 = control_flow_ops.while_loop(
lambda i, m, t1: math_ops.less(i, 4), body, [i, m, t1])
- s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.float32).eval()
+ s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.float32)
np_s1 = np.vstack([np.arange(1, 4) * i for i in range(4)])
- self.assertAllEqual(s1, np_s1)
+ self.assertAllEqual(self.evaluate(s1), np_s1)
+ @test_util.run_in_graph_and_eager_modes()
def testSerialize(self):
# pylint: disable=g-import-not-at-top
try:
@@ -226,8 +245,9 @@ class ListOpsTest(test_util.TensorFlowTestCase):
l_ps, element_dtype=dtypes.float32)
with ops.device("/job:worker"):
worker_e = array_ops.identity(e)
- self.assertAllEqual(worker_e.eval(), [2.0])
+ self.assertAllEqual(self.evaluate(worker_e), [2.0])
+ @test_util.run_in_graph_and_eager_modes()
def testPushPopGradients(self):
with backprop.GradientTape() as tape:
l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
@@ -237,18 +257,21 @@ class ListOpsTest(test_util.TensorFlowTestCase):
l = list_ops.tensor_list_push_back(l, c)
l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
e = 2 * e
- self.assertAllEqual(tape.gradient(e, [c])[0], 2.0)
+ self.assertAllEqual(self.evaluate(tape.gradient(e, [c])[0]), 2.0)
+ @test_util.run_in_graph_and_eager_modes()
def testStackFromTensorGradients(self):
with backprop.GradientTape() as tape:
c = constant_op.constant([1.0, 2.0])
tape.watch(c)
l = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
c2 = list_ops.tensor_list_stack(
- l, element_dtype=dtypes.float32)
+ l, element_dtype=dtypes.float32, num_elements=2)
result = c2 * 2.0
- self.assertAllEqual(tape.gradient(result, [c])[0], [2.0, 2.0])
+ grad = tape.gradient(result, [c])[0]
+ self.assertAllEqual(self.evaluate(grad), [2.0, 2.0])
+ @test_util.run_in_graph_and_eager_modes()
def testGetSetGradients(self):
with backprop.GradientTape() as tape:
c = constant_op.constant([1.0, 2.0])
@@ -261,16 +284,40 @@ class ListOpsTest(test_util.TensorFlowTestCase):
ee = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32)
y = e * e + ee * ee
grad_c, grad_c2 = tape.gradient(y, [c, c2])
- self.assertAllEqual(grad_c, [0.0, 4.0])
- self.assertAllEqual(grad_c2, 6.0)
+ self.assertAllEqual(self.evaluate(grad_c), [0.0, 4.0])
+ self.assertAllEqual(self.evaluate(grad_c2), 6.0)
+ @test_util.run_in_graph_and_eager_modes()
def testSetOutOfBounds(self):
c = constant_op.constant([1.0, 2.0])
l = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
with self.assertRaises(errors.InvalidArgumentError):
- list_ops.tensor_list_set_item(l, 20, 3.0)
+ self.evaluate(list_ops.tensor_list_set_item(l, 20, 3.0))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testResourceVariableScatterGather(self):
+ c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32)
+ l = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
+ v = vs.get_variable("var", initializer=[l] * 10, use_resource=True)
+ v_r_0_stacked = list_ops.tensor_list_stack(v[0], dtypes.float32)
+ self.evaluate(v.initializer)
+ self.assertAllEqual([1.0, 2.0], self.evaluate(v_r_0_stacked))
+ v_r_sparse_stacked = list_ops.tensor_list_stack(
+ v.sparse_read(0), dtypes.float32)
+ self.assertAllEqual([1.0, 2.0], self.evaluate(v_r_sparse_stacked))
+ l_new_0 = list_ops.tensor_list_from_tensor(
+ [3.0, 4.0], element_shape=scalar_shape())
+ l_new_1 = list_ops.tensor_list_from_tensor(
+ [5.0, 6.0], element_shape=scalar_shape())
+ updated_v = state_ops.scatter_update(v, [3, 5], [l_new_0, l_new_1])
+ updated_v_elems = array_ops.unstack(updated_v)
+ updated_v_stacked = [
+ list_ops.tensor_list_stack(el, dtypes.float32) for el in updated_v_elems
+ ]
+ expected = ([[1.0, 2.0]] * 3 + [[3.0, 4.0], [1.0, 2.0], [5.0, 6.0]] +
+ [[1.0, 2.0]] * 4)
+ self.assertAllEqual(self.evaluate(updated_v_stacked), expected)
if __name__ == "__main__":
- ops.enable_eager_execution()
test.main()
diff --git a/tensorflow/python/ops/batch_norm_benchmark.py b/tensorflow/python/ops/batch_norm_benchmark.py
index 5d68b47aea..d83b819097 100644
--- a/tensorflow/python/ops/batch_norm_benchmark.py
+++ b/tensorflow/python/ops/batch_norm_benchmark.py
@@ -25,6 +25,7 @@ import time
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradients_impl
@@ -39,7 +40,7 @@ from tensorflow.python.platform import test
def batch_norm_op(tensor, mean, variance, beta, gamma, scale):
"""Fused kernel for batch normalization."""
# _batch_norm_with_global_normalization is deprecated in v9
- ops.get_default_graph().graph_def_versions.producer = 8
+ test_util.set_producer_version(ops.get_default_graph(), 8)
# pylint: disable=protected-access
return gen_nn_ops._batch_norm_with_global_normalization(
tensor, mean, variance, beta, gamma, 0.001, scale)
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 710287012e..fb53d9ffea 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -1598,6 +1598,16 @@ class ControlFlowContext(object):
last_context = self._context_stack.pop()
graph._set_control_flow_context(last_context)
+ def EnterGradientColocation(self, op, gradient_uid):
+ """Start building a gradient colocated with an op."""
+ if self._outer_context:
+ self._outer_context.EnterGradientColocation(op, gradient_uid)
+
+ def ExitGradientColocation(self, op, gradient_uid):
+ """Start building a gradient colocated with an op."""
+ if self._outer_context:
+ self._outer_context.ExitGradientColocation(op, gradient_uid)
+
def ExitResult(self, result):
"""Make a list of tensors available in the outer context."""
if self._outer_context:
@@ -3184,12 +3194,18 @@ def while_loop(cond,
body = lambda i, lv: (i + 1, orig_body(*lv))
if context.executing_eagerly():
+ try_to_pack = len(loop_vars) == 1
+ packed = False # whether the body result was packed into a 1-item tuple
+
while cond(*loop_vars):
loop_vars = body(*loop_vars)
+ if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
+ packed = True
+ loop_vars = (loop_vars,)
if maximum_iterations is not None:
return loop_vars[1]
else:
- return loop_vars
+ return loop_vars[0] if packed else loop_vars
if shape_invariants is not None:
if maximum_iterations is not None:
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py
index f22f3059d1..289df6f301 100644
--- a/tensorflow/python/ops/control_flow_ops_test.py
+++ b/tensorflow/python/ops/control_flow_ops_test.py
@@ -947,5 +947,28 @@ class CaseTest(test_util.TensorFlowTestCase):
sess.run(output, feed_dict={x: 4})
+@test_util.with_c_api
+class WhileLoopTestCase(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testWhileLoopWithSingleVariable(self):
+ i = constant_op.constant(0)
+ c = lambda i: math_ops.less(i, 10)
+ b = lambda i: math_ops.add(i, 1)
+ r = control_flow_ops.while_loop(c, b, [i])
+
+ self.assertEqual(self.evaluate(r), 10)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testEagerWhileLoopWithSingleVariable_bodyReturnsTuple(self):
+ i = constant_op.constant(0)
+ c = lambda i: math_ops.less(i, 10)
+ b = lambda i: (math_ops.add(i, 1),)
+ r = control_flow_ops.while_loop(c, b, [i])
+
+ # Expect a tuple since that is what the body returns.
+ self.assertEqual(self.evaluate(r), (10,))
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 44473ec69c..13420b7f0e 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -208,7 +208,10 @@ def _AsList(x):
return x if isinstance(x, (list, tuple)) else [x]
-def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops):
+def _DefaultGradYs(grad_ys,
+ ys,
+ colocate_gradients_with_ops,
+ gradient_uid="__unsupported__"):
"""Fill in default values for grad_ys.
Args:
@@ -216,6 +219,9 @@ def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops):
ys: List of tensors.
colocate_gradients_with_ops: If True, try colocating gradients with
the corresponding op.
+ gradient_uid: A unique identifier within the graph indicating
+ which invocation of gradients is being executed. Used to cluster
+ ops for compilation.
Returns:
A list of gradients to use, without None.
@@ -231,7 +237,7 @@ def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops):
for i in xrange(len(grad_ys)):
grad_y = grad_ys[i]
y = ys[i]
- with _maybe_colocate_with(y.op, colocate_gradients_with_ops):
+ with _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops):
if grad_y is None:
if y.dtype.is_complex:
raise TypeError(
@@ -338,10 +344,10 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count):
@contextlib.contextmanager
-def _maybe_colocate_with(op, colocate_gradients_with_ops):
+def _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops): # pylint: disable=invalid-name
"""Context to colocate with `op` if `colocate_gradients_with_ops`."""
if colocate_gradients_with_ops:
- with ops.colocate_with(op):
+ with ops._colocate_with_for_gradient(op, gradient_uid): # pylint: disable=protected-access
yield
else:
yield
@@ -506,6 +512,9 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
with ops.name_scope(
name, "gradients",
list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope:
+ # Get a uid for this call to gradients that can be used to help
+ # cluster ops for compilation.
+ gradient_uid = ops.get_default_graph().unique_name("uid")
ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
xs = [
x.handle if resource_variable_ops.is_resource_variable(x) else x
@@ -513,7 +522,8 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
]
xs = ops.internal_convert_n_to_tensor_or_indexed_slices(
xs, name="x", as_ref=True)
- grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops)
+ grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops,
+ gradient_uid)
# The approach we take here is as follows: Create a list of all ops in the
# subgraph between the ys and xs. Visit these ops in reverse order of ids
@@ -570,10 +580,11 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
while queue:
# generate gradient subgraph for op.
op = queue.popleft()
- with _maybe_colocate_with(op, colocate_gradients_with_ops):
+ with _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops):
if loop_state:
loop_state.EnterGradWhileContext(op, before=True)
- out_grads = _AggregatedGrads(grads, op, loop_state, aggregation_method)
+ out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state,
+ aggregation_method)
if loop_state:
loop_state.ExitGradWhileContext(op, before=True)
@@ -633,7 +644,10 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
if gate_gradients and len([x for x in in_grads
if x is not None]) > 1:
with ops.device(None):
- with ops.colocate_with(None, ignore_existing=True):
+ with ops._colocate_with_for_gradient( # pylint: disable=protected-access
+ None,
+ gradient_uid,
+ ignore_existing=True):
in_grads = control_flow_ops.tuple(in_grads)
_LogOpGradients(op, out_grads, in_grads)
else:
@@ -789,7 +803,7 @@ def _LogOpGradients(op, out_grads, in_grads):
", ".join([x.name for x in in_grads if _FilterGrad(x)]))
-def _MultiDeviceAddN(tensor_list):
+def _MultiDeviceAddN(tensor_list, gradient_uid):
"""Adds tensors from potentially multiple devices."""
# Basic function structure comes from control_flow_ops.group().
# Sort tensors according to their devices.
@@ -808,7 +822,10 @@ def _MultiDeviceAddN(tensor_list):
for dev in sorted(six.iterkeys(tensors_on_device), key=DeviceKey):
tensors = tensors_on_device[dev]
- with ops.colocate_with(tensors[0].op, ignore_existing=True):
+ with ops._colocate_with_for_gradient( # pylint: disable=protected-access
+ tensors[0].op,
+ gradient_uid,
+ ignore_existing=True):
summands.append(math_ops.add_n(tensors))
return math_ops.add_n(summands)
@@ -834,12 +851,19 @@ class AggregationMethod(object):
EXPERIMENTAL_ACCUMULATE_N = 2
-def _AggregatedGrads(grads, op, loop_state, aggregation_method=None):
+def _AggregatedGrads(grads,
+ op,
+ gradient_uid,
+ loop_state,
+ aggregation_method=None):
"""Get the aggregated gradients for op.
Args:
grads: The map of memoized gradients.
op: The op to get gradients for.
+ gradient_uid: A unique identifier within the graph indicating
+ which invocation of gradients is being executed. Used to cluster
+ ops for compilation.
loop_state: An object for maintaining the state of the while loops in the
graph. It is of type ControlFlowState. None if the graph
contains no while loops.
@@ -916,7 +940,7 @@ def _AggregatedGrads(grads, op, loop_state, aggregation_method=None):
out_grads[i] = running_sum
else:
used = "add_n"
- out_grads[i] = _MultiDeviceAddN(out_grad)
+ out_grads[i] = _MultiDeviceAddN(out_grad, gradient_uid)
logging.vlog(2, " _AggregatedGrads %d x %s using %s", len(out_grad),
tensor_shape, used)
else:
diff --git a/tensorflow/python/ops/inplace_ops.py b/tensorflow/python/ops/inplace_ops.py
new file mode 100644
index 0000000000..e5b000086b
--- /dev/null
+++ b/tensorflow/python/ops/inplace_ops.py
@@ -0,0 +1,227 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Inplace operations.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import math_ops
+
+
+def _inplace_helper(x, i, v, op):
+ """Applies an inplace op on (x, i, v).
+
+ op is one of gen_array_ops.alias_inplace_update,
+ gen_array_ops.alias_inplace_add, or gen_array_ops.alias_inplace_sub.
+
+ If i is None, x and v must be the same shape. Computes
+ x op v;
+ If i is a scalar, x has a rank 1 higher than v's. Computes
+ x[i, :] op v;
+ Otherwise, x and v must have the same rank. Computes
+ x[i, :] op v;
+
+ Args:
+ x: A Tensor.
+ i: None, a scalar or a vector.
+ v: A Tensor.
+ op: alias_inplace_update, alias_inplace_add, or alias_inplace_sub.
+
+ Returns:
+ Returns x.
+
+ """
+ x = ops.convert_to_tensor(x)
+ v = ops.convert_to_tensor(v, x.dtype)
+ if i is None:
+ # Full tensor.
+ return array_ops.reshape(
+ op(array_ops.reshape(x, [1, -1]), [0], array_ops.reshape(v, [1, -1])),
+ array_ops.shape(x))
+ i = math_ops.to_int32(i)
+ if i.get_shape().ndims == 0:
+ # Single 0-dim update.
+ return op(x, array_ops.reshape(i, [1]), array_ops.expand_dims(v, 0))
+ return op(x, i, v)
+
+
+def alias_inplace_update(x, i, v):
+ """Applies an inplace update on input x at index i with value v. Aliases x.
+
+ If i is None, x and v must be the same shape. Computes
+ x = v;
+ If i is a scalar, x has a rank 1 higher than v's. Computes
+ x[i, :] = v;
+ Otherwise, x and v must have the same rank. Computes
+ x[i, :] = v;
+
+ Args:
+ x: A Tensor.
+ i: None, a scalar or a vector.
+ v: A Tensor.
+
+ Returns:
+ Returns x.
+
+ """
+ return _inplace_helper(x, i, v, gen_array_ops.inplace_update)
+
+
+def alias_inplace_add(x, i, v):
+ """Applies an inplace add on input x at index i with value v. Aliases x.
+
+ If i is None, x and v must be the same shape. Computes
+ x += v;
+ If i is a scalar, x has a rank 1 higher than v's. Computes
+ x[i, :] += v;
+ Otherwise, x and v must have the same rank. Computes
+ x[i, :] += v;
+
+ Args:
+ x: A Tensor.
+ i: None, a scalar or a vector.
+ v: A Tensor.
+
+ Returns:
+ Returns x.
+
+ """
+ return _inplace_helper(x, i, v, gen_array_ops.inplace_add)
+
+
+def alias_inplace_sub(x, i, v):
+ """Applies an inplace sub on input x at index i with value v. Aliases x.
+
+ If i is None, x and v must be the same shape. Computes
+ x -= v;
+ If i is a scalar, x has a rank 1 higher than v's. Computes
+ x[i, :] -= v;
+ Otherwise, x and v must have the same rank. Computes
+ x[i, :] -= v;
+
+ Args:
+ x: A Tensor.
+ i: None, a scalar or a vector.
+ v: A Tensor.
+
+ Returns:
+ Returns x.
+
+ """
+ return _inplace_helper(x, i, v, gen_array_ops.inplace_sub)
+
+
+def empty_like(x, init=None):
+ """Returns a non-initialized tensor with the same shape and dtype as x.
+
+ Args:
+ x: A Tensor.
+ init: Initialize the returned tensor with the default value of
+ x.dtype(), if True. Otherwise, do not initialize. Defaults to
+ None.
+
+ Returns:
+ A tensor y, whose dtype and shape are the same as those of x.
+ y is guaranteed not to be an alias of x. Upon return, y may contain
+ arbitrary data.
+
+ """
+ x = ops.convert_to_tensor(x)
+ return gen_array_ops.empty(array_ops.shape(x), x.dtype, init=init)
+
+
+def inplace_update(x, i, v):
+ """Applies an inplace update on input x at index i with value v.
+
+ Note that this function is not actually inplace - it allocates
+ a copy of x. The utility is not avoiding memory copies but rather
+ specifying a sparse update.
+
+ If i is None, x and v must be the same shape. Computes
+ y = x; y = v;
+ If i is a scalar, x has a rank 1 higher than v's. Computes
+ y = x; y[i, :] = v;
+ Otherwise, x and v must have the same rank. Computes
+ y = x; y[i, :] = v;
+
+ Args:
+ x: A Tensor.
+ i: None, a scalar or a vector.
+ v: A Tensor.
+
+ Returns:
+ Returns y, which is guaranteed not to be an alias of x.
+
+ """
+ return alias_inplace_update(gen_array_ops.deep_copy(x), i, v)
+
+
+def inplace_add(x, i, v):
+ """Applies an inplace add on input x at index i with value v.
+
+ Note that this function is not actually inplace - it allocates
+ a copy of x. The utility is not avoiding memory copies but rather
+ specifying a sparse update.
+
+ If i is None, x and v must be the same shape. Computes
+ y = x; y += v;
+ If i is a scalar, x has a rank 1 higher than v's. Computes
+ y = x; y[i, :] += v;
+ Otherwise, x and v must have the same rank. Computes
+ y = x; y[i, :] += v;
+
+ Args:
+ x: A Tensor.
+ i: None, a scalar or a vector.
+ v: A Tensor.
+
+ Returns:
+ Returns y, which is guaranteed not to be an alias of x.
+
+ """
+ return alias_inplace_add(gen_array_ops.deep_copy(x), i, v)
+
+
+def inplace_sub(x, i, v):
+ """Applies an inplace sub on input x at index i with value v.
+
+ Note that this function is not actually inplace - it allocates
+ a copy of x. The utility is not avoiding memory copies but rather
+ specifying a sparse update.
+
+ If i is None, x and v must be the same shape. Computes
+ y = x; y -= v;
+ If i is a scalar, x has a rank 1 higher than v's. Computes
+ y = x; y[i, :] -= v;
+ Otherwise, x and v must have the same rank. Computes
+ y = x; y[i, :] -= v;
+
+ Args:
+ x: A Tensor.
+ i: None, a scalar or a vector.
+ v: A Tensor.
+
+ Returns:
+ Returns y, which is guaranteed not to be an alias of x.
+
+ """
+ return alias_inplace_sub(gen_array_ops.deep_copy(x), i, v)
+
+empty = gen_array_ops.empty
diff --git a/tensorflow/python/ops/list_ops.py b/tensorflow/python/ops/list_ops.py
index bba59ebcef..bdf0774bbf 100644
--- a/tensorflow/python/ops/list_ops.py
+++ b/tensorflow/python/ops/list_ops.py
@@ -54,8 +54,8 @@ def _TensorListStackGrad(unused_op, dtensor):
@ops.RegisterGradient("TensorListFromTensor")
def _TensorListFromTensorGrad(op, dlist):
"""Gradient for TensorListFromTensor."""
- if op.inputs[0].shape[0] is not None:
- num_elements = op.inputs[0].shape[0]
+ if op.inputs[0].shape[0].value is not None:
+ num_elements = op.inputs[0].shape[0].value
else:
num_elements = None
if dlist is None:
@@ -63,9 +63,10 @@ def _TensorListFromTensorGrad(op, dlist):
element_dtype=op.inputs[0].dtype,
element_shape=gen_list_ops.tensor_list_element_shape(
op.outputs[0], shape_type=dtypes.int32))
- return gen_list_ops.tensor_list_stack(
- dlist, element_dtype=op.inputs[0].dtype,
- num_elements=num_elements)
+ tensor_grad = gen_list_ops.tensor_list_stack(
+ dlist, element_dtype=op.inputs[0].dtype, num_elements=num_elements)
+ shape_grad = None
+ return tensor_grad, shape_grad
@ops.RegisterGradient("TensorListGetItem")
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index fe380c44da..cbc2dcf419 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -1206,7 +1206,16 @@ class DeviceWrapper(RNNCell):
@tf_export("nn.rnn_cell.MultiRNNCell")
class MultiRNNCell(RNNCell):
- """RNN cell composed sequentially of multiple simple cells."""
+ """RNN cell composed sequentially of multiple simple cells.
+
+ Example:
+
+ ```python
+ num_units = [128, 64]
+ cells = [BasicLSTMCell(num_units=n) for n in num_units]
+ stacked_rnn_cell = MultiRNNCell(cells)
+ ```
+ """
def __init__(self, cells, state_is_tuple=True):
"""Create a RNN cell composed sequentially of a number of RNNCells.
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 3fd9275289..1dc7f991b3 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -297,7 +297,6 @@ CUDNN_DNN_ROUTINE_EACH_R7(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
namespace {
-// Forward declaration.
cudnnDataType_t GetRnnComputeType(dnn::DataType data_type);
cudnnHandle_t ToHandle(void* opaque_handle) {
@@ -478,6 +477,13 @@ port::Status CudnnSupport::Init() {
ToString(status))};
}
+port::StatusOr<std::tuple<int, int, int>> CudnnSupport::GetVersion() {
+ CudnnVersion version;
+ TF_RETURN_IF_ERROR(GetLoadedCudnnVersion(&version));
+ return std::make_tuple(version.major_version, version.minor_version,
+ version.patch_level);
+}
+
// Turns a BatchDescriptor structure into a cudnn tensor handle within a scope.
class ScopedTensorDescriptor {
public:
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
index e40ba9b012..0e5368aca8 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -46,6 +46,7 @@ class CudnnSupport : public dnn::DnnSupport {
~CudnnSupport() override;
port::Status Init() override;
+ port::StatusOr<std::tuple<int, int, int>> GetVersion() override;
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
int num_layers, int hidden_size, int input_size,
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 43cfd313c1..3c47d2c2e8 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -25,6 +25,7 @@ limitations under the License.
#include <functional>
#include <limits>
#include <memory>
+#include <tuple>
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/lib/array_slice.h"
@@ -885,6 +886,12 @@ class DnnSupport {
virtual port::Status Init() = 0;
+ // Gets the version of the backing library, as a {major, minor, patch} tuple.
+ virtual port::StatusOr<std::tuple<int, int, int>> GetVersion() {
+ return port::UnimplementedError(
+ "DnnSupport::GetVersion not implemented on this platform.");
+ }
+
// Performs a single-precision forward batch normalization operation onto
// the stream.
//
diff --git a/tensorflow/tools/api/generator/BUILD b/tensorflow/tools/api/generator/BUILD
index 9f1bdd8aae..a1c569951e 100644
--- a/tensorflow/tools/api/generator/BUILD
+++ b/tensorflow/tools/api/generator/BUILD
@@ -32,6 +32,7 @@ genrule(
# api/module1/module2/__init__.py and api/module3/__init__.py.
# keep sorted
outs = [
+ # BEGIN GENERATED FILES
"api/__init__.py",
"api/app/__init__.py",
"api/bitwise/__init__.py",
@@ -117,6 +118,7 @@ genrule(
"api/train/__init__.py",
"api/train/queue_runner/__init__.py",
"api/user_ops/__init__.py",
+ # END GENERATED FILES
],
cmd = "$(location create_python_api) $(OUTS)",
tools = ["create_python_api"],
diff --git a/tensorflow/tools/api/generator/create_python_api.py b/tensorflow/tools/api/generator/create_python_api.py
index 70f9776b08..c7748f5b7a 100644
--- a/tensorflow/tools/api/generator/create_python_api.py
+++ b/tensorflow/tools/api/generator/create_python_api.py
@@ -67,18 +67,23 @@ def format_import(source_module_name, source_name, dest_name):
return 'import %s as %s' % (source_name, dest_name)
-class _ModuleImportsBuilder(object):
+class _ModuleInitCodeBuilder(object):
"""Builds a map from module name to imports included in that module."""
def __init__(self):
- self.module_imports = collections.defaultdict(list)
- self._seen_api_names = set()
+ self.module_imports = collections.defaultdict(
+ lambda: collections.defaultdict(set))
+ self._dest_import_to_id = collections.defaultdict(int)
+ # Names that start with underscore in the root module.
+ self._underscore_names_in_root = []
def add_import(
- self, dest_module_name, source_module_name, source_name, dest_name):
+ self, symbol_id, dest_module_name, source_module_name, source_name,
+ dest_name):
"""Adds this import to module_imports.
Args:
+ symbol_id: (number) Unique identifier of the symbol to import.
dest_module_name: (string) Module name to add import to.
source_module_name: (string) Module to import from.
source_name: (string) Name of the symbol to import.
@@ -89,34 +94,67 @@ class _ModuleImportsBuilder(object):
dest_name has already been added to dest_module_name.
"""
import_str = format_import(source_module_name, source_name, dest_name)
- if import_str in self.module_imports[dest_module_name]:
- return
# Check if we are trying to expose two different symbols with same name.
full_api_name = dest_name
if dest_module_name:
full_api_name = dest_module_name + '.' + full_api_name
- if full_api_name in self._seen_api_names:
+ if (full_api_name in self._dest_import_to_id and
+ symbol_id != self._dest_import_to_id[full_api_name] and
+ symbol_id != -1):
raise SymbolExposedTwiceError(
'Trying to export multiple symbols with same name: %s.' %
full_api_name)
- self._seen_api_names.add(full_api_name)
+ self._dest_import_to_id[full_api_name] = symbol_id
- self.module_imports[dest_module_name].append(import_str)
+ if not dest_module_name and dest_name.startswith('_'):
+ self._underscore_names_in_root.append(dest_name)
+ # The same symbol can be available in multiple modules.
+ # We store all possible ways of importing this symbol and later pick just
+ # one.
+ self.module_imports[dest_module_name][full_api_name].add(import_str)
-def get_api_imports():
- """Get a map from destination module to formatted imports.
+ def build(self):
+ """Get a map from destination module to __init__.py code for that module.
+
+ Returns:
+ A dictionary where
+ key: (string) destination module (for e.g. tf or tf.consts).
+ value: (string) text that should be in __init__.py files for
+ corresponding modules.
+ """
+ module_text_map = {}
+ for dest_module, dest_name_to_imports in self.module_imports.items():
+ # Sort all possible imports for a symbol and pick the first one.
+ imports_list = [
+ sorted(imports)[0]
+ for _, imports in dest_name_to_imports.items()]
+ module_text_map[dest_module] = '\n'.join(sorted(imports_list))
+
+ # Expose exported symbols with underscores in root module
+ # since we import from it using * import.
+ underscore_names_str = ', '.join(
+ '\'%s\'' % name for name in self._underscore_names_in_root)
+ module_text_map[''] += '''
+_names_with_underscore = [%s]
+__all__ = [s for s in dir() if not s.startswith('_')]
+__all__.extend([s for s in _names_with_underscore])
+''' % underscore_names_str
+
+ return module_text_map
+
+
+def get_api_init_text():
+ """Get a map from destination module to __init__.py code for that module.
Returns:
A dictionary where
key: (string) destination module (for e.g. tf or tf.consts).
- value: List of strings representing module imports
- (for e.g. 'from foo import bar') and constant
- assignments (for e.g. 'FOO = 123').
+ value: (string) text that should be in __init__.py files for
+ corresponding modules.
"""
- module_imports_builder = _ModuleImportsBuilder()
- visited_symbols = set()
+ module_code_builder = _ModuleInitCodeBuilder()
# Traverse over everything imported above. Specifically,
# we want to traverse over TensorFlow Python modules.
@@ -131,8 +169,6 @@ def get_api_imports():
for module_contents_name in dir(module):
attr = getattr(module, module_contents_name)
- if id(attr) in visited_symbols:
- continue
# If attr is _tf_api_constants attribute, then add the constants.
if module_contents_name == _API_CONSTANTS_ATTR:
@@ -140,30 +176,25 @@ def get_api_imports():
for export in exports:
names = export.split('.')
dest_module = '.'.join(names[:-1])
- module_imports_builder.add_import(
- dest_module, module.__name__, value, names[-1])
+ module_code_builder.add_import(
+ -1, dest_module, module.__name__, value, names[-1])
continue
_, attr = tf_decorator.unwrap(attr)
# If attr is a symbol with _tf_api_names attribute, then
# add import for it.
if hasattr(attr, '__dict__') and _API_NAMES_ATTR in attr.__dict__:
- # If the same symbol is available using multiple names, only create
- # imports for it once.
- if id(attr) in visited_symbols:
- continue
- visited_symbols.add(id(attr))
-
for export in attr._tf_api_names: # pylint: disable=protected-access
names = export.split('.')
dest_module = '.'.join(names[:-1])
- module_imports_builder.add_import(
- dest_module, module.__name__, module_contents_name, names[-1])
+ module_code_builder.add_import(
+ id(attr), dest_module, module.__name__, module_contents_name,
+ names[-1])
# Import all required modules in their parent modules.
# For e.g. if we import 'foo.bar.Value'. Then, we also
# import 'bar' in 'foo'.
- imported_modules = set(module_imports_builder.module_imports.keys())
+ imported_modules = set(module_code_builder.module_imports.keys())
for module in imported_modules:
if not module:
continue
@@ -176,11 +207,11 @@ def get_api_imports():
parent_module += ('.' + module_split[submodule_index-1] if parent_module
else module_split[submodule_index-1])
import_from += '.' + parent_module
- module_imports_builder.add_import(
- parent_module, import_from, module_split[submodule_index],
- module_split[submodule_index])
+ module_code_builder.add_import(
+ -1, parent_module, import_from,
+ module_split[submodule_index], module_split[submodule_index])
- return module_imports_builder.module_imports
+ return module_code_builder.build()
def create_api_files(output_files):
@@ -196,16 +227,19 @@ def create_api_files(output_files):
"""
module_name_to_file_path = {}
for output_file in output_files:
+ # Convert path separators to '/' for easier parsing below.
+ normalized_output_file = output_file.replace(os.sep, '/')
if _API_DIR not in output_file:
raise ValueError(
'Output files must be in api/ directory, found %s.' % output_file)
# Get the module name that corresponds to output_file.
# First get module directory under _API_DIR.
module_dir = os.path.dirname(
- output_file[output_file.rfind(_API_DIR)+len(_API_DIR):])
+ normalized_output_file[
+ normalized_output_file.rfind(_API_DIR)+len(_API_DIR):])
# Convert / to .
module_name = module_dir.replace('/', '.').strip('.')
- module_name_to_file_path[module_name] = output_file
+ module_name_to_file_path[module_name] = os.path.normpath(output_file)
# Create file for each expected output in genrule.
for module, file_path in module_name_to_file_path.items():
@@ -213,11 +247,11 @@ def create_api_files(output_files):
os.makedirs(os.path.dirname(file_path))
open(file_path, 'a').close()
- module_imports = get_api_imports()
+ module_text_map = get_api_init_text()
# Add imports to output files.
missing_output_files = []
- for module, exports in module_imports.items():
+ for module, text in module_text_map.items():
# Make sure genrule output file list is in sync with API exports.
if module not in module_name_to_file_path:
module_file_path = '"api/%s/__init__.py"' % (
@@ -225,7 +259,7 @@ def create_api_files(output_files):
missing_output_files.append(module_file_path)
continue
with open(module_name_to_file_path[module], 'w') as fp:
- fp.write(_GENERATED_FILE_HEADER + '\n'.join(exports))
+ fp.write(_GENERATED_FILE_HEADER + text)
if missing_output_files:
raise ValueError(
@@ -242,6 +276,16 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'outputs', metavar='O', type=str, nargs='+',
- help='Python files that we expect this script to output.')
+ help='If a single file is passed in, then we we assume it contains a '
+ 'semicolon-separated list of Python files that we expect this script to '
+ 'output. If multiple files are passed in, then we assume output files '
+ 'are listed directly as arguments.')
args = parser.parse_args()
- main(args.outputs)
+ if len(args.outputs) == 1:
+ # If we only get a single argument, then it must be a file containing
+ # list of outputs.
+ with open(args.outputs[0]) as output_list_file:
+ outputs = [line.strip() for line in output_list_file.read().split(';')]
+ else:
+ outputs = args.outputs
+ main(outputs)
diff --git a/tensorflow/tools/api/generator/create_python_api_test.py b/tensorflow/tools/api/generator/create_python_api_test.py
index 2760779e6e..218c812045 100644
--- a/tensorflow/tools/api/generator/create_python_api_test.py
+++ b/tensorflow/tools/api/generator/create_python_api_test.py
@@ -56,7 +56,7 @@ class CreatePythonApiTest(test.TestCase):
del sys.modules[_MODULE_NAME]
def testFunctionImportIsAdded(self):
- imports = create_python_api.get_api_imports()
+ imports = create_python_api.get_api_init_text()
expected_import = (
'from test.tensorflow.test_module import test_op as test_op1')
self.assertTrue(
@@ -69,14 +69,14 @@ class CreatePythonApiTest(test.TestCase):
msg='%s not in %s' % (expected_import, str(imports)))
def testClassImportIsAdded(self):
- imports = create_python_api.get_api_imports()
+ imports = create_python_api.get_api_init_text()
expected_import = 'from test.tensorflow.test_module import TestClass'
self.assertTrue(
'TestClass' in str(imports),
msg='%s not in %s' % (expected_import, str(imports)))
def testConstantIsAdded(self):
- imports = create_python_api.get_api_imports()
+ imports = create_python_api.get_api_init_text()
expected = 'from test.tensorflow.test_module import _TEST_CONSTANT'
self.assertTrue(expected in str(imports),
msg='%s not in %s' % (expected, str(imports)))
diff --git a/tensorflow/tools/api/tests/BUILD b/tensorflow/tools/api/tests/BUILD
index 0dc154b6d2..724b12cd47 100644
--- a/tensorflow/tools/api/tests/BUILD
+++ b/tensorflow/tools/api/tests/BUILD
@@ -23,7 +23,6 @@ py_test(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow:experimental_tensorflow_py",
"//tensorflow:tensorflow_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:lib",
diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py
index 7eeae05847..1ad6b6d1c0 100644
--- a/tensorflow/tools/api/tests/api_compatibility_test.py
+++ b/tensorflow/tools/api/tests/api_compatibility_test.py
@@ -34,7 +34,6 @@ import sys
import unittest
import tensorflow as tf
-from tensorflow import experimental_api as api
from google.protobuf import text_format
@@ -47,8 +46,6 @@ from tensorflow.tools.api.lib import python_object_to_proto_visitor
from tensorflow.tools.common import public_api
from tensorflow.tools.common import traverse
-if hasattr(tf, 'experimental_api'):
- del tf.experimental_api
# FLAGS defined at the bottom:
FLAGS = None
@@ -145,9 +142,6 @@ class ApiCompatibilityTest(test.TestCase):
verbose_diff_message = ''
# First check if the key is not found in one or the other.
if key in only_in_expected:
- # TODO(annarev): remove once we switch to using tf_export decorators.
- if key == 'tensorflow.math':
- continue
diff_message = 'Object %s expected but not found (removed). %s' % (
key, additional_missing_object_message)
verbose_diff_message = diff_message
@@ -232,13 +226,6 @@ class ApiCompatibilityTest(test.TestCase):
for filename in golden_file_list
}
- # TODO(annarev): remove once we switch to using tf_export decorators.
- tf_module = golden_proto_dict['tensorflow'].tf_module
- for i in range(len(tf_module.member)):
- if tf_module.member[i].name == 'math':
- del tf_module.member[i]
- break
-
# Diff them. Do not fail if called with update.
# If the test is run to update goldens, only report diffs but do not fail.
self._AssertProtoDictEquals(
@@ -247,49 +234,6 @@ class ApiCompatibilityTest(test.TestCase):
verbose=FLAGS.verbose_diffs,
update_goldens=FLAGS.update_goldens)
- @unittest.skipUnless(
- sys.version_info.major == 2,
- 'API compabitility test goldens are generated using python2.')
- def testNewAPIBackwardsCompatibility(self):
- # Extract all API stuff.
- visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
-
- public_api_visitor = public_api.PublicAPIVisitor(visitor)
- public_api_visitor.do_not_descend_map['tf'].append('contrib')
- public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
- # TODO(annarev): Make slide_dataset available in API.
- public_api_visitor.private_map['tf'] = ['slide_dataset']
- traverse.traverse(api, public_api_visitor)
-
- proto_dict = visitor.GetProtos()
-
- # Read all golden files.
- expression = os.path.join(
- resource_loader.get_root_dir_with_all_resources(),
- _KeyToFilePath('*'))
- golden_file_list = file_io.get_matching_files(expression)
-
- def _ReadFileToProto(filename):
- """Read a filename, create a protobuf from its contents."""
- ret_val = api_objects_pb2.TFAPIObject()
- text_format.Merge(file_io.read_file_to_string(filename), ret_val)
- return ret_val
-
- golden_proto_dict = {
- _FileNameToKey(filename): _ReadFileToProto(filename)
- for filename in golden_file_list
- }
-
- # Diff them. Do not fail if called with update.
- # If the test is run to update goldens, only report diffs but do not fail.
- self._AssertProtoDictEquals(
- golden_proto_dict,
- proto_dict,
- verbose=FLAGS.verbose_diffs,
- update_goldens=False,
- additional_missing_object_message=
- 'Check if tf_export decorator/call is missing for this symbol.')
-
if __name__ == '__main__':
parser = argparse.ArgumentParser()
diff --git a/tensorflow/tools/ci_build/windows/cpu/cmake/run_py.bat b/tensorflow/tools/ci_build/windows/cpu/cmake/run_py.bat
index 3c3b223a00..30554a084c 100644
--- a/tensorflow/tools/ci_build/windows/cpu/cmake/run_py.bat
+++ b/tensorflow/tools/ci_build/windows/cpu/cmake/run_py.bat
@@ -28,6 +28,9 @@ IF DEFINED TF_NIGHTLY (ECHO TF_NIGHTLY is set to %TF_NIGHTLY%) ELSE (SET TF_NIGH
:: Set pip binary location. Do not override if it is set already.
IF DEFINED PIP_EXE (ECHO PIP_EXE is set to %PIP_EXE%) ELSE (SET PIP_EXE="C:\Program Files\Anaconda3\Scripts\pip.exe")
+:: Install absl-py.
+%PIP_EXE% install --upgrade absl-py
+
:: Run the CMAKE build to build the pip package.
CALL %REPO_ROOT%\tensorflow\tools\ci_build\windows\cpu\cmake\run_build.bat
if %errorlevel% neq 0 exit /b %errorlevel%
@@ -37,9 +40,6 @@ DIR %REPO_ROOT%\%BUILD_DIR%\tf_python\dist\ /S /B > wheel_filename_file
set /p WHEEL_FILENAME=<wheel_filename_file
del wheel_filename_file
-:: Install absl-py.
-%PIP_EXE% install --upgrade absl-py
-
:: Install the pip package.
echo Installing PIP package...
%PIP_EXE% install --upgrade --no-deps %WHEEL_FILENAME% -v -v
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index ed941c3bc2..6511a50b3b 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -35,7 +35,6 @@ REQUIRED_PACKAGES = [
'absl-py >= 0.1.6',
'astor >= 0.6.0',
'gast >= 0.2.0',
- 'grpcio >= 1.8.6',
'numpy >= 1.13.3',
'six >= 1.10.0',
'protobuf >= 3.4.0',
@@ -43,6 +42,12 @@ REQUIRED_PACKAGES = [
'termcolor >= 1.1.0',
]
+if sys.byteorder == 'little':
+ # grpcio does not build correctly on big-endian machines due to lack of
+ # BoringSSL support.
+ # See https://github.com/tensorflow/tensorflow/issues/17882.
+ REQUIRED_PACKAGES.append('grpcio >= 1.8.6')
+
project_name = 'tensorflow'
if '--project_name' in sys.argv:
project_name_idx = sys.argv.index('--project_name')