aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-06-01 12:58:16 -0700
committerGravatar GitHub <noreply@github.com>2018-06-01 12:58:16 -0700
commit744cf3d3e06fb63ffa40086766137daedc01a5ba (patch)
treead87aeeaae9cce01bfdf017b6650aea23f9a2817
parenteebb9e0449b38703869ae7ccd0aa2c649f9f5aaf (diff)
parent8f79ab773fe44e4779138a77a3bda4b18245d658 (diff)
Merge pull request #19680 from case540/branch_198811639
Branch 198811639
-rw-r--r--tensorflow/BUILD17
-rw-r--r--tensorflow/__init__.py3
-rw-r--r--tensorflow/api_template.__init__.py43
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc8
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc2
-rw-r--r--tensorflow/compiler/jit/xla_compile_on_demand_op.cc3
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla.cc3
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc73
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h9
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc78
-rw-r--r--tensorflow/compiler/xla/BUILD31
-rw-r--r--tensorflow/compiler/xla/client/BUILD1
-rw-r--r--tensorflow/compiler/xla/client/client.h2
-rw-r--r--tensorflow/compiler/xla/client/executable_build_options.h9
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc5
-rw-r--r--tensorflow/compiler/xla/client/local_client.h5
-rw-r--r--tensorflow/compiler/xla/client/xla_client/BUILD1
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_service.cc88
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_service.h47
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_stub.cc93
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_stub.h39
-rw-r--r--tensorflow/compiler/xla/rpc/xla_service.proto60
-rw-r--r--tensorflow/compiler/xla/scanner.cc197
-rw-r--r--tensorflow/compiler/xla/scanner.h102
-rw-r--r--tensorflow/compiler/xla/scanner_test.cc124
-rw-r--r--tensorflow/compiler/xla/service/BUILD70
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc126
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc26
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.cc286
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/channel_tracker.h1
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.cc53
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.h33
-rw-r--r--tensorflow/compiler/xla/service/compiler.cc5
-rw-r--r--tensorflow/compiler/xla/service/compiler.h6
-rw-r--r--tensorflow/compiler/xla/service/computation_tracker.cc256
-rw-r--r--tensorflow/compiler/xla/service/computation_tracker.h147
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD10
-rw-r--r--tensorflow/compiler/xla/service/gpu/backend_configs.proto27
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc14
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc15
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc21
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc136
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h46
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc66
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.h3
-rw-r--r--tensorflow/compiler/xla/service/local_service.cc83
-rw-r--r--tensorflow/compiler/xla/service/local_service.h17
-rw-r--r--tensorflow/compiler/xla/service/service.cc909
-rw-r--r--tensorflow/compiler/xla/service/service.h123
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc3557
-rw-r--r--tensorflow/compiler/xla/service/user_computation.h413
-rw-r--r--tensorflow/compiler/xla/service/user_computation_test.cc340
-rw-r--r--tensorflow/compiler/xla/service/while_util.cc10
-rw-r--r--tensorflow/compiler/xla/service/while_util.h12
-rw-r--r--tensorflow/compiler/xla/service/while_util_test.cc43
-rw-r--r--tensorflow/compiler/xla/service_interface.h41
-rw-r--r--tensorflow/compiler/xla/tools/BUILD1
-rw-r--r--tensorflow/compiler/xla/tools/dumped_computation_to_text.cc1
-rw-r--r--tensorflow/compiler/xla/tools/parser/BUILD1
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser.cc65
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser.h11
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc23
-rw-r--r--tensorflow/compiler/xla/tools/replay_computation.cc104
-rw-r--r--tensorflow/compiler/xla/util.h7
-rw-r--r--tensorflow/contrib/autograph/converters/break_statements.py16
-rw-r--r--tensorflow/contrib/autograph/converters/continue_statements.py174
-rw-r--r--tensorflow/contrib/autograph/operators/BUILD12
-rw-r--r--tensorflow/contrib/autograph/operators/__init__.py13
-rw-r--r--tensorflow/contrib/autograph/operators/data_structures.py249
-rw-r--r--tensorflow/contrib/autograph/operators/data_structures_test.py87
-rw-r--r--tensorflow/contrib/autograph/operators/slices.py133
-rw-r--r--tensorflow/contrib/autograph/operators/slices_test.py51
-rw-r--r--tensorflow/contrib/autograph/pyct/transformer.py148
-rw-r--r--tensorflow/contrib/autograph/pyct/transformer_test.py42
-rw-r--r--tensorflow/contrib/cloud/python/ops/gcs_config_ops.py12
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake18
-rw-r--r--tensorflow/contrib/cmake/tf_tests.cmake4
-rw-r--r--tensorflow/contrib/data/kernels/csv_dataset_op.cc4
-rw-r--r--tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc26
-rw-r--r--tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc13
-rw-r--r--tensorflow/contrib/data/kernels/threadpool_dataset_op.cc13
-rw-r--r--tensorflow/contrib/data/kernels/unique_dataset_op.cc11
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py8
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb54
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb551
-rw-r--r--tensorflow/contrib/factorization/python/ops/factorization_ops.py129
-rw-r--r--tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc4
-rw-r--r--tensorflow/contrib/learn/BUILD1
-rw-r--r--tensorflow/contrib/lite/build_def.bzl1
-rw-r--r--tensorflow/contrib/lite/builtin_op_data.h4
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h1
-rw-r--r--tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md15
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD14
-rw-r--r--tensorflow/contrib/lite/kernels/basic_rnn.cc41
-rw-r--r--tensorflow/contrib/lite/kernels/internal/common.h6
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.cc7
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.h6
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h68
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h156
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h100
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/sparse_to_dense.cc275
-rw-r--r--tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc155
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc41
-rw-r--r--tensorflow/contrib/lite/model.cc10
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc1
-rw-r--r--tensorflow/contrib/lite/op_resolver.h4
-rw-r--r--tensorflow/contrib/lite/profiling/BUILD7
-rw-r--r--tensorflow/contrib/lite/profiling/profile_buffer.h12
-rw-r--r--tensorflow/contrib/lite/profiling/time.cc29
-rw-r--r--tensorflow/contrib/lite/profiling/time.h27
-rw-r--r--tensorflow/contrib/lite/python/BUILD19
-rw-r--r--tensorflow/contrib/lite/python/convert.py63
-rw-r--r--tensorflow/contrib/lite/python/convert_saved_model.py118
-rw-r--r--tensorflow/contrib/lite/python/convert_saved_model_test.py55
-rw-r--r--tensorflow/contrib/lite/python/lite.py226
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py241
-rw-r--r--tensorflow/contrib/lite/python/tflite_convert.py324
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs6
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h141
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py77
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc19
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/python_api.md49
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc10
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc32
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc20
-rw-r--r--tensorflow/contrib/lite/toco/model.h14
-rw-r--r--tensorflow/contrib/lite/toco/python/BUILD6
-rw-r--r--tensorflow/contrib/lite/toco/python/toco_wrapper.py40
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.h21
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc23
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc9
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc1
-rw-r--r--tensorflow/contrib/lite/tools/BUILD75
-rw-r--r--tensorflow/contrib/lite/tools/benchmark_main.cc37
-rw-r--r--tensorflow/contrib/lite/tools/benchmark_model.cc518
-rw-r--r--tensorflow/contrib/lite/tools/benchmark_model.h161
-rw-r--r--tensorflow/contrib/lite/tools/benchmark_tflite_model.cc352
-rw-r--r--tensorflow/contrib/lite/tools/benchmark_tflite_model.h90
-rw-r--r--tensorflow/contrib/lite/tools/command_line_flags.cc189
-rw-r--r--tensorflow/contrib/lite/tools/command_line_flags.h112
-rw-r--r--tensorflow/contrib/lite/tools/command_line_flags_test.cc153
-rw-r--r--tensorflow/contrib/lite/tools/logging.h75
-rw-r--r--tensorflow/contrib/lite/util.cc10
-rw-r--r--tensorflow/contrib/lite/util.h2
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment_test.cc4
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py4
-rw-r--r--tensorflow/core/BUILD59
-rw-r--r--tensorflow/core/common_runtime/function_test.cc2
-rw-r--r--tensorflow/core/framework/dataset.h18
-rw-r--r--tensorflow/core/framework/variable.proto3
-rw-r--r--tensorflow/core/graph/algorithm_test.cc4
-rw-r--r--tensorflow/core/graph/graph_constructor.cc15
-rw-r--r--tensorflow/core/graph/graph_partition_test.cc16
-rw-r--r--tensorflow/core/graph/optimizer_cse_test.cc32
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc23
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc301
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h2
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc165
-rw-r--r--tensorflow/core/kernels/batching_util/BUILD21
-rw-r--r--tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h548
-rw-r--r--tensorflow/core/kernels/batching_util/serial_device_batch_scheduler_test.cc394
-rw-r--r--tensorflow/core/kernels/data/batch_dataset_op.cc11
-rw-r--r--tensorflow/core/kernels/data/cache_dataset_ops.cc22
-rw-r--r--tensorflow/core/kernels/data/concatenate_dataset_op.cc24
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.cc5
-rw-r--r--tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc12
-rw-r--r--tensorflow/core/kernels/data/filter_dataset_op.cc11
-rw-r--r--tensorflow/core/kernels/data/flat_map_dataset_op.cc14
-rw-r--r--tensorflow/core/kernels/data/generator_dataset_op.cc6
-rw-r--r--tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc13
-rw-r--r--tensorflow/core/kernels/data/group_by_window_dataset_op.cc17
-rw-r--r--tensorflow/core/kernels/data/interleave_dataset_op.cc13
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc26
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc13
-rw-r--r--tensorflow/core/kernels/data/map_dataset_op.cc13
-rw-r--r--tensorflow/core/kernels/data/padded_batch_dataset_op.cc14
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc9
-rw-r--r--tensorflow/core/kernels/data/parallel_map_dataset_op.cc11
-rw-r--r--tensorflow/core/kernels/data/prefetch_autotuner_test.cc2
-rw-r--r--tensorflow/core/kernels/data/prefetch_dataset_op.cc11
-rw-r--r--tensorflow/core/kernels/data/random_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/range_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/reader_dataset_ops.cc12
-rw-r--r--tensorflow/core/kernels/data/repeat_dataset_op.cc21
-rw-r--r--tensorflow/core/kernels/data/scan_dataset_op.cc11
-rw-r--r--tensorflow/core/kernels/data/shuffle_dataset_op.cc38
-rw-r--r--tensorflow/core/kernels/data/skip_dataset_op.cc15
-rw-r--r--tensorflow/core/kernels/data/slide_dataset_op.cc29
-rw-r--r--tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/sql_dataset_ops.cc4
-rw-r--r--tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc11
-rw-r--r--tensorflow/core/kernels/data/stats_dataset_ops.cc24
-rw-r--r--tensorflow/core/kernels/data/take_dataset_op.cc19
-rw-r--r--tensorflow/core/kernels/data/tensor_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/tensor_queue_dataset_op.cc25
-rw-r--r--tensorflow/core/kernels/data/tensor_slice_dataset_op.cc6
-rw-r--r--tensorflow/core/kernels/data/unbatch_dataset_op.cc9
-rw-r--r--tensorflow/core/kernels/data/window_dataset.cc4
-rw-r--r--tensorflow/core/kernels/data/writer_ops.cc8
-rw-r--r--tensorflow/core/kernels/data/zip_dataset_op.cc19
-rw-r--r--tensorflow/core/kernels/inplace_ops.cc1
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc70
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system_test.cc46
-rw-r--r--tensorflow/core/platform/cloud/ram_file_block_cache.h2
-rw-r--r--tensorflow/core/platform/default/build_config.bzl3
-rw-r--r--tensorflow/core/platform/default/human_readable_json.cc54
-rw-r--r--tensorflow/core/platform/human_readable_json.h37
-rw-r--r--tensorflow/docs_src/extend/new_data_formats.md6
-rw-r--r--tensorflow/docs_src/performance/benchmarks.md2
-rw-r--r--tensorflow/docs_src/performance/index.md39
-rw-r--r--tensorflow/docs_src/performance/leftnav_files1
-rw-r--r--tensorflow/python/BUILD34
-rw-r--r--tensorflow/python/eager/function.py2
-rw-r--r--tensorflow/python/eager/graph_callable.py2
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc4
-rw-r--r--tensorflow/python/estimator/BUILD20
-rw-r--r--tensorflow/python/estimator/estimator.py91
-rw-r--r--tensorflow/python/estimator/estimator_test.py33
-rw-r--r--tensorflow/python/estimator/export/export.py36
-rw-r--r--tensorflow/python/estimator/export/export_test.py35
-rw-r--r--tensorflow/python/estimator/util.py57
-rw-r--r--tensorflow/python/estimator/util_test.py102
-rw-r--r--tensorflow/python/framework/function.py8
-rw-r--r--tensorflow/python/framework/function_def_to_graph.py189
-rw-r--r--tensorflow/python/framework/function_def_to_graph_test.py184
-rw-r--r--tensorflow/python/framework/ops.py8
-rw-r--r--tensorflow/python/grappler/layout_optimizer_test.py4
-rw-r--r--tensorflow/python/keras/engine/network.py52
-rw-r--r--tensorflow/python/keras/layers/normalization.py26
-rw-r--r--tensorflow/python/keras/model_subclassing_test.py45
-rw-r--r--tensorflow/python/keras/models_test.py21
-rw-r--r--tensorflow/python/keras/optimizers.py3
-rw-r--r--tensorflow/python/keras/utils/layer_utils.py55
-rw-r--r--tensorflow/python/kernel_tests/BUILD58
-rw-r--r--tensorflow/python/kernel_tests/ackermann_op.cc (renamed from tensorflow/user_ops/ackermann_op.cc)0
-rw-r--r--tensorflow/python/kernel_tests/ackermann_test.py (renamed from tensorflow/user_ops/ackermann_test.py)14
-rw-r--r--tensorflow/python/kernel_tests/duplicate_op.cc (renamed from tensorflow/user_ops/duplicate_op.cc)0
-rw-r--r--tensorflow/python/kernel_tests/duplicate_op_test.py (renamed from tensorflow/user_ops/duplicate_op_test.py)17
-rw-r--r--tensorflow/python/kernel_tests/inplace_ops_test.py12
-rw-r--r--tensorflow/python/kernel_tests/invalid_op.cc (renamed from tensorflow/user_ops/invalid_op.cc)0
-rw-r--r--tensorflow/python/kernel_tests/invalid_op_test.py (renamed from tensorflow/user_ops/invalid_op_test.py)17
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py26
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py17
-rw-r--r--tensorflow/python/ops/functional_ops.py3
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py31
-rw-r--r--tensorflow/python/ops/variable_scope.py6
-rw-r--r--tensorflow/python/ops/variables.py7
-rw-r--r--tensorflow/python/training/checkpointable/data_structures.py36
-rw-r--r--tensorflow/python/training/checkpointable/data_structures_test.py19
-rw-r--r--tensorflow/python/training/moving_averages.py13
-rw-r--r--tensorflow/python/training/moving_averages_test.py1
-rw-r--r--tensorflow/python/util/stat_summarizer.i5
-rw-r--r--tensorflow/security/advisory/tfsa-2018-001.md4
-rw-r--r--tensorflow/security/advisory/tfsa-2018-002.md2
-rw-r--r--tensorflow/security/advisory/tfsa-2018-003.md4
-rw-r--r--tensorflow/security/advisory/tfsa-2018-004.md2
-rw-r--r--tensorflow/security/advisory/tfsa-2018-005.md2
-rw-r--r--tensorflow/security/advisory/tfsa-2018-006.md2
-rw-r--r--tensorflow/security/index.md12
-rw-r--r--tensorflow/stream_executor/dnn.cc2
-rw-r--r--tensorflow/tools/api/generator/BUILD116
-rw-r--r--tensorflow/tools/api/generator/api_gen.bzl125
-rw-r--r--tensorflow/tools/api/generator/create_python_api.py85
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-variable.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-exponential-moving-average.pbtxt4
-rw-r--r--tensorflow/tools/api/lib/api_objects.proto4
-rw-r--r--tensorflow/tools/pip_package/BUILD4
-rwxr-xr-xtensorflow/tools/pip_package/build_pip_package.sh4
-rw-r--r--tensorflow/tools/pip_package/setup.py3
-rw-r--r--tensorflow/user_ops/BUILD52
-rw-r--r--tensorflow/workspace.bzl8
277 files changed, 9208 insertions, 9133 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index f4351f9dce..9b07669a5d 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -19,6 +19,10 @@ load(
"//tensorflow/core:platform/default/build_config.bzl",
"tf_additional_binary_deps",
)
+load(
+ "//tensorflow/tools/api/generator:api_gen.bzl",
+ "gen_api_init_files", # @unused
+)
# Config setting for determining if we are building for Android.
config_setting(
@@ -534,13 +538,16 @@ exports_files(
],
)
+gen_api_init_files(
+ name = "python_api_gen",
+ srcs = ["api_template.__init__.py"],
+ root_init_template = "api_template.__init__.py",
+)
+
py_library(
name = "tensorflow_py",
- srcs = ["__init__.py"],
+ srcs = [":python_api_gen"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
- deps = [
- "//tensorflow/python",
- "//tensorflow/tools/api/generator:python_api",
- ],
+ deps = ["//tensorflow/python"],
)
diff --git a/tensorflow/__init__.py b/tensorflow/__init__.py
index c8683e3976..440e9f8dbd 100644
--- a/tensorflow/__init__.py
+++ b/tensorflow/__init__.py
@@ -22,9 +22,6 @@ from __future__ import print_function
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
-# pylint: disable=wildcard-import
-from tensorflow.tools.api.generator.api import * # pylint: disable=redefined-builtin
-# pylint: enable=wildcard-import
from tensorflow.python.util.lazy_loader import LazyLoader
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py
new file mode 100644
index 0000000000..9b0d7d48af
--- /dev/null
+++ b/tensorflow/api_template.__init__.py
@@ -0,0 +1,43 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Bring in all of the public TensorFlow interface into this module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=g-bad-import-order
+from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
+# API IMPORTS PLACEHOLDER
+
+from tensorflow.python.util.lazy_loader import LazyLoader
+contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
+del LazyLoader
+
+from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
+app.flags = flags # pylint: disable=undefined-variable
+
+del absolute_import
+del division
+del print_function
+
+# These symbols appear because we import the python package which
+# in turn imports from tensorflow.core and tensorflow.python. They
+# must come from this module. So python adds these symbols for the
+# resolution to succeed.
+# pylint: disable=undefined-variable
+del python
+del core
+# pylint: enable=undefined-variable
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
index 5ec24d39a2..eef113a354 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
@@ -1050,7 +1050,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
.WithAttr("_outside", "O1"));
Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2",
{DT_FLOAT, DT_FLOAT}, shape2.opts());
- Node* h = Binary(ops::NodeOut(recv2, 0), e,
+ Node* h = Binary(ops::NodeOut(recv2, 1), e,
shape2.opts()
.WithName("H")
.WithAttr("_encapsulate", "F1")
@@ -1075,7 +1075,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O2_host_compute"},
"XlaHostCompute",
- {"D:o:0", "F:o:0"},
+ {"F:o:0", "D:o:0"},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"ancestors",
@@ -1123,13 +1123,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2",
{DT_FLOAT, DT_FLOAT}, b2.opts());
- Node* g = Binary(e, ops::NodeOut(recv2, 1),
+ Node* g = Binary(e, ops::NodeOut(recv2, 0),
b2.opts()
.WithName("G")
.WithControlInputs({recv2, e})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O2"));
- Node* h = Binary(ops::NodeOut(recv2, 0), e,
+ Node* h = Binary(ops::NodeOut(recv2, 1), e,
b2.opts()
.WithName("H")
.WithAttr("_encapsulate", "F1")
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index 27287e0f96..902fe27acd 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -148,7 +148,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
XlaCompiler::Options options;
options.client = client;
- options.device_type = &cache->device_type();
+ options.device_type = cache->device_type();
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
options.graph_def_version = ctx->function_library()->graph_def_version();
options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId);
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index ab644ff5a6..b1943d3e1a 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -151,8 +151,7 @@ Status XlaCompileOnDemandOp::Compile(
core::ScopedUnref cache_ref(cache);
XlaCompiler::Options options;
- DeviceType device_type = metadata.jit_device_type();
- options.device_type = &device_type;
+ options.device_type = metadata.jit_device_type();
options.client = metadata.client();
options.flib_def =
new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{});
diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc
index 3a08aa8cf4..ac768b206e 100644
--- a/tensorflow/compiler/tf2xla/tf2xla.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla.cc
@@ -263,8 +263,7 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph, xla::Client* client,
// Compile the graph into an XLA computation.
XlaCompiler::Options compiler_options;
compiler_options.client = client;
- DeviceType device_type(DEVICE_CPU_XLA_JIT);
- compiler_options.device_type = &device_type;
+ compiler_options.device_type = DeviceType(DEVICE_CPU_XLA_JIT);
compiler_options.flib_def = &graph->flib_def();
compiler_options.graph_def_version = graph->versions().producer();
compiler_options.allow_cpu_custom_calls = true;
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index f7098917b1..a8bd199675 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -83,12 +83,9 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
: options_(options),
initialization_status_(Status::OK()),
next_step_id_(1),
- device_(
- new XlaCompilationDevice(SessionOptions(), *options_.device_type)),
+ device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
device_mgr_({device_}) {
- // We no longer need the device_type.
- options_.device_type = nullptr;
-
+ CHECK(!options_.device_type.type_string().empty());
if (options_.populate_resource_manager) {
initialization_status_ =
(*options_.populate_resource_manager)(device_->resource_manager());
@@ -228,7 +225,7 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
// Computes the XLA shape for argument 'arg'.
Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
bool is_entry_computation,
- xla::Shape* xla_shape) {
+ xla::Shape* xla_shape) const {
switch (arg.kind) {
case XlaCompiler::Argument::kConstant:
LOG(FATAL) << "Unreachable case";
@@ -659,6 +656,65 @@ Status XlaCompiler::CompileSingleOp(
return CompileGraph(options, name, std::move(graph), args, result);
}
+namespace {
+
+// Check that the ops of all non-functional nodes have been registered.
+string ValidateFunctionDef(const FunctionDef* fdef,
+ const FunctionLibraryDefinition& flib_def) {
+ std::vector<string> invalid_ops;
+ for (const NodeDef& node : fdef->node_def()) {
+ const string& op = node.op();
+ if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) {
+ continue;
+ }
+ const OpDef* op_def;
+ if (!OpRegistry::Global()->LookUpOpDef(op, &op_def).ok()) {
+ invalid_ops.push_back(op);
+ }
+ }
+ return tensorflow::str_util::Join(invalid_ops, ", ");
+}
+
+// Check that the graph doesn't have any invalid nodes (e.g. incompatible with
+// given device_type, invalid data type, missing attributes...)
+Status ValidateGraph(const Graph* graph,
+ const FunctionLibraryDefinition& flib_def,
+ const DeviceType& device_type, const string& name) {
+ std::vector<string> invalid_ops;
+ for (const Node* node : graph->nodes()) {
+ if (node->type_string() == FunctionLibraryDefinition::kGradientOp) {
+ continue;
+ }
+ const FunctionDef* fdef = flib_def.Find(node->def().op());
+ if (fdef) {
+ string error_msg = ValidateFunctionDef(fdef, flib_def);
+ if (!error_msg.empty()) {
+ invalid_ops.push_back(
+ strings::StrCat(node->def().op(), ":{", error_msg, "}"));
+ }
+ continue;
+ }
+ const OpDef* op_def;
+ if (!OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def).ok()) {
+ invalid_ops.push_back(node->def().op());
+ continue;
+ }
+ TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def));
+ if (!FindKernelDef(device_type, node->def(), nullptr, nullptr).ok()) {
+ invalid_ops.push_back(node->def().op());
+ }
+ }
+ if (!invalid_ops.empty()) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Detected unsupported operations when trying to compile graph ", name,
+ " on ", device_type.type_string(), ":",
+ tensorflow::str_util::Join(invalid_ops, ", ")));
+ }
+ return Status::OK();
+}
+
+} // namespace
+
Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
string const& name,
std::unique_ptr<Graph> graph,
@@ -681,6 +737,11 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(),
graph.get(), local_flib_def_.get()));
+ // Detect invalid nodes.
+ // FunctionalizeControlFlow may remove some nodes from the graph.
+ TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def,
+ options_.device_type, name));
+
xla::XlaBuilder builder(name);
XlaContext* context = new XlaContext(
this, &builder, options_.allow_cpu_custom_calls,
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index bf496bd8bc..c93850ce27 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
@@ -244,9 +245,9 @@ class XlaCompiler {
typedef std::function<TensorShape(const TensorShape&, DataType)>
ShapeRepresentationFn;
struct Options {
- // Name of the compilation device to use. Needs to be live only during
- // XlaCompiler's constructor.
- const DeviceType* device_type = nullptr;
+ // Name of the compilation device to use. It must be set by the caller.
+ // The default empty value is invalid.
+ DeviceType device_type = DeviceType("");
xla::Client* client = nullptr;
@@ -313,7 +314,7 @@ class XlaCompiler {
// See the class comment for more details about the argument passing
// convention.
Status XLAShapeForArgument(const Argument& arg, bool is_entry_computation,
- xla::Shape* xla_shape);
+ xla::Shape* xla_shape) const;
// Retrieves the channel handle associated with `key`. Allocates
// a new channel handle if none exists.
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 55772ca324..5fbf4b952c 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -45,8 +45,6 @@ namespace tensorflow {
class XlaCompilerTest : public ::testing::Test {
protected:
- XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {}
-
void SetUp() override {
client_ = xla::ClientLibrary::LocalClientOrDie();
@@ -58,7 +56,7 @@ class XlaCompilerTest : public ::testing::Test {
XlaCompiler::Options DefaultOptions() {
XlaCompiler::Options options;
- options.device_type = &cpu_device_type_;
+ options.device_type = DeviceType(DEVICE_CPU_XLA_JIT);
options.client = client_;
options.flib_def = flib_def_.get();
return options;
@@ -68,7 +66,6 @@ class XlaCompilerTest : public ::testing::Test {
return compiler->local_flib_def_.get();
}
- DeviceType cpu_device_type_;
xla::Client* client_;
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
};
@@ -979,5 +976,78 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
+// Tests a graph which has a function with an invalid op.
+TEST_F(XlaCompilerTest, FunctionWithInvalidOp) {
+ XlaCompiler compiler(DefaultOptions());
+
+ FunctionDefLibrary flib;
+ FunctionDef fn = FillFn();
+ NodeDef* node = fn.add_node_def();
+ node->set_name("Invalid");
+ node->set_op("InvalidOp"); /* unsupported op */
+ node = fn.add_node_def();
+ node->set_name("Switch");
+ node->set_op("Switch"); /* control flow node */
+ *flib.add_function() = fn;
+
+ TF_ASSERT_OK(flib_def_->AddFunctionDef(fn));
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto value = ops::Const<int32>(scope.WithOpName("value"), 1, {});
+ auto shape = ops::Const<int32>(scope.WithOpName("shape"), {5}, {1});
+ TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(flib));
+
+ NodeDef def;
+ TF_ASSERT_OK(NodeDefBuilder("fill_fn", "FillFn", flib_def_.get())
+ .Input(value.name(), 0, DT_INT32)
+ .Input(shape.name(), 1, DT_INT32)
+ .Finalize(&def));
+ Status status;
+ Node* fill = scope.graph()->AddNode(def, &status);
+ TF_ASSERT_OK(status);
+ TF_ASSERT_OK(scope.DoShapeInference(fill));
+ scope.graph()->AddEdge(value.node(), 0, fill, 0);
+ scope.graph()->AddEdge(shape.node(), 0, fill, 1);
+
+ auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0);
+
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+
+ std::vector<XlaCompiler::Argument> args;
+ XlaCompiler::CompilationResult result;
+ status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
+ std::move(graph), args, &result);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(
+ str_util::StrContains(status.error_message(), "FillFn:{InvalidOp}"))
+ << status.error_message();
+}
+
+// Tests a graph which has a node with invalid data type.
+TEST_F(XlaCompilerTest, NodeWithInvalidDataType) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ NodeDef shape;
+ shape.set_name("Shape");
+ shape.set_op("Shape");
+ (*shape.mutable_attr())["T"].set_type(DT_INT32);
+ (*shape.mutable_attr())["out_type"].set_type(DT_BOOL); /* invalid type */
+ Status status;
+ Node* shape_node = graph->AddNode(shape, &status);
+ TF_ASSERT_OK(status);
+ graph->AddControlEdge(graph->source_node(), shape_node);
+
+ std::vector<XlaCompiler::Argument> args;
+ XlaCompiler::CompilationResult result;
+ XlaCompiler compiler(DefaultOptions());
+ status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type",
+ std::move(graph), args, &result);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(str_util::StrContains(status.error_message(),
+ "is not in the list of allowed values"))
+ << status.error_message();
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index c08db7e3fb..c6deb959a5 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -500,37 +500,6 @@ cc_library(
)
cc_library(
- name = "scanner",
- srcs = ["scanner.cc"],
- hdrs = ["scanner.h"],
- visibility = [":internal"],
- deps = [
- ":status",
- ":status_macros",
- ":types",
- ":util",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- ],
-)
-
-tf_cc_test(
- name = "scanner_test",
- srcs = ["scanner_test.cc"],
- deps = [
- ":scanner",
- ":status",
- ":status_macros",
- ":test",
- ":types",
- ":util",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:test_main",
- ],
-)
-
-cc_library(
name = "text_literal_reader",
srcs = ["text_literal_reader.cc"],
hdrs = ["text_literal_reader.h"],
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index aacb394ae5..c4f0c4468f 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -86,6 +86,7 @@ cc_library(
hdrs = ["executable_build_options.h"],
deps = [
"//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h
index cda8a71f71..68f0d0ac78 100644
--- a/tensorflow/compiler/xla/client/client.h
+++ b/tensorflow/compiler/xla/client/client.h
@@ -153,8 +153,6 @@ class Client {
//
// If output_layout is non-null, then the output of the computation will be
// stored using that layout.
- //
- // TODO(b/74197823): This is a part of a NOT YET ready refactor.
StatusOr<std::unique_ptr<Literal>> ComputeConstant(
const XlaComputation& computation,
const Layout* output_layout = nullptr) const;
diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h
index 11f1098360..393da381fb 100644
--- a/tensorflow/compiler/xla/client/executable_build_options.h
+++ b/tensorflow/compiler/xla/client/executable_build_options.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/optional.h"
@@ -76,6 +77,13 @@ class ExecutableBuildOptions {
ExecutableBuildOptions& set_hlo_profile(bool enabled);
tensorflow::gtl::optional<bool> hlo_profile() const;
+ void add_disabled_hlo_pass(tensorflow::StringPiece pass_name) {
+ disabled_hlo_passes_.push_back(std::string(pass_name));
+ }
+ const tensorflow::gtl::ArraySlice<std::string> disabled_hlo_passes() const {
+ return disabled_hlo_passes_;
+ }
+
// Returns a string representation of the build options, suitable for
// debugging.
string ToString() const;
@@ -89,6 +97,7 @@ class ExecutableBuildOptions {
tensorflow::gtl::optional<string> dump_optimized_hlo_proto_to_;
tensorflow::gtl::optional<string> dump_per_pass_hlo_proto_to_;
DeviceMemoryAllocator* device_allocator_ = nullptr;
+ std::vector<std::string> disabled_hlo_passes_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index a7c55c6b2b..f9003373a6 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -304,6 +304,11 @@ StatusOr<std::unique_ptr<Literal>> LocalClient::ShapedBufferToLiteral(
shaped_buffer);
}
+StatusOr<const ShapedBuffer*> LocalClient::GlobalDataToShapedBuffer(
+ const GlobalDataHandle& data, int replica_number) {
+ return local_service_->GlobalDataToShapedBuffer(data, replica_number);
+}
+
Status LocalClient::TransferToInfeedLocal(const Literal& literal,
int device_ordinal) {
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h
index 3f23e52fc2..5b408cc6b2 100644
--- a/tensorflow/compiler/xla/client/local_client.h
+++ b/tensorflow/compiler/xla/client/local_client.h
@@ -136,6 +136,11 @@ class LocalClient : public Client {
StatusOr<std::unique_ptr<Literal>> ShapedBufferToLiteral(
const ShapedBuffer& shaped_buffer);
+ // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid
+ // as long as the handle is valid.
+ StatusOr<const ShapedBuffer*> GlobalDataToShapedBuffer(
+ const GlobalDataHandle& data, int replica_number);
+
// Transfer the given literal to the infeed queue of the given device.
// TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
// not inherit from Client and there is no possibility of confusion with
diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD
index 0d6e207971..507a2dc5f0 100644
--- a/tensorflow/compiler/xla/client/xla_client/BUILD
+++ b/tensorflow/compiler/xla/client/xla_client/BUILD
@@ -37,7 +37,6 @@ cc_library(
],
)
-# TODO(b/74197823): Replace computation_builder with xla_builder.
cc_library(
name = "xla_builder",
srcs = ["xla_builder.cc"],
diff --git a/tensorflow/compiler/xla/rpc/grpc_service.cc b/tensorflow/compiler/xla/rpc/grpc_service.cc
index 5f4dc6bd08..4e1435fa30 100644
--- a/tensorflow/compiler/xla/rpc/grpc_service.cc
+++ b/tensorflow/compiler/xla/rpc/grpc_service.cc
@@ -32,19 +32,6 @@ namespace xla {
return tensorflow::ToGrpcStatus(s);
}
-::grpc::Status GRPCService::Computation(::grpc::ServerContext* context,
- const ComputationRequest* arg,
- ComputationResponse* result) {
- return DelegateRPC(
- [this, arg, result]() { return service_->Computation(arg, result); });
-}
-
-::grpc::Status GRPCService::CreateOp(::grpc::ServerContext* context,
- const OpRequest* arg, OpResponse* result) {
- return DelegateRPC(
- [this, arg, result]() { return service_->Op(arg, result); });
-}
-
::grpc::Status GRPCService::Unregister(::grpc::ServerContext* context,
const UnregisterRequest* arg,
UnregisterResponse* result) {
@@ -60,21 +47,6 @@ namespace xla {
});
}
-::grpc::Status GRPCService::SetReturnValue(::grpc::ServerContext* context,
- const SetReturnValueRequest* arg,
- SetReturnValueResponse* results) {
- return DelegateRPC([this, arg, results]() {
- return service_->SetReturnValue(arg, results);
- });
-}
-
-::grpc::Status GRPCService::Execute(::grpc::ServerContext* context,
- const ExecuteRequest* arg,
- ExecuteResponse* result) {
- return DelegateRPC(
- [this, arg, result]() { return service_->Execute(arg, result); });
-}
-
::grpc::Status GRPCService::ExecuteGraph(::grpc::ServerContext* /*context*/,
const ExecuteGraphRequest* arg,
ExecuteResponse* result) {
@@ -82,13 +54,6 @@ namespace xla {
[this, arg, result]() { return service_->ExecuteGraph(arg, result); });
}
-::grpc::Status GRPCService::ExecuteAsync(::grpc::ServerContext* context,
- const ExecuteAsyncRequest* arg,
- ExecuteAsyncResponse* result) {
- return DelegateRPC(
- [this, arg, result]() { return service_->ExecuteAsync(arg, result); });
-}
-
::grpc::Status GRPCService::WaitForExecution(::grpc::ServerContext* context,
const WaitForExecutionRequest* arg,
WaitForExecutionResponse* result) {
@@ -136,20 +101,6 @@ namespace xla {
[this, arg, result]() { return service_->ResetDevice(arg, result); });
}
-::grpc::Status GRPCService::IsConstant(::grpc::ServerContext* context,
- const IsConstantRequest* arg,
- IsConstantResponse* result) {
- return DelegateRPC(
- [this, arg, result]() { return service_->IsConstant(arg, result); });
-}
-
-::grpc::Status GRPCService::ComputeConstant(::grpc::ServerContext* context,
- const ComputeConstantRequest* arg,
- ComputeConstantResponse* result) {
- return DelegateRPC(
- [this, arg, result]() { return service_->ComputeConstant(arg, result); });
-}
-
::grpc::Status GRPCService::GetShape(::grpc::ServerContext* context,
const GetShapeRequest* arg,
GetShapeResponse* result) {
@@ -157,43 +108,4 @@ namespace xla {
[this, arg, result]() { return service_->GetShape(arg, result); });
}
-::grpc::Status GRPCService::GetComputationShape(
- ::grpc::ServerContext* context, const GetComputationShapeRequest* arg,
- GetComputationShapeResponse* result) {
- return DelegateRPC([this, arg, result]() {
- return service_->GetComputationShape(arg, result);
- });
-}
-
-::grpc::Status GRPCService::GetLocalShape(::grpc::ServerContext* context,
- const GetLocalShapeRequest* arg,
- GetLocalShapeResponse* result) {
- return DelegateRPC(
- [this, arg, result]() { return service_->GetLocalShape(arg, result); });
-}
-
-::grpc::Status GRPCService::GetComputationStats(
- ::grpc::ServerContext* context, const ComputationStatsRequest* arg,
- ComputationStatsResponse* result) {
- return DelegateRPC([this, arg, result]() {
- return service_->GetComputationStats(arg, result);
- });
-}
-
-::grpc::Status GRPCService::SnapshotComputation(
- ::grpc::ServerContext* context, const SnapshotComputationRequest* arg,
- SnapshotComputationResponse* result) {
- return DelegateRPC([this, arg, result]() {
- return service_->SnapshotComputation(arg, result);
- });
-}
-
-::grpc::Status GRPCService::LoadComputationSnapshot(
- ::grpc::ServerContext* context, const LoadComputationSnapshotRequest* arg,
- LoadComputationSnapshotResponse* result) {
- return DelegateRPC([this, arg, result]() {
- return service_->LoadComputationSnapshot(arg, result);
- });
-}
-
} // namespace xla
diff --git a/tensorflow/compiler/xla/rpc/grpc_service.h b/tensorflow/compiler/xla/rpc/grpc_service.h
index 50f02796f2..5cd573167a 100644
--- a/tensorflow/compiler/xla/rpc/grpc_service.h
+++ b/tensorflow/compiler/xla/rpc/grpc_service.h
@@ -31,13 +31,6 @@ class GRPCService : public grpc::XlaService::Service {
static StatusOr<std::unique_ptr<GRPCService>> NewService(
se::Platform* platform = nullptr);
- ::grpc::Status Computation(::grpc::ServerContext* context,
- const ComputationRequest* arg,
- ComputationResponse* result) override;
-
- ::grpc::Status CreateOp(::grpc::ServerContext* context, const OpRequest* arg,
- OpResponse* result) override;
-
::grpc::Status Unregister(::grpc::ServerContext* context,
const UnregisterRequest* arg,
UnregisterResponse* result) override;
@@ -46,22 +39,10 @@ class GRPCService : public grpc::XlaService::Service {
const DeconstructTupleRequest* arg,
DeconstructTupleResponse* result) override;
- ::grpc::Status SetReturnValue(::grpc::ServerContext* context,
- const SetReturnValueRequest* arg,
- SetReturnValueResponse* results) override;
-
- ::grpc::Status Execute(::grpc::ServerContext* context,
- const ExecuteRequest* arg,
- ExecuteResponse* result) override;
-
::grpc::Status ExecuteGraph(::grpc::ServerContext* context,
const ExecuteGraphRequest* arg,
ExecuteResponse* result) override;
- ::grpc::Status ExecuteAsync(::grpc::ServerContext* context,
- const ExecuteAsyncRequest* arg,
- ExecuteAsyncResponse* result) override;
-
::grpc::Status WaitForExecution(::grpc::ServerContext* context,
const WaitForExecutionRequest* arg,
WaitForExecutionResponse* result) override;
@@ -86,38 +67,10 @@ class GRPCService : public grpc::XlaService::Service {
const ResetDeviceRequest* arg,
ResetDeviceResponse* result) override;
- ::grpc::Status IsConstant(::grpc::ServerContext* context,
- const IsConstantRequest* arg,
- IsConstantResponse* result) override;
-
- ::grpc::Status ComputeConstant(::grpc::ServerContext* context,
- const ComputeConstantRequest* arg,
- ComputeConstantResponse* result) override;
-
::grpc::Status GetShape(::grpc::ServerContext* context,
const GetShapeRequest* arg,
GetShapeResponse* result) override;
- ::grpc::Status GetComputationShape(
- ::grpc::ServerContext* context, const GetComputationShapeRequest* arg,
- GetComputationShapeResponse* result) override;
-
- ::grpc::Status GetLocalShape(::grpc::ServerContext* context,
- const GetLocalShapeRequest* arg,
- GetLocalShapeResponse* result) override;
-
- ::grpc::Status GetComputationStats(::grpc::ServerContext* context,
- const ComputationStatsRequest* arg,
- ComputationStatsResponse* result) override;
-
- ::grpc::Status SnapshotComputation(
- ::grpc::ServerContext* context, const SnapshotComputationRequest* arg,
- SnapshotComputationResponse* result) override;
-
- ::grpc::Status LoadComputationSnapshot(
- ::grpc::ServerContext* context, const LoadComputationSnapshotRequest* arg,
- LoadComputationSnapshotResponse* result) override;
-
private:
std::unique_ptr<::xla::Service> service_;
diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.cc b/tensorflow/compiler/xla/rpc/grpc_stub.cc
index 620ac6cec4..7b8ab158e1 100644
--- a/tensorflow/compiler/xla/rpc/grpc_stub.cc
+++ b/tensorflow/compiler/xla/rpc/grpc_stub.cc
@@ -62,21 +62,6 @@ Status GRPCStub::ResetDevice(const ResetDeviceRequest* request,
});
}
-Status GRPCStub::LoadComputationSnapshot(
- const LoadComputationSnapshotRequest* request,
- LoadComputationSnapshotResponse* response) {
- return MakeRPC([this, request, response](::grpc::ClientContext* context) {
- return grpc_stub_->LoadComputationSnapshot(context, *request, response);
- });
-}
-
-Status GRPCStub::Execute(const ExecuteRequest* request,
- ExecuteResponse* response) {
- return MakeRPC([this, request, response](::grpc::ClientContext* context) {
- return grpc_stub_->Execute(context, *request, response);
- });
-}
-
Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request,
ExecuteResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
@@ -84,13 +69,6 @@ Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request,
});
}
-Status GRPCStub::ExecuteParallel(const ExecuteParallelRequest* request,
- ExecuteParallelResponse* response) {
- return MakeRPC([this, request, response](::grpc::ClientContext* context) {
- return grpc_stub_->ExecuteParallel(context, *request, response);
- });
-}
-
Status GRPCStub::ExecuteGraphParallel(
const ExecuteGraphParallelRequest* request,
ExecuteParallelResponse* response) {
@@ -99,13 +77,6 @@ Status GRPCStub::ExecuteGraphParallel(
});
}
-Status GRPCStub::ExecuteAsync(const ExecuteAsyncRequest* request,
- ExecuteAsyncResponse* response) {
- return MakeRPC([this, request, response](::grpc::ClientContext* context) {
- return grpc_stub_->ExecuteAsync(context, *request, response);
- });
-}
-
Status GRPCStub::WaitForExecution(const WaitForExecutionRequest* request,
WaitForExecutionResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
@@ -120,13 +91,6 @@ Status GRPCStub::DeconstructTuple(const DeconstructTupleRequest* request,
});
}
-Status GRPCStub::GetComputationStats(const ComputationStatsRequest* request,
- ComputationStatsResponse* response) {
- return MakeRPC([this, request, response](::grpc::ClientContext* context) {
- return grpc_stub_->GetComputationStats(context, *request, response);
- });
-}
-
Status GRPCStub::GetComputationGraphStats(
const ComputationGraphStatsRequest* request,
ComputationStatsResponse* response) {
@@ -135,13 +99,6 @@ Status GRPCStub::GetComputationGraphStats(
});
}
-Status GRPCStub::GetComputationShape(const GetComputationShapeRequest* request,
- GetComputationShapeResponse* response) {
- return MakeRPC([this, request, response](::grpc::ClientContext* context) {
- return grpc_stub_->GetComputationShape(context, *request, response);
- });
-}
-
Status GRPCStub::GetShape(const GetShapeRequest* request,
GetShapeResponse* response) {
return MakeRPC([this, request, response](::grpc::ClientContext* context) {
@@ -163,48 +120,6 @@ Status GRPCStub::CreateChannelHandle(const CreateChannelHandleRequest* request,
});
}
-// Methods used by ComputationBuilder.
-Status GRPCStub::Computation(const ComputationRequest* request,
- ComputationResponse* response) {
- return MakeRPC([this, request, response](::grpc::ClientContext* context) {
- return grpc_stub_->Computation(context, *request, response);
- });
-}
-
-Status GRPCStub::Op(const OpRequest* request, OpResponse* response) {
- return MakeRPC([this, request, response](::grpc::ClientContext* context) {
- return grpc_stub_->CreateOp(context, *request, response);
- });
-}
-
-Status GRPCStub::GetLocalShape(const GetLocalShapeRequest* request,
- GetLocalShapeResponse* response) {
- return MakeRPC([this, request, response](::grpc::ClientContext* context) {
- return grpc_stub_->GetLocalShape(context, *request, response);
- });
-}
-
-Status GRPCStub::SetReturnValue(const SetReturnValueRequest* request,
- SetReturnValueResponse* responses) {
- return MakeRPC([this, request, responses](::grpc::ClientContext* context) {
- return grpc_stub_->SetReturnValue(context, *request, responses);
- });
-}
-
-Status GRPCStub::IsConstant(const IsConstantRequest* request,
- IsConstantResponse* response) {
- return MakeRPC([this, request, response](::grpc::ClientContext* context) {
- return grpc_stub_->IsConstant(context, *request, response);
- });
-}
-
-Status GRPCStub::ComputeConstant(const ComputeConstantRequest* request,
- ComputeConstantResponse* response) {
- return MakeRPC([this, request, response](::grpc::ClientContext* context) {
- return grpc_stub_->ComputeConstant(context, *request, response);
- });
-}
-
Status GRPCStub::ComputeConstantGraph(
const ComputeConstantGraphRequest* request,
ComputeConstantResponse* response) {
@@ -213,14 +128,6 @@ Status GRPCStub::ComputeConstantGraph(
});
}
-// Methods used by Computation.
-Status GRPCStub::SnapshotComputation(const SnapshotComputationRequest* request,
- SnapshotComputationResponse* response) {
- return MakeRPC([this, request, response](::grpc::ClientContext* context) {
- return grpc_stub_->SnapshotComputation(context, *request, response);
- });
-}
-
// Methods used by GlobalData.
Status GRPCStub::Unregister(const UnregisterRequest* request,
UnregisterResponse* response) {
diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.h b/tensorflow/compiler/xla/rpc/grpc_stub.h
index 5906d45769..8dfcb76138 100644
--- a/tensorflow/compiler/xla/rpc/grpc_stub.h
+++ b/tensorflow/compiler/xla/rpc/grpc_stub.h
@@ -43,39 +43,21 @@ class GRPCStub : public ServiceInterface {
Status ResetDevice(const ResetDeviceRequest* arg,
ResetDeviceResponse* result) override;
- Status LoadComputationSnapshot(
- const LoadComputationSnapshotRequest* request,
- LoadComputationSnapshotResponse* result) override;
-
- Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override;
-
Status ExecuteGraph(const ExecuteGraphRequest* request,
ExecuteResponse* response) override;
- Status ExecuteParallel(const ExecuteParallelRequest* arg,
- ExecuteParallelResponse* result) override;
-
Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* request,
ExecuteParallelResponse* response) override;
- Status ExecuteAsync(const ExecuteAsyncRequest* arg,
- ExecuteAsyncResponse* result) override;
-
Status WaitForExecution(const WaitForExecutionRequest* arg,
WaitForExecutionResponse* result) override;
Status DeconstructTuple(const DeconstructTupleRequest* arg,
DeconstructTupleResponse* result) override;
- Status GetComputationStats(const ComputationStatsRequest* arg,
- ComputationStatsResponse* result) override;
-
Status GetComputationGraphStats(const ComputationGraphStatsRequest* request,
ComputationStatsResponse* response) override;
- Status GetComputationShape(const GetComputationShapeRequest* arg,
- GetComputationShapeResponse* result) override;
-
Status GetShape(const GetShapeRequest* arg,
GetShapeResponse* result) override;
@@ -85,30 +67,9 @@ class GRPCStub : public ServiceInterface {
Status CreateChannelHandle(const CreateChannelHandleRequest* arg,
CreateChannelHandleResponse* result) override;
- // Methods used by ComputationBuilder.
- Status Computation(const ComputationRequest* arg,
- ComputationResponse* result) override;
-
- Status Op(const OpRequest* arg, OpResponse* result) override;
- Status GetLocalShape(const GetLocalShapeRequest* arg,
- GetLocalShapeResponse* result) override;
-
- Status SetReturnValue(const SetReturnValueRequest* arg,
- SetReturnValueResponse* results) override;
-
- Status IsConstant(const IsConstantRequest* arg,
- IsConstantResponse* result) override;
-
- Status ComputeConstant(const ComputeConstantRequest* arg,
- ComputeConstantResponse* result) override;
-
Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
ComputeConstantResponse* result) override;
- // Methods used by Computation.
- Status SnapshotComputation(const SnapshotComputationRequest* ag,
- SnapshotComputationResponse* result) override;
-
// Methods used by GlobalData.
Status Unregister(const UnregisterRequest* arg,
UnregisterResponse* result) override;
diff --git a/tensorflow/compiler/xla/rpc/xla_service.proto b/tensorflow/compiler/xla/rpc/xla_service.proto
index c47164ee1b..92eb19ec0f 100644
--- a/tensorflow/compiler/xla/rpc/xla_service.proto
+++ b/tensorflow/compiler/xla/rpc/xla_service.proto
@@ -75,19 +75,7 @@ service XlaService {
rpc GetShape(GetShapeRequest) returns (GetShapeResponse) {
}
- // Requests the program shape of the referenced computation.
- rpc GetComputationShape(GetComputationShapeRequest)
- returns (GetComputationShapeResponse) {
- }
-
// Requests the statistics of the given computation.
- rpc GetComputationStats(ComputationStatsRequest)
- returns (ComputationStatsResponse) {
- }
-
- // Requests the statistics of the given computation.
- //
- // TODO(b/74197823): This is a part of a NOT YET ready refactor.
rpc GetComputationGraphStats(ComputationGraphStatsRequest)
returns (ComputationStatsResponse) {
}
@@ -121,15 +109,6 @@ service XlaService {
rpc ResetDevice(ResetDeviceRequest) returns (ResetDeviceResponse) {
}
- // Tests if an expression is a compile-time constant.
- rpc IsConstant(IsConstantRequest) returns (IsConstantResponse) {
- }
-
- // Computes the value of a constant expression.
- rpc ComputeConstant(ComputeConstantRequest)
- returns (ComputeConstantResponse) {
- }
-
// Computes the value of a constant expression. The request contains the
// computation graph for the constant expression.
rpc ComputeConstantGraph(ComputeConstantGraphRequest)
@@ -165,20 +144,6 @@ service XlaService {
rpc SetReturnValue(SetReturnValueRequest) returns (SetReturnValueResponse) {
}
- // Computation creates a new computation with the given name.
- // A unique ComputationHandle is returned.
- rpc Computation(ComputationRequest) returns (ComputationResponse) {
- }
-
- // Adds a new op to a computation.
- rpc CreateOp(OpRequest) returns (OpResponse) {
- }
-
- // Invokes the provided computation with the provided global data passed as
- // immutable arguments. Returns global data output and execution timing.
- rpc Execute(ExecuteRequest) returns (ExecuteResponse) {
- }
-
// Invokes the provided computation with the provided global data passed as
// immutable arguments. The request contains the whole computation graph.
// Returns global data output and execution timing.
@@ -188,38 +153,13 @@ service XlaService {
// Invokes the provided list of computations in parallel with the provided
// global data for each computation. Returns a list of global data output and
// execution timing.
- rpc ExecuteParallel(ExecuteParallelRequest)
- returns (ExecuteParallelResponse) {
- }
-
- // Invokes the provided list of computations in parallel with the provided
- // global data for each computation. Returns a list of global data output and
- // execution timing.
- //
- // TODO(b/74197823): This is a part of a NOT YET ready refactor.
rpc ExecuteGraphParallel(ExecuteGraphParallelRequest)
returns (ExecuteParallelResponse) {
}
- // Invokes the provided computation with the provided global data passed as
- // immutable arguments. Returns a handle to the execution.
- rpc ExecuteAsync(ExecuteAsyncRequest) returns (ExecuteAsyncResponse) {
- }
-
// Waits until the given execution (aysnchronously launched) is complete, and
// returns the global data output.
rpc WaitForExecution(WaitForExecutionRequest)
returns (WaitForExecutionResponse) {
}
-
- // Serializes a computation to proto form, so it can be loaded via
- // LoadComputationSnapshot.
- rpc SnapshotComputation(SnapshotComputationRequest)
- returns (SnapshotComputationResponse) {
- }
-
- // Loads a computation from a captured snapshot.
- rpc LoadComputationSnapshot(LoadComputationSnapshotRequest)
- returns (LoadComputationSnapshotResponse) {
- }
}
diff --git a/tensorflow/compiler/xla/scanner.cc b/tensorflow/compiler/xla/scanner.cc
deleted file mode 100644
index f23a1417fc..0000000000
--- a/tensorflow/compiler/xla/scanner.cc
+++ /dev/null
@@ -1,197 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/scanner.h"
-
-#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-
-namespace xla {
-namespace {
-
-// Returns true if c can be the first character in an identifier.
-bool IsIdentifierFirst(int c) { return std::isalpha(c) || c == '_'; }
-
-// Returns true if c can be the non-first character in an identifier.
-bool IsIdentifierLater(int c) { return std::isalnum(c) || c == '_'; }
-
-// Returns true if str is an identifier.
-bool IsIdentifier(tensorflow::StringPiece str) {
- if (str.empty() || !IsIdentifierFirst(str[0])) {
- return false;
- }
- for (int64 i = 1; i < str.size(); ++i) {
- if (!IsIdentifierLater(str[i])) {
- return false;
- }
- }
- return true;
-}
-
-} // namespace
-
-Scanner::Scanner(tensorflow::StringPiece input) : input_(input), position_(0) {}
-
-bool Scanner::ok() const { return status().ok(); }
-
-const Status& Scanner::status() const { return status_; }
-
-bool Scanner::Match(tensorflow::StringPiece match) {
- SkipWhitespace();
- if (ok() && position_ + match.size() <= input_.size() &&
- std::equal(match.begin(), match.end(), input_.begin() + position_)) {
- SkipChars(match.size());
-
- VLOG(10) << "Matched \"" << match << "\"";
- return true;
- } else {
- return false;
- }
-}
-
-void Scanner::Expect(tensorflow::StringPiece expect) {
- if (!Match(expect)) {
- SetError(tensorflow::strings::StrCat("Expected \"", expect, "\"."));
- }
-}
-
-bool Scanner::MatchReadIdentifier(string* identifier) {
- SkipWhitespace();
- if (!IsIdentifierFirst(PeekChar())) {
- return false;
- }
- identifier->clear();
- do {
- *identifier += ReadChar();
- } while (IsIdentifierLater(PeekChar()));
-
- VLOG(10) << "Read identifier " << identifier;
- CHECK(IsIdentifier(*identifier));
- return true;
-}
-
-string Scanner::ReadIdentifier() {
- string identifier;
- if (!MatchReadIdentifier(&identifier)) {
- SetError("Expected identifier.");
- }
- return identifier;
-}
-
-void Scanner::ExpectIdentifier(tensorflow::StringPiece expect) {
- CHECK(IsIdentifier(expect));
-
- string identifier;
- if (!MatchReadIdentifier(&identifier)) {
- SetError(tensorflow::strings::StrCat("Expected identifier ", expect, "."));
- }
- if (identifier != expect) {
- SetError(tensorflow::strings::StrCat("Expected identifier ", expect,
- ", but got ", identifier, "."));
- }
-}
-
-// Matches the end of the input, also known as End Of File (EOF).
-bool Scanner::MatchEof() {
- SkipWhitespace();
- return PeekChar() == EOF;
-}
-
-void Scanner::ExpectEof() {
- if (!MatchEof()) {
- SetError("Expected end of input.");
- }
-}
-
-// Reads a vector of the format "(1, 2, 3)".
-std::vector<int64> Scanner::ReadIntVector() {
- std::vector<int64> ints;
- Expect("(");
- if (!Match(")") && ok()) {
- ints.push_back(ReadInt());
- while (Match(",")) {
- ints.push_back(ReadInt());
- }
- Expect(")");
- }
-
- VLOG(10) << "Read int vector with " << ints.size() << " elements.";
- return ints;
-}
-
-int64 Scanner::ReadInt() {
- bool negative = Match("-");
- if (!PeekDigit()) {
- SetError("Expected integer.");
- return 0;
- }
-
- int64 integer = 0;
- do {
- integer = (ReadChar() - '0') + integer * 10;
- } while (PeekDigit());
- integer = negative ? -integer : integer;
-
- VLOG(10) << "Read integer " << integer;
- return integer;
-}
-
-void Scanner::SkipWhitespace() {
- while (PeekWhitespace()) {
- SkipChars(1);
- }
-}
-
-int Scanner::ReadChar() {
- int c = PeekChar();
- SkipChars(1);
-
- VLOG(20) << "Read char " << c;
- return c;
-}
-
-int Scanner::PeekChar() const {
- return ok() && position_ < input_.size() ? input_[position_] : EOF;
-}
-
-bool Scanner::PeekDigit() const {
- // Do not use std::isdigit since it depends on the locale and we do not
- // handle any digits beyond 0-9.
- const char c = PeekChar();
- return '0' <= c && c <= '9';
-}
-
-bool Scanner::PeekAlnum() const { return std::isalnum(PeekChar()); }
-
-bool Scanner::PeekWhitespace() const { return std::isspace(PeekChar()); }
-
-void Scanner::SkipChars(int64 count) {
- CHECK_GE(count, 0);
- position_ += count;
-}
-
-void Scanner::SetError(string error_message) {
- // Only the first error is recorded since any later errors will likely be a
- // consequence of the first error.
- if (ok()) {
- status_ = InvalidArgumentStrCat(std::move(error_message));
- position_ = input_.size();
- VLOG(10) << "Failed scanner with error " << status_.ToString();
- } else {
- VLOG(10) << "Error on already failed scanner is " << error_message;
- }
-}
-
-} // namespace xla
diff --git a/tensorflow/compiler/xla/scanner.h b/tensorflow/compiler/xla/scanner.h
deleted file mode 100644
index 86b04ae7f9..0000000000
--- a/tensorflow/compiler/xla/scanner.h
+++ /dev/null
@@ -1,102 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_SCANNER_H_
-#define TENSORFLOW_COMPILER_XLA_SCANNER_H_
-
-#include "tensorflow/compiler/xla/status.h"
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/stringpiece.h"
-
-namespace xla {
-
-// Simple class for parsing data. The concepts for the interface are:
-//
-// Match(x): Returns true if x is next in the input and in that case skips
-// past it. Otherwise returns false.
-//
-// Expect(x): As Match(x), but requires x to be next in the input.
-//
-// MatchReadX(x): Returns true if an X is next in the input and in that case
-// skips past it and assigns it to x. Otherwise returns false.
-//
-// ReadX(): As ReadMatchX(), but requires an X to be next in the input and
-// returns it.
-//
-// PeekX(): Returns true if an X is next in the input and does not skip
-// past it either way.
-//
-// All of these, except those that work on individual characters, skip
-// whitespace.
-//
-// If a requirement is not met, the error is available in status(). A Scanner
-// with a failed status() will behave as though the rest of the input is EOF and
-// will not record further errors after that point.
-class Scanner {
- public:
- Scanner(tensorflow::StringPiece input);
-
- bool ok() const;
- const Status& status() const;
-
- bool Match(tensorflow::StringPiece match);
- void Expect(tensorflow::StringPiece expect);
-
- // Match-reads an identifier. An identifier starts with an alphabetic
- // character or an underscore followed by any number of characters that are
- // each alphanumeric or underscore.
- bool MatchReadIdentifier(string* identifier);
-
- string ReadIdentifier();
-
- void ExpectIdentifier(tensorflow::StringPiece expect);
-
- // Matches the end of the input, also known as End Of File (EOF).
- bool MatchEof();
- void ExpectEof();
-
- // Reads a vector of the format "(1, 4, 5)".
- std::vector<int64> ReadIntVector();
-
- // Reads an integer. Can start with a minus but not a plus.
- int64 ReadInt();
-
- // Keeps skipping until encountering a non-whitespace character.
- void SkipWhitespace();
-
- // *** Below here are character-level methods that do not skip whitespace.
-
- int ReadChar();
- int PeekChar() const;
- bool PeekDigit() const;
- bool PeekAlnum() const;
- bool PeekWhitespace() const;
-
- // Skip past the next count characters.
- void SkipChars(int64 count);
-
- private:
- // Sets a failed status. The input is in effect replaced with EOF after
- // this. Only the first error is recorded.
- void SetError(string error_message);
-
- const tensorflow::StringPiece input_;
- int64 position_;
- Status status_;
-};
-
-} // namespace xla
-
-#endif // TENSORFLOW_COMPILER_XLA_SCANNER_H_
diff --git a/tensorflow/compiler/xla/scanner_test.cc b/tensorflow/compiler/xla/scanner_test.cc
deleted file mode 100644
index 10cd0c6a04..0000000000
--- a/tensorflow/compiler/xla/scanner_test.cc
+++ /dev/null
@@ -1,124 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// TODO(b/80179519): Fix open source build for real.
-#if 0
-#include "tensorflow/compiler/xla/scanner.h"
-
-#include <string>
-
-#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/platform/env.h"
-
-namespace xla {
-namespace {
-
-TEST(Scanner, Empty) {
- Scanner scanner("");
-
- EXPECT_EQ(scanner.PeekChar(), EOF);
- EXPECT_TRUE(scanner.MatchEof());
- EXPECT_TRUE(scanner.Match(""));
- EXPECT_FALSE(scanner.Match("1"));
- EXPECT_TRUE(scanner.ok());
-}
-
-TEST(Scanner, Prefix) {
- Scanner scanner("1234 5");
- EXPECT_FALSE(scanner.MatchEof());
- EXPECT_TRUE(scanner.Match("12"));
- EXPECT_TRUE(scanner.Match("34 "));
- EXPECT_FALSE(scanner.MatchEof());
- EXPECT_FALSE(scanner.Match("5 "));
- EXPECT_TRUE(scanner.Match("5"));
- EXPECT_TRUE(scanner.MatchEof());
-}
-
-TEST(Scanner, Whitespace) {
- Scanner scanner(" \t\n\r 1\t2\n\n");
-
- EXPECT_FALSE(scanner.Match(" "));
- EXPECT_TRUE(scanner.Match("1"));
- EXPECT_TRUE(scanner.Match("2"));
- EXPECT_TRUE(scanner.MatchEof());
- EXPECT_TRUE(scanner.ok());
-}
-
-TEST(Scanner, Fail) {
- Scanner scanner("153 4q");
-
- scanner.Expect("5");
- EXPECT_FALSE(scanner.ok());
- EXPECT_FALSE(scanner.status().ok());
-
- EXPECT_TRUE(scanner.MatchEof());
-}
-
-TEST(Scanner, Identifier) {
- Scanner scanner("1 q1 _1_ _1a= qqb");
-
- string identifier = "foo";
- EXPECT_FALSE(scanner.MatchReadIdentifier(&identifier));
- EXPECT_EQ(identifier, "foo");
- scanner.Match("1");
-
- EXPECT_TRUE(scanner.MatchReadIdentifier(&identifier));
- EXPECT_EQ(identifier, "q1");
-
- scanner.ExpectIdentifier("_1_");
- EXPECT_TRUE(scanner.ok());
-
- scanner.ExpectIdentifier("_1a");
- EXPECT_TRUE(scanner.ok());
-
- // The = after _1a is not included in the identifier.
- scanner.Expect("=");
-
- // The expected identifier matches a prefix but is not the full identifier in
- // the input.
- EXPECT_TRUE(scanner.ok());
- scanner.ExpectIdentifier("qq");
- EXPECT_FALSE(scanner.ok());
-}
-
-TEST(Scanner, Int) {
- Scanner scanner("1_2 3% -1 124345 -363 0 -0");
- EXPECT_EQ(1, scanner.ReadInt());
- EXPECT_TRUE(scanner.Match("_"));
- EXPECT_EQ(2, scanner.ReadInt());
- EXPECT_EQ(3, scanner.ReadInt());
- EXPECT_TRUE(scanner.Match("%"));
- EXPECT_EQ(-1, scanner.ReadInt());
- EXPECT_EQ(124345, scanner.ReadInt());
- EXPECT_EQ(-363, scanner.ReadInt());
- EXPECT_EQ(0, scanner.ReadInt());
- EXPECT_EQ(0, scanner.ReadInt());
- EXPECT_TRUE(scanner.MatchEof());
-}
-
-TEST(Scanner, IntVector) {
- Scanner scanner("()(0) (-1,2) ( 3 , 4 )");
- EXPECT_THAT(scanner.ReadIntVector(), testing::IsEmpty());
- EXPECT_THAT(scanner.ReadIntVector(), testing::ElementsAre(0));
- EXPECT_THAT(scanner.ReadIntVector(), testing::ElementsAre(-1, 2));
- EXPECT_THAT(scanner.ReadIntVector(), testing::ElementsAre(3, 4));
- EXPECT_TRUE(scanner.MatchEof());
- EXPECT_TRUE(scanner.ok());
-}
-
-} // namespace
-} // namespace xla
-#endif
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 4d653a0196..2b14b63ea8 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -309,6 +309,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:human_readable_json",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
@@ -425,6 +426,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
@@ -548,45 +550,6 @@ tf_cc_test(
)
cc_library(
- name = "user_computation",
- srcs = ["user_computation.cc"],
- hdrs = ["user_computation.h"],
- deps = [
- ":hlo",
- ":session_proto",
- ":shape_inference",
- ":versioned_computation_handle",
- "//tensorflow/compiler/xla:literal_util",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:status_macros",
- "//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla:xla_proto",
- "//tensorflow/core:lib",
- ],
-)
-
-tf_cc_test(
- name = "user_computation_test",
- srcs = ["user_computation_test.cc"],
- deps = [
- ":hlo_matchers",
- ":user_computation",
- "//tensorflow/compiler/xla:literal_util",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:status_macros",
- "//tensorflow/compiler/xla:test",
- "//tensorflow/compiler/xla:test_helpers",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/service:hlo",
- "//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/core:test",
- ],
-)
-
-cc_library(
name = "platform_util",
srcs = ["platform_util.cc"],
hdrs = ["platform_util.h"],
@@ -634,7 +597,6 @@ cc_library(
":compilation_cache",
":compiler",
":computation_layout",
- ":computation_tracker",
":device_memory_allocator",
":executable",
":execution_tracker",
@@ -648,7 +610,6 @@ cc_library(
":session_proto",
":source_map_util",
":transfer_manager",
- ":user_computation",
":versioned_computation_handle",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:execution_options_util",
@@ -676,7 +637,6 @@ cc_library(
":backend",
":compiler",
":computation_layout",
- ":computation_tracker",
":device_memory_allocator",
":executable",
":hlo",
@@ -685,7 +645,6 @@ cc_library(
":platform_util",
":service",
":shaped_buffer",
- ":user_computation",
":versioned_computation_handle",
"//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:shape_layout",
@@ -710,7 +669,6 @@ cc_library(
":backend",
":compiler",
":computation_layout",
- ":computation_tracker",
":platform_util",
":service",
"//tensorflow/compiler/xla:status_macros",
@@ -906,32 +864,12 @@ cc_library(
)
cc_library(
- name = "computation_tracker",
- srcs = ["computation_tracker.cc"],
- hdrs = ["computation_tracker.h"],
- deps = [
- ":hlo",
- ":hlo_module_config",
- ":session_proto",
- ":user_computation",
- ":versioned_computation_handle",
- "//tensorflow/compiler/xla:status_macros",
- "//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/core:lib",
- ],
-)
-
-cc_library(
name = "channel_tracker",
srcs = ["channel_tracker.cc"],
hdrs = ["channel_tracker.h"],
deps = [
":hlo",
":session_proto",
- ":user_computation",
":versioned_computation_handle",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
@@ -1038,7 +976,6 @@ tf_cc_test(
":buffer_assignment",
":buffer_value",
":call_graph",
- ":computation_tracker",
":copy_insertion",
":cpu_plugin",
":flatten_call_graph",
@@ -1710,13 +1647,11 @@ tf_cc_test(
name = "hlo_cost_analysis_test",
srcs = ["hlo_cost_analysis_test.cc"],
deps = [
- ":computation_tracker",
":cpu_plugin",
":hlo",
":hlo_cost_analysis",
":local_service",
":service",
- ":user_computation",
":versioned_computation_handle",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
@@ -2920,6 +2855,7 @@ tf_cc_test(
deps = [
":while_util",
"//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index c65c91e8e0..e1a45e453e 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -233,10 +233,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
HloInstruction* operand, HloInstruction* max,
HloInstruction* max_operand);
- // A Reshape or Broadcast that feeds an element-wise operation with a unique
- // non-scalar operand can sink to after the operation.
- StatusOr<bool> TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(
- HloInstruction* reshape_or_broadcast);
+ // A Broadcast that feeds an element-wise operation with a unique non-scalar
+ // operand can sink to after the operation.
+ StatusOr<bool> TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
+ HloInstruction* broadcast);
// Replaces the existing HLO instruction old_instruction, with
// new_instruction, and marks the optimizer status as changed.
@@ -1305,7 +1305,7 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
// broadcast after the unary element-wise operation.
TF_ASSIGN_OR_RETURN(
bool sink_succeeded,
- TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(broadcast));
+ TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(broadcast));
changed_ |= sink_succeeded;
if (sink_succeeded) {
return Status::OK();
@@ -1557,15 +1557,16 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
return Status::OK();
}
-StatusOr<bool> AlgebraicSimplifierVisitor::
- TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(
- HloInstruction* reshape_or_broadcast) {
+StatusOr<bool>
+AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
+ HloInstruction* broadcast) {
+ TF_RET_CHECK(broadcast->opcode() == HloOpcode::kBroadcast);
bool changed = false;
- if (ShapeUtil::IsScalar(reshape_or_broadcast->shape())) {
+ if (ShapeUtil::IsScalar(broadcast->shape())) {
return false;
}
- HloInstruction* operand = reshape_or_broadcast->mutable_operand(0);
- for (HloInstruction* user : reshape_or_broadcast->users()) {
+ HloInstruction* operand = broadcast->mutable_operand(0);
+ for (HloInstruction* user : broadcast->users()) {
if (user->user_count() == 0 && user != computation_->root_instruction()) {
continue;
}
@@ -1583,55 +1584,50 @@ StatusOr<bool> AlgebraicSimplifierVisitor::
continue;
}
- int64 reshape_or_broadcast_operand_index = -1;
// Find the unique non-scalar operand or continue if there isn't one.
- int64 scalar_count = 0;
- for (int64 i = 0; i < user->operand_count(); ++i) {
- if (ShapeUtil::IsScalar(user->operand(i)->shape())) {
- ++scalar_count;
- } else {
- reshape_or_broadcast_operand_index = i;
+ int64 scalar_broadcast_count = 0;
+ int64 broadcast_use_count = 0;
+ for (HloInstruction* user_operand : user->operands()) {
+ if (user_operand->opcode() == HloOpcode::kBroadcast &&
+ ShapeUtil::IsScalar(user_operand->operand(0)->shape())) {
+ ++scalar_broadcast_count;
+ } else if (broadcast == user_operand) {
+ ++broadcast_use_count;
}
}
- if (scalar_count != user->operand_count() - 1) {
+ if (scalar_broadcast_count + broadcast_use_count != user->operand_count()) {
continue;
}
- VLOG(4) << "Sinking reshape or broadcast after user:";
- VLOG(4) << " old reshape/broadcast: " << reshape_or_broadcast->ToString();
+ std::vector<HloInstruction*> new_operands;
+ new_operands.reserve(user->operand_count());
+
+ for (HloInstruction* user_operand : user->operands()) {
+ if (user_operand->opcode() == HloOpcode::kBroadcast &&
+ ShapeUtil::IsScalar(user_operand->operand(0)->shape())) {
+ new_operands.push_back(
+ computation_->AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::ChangeElementType(
+ operand->shape(), user_operand->shape().element_type()),
+ user_operand->mutable_operand(0), {})));
+ } else {
+ CHECK_EQ(broadcast, user_operand);
+ new_operands.push_back(operand);
+ }
+ }
+ VLOG(4) << "Sinking broadcast after user:";
+ VLOG(4) << " old broadcast: " << broadcast->ToString();
VLOG(4) << " old user: " << user->ToString();
- CHECK_EQ(user->operand(reshape_or_broadcast_operand_index),
- reshape_or_broadcast);
- auto new_user_operands = user->operands();
- new_user_operands[reshape_or_broadcast_operand_index] = operand;
- auto new_user = computation_->AddInstruction(user->CloneWithNewOperands(
- ShapeUtil::MakeShapeWithLayout(
- user->shape().element_type(),
- AsInt64Slice(operand->shape().dimensions()),
- LayoutUtil::MinorToMajor(operand->shape())),
- new_user_operands));
+ HloInstruction* new_user =
+ computation_->AddInstruction(user->CloneWithNewOperands(
+ ShapeUtil::ChangeElementType(operand->shape(),
+ user->shape().element_type()),
+ new_operands));
VLOG(4) << " new user: " << new_user->ToString();
- HloInstruction* new_reshape_or_broadcast = nullptr;
- if (reshape_or_broadcast->opcode() == HloOpcode::kReshape) {
- new_reshape_or_broadcast =
- computation_->AddInstruction(HloInstruction::CreateReshape(
- ShapeUtil::MakeShapeWithLayout(
- user->shape().element_type(),
- AsInt64Slice(reshape_or_broadcast->shape().dimensions()),
- LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())),
- new_user));
- } else {
- TF_RET_CHECK(reshape_or_broadcast->opcode() == HloOpcode::kBroadcast);
- new_reshape_or_broadcast =
- computation_->AddInstruction(HloInstruction::CreateBroadcast(
- ShapeUtil::MakeShapeWithLayout(
- user->shape().element_type(),
- AsInt64Slice(reshape_or_broadcast->shape().dimensions()),
- LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())),
- new_user, reshape_or_broadcast->dimensions()));
- }
- VLOG(4) << " new reshape/broadcast: "
- << new_reshape_or_broadcast->ToString();
- TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_reshape_or_broadcast));
+ HloInstruction* new_broadcast =
+ computation_->AddInstruction(HloInstruction::CreateBroadcast(
+ user->shape(), new_user, broadcast->dimensions()));
+ VLOG(4) << " new broadcast: " << new_broadcast->ToString();
+ TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_broadcast));
changed = true;
}
return changed;
@@ -1674,16 +1670,6 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
}
}
- // A Reshape that feeds a unary element-wise operation can sink the
- // reshape after the unary element-wise operation.
- TF_ASSIGN_OR_RETURN(
- bool sink_succeeded,
- TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(reshape));
- changed_ |= sink_succeeded;
- if (sink_succeeded) {
- return Status::OK();
- }
-
// Make this a bitcast if possible.
if (is_layout_sensitive_ &&
ReshapeIsBitcast(reshape, valid_bitcast_callback_)) {
@@ -1788,6 +1774,11 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
new_reduce_dimensions, function));
}
+ if (ShapeUtil::ElementsIn(reduce->shape()) ==
+ ShapeUtil::ElementsIn(arg->shape())) {
+ return ReplaceWithNewInstruction(
+ reduce, HloInstruction::CreateReshape(reduce->shape(), arg));
+ }
// A reshape that collapses multiple dimensions into a dimension being
// reduced can just reduce all of those dimensions instead of doing a
// collapsing reshape before a reduction.
@@ -1832,15 +1823,6 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
new_reduce_dimensions, function));
}
}
- if (ShapeUtil::ElementsIn(reduce->shape()) ==
- ShapeUtil::ElementsIn(arg->shape()) ||
- ShapeUtil::HasZeroElements(arg->shape())) {
- auto reshape = computation_->AddInstruction(
- HloInstruction::CreateReshape(reduce->shape(), arg));
- return ReplaceWithNewInstruction(
- reduce, HloInstruction::CreateMap(reduce->shape(),
- {init_value, reshape}, function));
- }
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index d5f0afe960..cda157f9fa 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -1351,32 +1351,6 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) {
op::Tuple(op::Bitcast(), dimensions_wrong_reshape, layout_wrong_reshape));
}
-TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) {
- HloComputation::Builder builder(TestName());
- HloInstruction* param =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 0, ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), "param"));
- HloInstruction* movable_reshape =
- builder.AddInstruction(HloInstruction::CreateReshape(
- ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), param));
- HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
- builder.AddInstruction(
- HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}),
- HloOpcode::kMaximum, movable_reshape, zero));
- auto computation = module().AddEntryComputation(builder.Build());
-
- EXPECT_THAT(computation->root_instruction(),
- op::Maximum(op::Reshape(param), zero));
-
- AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
- bitcasting_callback());
-
- simplifier.Run(&module()).ValueOrDie();
- EXPECT_THAT(computation->root_instruction(),
- op::Reshape(op::Maximum(param, zero)));
-}
-
// Regression test for a bug in the reshape sinking transformation, where
// moving a reshape to a scalar led to a crash.
TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) {
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc
index 96e02b82b9..598718c72c 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc
@@ -98,21 +98,67 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
return *scalar_add_computation;
}
- // Current HloComputation instance the BatchNormExpander is
- // traversing.
- HloComputation* computation_;
+ // TODO(b/80534766): Remove maps after performance issues with scalar
+ // broadcasts are resolved on all backends.
+ HloComputation* GetOrCreateScalarRsqrtComputation(
+ PrimitiveType primitive_type) {
+ HloComputation** scalar_rsqrt_computation =
+ &scalar_rsqrt_computations_[primitive_type];
+ if (*scalar_rsqrt_computation) {
+ return *scalar_rsqrt_computation;
+ }
- bool rewrite_training_op_;
- bool rewrite_inference_op_;
- bool rewrite_grad_op_;
- bool use_fusion_;
+ HloComputation::Builder b("scalar_add_computation");
+ Shape shape = ShapeUtil::MakeShape(primitive_type, {});
+ auto scalar_lhs = b.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "scalar_lhs"));
+ auto scalar_rhs = b.AddInstruction(HloInstruction::CreateConvert(
+ shape, b.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR0<float>(-0.5f)))));
+ auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kPower, scalar_lhs, scalar_rhs));
+ *scalar_rsqrt_computation =
+ computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
+ return *scalar_rsqrt_computation;
+ }
- // Whether rewrite has occurred.
- bool changed_ = false;
+ std::unique_ptr<HloInstruction> Rsqrt(HloInstruction* operand) {
+ return HloInstruction::CreateMap(
+ operand->shape(), {operand},
+ GetOrCreateScalarRsqrtComputation(operand->shape().element_type()));
+ }
- // Cached computations for adding two scalars.
- tensorflow::gtl::FlatMap<PrimitiveType, HloComputation*>
- scalar_add_computations_;
+ HloComputation* GetOrCreateScalarMeanComputation(PrimitiveType primitive_type,
+ int64 element_count) {
+ HloComputation** scalar_mean_computation =
+ &scalar_mean_computations_[std::pair<PrimitiveType, int64>(
+ primitive_type, element_count)];
+ if (*scalar_mean_computation) {
+ return *scalar_mean_computation;
+ }
+
+ HloComputation::Builder b("scalar_add_computation");
+ Shape shape = ShapeUtil::MakeShape(primitive_type, {});
+ auto scalar_lhs = b.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "scalar_lhs"));
+ auto scalar_rhs = b.AddInstruction(HloInstruction::CreateConvert(
+ shape, b.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(
+ 1.0f / static_cast<float>(element_count))))));
+ auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
+ shape, HloOpcode::kMultiply, scalar_lhs, scalar_rhs));
+ *scalar_mean_computation =
+ computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
+ return *scalar_mean_computation;
+ }
+
+ std::unique_ptr<HloInstruction> Mean(int64 element_count,
+ HloInstruction* operand) {
+ return HloInstruction::CreateMap(
+ operand->shape(), {operand},
+ GetOrCreateScalarMeanComputation(operand->shape().element_type(),
+ element_count));
+ }
// Replaces the existing HLO instruction old_instruction, with
// new_instruction, and marks the optimizer status as changed.
@@ -136,6 +182,25 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
changed_ = true;
return Status::OK();
}
+ // Current HloComputation instance the BatchNormExpander is
+ // traversing.
+ HloComputation* computation_;
+
+ bool rewrite_training_op_;
+ bool rewrite_inference_op_;
+ bool rewrite_grad_op_;
+ bool use_fusion_;
+
+ // Whether rewrite has occurred.
+ bool changed_ = false;
+
+ // Cached computations for adding two scalars.
+ tensorflow::gtl::FlatMap<PrimitiveType, HloComputation*>
+ scalar_add_computations_;
+ tensorflow::gtl::FlatMap<PrimitiveType, HloComputation*>
+ scalar_rsqrt_computations_;
+ tensorflow::gtl::FlatMap<std::pair<PrimitiveType, int64>, HloComputation*>
+ scalar_mean_computations_;
};
} // namespace
@@ -167,6 +232,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
added_instructions.push_back(added_inst);
return added_inst;
};
+ auto add_binary = [&](const Shape& shape, const HloOpcode opcode,
+ HloInstruction* a, HloInstruction* b) {
+ return add(HloInstruction::CreateBinary(shape, opcode, a, b));
+ };
int64 instruction_count_before = computation_->instruction_count();
// Expand batch norm training into smaller HLO ops.
@@ -176,12 +245,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
int64 feature_index = batch_norm->feature_index();
const int64 feature_count = operand_shape.dimensions(feature_index);
const int64 size_in_elements = ShapeUtil::ElementsIn(operand_shape);
- auto elements_per_feature_literal =
- Literal::CreateR0<float>(size_in_elements / feature_count);
- TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
- elements_per_feature_literal->Convert(ptype));
- auto elements_per_feature = add(
- HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
+ int64 elements_per_feature_int64 = size_in_elements / feature_count;
HloInstruction* scale = batch_norm->mutable_operand(1);
HloInstruction* offset = batch_norm->mutable_operand(2);
@@ -193,8 +257,9 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
- auto epsilon =
- add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
+ auto epsilon = add(HloInstruction::CreateBroadcast(
+ operand_shape,
+ add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {}));
std::vector<int64> dimensions_without_feature;
for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) {
@@ -213,8 +278,8 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
GetOrCreateScalarAddComputation(ptype);
// X^2.
- auto operand_squared = add(HloInstruction::CreateBinary(
- operand_shape, HloOpcode::kMultiply, operand, operand));
+ auto operand_squared =
+ add_binary(operand_shape, HloOpcode::kMultiply, operand, operand);
// Sum[X].
auto sum = add(HloInstruction::CreateReduce(feature_shape, operand, zero,
dimensions_without_feature,
@@ -240,56 +305,47 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
}
// E[X].
- auto mean = add(HloInstruction::CreateBinary(
- feature_shape, HloOpcode::kDivide, sum, elements_per_feature));
+ auto mean = add(Mean(elements_per_feature_int64, sum));
auto mean_broadcasted = add(
HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index}));
// E[X^2].
- auto square_mean = add(HloInstruction::CreateBinary(
- feature_shape, HloOpcode::kDivide, squared_sum, elements_per_feature));
+ auto square_mean = add(Mean(elements_per_feature_int64, squared_sum));
// E^2[X].
- auto mean_square = add(HloInstruction::CreateBinary(
- feature_shape, HloOpcode::kMultiply, mean, mean));
+ auto mean_square =
+ add_binary(feature_shape, HloOpcode::kMultiply, mean, mean);
// Var[X].
- auto var = add(HloInstruction::CreateBinary(
- feature_shape, HloOpcode::kSubtract, square_mean, mean_square));
+ auto var =
+ add_binary(feature_shape, HloOpcode::kSubtract, square_mean, mean_square);
auto var_broadcasted =
add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
// Var[X] + epsilon.
- auto var_add_epsilon = add(HloInstruction::CreateBinary(
- operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon));
-
- auto neg_half_literal = Literal::CreateR0(-0.5f);
- TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
- auto neg_half =
- add(HloInstruction::CreateConstant(std::move(neg_half_literal)));
+ auto var_add_epsilon =
+ add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon);
// 1 / Sqrt[Var[X] + epsilon].
- auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary(
- operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half));
+ auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon));
// X - E[X].
- auto operand_minus_mean = add(HloInstruction::CreateBinary(
- operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted));
+ auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract,
+ operand, mean_broadcasted);
// (X - E[X]) / Sqrt[Var[X] + epsilon].
- auto normalized = add(
- HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply,
- operand_minus_mean, rsqrt_var_add_epsilon));
+ auto normalized = add_binary(operand_shape, HloOpcode::kMultiply,
+ operand_minus_mean, rsqrt_var_add_epsilon);
// (X - E[X]) / Sqrt[Var[X] + epsilon] * scale.
- auto scaled_normalized = add(HloInstruction::CreateBinary(
- operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted));
+ auto scaled_normalized = add_binary(operand_shape, HloOpcode::kMultiply,
+ normalized, scale_broadcasted);
// (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset.
- auto shifted_normalized = add(HloInstruction::CreateBinary(
- operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted));
+ auto shifted_normalized = add_binary(operand_shape, HloOpcode::kAdd,
+ scaled_normalized, offset_broadcasted);
auto tuple = HloInstruction::CreateTuple({shifted_normalized, mean, var});
@@ -331,8 +387,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
- auto epsilon = computation_->AddInstruction(
- HloInstruction::CreateConstant(std::move(epsilon_literal)));
+ auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast(
+ operand_shape,
+ computation_->AddInstruction(
+ HloInstruction::CreateConstant(std::move(epsilon_literal))),
+ {}));
std::vector<int64> dimensions_without_feature;
@@ -349,6 +408,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
added_instructions.push_back(added_inst);
return added_inst;
};
+ auto add_binary = [&](const Shape& shape, const HloOpcode opcode,
+ HloInstruction* a, HloInstruction* b) {
+ return add(HloInstruction::CreateBinary(shape, opcode, a, b));
+ };
int64 instruction_count_before = computation_->instruction_count();
auto scale_broadcasted = add(
@@ -364,30 +427,23 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
// Var[X] + epsilon.
- auto var_add_epsilon = add(HloInstruction::CreateBinary(
- operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon));
-
- auto neg_half_literal = Literal::CreateR0(-0.5f);
- TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
- auto neg_half =
- add(HloInstruction::CreateConstant(std::move(neg_half_literal)));
+ auto var_add_epsilon =
+ add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon);
// 1 / Sqrt[Var[X] + epsilon].
- auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary(
- operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half));
+ auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon));
// X - E[X].
- auto operand_minus_mean = add(HloInstruction::CreateBinary(
- operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted));
+ auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract,
+ operand, mean_broadcasted);
// (X - E[X]) / Sqrt[Var[X] + epsilon].
- auto normalized = add(
- HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply,
- operand_minus_mean, rsqrt_var_add_epsilon));
+ auto normalized = add_binary(operand_shape, HloOpcode::kMultiply,
+ operand_minus_mean, rsqrt_var_add_epsilon);
// (X - E[X]) / Sqrt[Var[X] + epsilon] * scale.
- auto scaled_normalized = add(HloInstruction::CreateBinary(
- operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted));
+ auto scaled_normalized = add_binary(operand_shape, HloOpcode::kMultiply,
+ normalized, scale_broadcasted);
// (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset.
auto shifted_normalized = HloInstruction::CreateBinary(
@@ -435,6 +491,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
added_instructions.push_back(added_inst);
return added_inst;
};
+ auto add_binary = [&](const Shape& shape, const HloOpcode opcode,
+ HloInstruction* a, HloInstruction* b) {
+ return add(HloInstruction::CreateBinary(shape, opcode, a, b));
+ };
int64 instruction_count_before = computation_->instruction_count();
HloInstruction* activation = batch_norm->mutable_operand(0);
@@ -450,26 +510,20 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
const int64 size_in_elements = ShapeUtil::ElementsIn(activation_shape);
const int64 feature_count = activation_shape.dimensions(feature_index);
- auto elements_per_feature_literal =
- Literal::CreateR0<float>(size_in_elements / feature_count);
- TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
- elements_per_feature_literal->Convert(ptype));
- auto elements_per_feature = add(
- HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
+ const int64 elements_per_feature_int64 = size_in_elements / feature_count;
auto zero_literal = Literal::CreateR0(0.0f);
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
- auto neg_half_literal = Literal::CreateR0(-0.5f);
- TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
- auto neg_half =
- add(HloInstruction::CreateConstant(std::move(neg_half_literal)));
-
auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
- auto epsilon =
+ auto epsilon_scalar =
add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
+ auto epsilon_activation = add(
+ HloInstruction::CreateBroadcast(activation_shape, epsilon_scalar, {}));
+ auto epsilon_feature =
+ add(HloInstruction::CreateBroadcast(feature_shape, epsilon_scalar, {}));
std::vector<int64> dimensions_without_feature;
@@ -489,26 +543,21 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
HloInstruction::CreateBroadcast(activation_shape, mean, {feature_index}));
// rsqrt[Var[X] + epsilon].
- auto rsqrt_var_add_epsilon_broadcasted = add(HloInstruction::CreateBinary(
- activation_shape, HloOpcode::kPower,
- add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd,
- variance_broadcasted, epsilon)),
- neg_half));
-
- auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary(
- feature_shape, HloOpcode::kPower,
- add(HloInstruction::CreateBinary(feature_shape, HloOpcode::kAdd, variance,
- epsilon)),
- neg_half));
+ auto rsqrt_var_add_epsilon_broadcasted =
+ add(Rsqrt(add_binary(activation_shape, HloOpcode::kAdd,
+ variance_broadcasted, epsilon_activation)));
+
+ auto rsqrt_var_add_epsilon = add(Rsqrt(
+ add_binary(feature_shape, HloOpcode::kAdd, variance, epsilon_feature)));
// X - E[X].
- auto activation_minus_mean = add(HloInstruction::CreateBinary(
- activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted));
+ auto activation_minus_mean = add_binary(
+ activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted);
// Grad[Y] * (X - E[X]).
auto grad_output_times_activiation_minus_mean =
- add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply,
- grad_output, activation_minus_mean));
+ add_binary(activation_shape, HloOpcode::kMultiply, grad_output,
+ activation_minus_mean);
HloComputation* add_reduce_computation =
GetOrCreateScalarAddComputation(ptype);
@@ -540,9 +589,9 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
}
// Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]).
- auto grad_scale = add(HloInstruction::CreateBinary(
- feature_shape, HloOpcode::kMultiply,
- sum_grad_output_times_activiation_minus_mean, rsqrt_var_add_epsilon));
+ auto grad_scale = add_binary(feature_shape, HloOpcode::kMultiply,
+ sum_grad_output_times_activiation_minus_mean,
+ rsqrt_var_add_epsilon);
// I2 = Sum(Grad[Y])
auto i2 = add(HloInstruction::CreateBroadcast(activation_shape, grad_beta,
@@ -554,39 +603,40 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
{feature_index}));
// I4 = (X - E[X]) * I3
- auto i4 = add(HloInstruction::CreateBinary(
- activation_shape, HloOpcode::kMultiply, i3, activation_minus_mean));
+ auto i4 = add_binary(activation_shape, HloOpcode::kMultiply, i3,
+ activation_minus_mean);
// I5 = I4 / (Var[X] + epsilon)
- auto i5 = add(HloInstruction::CreateBinary(
- activation_shape, HloOpcode::kDivide, i4,
- add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd,
- variance_broadcasted, epsilon))));
+ auto i5 = add_binary(activation_shape, HloOpcode::kDivide, i4,
+ add_binary(activation_shape, HloOpcode::kAdd,
+ variance_broadcasted, epsilon_activation));
// scale * rsqrt[Var[X] + epsilon] * 1/N
- auto scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary(
- activation_shape, HloOpcode::kMultiply, scale_broadcasted,
- rsqrt_var_add_epsilon_broadcasted));
+ auto scale_times_rsqrt_var_add_epsilon =
+ add_binary(activation_shape, HloOpcode::kMultiply, scale_broadcasted,
+ rsqrt_var_add_epsilon_broadcasted);
- scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary(
- activation_shape, HloOpcode::kDivide, scale_times_rsqrt_var_add_epsilon,
- elements_per_feature));
+ scale_times_rsqrt_var_add_epsilon =
+ add(Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon));
- auto i1 =
- add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply,
- grad_output, elements_per_feature));
+ auto elements_per_feature_literal =
+ Literal::CreateR0<float>(elements_per_feature_int64);
+ TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
+ elements_per_feature_literal->Convert(ptype));
+ auto elements_per_feature = add(
+ HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
+ auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output,
+ add(HloInstruction::CreateBroadcast(
+ activation_shape, elements_per_feature, {})));
// I6 = I1 - I2 - I5
- auto i6 = add(HloInstruction::CreateBinary(
+ auto i6 = add_binary(
activation_shape, HloOpcode::kSubtract,
- add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kSubtract,
- i1, i2)),
- i5));
+ add_binary(activation_shape, HloOpcode::kSubtract, i1, i2), i5);
// Grad[X] = scale * rsqrt[Var[X] + epsilon] * 1/N * I6.
- auto grad_activation =
- add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply,
- scale_times_rsqrt_var_add_epsilon, i6));
+ auto grad_activation = add_binary(activation_shape, HloOpcode::kMultiply,
+ scale_times_rsqrt_var_add_epsilon, i6);
auto tuple =
HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta});
if (batch_norm->has_sharding()) {
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index a4fb0eefac..bdcea92882 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
-#include "tensorflow/compiler/xla/service/computation_tracker.h"
#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
@@ -82,7 +81,7 @@ const std::vector<const HloInstruction*> GetInstructions(HloInstruction* root) {
class BufferAssignmentTest : public HloTestBase {
protected:
- BufferAssignmentTest() : computation_tracker_() {}
+ BufferAssignmentTest() {}
~BufferAssignmentTest() override {}
std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
@@ -252,9 +251,6 @@ class BufferAssignmentTest : public HloTestBase {
return total_size;
}
- // Computation tracker for nested computations.
- ComputationTracker computation_tracker_;
-
// Shapes for use in the examples.
Shape s32_ = ShapeUtil::MakeShape(xla::S32, {});
Shape r0f32_ = ShapeUtil::MakeShape(xla::F32, {});
diff --git a/tensorflow/compiler/xla/service/channel_tracker.h b/tensorflow/compiler/xla/service/channel_tracker.h
index c7763f2ca3..e415fb27e6 100644
--- a/tensorflow/compiler/xla/service/channel_tracker.h
+++ b/tensorflow/compiler/xla/service/channel_tracker.h
@@ -20,7 +20,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/session.pb.h"
-#include "tensorflow/compiler/xla/service/user_computation.h"
#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc
index d39fd7307a..d8fdccf9bb 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.cc
+++ b/tensorflow/compiler/xla/service/compile_only_service.cc
@@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
-#include "tensorflow/compiler/xla/service/computation_tracker.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -104,56 +103,4 @@ CompileOnlyService::CompileAheadOfTime(
return compiler_->CompileAheadOfTime(std::move(hlo_modules), options);
}
-StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
-CompileOnlyService::CompileAheadOfTime(
- const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
- const AotCompilationOptions& options) {
- std::vector<std::unique_ptr<HloModule>> hlo_modules;
- for (const AotComputationInstance& instance : computations) {
- TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
- computation_tracker_.Resolve(instance.computation));
- VersionedComputationHandle versioned_handle =
- user_computation->GetVersionedHandle();
-
- const DebugOptions& debug_options = options.debug_options();
-
- // Dump computation proto state if flag is set.
- const string& directory_path = debug_options.xla_dump_computations_to();
- if (!directory_path.empty()) {
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<SessionModule> session_module,
- computation_tracker_.SnapshotComputation(versioned_handle.handle));
- string filename = tensorflow::strings::StrCat(
- "computation_", versioned_handle.handle.handle(), "__",
- session_module->entry().name(), "__version_",
- versioned_handle.version);
- const string& per_host_path = tensorflow::io::JoinPath(
- directory_path, tensorflow::port::Hostname());
-
- TF_RETURN_IF_ERROR(Executable::DumpToDirectory(per_host_path, filename,
- *session_module));
- }
-
- TF_ASSIGN_OR_RETURN(
- std::shared_ptr<const ProgramShape> program_shape,
- user_computation->ComputeProgramShape(versioned_handle.version));
-
- ExecutionOptions execution_options;
- *execution_options.mutable_debug_options() = debug_options;
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<HloModuleConfig> module_config,
- CreateModuleConfig(*program_shape, instance.argument_layouts,
- &execution_options, user_computation));
-
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> hlo_module,
- computation_tracker_.BuildHloModule(
- versioned_handle, *module_config,
- /*include_unreachable_instructions=*/true));
- TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module));
- hlo_modules.push_back(std::move(hlo_module));
- }
-
- return compiler_->CompileAheadOfTime(std::move(hlo_modules), options);
-}
-
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h
index 7f2ce0e897..e6a66c202d 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.h
+++ b/tensorflow/compiler/xla/service/compile_only_service.h
@@ -38,24 +38,7 @@ class CompileOnlyService : public Service {
static StatusOr<std::unique_ptr<CompileOnlyService>> NewService(
const ServiceOptions& options);
- // A description of a computation to compile using CompileAheadOfTime.
- struct AotComputationInstance {
- ComputationHandle computation;
- std::vector<const Shape*> argument_layouts;
- const Shape* result_layout = nullptr;
- };
-
- // Compiles a list of computations for ahead-of-time execution. This is
- // intended for use in static compilation. See
- // |CompileOnlyClient::CompileAheadOfTime| for additional details.
- StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
- CompileAheadOfTime(
- const tensorflow::gtl::ArraySlice<AotComputationInstance> computations,
- const AotCompilationOptions& Options);
-
// A description of a xla computation to compile using CompileAheadOfTime.
- //
- // TODO(b/74197823): This is a part of a NOT YET ready refactor.
struct AotXlaComputationInstance {
HloModuleProto computation;
std::vector<const Shape*> argument_layouts;
@@ -65,31 +48,15 @@ class CompileOnlyService : public Service {
// Compiles a list of xla computations for ahead-of-time execution. This is
// intended for use in static compilation. See
// |CompileOnlyClient::CompileAheadOfTime| for additional details.
- //
- // TODO(b/74197823): This is a part of a NOT YET ready refactor.
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
const AotCompilationOptions& options);
- // Override Service methods that require or imply the existence of an
- // execute backend. Note that this does not include TransferToClient, as
- // computing constants produces global data that we may wish to transfer.
- Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override {
- return Unimplemented("CompileOnlyService does not support execution.");
- }
- Status ExecuteParallel(const ExecuteParallelRequest* arg,
- ExecuteParallelResponse* result) override {
- return Unimplemented("CompileOnlyService does not support execution.");
- }
Status GetDeviceHandles(const GetDeviceHandlesRequest* arg,
GetDeviceHandlesResponse* result) override {
return Unimplemented("CompileOnlyService does not support devices.");
}
- Status ExecuteAsync(const ExecuteAsyncRequest* arg,
- ExecuteAsyncResponse* result) override {
- return Unimplemented("CompileOnlyService does not support execution.");
- }
Status WaitForExecution(const WaitForExecutionRequest* arg,
WaitForExecutionResponse* result) override {
return Unimplemented("CompileOnlyService does not support execution.");
diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc
index 31f84e88f8..6f06bba679 100644
--- a/tensorflow/compiler/xla/service/compiler.cc
+++ b/tensorflow/compiler/xla/service/compiler.cc
@@ -28,8 +28,9 @@ namespace xla {
/* static */ tensorflow::mutex Compiler::platform_compiler_mutex_(
tensorflow::LINKER_INITIALIZED);
-std::vector<string> Compiler::ComputeBackendConfigs(
- const HloInstruction& hlo, se::StreamExecutor* executor) const {
+std::vector<std::unique_ptr<tensorflow::protobuf::Message>>
+Compiler::ComputeBackendConfigs(const HloInstruction& hlo,
+ se::StreamExecutor* executor) const {
CHECK(executor != nullptr);
return {};
}
diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h
index c39db58b78..6c52ffd800 100644
--- a/tensorflow/compiler/xla/service/compiler.h
+++ b/tensorflow/compiler/xla/service/compiler.h
@@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/thread_annotations.h"
@@ -161,8 +162,9 @@ class Compiler {
//
// The stream executor is passed in to provide information about the hardware
// that the backend configurations would be targeting.
- virtual std::vector<string> ComputeBackendConfigs(
- const HloInstruction& hlo, se::StreamExecutor* executor) const;
+ virtual std::vector<std::unique_ptr<tensorflow::protobuf::Message>>
+ ComputeBackendConfigs(const HloInstruction& hlo,
+ se::StreamExecutor* executor) const;
// Compiles the HLO module for ahead-of-time execution. This is intended for
// use in static compilation.
diff --git a/tensorflow/compiler/xla/service/computation_tracker.cc b/tensorflow/compiler/xla/service/computation_tracker.cc
deleted file mode 100644
index 70e25eebdb..0000000000
--- a/tensorflow/compiler/xla/service/computation_tracker.cc
+++ /dev/null
@@ -1,256 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/computation_tracker.h"
-
-#include <list>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "tensorflow/compiler/xla/ptr_util.h"
-#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
-#include "tensorflow/core/platform/logging.h"
-
-using ::tensorflow::strings::Appendf;
-
-namespace xla {
-
-ComputationTracker::ComputationTracker() : next_computation_(1) {}
-
-ComputationHandle ComputationTracker::NewComputation(
- const string& computation_name) {
- tensorflow::mutex_lock lock(computation_mutex_);
- ComputationHandle computation_handle;
- int64 handle_value = next_computation_++;
- computation_handle.set_handle(handle_value);
- opaque_to_computation_[handle_value] =
- MakeUnique<UserComputation>(computation_name, computation_handle);
- return computation_handle;
-}
-
-StatusOr<ComputationHandle> ComputationTracker::LoadSessionModule(
- const SessionModule& session_module) {
- tensorflow::mutex_lock lock(computation_mutex_);
-
- // For each embedded computation, create a new computation based on its
- // serialized data, and place the mapping from the old computation handle to
- // the new computation handle.
-
- // Build a mapping from old embedded computation handles to new computation
- // handles. We build the ID mapping first since the embedded computations are
- // in no particular order and may refer to each other.
- std::map<int64, ComputationHandle> old_to_new;
- for (const SessionComputation& computation :
- session_module.embedded_computations()) {
- const int64 old_handle = computation.computation_handle().handle();
- if (!old_to_new.emplace(old_handle, AllocateHandle()).second) {
- return InvalidArgument("Duplicate embedded computation handle %lld",
- old_handle);
- }
- }
-
- // Create a new computation from each serialized embedded computation.
- for (const SessionComputation& computation :
- session_module.embedded_computations()) {
- const int64 old_handle = computation.computation_handle().handle();
- const ComputationHandle& new_handle = old_to_new[old_handle];
- TF_ASSIGN_OR_RETURN(opaque_to_computation_[new_handle.handle()],
- UserComputation::MakeWithRemapping(
- computation, new_handle, old_to_new));
- }
-
- // Finally, place the entry computation in the tracker with all of the
- // remappings populated from the above.
- const int64 old_handle = session_module.entry().computation_handle().handle();
- TF_ASSIGN_OR_RETURN(
- old_to_new[old_handle],
- LoadSessionComputation(session_module.entry(), &old_to_new));
- return old_to_new[old_handle];
-}
-
-StatusOr<std::unique_ptr<SessionModule>>
-ComputationTracker::SnapshotComputation(const ComputationHandle& computation) {
- TF_ASSIGN_OR_RETURN(UserComputation * user_computation, Resolve(computation));
- const VersionedComputationHandle entry_versioned_handle =
- user_computation->GetVersionedHandle();
- std::set<VersionedComputationHandle> visited;
- std::list<VersionedComputationHandle> post_order;
- {
- tensorflow::mutex_lock lock(computation_mutex_);
- ComputeComputationPostOrder(entry_versioned_handle, &visited, &post_order);
- }
- auto session_module = MakeUnique<SessionModule>();
- *session_module->mutable_entry() =
- Resolve(entry_versioned_handle.handle)
- .ValueOrDie()
- ->CloneSessionComputation(entry_versioned_handle.version);
- for (auto it = ++post_order.rbegin(); it != post_order.rend(); ++it) {
- *session_module->add_embedded_computations() =
- Resolve(it->handle).ValueOrDie()->CloneSessionComputation(it->version);
- }
- return std::move(session_module);
-}
-
-StatusOr<UserComputation*> ComputationTracker::Resolve(
- const ComputationHandle& computation) const {
- tensorflow::mutex_lock lock(computation_mutex_);
- return ResolveInternal(computation);
-}
-
-ComputationHandle ComputationTracker::AllocateHandle() {
- int64 handle_value = next_computation_++;
- ComputationHandle result;
- result.set_handle(handle_value);
- return result;
-}
-
-StatusOr<ComputationHandle> ComputationTracker::LoadSessionComputation(
- const SessionComputation& session_computation,
- std::map<int64, ComputationHandle>* old_to_new) {
- TF_RET_CHECK(old_to_new != nullptr);
- const ComputationHandle new_handle = AllocateHandle();
- (*old_to_new)[session_computation.computation_handle().handle()] = new_handle;
- TF_ASSIGN_OR_RETURN(opaque_to_computation_[new_handle.handle()],
- UserComputation::MakeWithRemapping(
- session_computation, new_handle, *old_to_new));
- return new_handle;
-}
-
-StatusOr<UserComputation*> ComputationTracker::ResolveInternal(
- const ComputationHandle& computation) const {
- auto it = opaque_to_computation_.find(computation.handle());
- if (it == opaque_to_computation_.end()) {
- return NotFound("computation handle not found: %lld", computation.handle());
- }
- UserComputation* user_computation = it->second.get();
- return user_computation;
-}
-
-void ComputationTracker::ComputeComputationPostOrder(
- const VersionedComputationHandle& versioned_handle,
- std::set<VersionedComputationHandle>* visited,
- std::list<VersionedComputationHandle>* post_order) const {
- if (visited->count(versioned_handle) > 0) {
- CHECK_EQ(1, visited->count(versioned_handle));
- return;
- }
-
- UserComputation* computation =
- ResolveInternal(versioned_handle.handle).ValueOrDie();
- std::vector<VersionedComputationHandle> embedded_handles =
- computation->GetEmbeddedComputations(versioned_handle.version);
-
- for (const auto& embedded_handle : embedded_handles) {
- ComputeComputationPostOrder(embedded_handle, visited, post_order);
- }
-
- visited->insert(versioned_handle);
- post_order->push_back(versioned_handle);
-}
-
-StatusOr<std::unique_ptr<HloModule>> ComputationTracker::BuildHloModule(
- const VersionedComputationHandle& entry_handle,
- const HloModuleConfig& config,
- bool include_unreachable_instructions) const {
- tensorflow::mutex_lock lock(computation_mutex_);
-
- VLOG(1) << "BuildHloModule(" << entry_handle
- << ", include_unreachable_instructions="
- << include_unreachable_instructions << ")";
- XLA_VLOG_LINES(1, ToStringInternal());
-
- TF_ASSIGN_OR_RETURN(UserComputation * entry_computation,
- ResolveInternal(entry_handle.handle));
-
- // Build a topological sort of the entry and any embedded computations as a
- // list. The root of the computation will be the last element in the list.
- std::set<VersionedComputationHandle> visited;
- std::list<VersionedComputationHandle> post_order;
- ComputeComputationPostOrder(entry_handle, &visited, &post_order);
-
- // Map from ComputationHandle value and computation version to HloComputation.
- std::map<VersionedComputationHandle, HloComputation*> hlo_computations;
-
- // The resolver lambda resolves VersionedHandles to embedded
- // HloComputation*. This is required by UserComputation::BuildHloComputation
- // when lowering calling operations (map, reduce etc).
- auto resolver = [&hlo_computations](
- const VersionedComputationHandle& versioned_handle) -> HloComputation* {
- CHECK_GT(hlo_computations.count(versioned_handle), 0);
- return hlo_computations.at(versioned_handle);
- };
-
- // Print the post-order list for this entry computation.
- if (VLOG_IS_ON(2)) {
- VLOG(2) << "Visiting UserComputations in post order:";
- for (const VersionedComputationHandle& versioned_handle : post_order) {
- VLOG(2) << " " << versioned_handle;
- }
- }
-
- string module_name =
- tensorflow::strings::StrCat(entry_computation->name(), "_module");
- auto module = MakeUnique<HloModule>(module_name, entry_handle, config);
- for (auto versioned_handle : post_order) {
- UserComputation* computation =
- ResolveInternal(versioned_handle.handle).ValueOrDie();
-
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<HloComputation> hlo_computation,
- computation->BuildHloComputation(versioned_handle.version, resolver,
- config.debug_options(),
- include_unreachable_instructions));
-
- // Add the newly created computation to VersionedHandle-to-HloComputation
- // map.
- DCHECK_EQ(0, hlo_computations.count(versioned_handle));
- hlo_computations[versioned_handle] = hlo_computation.get();
-
- if (computation == entry_computation) {
- module->AddEntryComputation(std::move(hlo_computation));
- } else {
- module->AddEmbeddedComputation(std::move(hlo_computation));
- }
- }
-
- return std::move(module);
-}
-
-string ComputationTracker::ToString() const {
- tensorflow::mutex_lock lock(computation_mutex_);
- return ToStringInternal();
-}
-
-string ComputationTracker::ToStringInternal() const {
- string out;
- Appendf(&out, "ComputationTracker(%p):\n", this);
- for (const auto& handle_computation : opaque_to_computation_) {
- int64 handle = handle_computation.first;
- const std::unique_ptr<UserComputation>& computation =
- handle_computation.second;
- Appendf(&out, " %4lld : %s \"%s\"\n", handle,
- computation->GetVersionedHandle().ToString().c_str(),
- computation->name().c_str());
- }
- return out;
-}
-
-} // namespace xla
diff --git a/tensorflow/compiler/xla/service/computation_tracker.h b/tensorflow/compiler/xla/service/computation_tracker.h
deleted file mode 100644
index d42d66adef..0000000000
--- a/tensorflow/compiler/xla/service/computation_tracker.h
+++ /dev/null
@@ -1,147 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_
-
-#include <list>
-#include <map>
-#include <memory>
-#include <set>
-#include <string>
-
-#include "tensorflow/compiler/xla/service/hlo_module.h"
-#include "tensorflow/compiler/xla/service/hlo_module_config.h"
-#include "tensorflow/compiler/xla/service/session.pb.h"
-#include "tensorflow/compiler/xla/service/user_computation.h"
-#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
-#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/thread_annotations.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace xla {
-
-// Tracks computations for the XLA service; computations can be registered
-// with a UserComputation instance and can be resolved from a handle for later
-// use.
-//
-// This class is also capable of serializing/deserializing computations that it
-// tracks (and to serialize properly you need to serialize all referred-to
-// computations as well).
-class ComputationTracker {
- public:
- ComputationTracker();
-
- // Creates a new UserComputation object and returns the corresponding
- // ComputationHandle for it.
- //
- // Precondition: user_computation is not already present in the map.
- ComputationHandle NewComputation(const string& computation_name);
-
- // Restores session data for a computation that has been serialized, and
- // allocates a new computation handle for it.
- StatusOr<ComputationHandle> LoadSessionModule(
- const SessionModule& session_module);
-
- // Snapshots a computation (referenced by the provided handle) at its latest
- // version, returning a module where it is the entry, and any referred-to
- // computations are entrained as "embedded" (non-entry) computations.
- StatusOr<std::unique_ptr<SessionModule>> SnapshotComputation(
- const ComputationHandle& computation);
-
- // Resolves a ComputationHandle to a UserComputation that is present in the
- // map.
- StatusOr<UserComputation*> Resolve(
- const ComputationHandle& computation) const;
-
- // Builds an HLO module using the specified computation as the entry. The
- // module will include the entry computation as well as all computations which
- // are called directly or indirectly from the entry computation via operations
- // like "map". config is the HLO module configuration to use for the
- // constructed module.
- // If include_unreachable_instructions is true, then instructions
- // which are not reachable from the root are lowered into HloInstructions
- // including unreachable parameters. This ensures the entry HloComputation has
- // the same program shape (ProgramShape) as the entry UserComputation.
- StatusOr<std::unique_ptr<HloModule>> BuildHloModule(
- const VersionedComputationHandle& entry_handle,
- const HloModuleConfig& config,
- bool include_unreachable_instructions = true) const;
-
- string ToString() const;
-
- private:
- // Bumps the next_computation_ number and returns the allocated number wrapped
- // in a ComputationHandle.
- ComputationHandle AllocateHandle()
- EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
-
- // Loads a session computation into a UserComputation, registers it, and
- // returns the computation handle of the registered computation. If old_to_new
- // is provided, it is used for remapping references to computations present in
- // session_computation.
- //
- // old_to_new will be updated with the mapping from session_computation's old
- // handle to the returned handle value, and may not be null.
- StatusOr<ComputationHandle> LoadSessionComputation(
- const SessionComputation& session_computation,
- std::map<int64, ComputationHandle>* old_to_new)
- EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
-
- // Internal implementation of Resolve method which requires, but does not
- // acquire the mutex.
- StatusOr<UserComputation*> ResolveInternal(
- const ComputationHandle& computation) const
- EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
-
- // Builds a post order sort of a computation ("entry") and all of its embedded
- // computations including all transitively embedded computations. An embedded
- // computation (the callee) will always appear in the sort before the
- // computation which calls the embedded computation (the caller). Necessarily,
- // the entry computation is the last element in the sort. visited and
- // post_order should be empty when calling. post_order contains the post order
- // sort when the function return.
- void ComputeComputationPostOrder(
- const VersionedComputationHandle& versioned_handle,
- std::set<VersionedComputationHandle>* visited,
- std::list<VersionedComputationHandle>* post_order) const
- EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
-
- string ToStringInternal() const EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
-
- // Guards the computation mapping. Marked mutable so that the Resolve method
- // can remain const; Resolve does't really modify the tracker in any way, but
- // it has to lock the mutex for safety.
- mutable tensorflow::mutex computation_mutex_;
-
- // The next sequence number to assign to a computation, guarded by the same
- // mutex as the mapping as they'll be mutated at the same time.
- int64 next_computation_ GUARDED_BY(computation_mutex_);
-
- // Mapping from ComputationHandle value to the corresponding registered
- // UserComputation object.
- std::map<int64, std::unique_ptr<UserComputation>> opaque_to_computation_
- GUARDED_BY(computation_mutex_);
-
- TF_DISALLOW_COPY_AND_ASSIGN(ComputationTracker);
-};
-
-} // namespace xla
-
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 2794930248..68297ad4ae 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -1,6 +1,8 @@
# Description:
# GPU-specific components in XLA service implementation.
+load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
+
licenses(["notice"]) # Apache 2.0
package(default_visibility = [":friends"])
@@ -23,6 +25,11 @@ filegroup(
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+xla_proto_library(
+ name = "backend_configs",
+ srcs = ["backend_configs.proto"],
+)
+
cc_library(
name = "gpu_constants",
srcs = ["gpu_constants.cc"],
@@ -133,6 +140,7 @@ cc_library(
"ir_emitter_unnested.h",
],
deps = [
+ ":backend_configs",
":cudnn_convolution_runner",
":elemental_ir_emitter",
":gpu_constants",
@@ -266,6 +274,7 @@ cc_library(
"while_thunk.h",
],
deps = [
+ ":backend_configs",
":buffer_allocations",
":cudnn_convolution_runner",
":infeed_manager",
@@ -322,6 +331,7 @@ cc_library(
srcs = ["cudnn_convolution_algorithm_picker.cc"],
hdrs = ["cudnn_convolution_algorithm_picker.h"],
deps = [
+ ":backend_configs",
":cudnn_convolution_runner",
":gpu_executable",
":ir_emission_utils",
diff --git a/tensorflow/compiler/xla/service/gpu/backend_configs.proto b/tensorflow/compiler/xla/service/gpu/backend_configs.proto
new file mode 100644
index 0000000000..640c6392b8
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/backend_configs.proto
@@ -0,0 +1,27 @@
+syntax = "proto3";
+
+package xla.gpu;
+
+// Backend configs for XLA:GPU.
+//
+// These are metadata that the GPU backend attaches to HloInstrucitons and later
+// uses during e.g. codegen.
+//
+// Remember that proto3 doesn't give clients a way to tell the difference
+// between a field not being present and a field having the default value.
+// Choose your defaults carefully.
+//
+// No guarantee is made about the stability of these protos.
+//
+// See HloInstruction::backend_config() for more info.
+
+// Backend config for a convolution that runs through cudnn.
+message CudnnConvBackendConfig {
+ // Opaque algorithm number of cudnn algorithm chosen for this conv.
+ int64 algorithm = 1;
+
+ // Whether we may use tensor cores when running this conv. Even if this is
+ // true, cudnn may choose not to use tensor cores, e.g. because the GPU or
+ // selected algorithm doesn't support it.
+ bool tensor_ops_enabled = 2;
+}
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index 6a46bdb9b4..3dc98c4c93 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h"
+#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/core/lib/gtl/optional.h"
@@ -316,21 +317,20 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
Shape new_call_shape =
ShapeUtil::MakeTupleShape({instr->shape().tuple_shapes(0),
ShapeUtil::MakeShape(U8, {scratch_bytes})});
- HloInstruction* algorithm_hlo = computation->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int64>(algorithm)));
- HloInstruction* tensor_ops_enabled_hlo =
- computation->AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR0<bool>(tensor_ops_enabled)));
+
+ CudnnConvBackendConfig backend_config;
+ backend_config.set_algorithm(algorithm);
+ backend_config.set_tensor_ops_enabled(tensor_ops_enabled);
HloInstruction* new_call =
computation->AddInstruction(HloInstruction::CreateCustomCall(
new_call_shape,
- {instr->mutable_operand(0), instr->mutable_operand(1), algorithm_hlo,
- tensor_ops_enabled_hlo},
+ {instr->mutable_operand(0), instr->mutable_operand(1)},
instr->custom_call_target()));
new_call->set_window(instr->window());
new_call->set_convolution_dimension_numbers(
instr->convolution_dimension_numbers());
+ TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config));
// Repackage new_call so it has the same shape as the original call, namely
// (conv_result, u8[0]).
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
index d9560779f3..c5ccdd4a7d 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
@@ -78,12 +78,6 @@ StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
for (int64 i = 0; i < hlo->operand_count() - 2; ++i) {
TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
}
- } else if (IsCustomCallToDnnConvolution(*hlo)) {
- // The last two arguments to a CUDNN convolution are two HLO constants for
- // cudnn algorithm and tensor_ops_enabled flag, which shouldn't be copied.
- for (int64 i = 0; i < hlo->operand_count() - 2; ++i) {
- TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
- }
} else if (ImplementedAsLibraryCall(*hlo) ||
hlo->opcode() == HloOpcode::kCrossReplicaSum) {
// For all other library calls and cross-replica-sum, materialize all the
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 22e7150995..67890bfed1 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -162,19 +162,8 @@ static HloInstruction* CreateCudnnConv(
Shape call_shape =
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})});
- // Our CustomCall takes four arguments: The conv lhs and rhs, the cudnn
- // algorithm to use, and a boolean indicating whether to use tensor cores.
- //
- // It's up to a later pass to choose the algorithm and decide whether to use
- // tensor cores, so to indicate that we haven't yet made a choice, we speicfy
- // -1 and false for those args.
- HloInstruction* negative_one = computation->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int64>(-1)));
- HloInstruction* false_constant = computation->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
- HloInstruction* custom_call =
- computation->AddInstruction(HloInstruction::CreateCustomCall(
- call_shape, {lhs, rhs, negative_one, false_constant}, call_target));
+ HloInstruction* custom_call = computation->AddInstruction(
+ HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target));
custom_call->set_window(window);
custom_call->set_convolution_dimension_numbers(dnums);
return custom_call;
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index ae4e305b80..0f5c003341 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
+#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
@@ -423,15 +424,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
- const HloInstruction* algorithm_inst = custom_call->operand(2);
- CHECK(algorithm_inst->IsConstant()) << algorithm_inst->ToString();
- int64 algorithm = algorithm_inst->literal().Get<int64>({});
-
- const HloInstruction* tensor_ops_enabled_inst = custom_call->operand(3);
- CHECK(tensor_ops_enabled_inst->IsConstant())
- << tensor_ops_enabled_inst->ToString();
- bool tensor_ops_enabled = tensor_ops_enabled_inst->literal().Get<bool>({});
-
+ TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
+ custom_call->backend_config<CudnnConvBackendConfig>());
const auto& target = custom_call->custom_call_target();
std::unique_ptr<ConvolutionThunk> thunk;
if (target == kCudnnConvForwardCallTarget) {
@@ -446,7 +440,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
/*filter_shape=*/rhs_shape,
/*output_shape=*/conv_result_shape, //
custom_call->window(), custom_call->convolution_dimension_numbers(),
- algorithm, tensor_ops_enabled, custom_call);
+ backend_config.algorithm(), backend_config.tensor_ops_enabled(),
+ custom_call);
} else if (target == kCudnnConvBackwardInputCallTarget) {
thunk = MakeUnique<ConvolutionThunk>(
CudnnConvKind::kBackwardInput,
@@ -459,7 +454,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
/*filter_shape=*/rhs_shape,
/*output_shape=*/lhs_shape, //
custom_call->window(), custom_call->convolution_dimension_numbers(),
- algorithm, tensor_ops_enabled, custom_call);
+ backend_config.algorithm(), backend_config.tensor_ops_enabled(),
+ custom_call);
} else if (target == kCudnnConvBackwardFilterCallTarget) {
thunk = MakeUnique<ConvolutionThunk>(
CudnnConvKind::kBackwardFilter,
@@ -472,7 +468,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
/*filter_shape=*/conv_result_shape,
/*output_shape=*/rhs_shape, //
custom_call->window(), custom_call->convolution_dimension_numbers(),
- algorithm, tensor_ops_enabled, custom_call);
+ backend_config.algorithm(), backend_config.tensor_ops_enabled(),
+ custom_call);
} else {
LOG(FATAL) << "Unexpected custom call target: "
<< custom_call->custom_call_target();
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 672b1c017a..05adb45713 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -1085,11 +1085,11 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
string HloDotDumper::GetInstructionNodeBackendConfig(
const HloInstruction* instr) {
- if (!show_backend_config_ || instr->backend_config().empty()) {
+ if (!show_backend_config_ || instr->raw_backend_config_string().empty()) {
return "";
}
- return StrCat("backend_config=\"", instr->backend_config(), "\"");
+ return StrCat("backend_config=\"", instr->raw_backend_config_string(), "\"");
}
string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index dc351e9968..4095b3d337 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -41,6 +41,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/human_readable_json.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -110,7 +111,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->name_ = proto.name();
instruction->metadata_ = proto.metadata();
- instruction->set_backend_config(proto.backend_config());
+ instruction->backend_config_ = proto.backend_config();
if (proto.has_literal()) {
TF_ASSIGN_OR_RETURN(instruction->literal_,
Literal::CreateFromProto(proto.literal()));
@@ -1329,6 +1330,14 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
break;
case HloOpcode::kCustomCall:
clone = CreateCustomCall(shape, new_operands, custom_call_target_);
+ if (window_ != nullptr) {
+ clone->window_ = MakeUnique<Window>(*window_);
+ }
+ if (convolution_dimension_numbers_ != nullptr) {
+ clone->convolution_dimension_numbers_ =
+ MakeUnique<ConvolutionDimensionNumbers>(
+ *convolution_dimension_numbers_);
+ }
break;
case HloOpcode::kHostCompute:
clone = CreateHostCompute(shape, new_operands, channel_name_,
@@ -1521,7 +1530,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
}
SetupDerivedInstruction(clone.get());
clone->set_parent(parent_);
- clone->set_backend_config(backend_config());
+ clone->set_raw_backend_config_string(backend_config_);
if (context != nullptr) {
context->MapInstruction(this, clone.get());
clone->ReplaceCalledComputations([&](HloComputation* callee) {
@@ -1881,6 +1890,19 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kMap:
return eq_computations(to_apply(), other.to_apply());
case HloOpcode::kCustomCall:
+ if ((window_ == nullptr) != (other.window_ == nullptr) ||
+ (window_ != nullptr &&
+ !protobuf_util::ProtobufEquals(window(), other.window()))) {
+ return false;
+ }
+ if ((convolution_dimension_numbers_ == nullptr) !=
+ (other.convolution_dimension_numbers_ == nullptr) ||
+ (convolution_dimension_numbers_ != nullptr &&
+ !protobuf_util::ProtobufEquals(
+ convolution_dimension_numbers(),
+ other.convolution_dimension_numbers()))) {
+ return false;
+ }
return custom_call_target_ == other.custom_call_target_;
case HloOpcode::kReverse:
return dimensions() == other.dimensions();
@@ -2182,8 +2204,8 @@ string HloInstruction::ToStringWithCanonicalNameMap(
!metadata_.source_file().empty())) {
StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}");
}
- if (options.print_backend_config() && !backend_config().empty()) {
- StrAppend(&result, ", backend_config=\"", CEscape(backend_config()), "\"");
+ if (options.print_backend_config() && !backend_config_.empty()) {
+ StrAppend(&result, ", backend_config=\"", CEscape(backend_config_), "\"");
}
return result;
}
@@ -2299,7 +2321,9 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
}
if (convolution_dimension_numbers_ != nullptr) {
- extra.push_back(ConvolutionDimensionNumbersToString());
+ extra.push_back(StrCat(
+ "dim_labels=",
+ ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_)));
}
if (dot_dimension_numbers_ != nullptr) {
extra.push_back(DotDimensionNumbersToString());
@@ -2461,7 +2485,7 @@ HloInstructionProto HloInstruction::ToProto() const {
}
*proto.mutable_metadata() = metadata_;
- proto.set_backend_config(backend_config());
+ proto.set_backend_config(backend_config_);
if (literal_ != nullptr) {
*proto.mutable_literal() = literal_->ToProto();
}
@@ -3419,42 +3443,8 @@ string RandomDistributionToString(const RandomDistribution& distribution) {
return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution));
}
-StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
- static std::unordered_map<string, RandomDistribution>* map = [] {
- static auto* map = new std::unordered_map<string, RandomDistribution>;
- for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) {
- if (RandomDistribution_IsValid(i)) {
- auto value = static_cast<RandomDistribution>(i);
- (*map)[RandomDistributionToString(value)] = value;
- }
- }
- return map;
- }();
- auto found = map->find(tensorflow::str_util::Lowercase(name));
- if (found == map->end()) {
- return InvalidArgument("Unknown distribution");
- }
- return found->second;
-}
-
-std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
- return os << ToString(kind);
-}
-
-string HloInstruction::ConvolutionDimensionNumbersToString() const {
- string result;
- if (convolution_dimension_numbers_ == nullptr) {
- return result;
- }
- const ConvolutionDimensionNumbers& dnums = *convolution_dimension_numbers_;
- // Show the given dimension labels in order of major to minor based on the
- // shape's layout.
- const auto append_dims = [&](const std::vector<string>& dims,
- const Shape& shape) {
- CHECK_EQ(dims.size(), ShapeUtil::Rank(shape));
- StrAppend(&result, Join(dims, ""));
- };
-
+string ConvolutionDimensionNumbersToString(
+ const ConvolutionDimensionNumbers& dnums) {
// lhs_dims[i] is the symbol of the logical dimension i for the lhs
// operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b".
std::vector<string> lhs_dims(2 + dnums.input_spatial_dimensions().size());
@@ -3478,19 +3468,8 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const {
output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i);
}
- result += "dim_labels=";
- append_dims(lhs_dims, operand(0)->shape());
- result += "_";
- append_dims(rhs_dims, operand(1)->shape());
- result += "->";
-
- // A convolution can be represented as a kConvolution HLO or as a CustomCall
- // that returns a tuple, the first element of which is the result of the
- // convolution.
- Shape this_shape =
- ShapeUtil::IsTuple(shape()) ? shape().tuple_shapes(0) : shape();
- append_dims(output_dims, this_shape);
- return result;
+ return StrCat(Join(lhs_dims, ""), "_", Join(rhs_dims, ""), "->",
+ Join(output_dims, ""));
}
string HloInstruction::DotDimensionNumbersToString() const {
@@ -3516,6 +3495,28 @@ string HloInstruction::DotDimensionNumbersToString() const {
return Join(result, ", ");
}
+StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
+ static std::unordered_map<string, RandomDistribution>* map = [] {
+ static auto* map = new std::unordered_map<string, RandomDistribution>;
+ for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) {
+ if (RandomDistribution_IsValid(i)) {
+ auto value = static_cast<RandomDistribution>(i);
+ (*map)[RandomDistributionToString(value)] = value;
+ }
+ }
+ return map;
+ }();
+ auto found = map->find(tensorflow::str_util::Lowercase(name));
+ if (found == map->end()) {
+ return InvalidArgument("Unknown distribution");
+ }
+ return found->second;
+}
+
+std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
+ return os << ToString(kind);
+}
+
string HloInstruction::GatherDimensionNumbersToString() const {
CHECK_NE(gather_dimension_numbers_.get(), nullptr);
string output_window_dims =
@@ -3547,6 +3548,31 @@ bool HloInstruction::CouldBeBitcast() const {
}
}
+Status HloInstruction::GetBackendConfigInternal(
+ tensorflow::protobuf::Message* proto) const {
+ proto->Clear();
+
+ // Empty string does not parse as valid JSON, but it's a valid backend config,
+ // corresponding to the empty proto.
+ if (backend_config_.empty()) {
+ return Status::OK();
+ }
+ return tensorflow::HumanReadableJsonToProto(backend_config_, proto);
+}
+
+Status HloInstruction::set_backend_config(
+ const tensorflow::protobuf::Message& proto) {
+ TF_ASSIGN_OR_RETURN(backend_config_, BackendConfigToRawString(proto));
+ return Status::OK();
+}
+
+/* static */ StatusOr<string> HloInstruction::BackendConfigToRawString(
+ const tensorflow::protobuf::Message& proto) {
+ string ret;
+ TF_RETURN_IF_ERROR(tensorflow::ProtoToHumanReadableJson(proto, &ret));
+ return ret;
+}
+
HloModule* HloInstruction::GetModule() const {
if (parent_) {
return parent_->parent();
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 6df97c40ba..d47af6c018 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -52,6 +52,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/iterator_range.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -775,6 +776,10 @@ class HloInstruction {
}
}
+ if (backend_config_ != other.backend_config_) {
+ return false;
+ }
+
return IdenticalSlowPath(other, eq_computations);
}
@@ -1313,9 +1318,6 @@ class HloInstruction {
return fft_length_;
}
- // Returns the dump string of the convolution dimension numbers.
- string ConvolutionDimensionNumbersToString() const;
-
// Returns data on the dimension numbers used for a dot operation.
const DotDimensionNumbers& dot_dimension_numbers() const {
CHECK(dot_dimension_numbers_ != nullptr);
@@ -1449,13 +1451,34 @@ class HloInstruction {
// this field and they cannot interpret it due to its meaning being backend
// specific.
//
- // TODO(b/78194644): Introduce structured configuration format as per
- // go/xla-heuristics.
- const string& backend_config() const { return backend_config_; }
- void set_backend_config(string backend_config) {
- backend_config_ = std::move(backend_config);
+ // ConfigProto should be a protobuf Message type.
+ template <typename ConfigProto>
+ StatusOr<ConfigProto> backend_config() const {
+ ConfigProto proto;
+ TF_RETURN_IF_ERROR(GetBackendConfigInternal(&proto));
+ return std::move(proto);
+ }
+ Status set_backend_config(const tensorflow::protobuf::Message& proto);
+
+ // Getter/setter for raw JSON-encoded backend config. Prefer the
+ // functions above that deal in proto Messages where possible.
+ const string& raw_backend_config_string() const { return backend_config_; }
+ void set_raw_backend_config_string(string config_str) {
+ backend_config_ = std::move(config_str);
}
+ // Returns a string representation of a proto in the format used by
+ // raw_backend_config_string.
+ //
+ // This is morally equivalent to:
+ //
+ // HloInstruction instr;
+ // TF_RETURN_IF_ERROR(instr.set_backend_config(proto));
+ // return instr.raw_backend_config_string();
+ //
+ static StatusOr<string> BackendConfigToRawString(
+ const tensorflow::protobuf::Message& proto);
+
// Sets the debug metadata for this instruction.
void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
const OpMetadata& metadata() const { return metadata_; }
@@ -1576,6 +1599,10 @@ class HloInstruction {
// Returns how this instruction uses elements of its `i`th operand.
UseKind OperandElementUse(int64 i) const;
+ // Helper for implementing backend_config(). Parses backend_config_ into the
+ // given proto.
+ Status GetBackendConfigInternal(tensorflow::protobuf::Message* proto) const;
+
int unique_id_; // Unique to this HloInstruction within a HloModule
// Opcode for this instruction.
@@ -1749,6 +1776,9 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
string PaddingConfigToString(const PaddingConfig& padding);
string OpMetadataToString(const OpMetadata& metadata);
string RandomDistributionToString(const RandomDistribution& distribution);
+string ConvolutionDimensionNumbersToString(
+ const ConvolutionDimensionNumbers& dnums);
+
StatusOr<RandomDistribution> StringToRandomDistribution(const string& name);
std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index e91cf2076f..a1a8814384 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/window_util.h"
namespace xla {
namespace {
@@ -1542,5 +1543,70 @@ ENTRY entry (param: s32[]) -> s32[] {
}
}
+TEST_F(HloInstructionTest, IdenticalAccountsForBackendConfig) {
+ const Shape shape = ShapeUtil::MakeShape(F32, {42});
+ HloComputation::Builder builder("test");
+ HloInstruction* p =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p"));
+
+ HloInstruction* add1 = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p, p));
+ HloInstruction* add2 = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p, p));
+
+ EXPECT_TRUE(add1->Identical(*add2));
+ add1->set_raw_backend_config_string("abc");
+ EXPECT_FALSE(add1->Identical(*add2));
+}
+
+TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallWindow) {
+ auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
+ /*operands=*/{},
+ /*custom_call_target=*/"foo");
+ auto instr2 = instr1->Clone();
+ EXPECT_TRUE(instr1->Identical(*instr2));
+
+ Window w = window_util::MakeWindow({1, 2, 3});
+ instr1->set_window(w);
+ EXPECT_FALSE(instr1->Identical(*instr2));
+}
+
+TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallDnums) {
+ auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
+ /*operands=*/{},
+ /*custom_call_target=*/"foo");
+ auto instr2 = instr1->Clone();
+ EXPECT_TRUE(instr1->Identical(*instr2));
+
+ ConvolutionDimensionNumbers dnums;
+ dnums.set_output_batch_dimension(42);
+ instr1->set_convolution_dimension_numbers(dnums);
+ EXPECT_FALSE(instr1->Identical(*instr2));
+}
+
+TEST_F(HloInstructionTest, CloneWindowOnCustomCall) {
+ auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
+ /*operands=*/{},
+ /*custom_call_target=*/"foo");
+ Window w = window_util::MakeWindow({1, 2, 3});
+ instr->set_window(w);
+ auto clone = instr->Clone();
+ EXPECT_TRUE(protobuf_util::ProtobufEquals(clone->window(), w))
+ << clone->window().DebugString();
+}
+
+TEST_F(HloInstructionTest, CloneDnumsOnCustomCall) {
+ auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
+ /*operands=*/{},
+ /*custom_call_target=*/"foo");
+ ConvolutionDimensionNumbers dnums;
+ dnums.set_output_batch_dimension(42);
+ instr->set_convolution_dimension_numbers(dnums);
+ auto clone = instr->Clone();
+ EXPECT_TRUE(protobuf_util::ProtobufEquals(
+ clone->convolution_dimension_numbers(), dnums))
+ << clone->convolution_dimension_numbers().DebugString();
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
index 7d706b5fd0..f6fa45a6b7 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
@@ -247,6 +247,13 @@ tensorflow::gtl::optional<int64> HloModuleGroupMetadata::GetInstructionDevice(
return device;
}
+int64 HloModuleGroupMetadata::GetDeviceModulesCount() const {
+ return std::count_if(modules_.begin(), modules_.end(),
+ [](const HloModule* module) {
+ return !module->config().is_host_module();
+ });
+}
+
Status HloModuleGroupMetadata::RecordInstructions() {
const auto visitor = [this](HloInstruction* hlo) -> Status {
if (hlo->opcode() == HloOpcode::kWhile) {
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
index 5f5bf27479..f68d4028dc 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
@@ -155,6 +155,9 @@ class HloModuleGroupMetadata {
tensorflow::gtl::optional<int64> GetInstructionDevice(
const HloInstruction& instruction) const;
+ // Returns the number of modules for devices (excluding the host module).
+ int64 GetDeviceModulesCount() const;
+
// Returns the companion instructions for the given instruction.
//
// Precondition: IsCompanionWhile(instruction) is true.
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index 0fa4061738..375c4a6780 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -24,14 +24,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
-#include "tensorflow/compiler/xla/service/computation_tracker.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
-#include "tensorflow/compiler/xla/service/user_computation.h"
#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -124,75 +122,17 @@ ExecutionOptions CreateExecutionOptions(
LayoutUtil::SetToDefaultLayout(
execution_options.mutable_shape_with_output_layout());
}
- return execution_options;
-}
-
-} // namespace
-
-StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
- const ComputationHandle& computation,
- const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
- const ExecutableBuildOptions& build_options) {
- TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
- computation_tracker_.Resolve(computation));
- VersionedComputationHandle versioned_handle =
- user_computation->GetVersionedHandle();
-
- TF_ASSIGN_OR_RETURN(
- std::shared_ptr<const ProgramShape> program_shape,
- user_computation->ComputeProgramShape(versioned_handle.version));
- // Validate incoming layouts.
- if (argument_layouts.size() != program_shape->parameters_size()) {
- return InvalidArgument(
- "Invalid number of arguments for computation: expected %d, got %zu.",
- program_shape->parameters_size(), argument_layouts.size());
+ for (const std::string& disabled_pass : build_options.disabled_hlo_passes()) {
+ execution_options.mutable_debug_options()->add_xla_disable_hlo_passes(
+ disabled_pass);
}
- for (int i = 0; i < argument_layouts.size(); ++i) {
- const Shape& argument_shape = *argument_layouts[i];
- TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape));
- if (!ShapeUtil::Compatible(argument_shape, program_shape->parameters(i))) {
- tensorflow::gtl::optional<const OpMetadata*> metadata =
- user_computation->ParameterMetadata(i);
- auto metadata_string = [&metadata]() -> string {
- if (!metadata.has_value()) {
- return "";
- }
- CHECK(metadata.value() != nullptr);
- const OpMetadata& m = *metadata.value();
- if (!m.source_file().empty()) {
- return tensorflow::strings::Printf(
- " (%s:%d)", m.source_file().c_str(), m.source_line());
- }
- return "";
- };
- return InvalidArgument(
- "Invalid argument shape for argument %d%s, expected %s, got %s.", i,
- metadata_string().c_str(),
- ShapeUtil::HumanString(program_shape->parameters(i)).c_str(),
- ShapeUtil::HumanString(argument_shape).c_str());
- }
- }
- if (build_options.result_layout() != nullptr) {
- TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(
- *build_options.result_layout(), program_shape->result()));
- }
-
- ExecutionOptions execution_options =
- CreateExecutionOptions(build_options, program_shape.get());
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
- CreateModuleConfig(*program_shape, argument_layouts,
- &execution_options, user_computation));
- TF_ASSIGN_OR_RETURN(
- se::StreamExecutor * executor,
- execute_backend_->stream_executor(build_options.device_ordinal()));
-
- return BuildExecutable(versioned_handle, std::move(module_config),
- execute_backend_.get(), executor,
- build_options.device_allocator());
+ return execution_options;
}
+} // namespace
+
StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
const XlaComputation& computation,
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
@@ -260,4 +200,15 @@ StatusOr<int> LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) {
/*computation_count=*/1);
}
+StatusOr<const ShapedBuffer*> LocalService::GlobalDataToShapedBuffer(
+ const GlobalDataHandle& data, int replica_number) {
+ TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data));
+ if (replica_number >= buffers.size()) {
+ return InvalidArgument(
+ "replica_number %d out of range; must be less than num_replicas = %zu.",
+ replica_number, buffers.size());
+ }
+ return buffers[replica_number];
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h
index 06567cabd6..39d6734c3f 100644
--- a/tensorflow/compiler/xla/service/local_service.h
+++ b/tensorflow/compiler/xla/service/local_service.h
@@ -41,23 +41,11 @@ class LocalService : public Service {
static StatusOr<std::unique_ptr<LocalService>> NewService(
const ServiceOptions& options);
- // Builds an Executable with the given argument layouts and options. If
- // result_layout is non-null, then the executable is compiled to produce a
- // result of the given layout. If device_allocator is non-null, then the
- // compiler may use it to allocate temp space on the device. The compiler is
- // responsible for freeing any memory it allocates this way.
- StatusOr<std::unique_ptr<Executable>> CompileExecutable(
- const ComputationHandle& computation,
- const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
- const ExecutableBuildOptions& options);
-
// Builds an Executable with the given XlaComputation, argument layouts and
// options. If result_layout is non-null, then the executable is compiled to
// produce a result of the given layout. If device_allocator is non-null,
// then the compiler may use it to allocate temp space on the device. The
// compiler is responsible for freeing any memory it allocates this way.
- //
- // TODO(b/74197823): This is a part of a NOT YET ready refactor.
StatusOr<std::unique_ptr<Executable>> CompileExecutable(
const XlaComputation& computation,
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
@@ -70,6 +58,11 @@ class LocalService : public Service {
// the "easy" case where a single replica is a single device.
StatusOr<int> ReplicaNumberToDeviceOrdinal(int replica_number);
+ // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid
+ // as long as the handle is valid.
+ StatusOr<const ShapedBuffer*> GlobalDataToShapedBuffer(
+ const GlobalDataHandle& data, int replica_number);
+
private:
explicit LocalService(const ServiceOptions& options,
std::unique_ptr<Backend> backend);
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index cb0f76ebe4..82be6bcf4f 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -195,20 +195,6 @@ Service::Service(const ServiceOptions& options,
}
}
-Status Service::Computation(const ComputationRequest* arg,
- ComputationResponse* result) {
- if (arg->name().empty()) {
- return InvalidArgument("computation request needs a name");
- }
-
- *result->mutable_computation() =
- computation_tracker_.NewComputation(arg->name());
- VLOG(1) << Printf("Created new computation %s on service %p, name %s",
- result->computation().ShortDebugString().c_str(), this,
- arg->name().c_str());
- return Status::OK();
-}
-
Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg,
CreateChannelHandleResponse* result) {
*result->mutable_channel() = channel_tracker_.NewChannel();
@@ -288,8 +274,7 @@ Service::ResolveAndValidateArguments(
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
- const ExecutionOptions* execution_options,
- const UserComputation* user_computation) {
+ const ExecutionOptions* execution_options) {
auto config = MakeUnique<HloModuleConfig>(program_shape);
ComputationLayout* host_computation_layout =
config->mutable_host_entry_computation_layout();
@@ -305,17 +290,9 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
// ProgramShape.
if (!ShapeUtil::Compatible(*argument_shapes[i],
program_shape.parameters(i))) {
- if (user_computation == nullptr) {
- return InvalidArgument(
- "Argument does not match shape of computation parameter %d: want "
- "%s, got %s",
- i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(),
- ShapeUtil::HumanString(*argument_shapes[i]).c_str());
- }
- return InvalidParameterArgument(
- *user_computation->ParameterMetadata(i).value(),
- "Argument does not match shape of computation parameter %d: want %s, "
- "got %s",
+ return InvalidArgument(
+ "Argument does not match shape of computation parameter %d: want "
+ "%s, got %s",
i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(),
ShapeUtil::HumanString(*argument_shapes[i]).c_str());
}
@@ -366,76 +343,12 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- const ExecutionOptions& execution_options,
- const UserComputation* user_computation) {
+ const ExecutionOptions& execution_options) {
std::vector<const Shape*> argument_shapes;
for (const auto* arg : arguments) {
argument_shapes.push_back(&arg->on_host_shape());
}
- return CreateModuleConfig(program_shape, argument_shapes, &execution_options,
- user_computation);
-}
-
-StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
- std::vector<VersionedComputationHandle> versioned_handles,
- std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
- Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
- DeviceMemoryAllocator* device_allocator) {
- VLOG(1) << Printf("BuildExecutable on service %p", this);
-
- // Dump computation proto state if flag is set.
- std::vector<std::unique_ptr<SessionModule>> session_modules;
- for (int64 i = 0; i < versioned_handles.size(); ++i) {
- const string& directory_path =
- module_configs[i]->debug_options().xla_dump_computations_to();
- const string& other_directory_path =
- module_configs[i]->debug_options().xla_dump_executions_to();
- if (directory_path.empty() && other_directory_path.empty()) {
- continue;
- }
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<SessionModule> session_module,
- computation_tracker_.SnapshotComputation(versioned_handles[i].handle));
- if (!directory_path.empty()) {
- string filename = Printf("computation_%lld__%s__version_%lld",
- versioned_handles[i].handle.handle(),
- session_module->entry().name().c_str(),
- versioned_handles[i].version);
- TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename,
- *session_module));
- session_modules.push_back(std::move(session_module));
- }
- }
-
- VLOG(1) << "Computation handles:";
- for (const VersionedComputationHandle& versioned_handle : versioned_handles) {
- VLOG(1) << versioned_handle;
- }
-
- CHECK_EQ(versioned_handles.size(), module_configs.size());
- std::vector<std::unique_ptr<HloModule>> modules;
- for (int64 i = 0; i < versioned_handles.size(); ++i) {
- const VersionedComputationHandle& versioned_handle = versioned_handles[i];
- const HloModuleConfig& config = *module_configs[i];
- TF_ASSIGN_OR_RETURN(auto module,
- computation_tracker_.BuildHloModule(
- versioned_handle, config,
- /*include_unreachable_instructions=*/true));
- modules.push_back(std::move(module));
- }
-
- TF_ASSIGN_OR_RETURN(
- std::vector<std::unique_ptr<Executable>> executables,
- backend->compiler()->Compile(std::move(modules), std::move(executors),
- device_allocator));
-
- for (size_t i = 0; i < versioned_handles.size(); ++i) {
- if (!module_configs[i]->debug_options().xla_dump_executions_to().empty()) {
- executables[i]->set_session_module(std::move(session_modules[i]));
- }
- }
-
- return std::move(executables);
+ return CreateModuleConfig(program_shape, argument_shapes, &execution_options);
}
StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
@@ -512,98 +425,6 @@ Status Service::ValidateEntryComputationLayout(HloModule* module) {
return Status::OK();
}
-StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
- const VersionedComputationHandle& versioned_handle,
- std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
- se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) {
- VLOG(1) << Printf("BuildExecutable on service %p with handle %s", this,
- versioned_handle.ToString().c_str());
-
- // Dump computation proto state if flag is set.
- std::unique_ptr<SessionModule> session_module;
- const string& directory_path =
- module_config->debug_options().xla_dump_computations_to();
- const string& other_directory_path =
- module_config->debug_options().xla_dump_executions_to();
- if (!directory_path.empty() || !other_directory_path.empty()) {
- TF_ASSIGN_OR_RETURN(
- session_module,
- computation_tracker_.SnapshotComputation(versioned_handle.handle));
- if (!directory_path.empty()) {
- string filename = Printf("computation_%lld__%s__version_%lld",
- versioned_handle.handle.handle(),
- session_module->entry().name().c_str(),
- versioned_handle.version);
- TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename,
- *session_module));
- }
- }
-
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<HloModule> module,
- computation_tracker_.BuildHloModule(versioned_handle, *module_config,
- /*include_unreachable_instructions=*/
- true));
-
- TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module));
-
- TF_ASSIGN_OR_RETURN(
- module, backend->compiler()->RunHloPasses(std::move(module), executor,
- device_allocator));
- // Check that on-host and on-device shapes are consistent.
- TF_RETURN_IF_ERROR(ValidateEntryComputationLayout(module.get()));
-
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
- backend->compiler()->RunBackend(
- std::move(module), executor, device_allocator));
-
- if (!other_directory_path.empty()) {
- executable->set_session_module(std::move(session_module));
- }
-
- return std::move(executable);
-}
-
-StatusOr<std::shared_ptr<Executable>> Service::BuildAndCacheExecutable(
- const VersionedComputationHandle& versioned_handle,
- std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
- se::StreamExecutor* executor, ExecutionProfile* profile,
- DeviceMemoryAllocator* device_allocator) {
- std::shared_ptr<Executable> executable =
- compilation_cache_.LookUp(versioned_handle, *module_config);
-
- if (executable != nullptr) {
- // Executable found in the computation cache.
- if (profile != nullptr) {
- profile->set_compilation_cache_hit(true);
- }
- return executable;
- }
-
- uint64 start_micros =
- // Avoid reading the clock if we don't want timing info
- (profile != nullptr) ? tensorflow::Env::Default()->NowMicros() : 0;
-
- // Take a copy of the module config, as compilation introduces layouts where
- // layouts were optional before.
- HloModuleConfig original_module_config = *module_config;
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Executable> executable_unique_ptr,
- BuildExecutable(versioned_handle, std::move(module_config), backend,
- executor, device_allocator));
-
- if (profile != nullptr) {
- uint64 end_micros = tensorflow::Env::Default()->NowMicros();
- uint64 milliseconds = (end_micros - start_micros) / 1000;
- profile->set_compilation_cache_hit(false);
- profile->set_compile_time_ms(milliseconds);
- }
-
- // Insert executable into the cache.
- return compilation_cache_.Insert(std::move(executable_unique_ptr),
- original_module_config);
-}
-
StatusOr<std::vector<GlobalDataHandle>>
Service::ExecuteParallelAndRegisterResult(
tensorflow::gtl::ArraySlice<Executable*> executables,
@@ -624,9 +445,16 @@ Service::ExecuteParallelAndRegisterResult(
// profiled.
std::map<int64, se::Stream*> index_to_profiled_streams;
- TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
- backend->computation_placer()->AssignDevices(
- options_.number_of_replicas(), executables.size()));
+ // Build DeviceAssignment for all cores based on the provided device handles.
+ DeviceAssignment device_assignment(options_.number_of_replicas(),
+ executables.size());
+ for (int64 i = 0; i < executables.size(); i++) {
+ TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i]));
+ CHECK_EQ(replicas.size(), arguments[i].size());
+ for (int64 replica = 0; replica < replicas.size(); ++replica) {
+ device_assignment(replica, i) = replicas[replica]->device_ordinal();
+ }
+ }
for (int64 i = 0; i < executables.size(); i++) {
// Stream executors for the replicas of the current computation.
@@ -799,13 +627,6 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
result_tag);
}
-Status Service::SetReturnValue(const SetReturnValueRequest* arg,
- SetReturnValueResponse* results) {
- TF_ASSIGN_OR_RETURN(UserComputation * computation,
- computation_tracker_.Resolve(arg->computation()));
- return computation->SetReturnValue(arg->operand());
-}
-
StatusOr<std::vector<se::StreamExecutor*>> Service::GetExecutors(
const ExecutionOptions& execution_options, int64 requests_size,
int64 request_index) const {
@@ -847,117 +668,6 @@ StatusOr<std::vector<std::vector<const ShapedBuffer*>>> Service::GetArguments(
return replicated_arguments;
}
-Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
- ExecuteParallelResponse* result) {
- VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString();
-
- std::vector<std::vector<std::vector<const ShapedBuffer*>>> all_arguments;
- std::vector<std::vector<se::StreamExecutor*>> all_executors;
- std::vector<VersionedComputationHandle> versioned_handles;
- std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
- std::vector<string> computation_names;
- std::vector<DeviceHandle> device_handles;
-
- int num_requested_devices =
- std::accumulate(arg->requests().begin(), arg->requests().end(), 0,
- [](int a, const ExecuteRequest& r) -> int {
- return a + r.execution_options().device_handles_size();
- });
- if (num_requested_devices * options_.number_of_replicas() >
- execute_backend_->device_count()) {
- return FailedPrecondition(
- "there are not enough stream executors to execute %d computations",
- num_requested_devices);
- }
-
- for (int64 i = 0; i < arg->requests_size(); ++i) {
- // Get the stream executor for the i'th computation. This stream executor
- // is one of the executors to run the replicated computation.
- const ExecutionOptions& execution_options =
- arg->requests(i).execution_options();
-
- // Get the executors.
- TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options,
- arg->requests_size(), i));
-
- // Resolve the UserComputation object associated with the requested
- // computation and compute the program shape.
- const ExecuteRequest& request = arg->requests(i);
- TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
- computation_tracker_.Resolve(request.computation()));
- VersionedComputationHandle versioned_handle =
- user_computation->GetVersionedHandle();
- if (user_computation->request_count(versioned_handle.version) == 0) {
- return InvalidArgument("computations may not be empty");
- }
-
- TF_ASSIGN_OR_RETURN(
- std::shared_ptr<const ProgramShape> program_shape,
- user_computation->ComputeProgramShape(versioned_handle.version));
-
- // Get the replicated arguments.
- TF_ASSIGN_OR_RETURN(auto replicated_arguments,
- GetArguments(execution_options, request.arguments()));
-
- // Create an HloModuleConfig object for the computation, given the shape of
- // the program and the argument allocations. Here, we care only about the
- // shapes of the arguments, so, it is sufficient to use the arguments of
- // replica 0.
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<HloModuleConfig> module_config,
- CreateModuleConfig(*program_shape, replicated_arguments.front(),
- request.execution_options(), user_computation));
- VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: "
- << module_config->host_entry_computation_layout().ToString();
-
- // Adds to the vectors to build and execute the computations after the loop.
- all_arguments.push_back(replicated_arguments);
- all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}});
- versioned_handles.push_back(versioned_handle);
- module_configs.push_back(std::move(module_config));
- computation_names.insert(computation_names.end(), executors.size(),
- user_computation->name());
- all_executors.push_back(executors);
- device_handles.insert(device_handles.end(),
- execution_options.device_handles().begin(),
- execution_options.device_handles().end());
- }
-
- // Build the user computations into HloModules and compile to generate the
- // executables.
- //
- // TODO(jlebar): There's currently no way to pass a device allocator to
- // ExecuteParallel, so we have to pass a null device_allocator below.
- TF_ASSIGN_OR_RETURN(
- std::vector<std::unique_ptr<Executable>> executables,
- BuildExecutables(versioned_handles, std::move(module_configs),
- execute_backend_.get(), all_executors,
- /*device_allocator=*/nullptr));
- std::vector<Executable*> executable_ptrs;
- executable_ptrs.reserve(executables.size());
- for (const auto& executable : executables) {
- executable_ptrs.push_back(executable.get());
- }
-
- // Execute the generated executables in parallel and return the device
- // handles for each computation's output.
- ExecutionProfile profile;
- TF_ASSIGN_OR_RETURN(
- std::vector<GlobalDataHandle> outputs,
- ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments,
- execute_backend_.get(), device_handles,
- computation_names, &profile));
- for (const GlobalDataHandle& output : outputs) {
- ExecuteResponse response;
- *response.mutable_output() = output;
- *response.mutable_profile() = profile;
- *result->add_responses() = response;
- }
-
- VLOG(1) << "successfully completed 'execute-parallel' request";
- return Status::OK();
-}
-
Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
ExecuteParallelResponse* result) {
VLOG(1) << "running execute-graph-parallel request";
@@ -1007,8 +717,7 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(request.computation().program_shape(),
replicated_arguments.front(),
- request.execution_options(),
- /*user_computation=*/nullptr));
+ request.execution_options()));
VLOG(3)
<< "ExecuteGraphParallel created HloModuleConfig computation layout: "
<< module_config->host_entry_computation_layout().ToString();
@@ -1083,15 +792,6 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
return Status::OK();
}
-Status Service::ExecuteOneToN(const ExecuteRequest* arg,
- ExecuteResponse* result) {
- ExecuteParallelRequest parallel_arg;
- *parallel_arg.add_requests() = *arg;
- ExecuteParallelResponse parallel_result;
- TF_RETURN_IF_ERROR(ExecuteParallel(&parallel_arg, &parallel_result));
- return PickParallelResponse(parallel_result, result);
-}
-
Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg,
ExecuteResponse* result) {
ExecuteGraphParallelRequest parallel_arg;
@@ -1124,80 +824,6 @@ Status Service::PickParallelResponse(
return Status::OK();
}
-Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) {
- VLOG(1) << "running execute request: " << arg->ShortDebugString();
-
- TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
- computation_tracker_.Resolve(arg->computation()));
-
- VersionedComputationHandle versioned_handle =
- user_computation->GetVersionedHandle();
-
- if (user_computation->request_count(versioned_handle.version) == 0) {
- return InvalidArgument("computations may not be empty");
- }
-
- // If we received multiple device handles, we must partition the module.
- if (arg->execution_options().device_handles_size() > 1) {
- return ExecuteOneToN(arg, result);
- }
-
- TF_ASSIGN_OR_RETURN(
- std::shared_ptr<const ProgramShape> program_shape,
- user_computation->ComputeProgramShape(versioned_handle.version));
-
- TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_,
- SingleComputationDeviceHandle()));
- TF_ASSIGN_OR_RETURN(
- std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
- ResolveAndValidateArguments(arg->arguments(), replicas));
-
- // Since we care only about the shapes of the arguments, it is sufficient to
- // use the arguments of replica 0.
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<HloModuleConfig> module_config,
- CreateModuleConfig(*program_shape, replicated_arguments.front(),
- arg->execution_options(), user_computation));
-
- VLOG(3) << "Execute created HloModuleConfig computation layout: "
- << module_config->host_entry_computation_layout().ToString();
-
- TF_ASSIGN_OR_RETURN(
- std::shared_ptr<Executable> executable,
- BuildAndCacheExecutable(versioned_handle, std::move(module_config),
- execute_backend_.get(),
- execute_backend_->default_stream_executor(),
- result->mutable_profile()));
-
- if (executable->dumping()) {
- executable->session_module()->set_execution_platform(
- execute_backend_->platform()->Name());
- TF_RETURN_IF_ERROR(RecordArguments(
- replicated_arguments.front(),
- execute_backend_->default_stream_executor(),
- execute_backend_->transfer_manager(), executable->session_module()));
- }
-
- TF_ASSIGN_OR_RETURN(
- *result->mutable_output(),
- ExecuteAndRegisterResult(
- executable.get(), replicated_arguments, execute_backend_.get(),
- "result of " + user_computation->name(), result->mutable_profile()));
-
- if (executable->dumping()) {
- TF_ASSIGN_OR_RETURN(
- const ShapedBuffer* result_buffer,
- allocation_tracker_.ResolveForReplica(result->output(), 0));
- TF_RETURN_IF_ERROR(RecordResult(
- *result_buffer, execute_backend_->default_stream_executor(),
- execute_backend_->transfer_manager(), executable->session_module()));
- TF_RETURN_IF_ERROR(executable->DumpSessionModule());
- }
-
- VLOG(1) << "successfully completed 'execute' request";
- return Status::OK();
-}
-
StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
const HloModuleProto& module_proto,
std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
@@ -1303,86 +929,6 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg,
return Status::OK();
}
-Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
- ExecuteAsyncResponse* result) {
- VLOG(1) << "running execute-async request: " << arg->ShortDebugString();
-
- TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
- computation_tracker_.Resolve(arg->computation()));
-
- VersionedComputationHandle versioned_handle =
- user_computation->GetVersionedHandle();
- if (user_computation->request_count(versioned_handle.version) == 0) {
- return InvalidArgument("computations may not be empty");
- }
-
- TF_ASSIGN_OR_RETURN(
- std::shared_ptr<const ProgramShape> program_shape,
- user_computation->ComputeProgramShape(versioned_handle.version));
-
- TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_,
- SingleComputationDeviceHandle()));
- TF_RET_CHECK(!replicas.empty());
- TF_ASSIGN_OR_RETURN(
- std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
- ResolveAndValidateArguments(arg->arguments(), replicas));
-
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<HloModuleConfig> module_config,
- CreateModuleConfig(*program_shape, replicated_arguments.front(),
- arg->execution_options(), user_computation));
-
- VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: "
- << module_config->host_entry_computation_layout().ToString();
-
- ExecutionProfile profile;
-
- TF_ASSIGN_OR_RETURN(
- std::shared_ptr<Executable> executable,
- BuildAndCacheExecutable(
- versioned_handle, std::move(module_config), execute_backend_.get(),
- execute_backend_->default_stream_executor(), &profile));
-
- // Set up streams.
- std::vector<Pool<se::Stream>::SmartPtr> streams;
- for (se::StreamExecutor* executor : replicas) {
- TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
- execute_backend_->BorrowStream(executor));
- streams.push_back(std::move(stream));
- }
-
- std::vector<ScopedShapedBuffer> result_buffers;
- for (size_t i = 0; i < streams.size(); ++i) {
- const auto& stream = streams[i];
- ExecutableRunOptions options;
- options.set_stream(stream.get());
- options.set_allocator(execute_backend_->memory_allocator());
- options.set_intra_op_thread_pool(
- execute_backend_->eigen_intra_op_thread_pool_device());
-
- ServiceExecutableRunOptions service_options(
- options, execute_backend_->StreamBorrower());
-
- TF_ASSIGN_OR_RETURN(ScopedShapedBuffer this_result_buffer,
- executable->ExecuteAsyncOnStream(
- &service_options, replicated_arguments[i]));
-
- result_buffers.emplace_back(std::move(this_result_buffer));
- }
-
- TF_ASSIGN_OR_RETURN(
- GlobalDataHandle output,
- allocation_tracker_.RegisterReplicatedBuffers(
- std::move(result_buffers), "result of " + user_computation->name()));
-
- *result->mutable_execution() = execution_tracker_.Register(
- execute_backend_.get(), std::move(streams), profile, output);
- streams.clear();
-
- VLOG(1) << "successfully completed 'execute-async' request";
- return Status::OK();
-}
-
Status Service::WaitForExecution(const WaitForExecutionRequest* arg,
WaitForExecutionResponse* result) {
TF_ASSIGN_OR_RETURN(const auto execution,
@@ -1549,117 +1095,6 @@ Status Service::ResetDevice(const ResetDeviceRequest* arg,
return execute_backend_->ResetDevices();
}
-Status Service::IsConstant(const IsConstantRequest* arg,
- IsConstantResponse* result) {
- TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
- computation_tracker_.Resolve(arg->computation()));
-
- VersionedComputationHandle versioned_handle =
- user_computation->GetVersionedHandleAtOperation(arg->operand());
-
- if (user_computation->request_count(versioned_handle.version) == 0) {
- return InvalidArgument("computations may not be empty");
- }
-
- TF_ASSIGN_OR_RETURN(
- bool is_constant,
- user_computation->IsConstant(arg->operand(), arg->num_parameters()));
-
- result->set_is_constant(is_constant);
- return Status::OK();
-}
-
-Status Service::ComputeConstant(const ComputeConstantRequest* arg,
- ComputeConstantResponse* result) {
- TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
- computation_tracker_.Resolve(arg->computation()));
-
- VersionedComputationHandle versioned_handle =
- user_computation->GetVersionedHandleAtOperation(arg->operand());
-
- if (user_computation->request_count(versioned_handle.version) == 0) {
- return InvalidArgument("computations may not be empty");
- }
-
- TF_ASSIGN_OR_RETURN(
- bool is_constant,
- user_computation->IsConstant(arg->operand(), arg->parameters_size()));
- if (!is_constant) {
- StatusOr<const OperationRequest*> op_request_status =
- user_computation->LookUpRequestForErrorReporting(arg->operand());
- string op_request_string = "<unknown operation>";
- if (op_request_status.ok()) {
- op_request_string = op_request_status.ValueOrDie()->ShortDebugString();
- }
- return InvalidArgument(
- "Operand to ComputeConstant depends on a parameter.\n\n"
- " op requested for constant evaluation: %s\n\n"
- "This is an internal error that typically happens when the XLA user "
- "(e.g. TensorFlow) is attempting to determine a value that must be a "
- "compile-time constant (e.g. an array dimension) but it is not capable "
- "of being evaluated at XLA compile time.\n\n"
- "Please file a usability bug with the framework being used (e.g. "
- "TensorFlow).",
- op_request_string.c_str());
- }
-
- // We can't use ComputeProgramShape because it checks that all parameter
- // instructions are present and contiguous. Instead construct ProgramShape
- // directly.
- ProgramShape program_shape;
- TF_ASSIGN_OR_RETURN(*program_shape.mutable_result(),
- user_computation->GetShape(arg->operand()));
-
- TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result()));
-
- ExecutionOptions execution_options = xla::CreateDefaultExecutionOptions();
- execution_options.mutable_debug_options()->set_xla_enable_fast_math(false);
- execution_options.mutable_debug_options()
- ->set_xla_eliminate_hlo_implicit_broadcast(true);
- *execution_options.mutable_shape_with_output_layout() =
- program_shape.result();
-
- Shape shape_with_output_layout(program_shape.result());
- if (arg->has_output_layout()) {
- TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(
- arg->output_layout(), execution_options.shape_with_output_layout()));
- *execution_options.mutable_shape_with_output_layout()->mutable_layout() =
- arg->output_layout();
- }
-
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
- CreateModuleConfig(program_shape, {}, execution_options,
- user_computation));
-
- // Exclude dead parameter instructions for the purpose of computing constants.
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<HloModule> module,
- computation_tracker_.BuildHloModule(versioned_handle, *module_config,
- /*include_unreachable_instructions=*/
- false));
-
- std::vector<std::unique_ptr<Literal>> parameters(arg->parameters_size());
- for (int64 i = 0; i < arg->parameters_size(); ++i) {
- TF_ASSIGN_OR_RETURN(parameters[i],
- Literal::CreateFromProto(arg->parameters(i)));
- }
- HloEvaluator evaluator;
- TF_ASSIGN_OR_RETURN(
- auto result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(*module, parameters));
-
- // Since the shape_with_output_layout option in ExecutionOption is
- // non-effective to the Evaluator results, explicit relayout here.
- //
- // TODO(b/77824332): Make HloEvaluator take care of the re-layout.
- if (arg->has_output_layout()) {
- result_literal = result_literal->Relayout(arg->output_layout());
- }
- *result->mutable_literal() = result_literal->ToProto();
-
- return Status::OK();
-}
-
Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
ComputeConstantResponse* result) {
if (!arg->has_computation()) {
@@ -1709,60 +1144,6 @@ Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) {
return Status::OK();
}
-Status Service::GetComputationShape(const GetComputationShapeRequest* arg,
- GetComputationShapeResponse* result) {
- TF_ASSIGN_OR_RETURN(UserComputation * computation,
- computation_tracker_.Resolve(arg->computation()));
-
- VersionedComputationHandle versioned_handle =
- computation->GetVersionedHandle();
-
- TF_ASSIGN_OR_RETURN(auto program_shape, computation->ComputeProgramShape(
- versioned_handle.version));
- *result->mutable_program_shape() = *program_shape;
- return Status::OK();
-}
-
-Status Service::GetLocalShape(const GetLocalShapeRequest* arg,
- GetLocalShapeResponse* result) {
- TF_ASSIGN_OR_RETURN(UserComputation * computation,
- computation_tracker_.Resolve(arg->computation()));
-
- TF_ASSIGN_OR_RETURN(*result->mutable_shape(),
- computation->GetShape(arg->operand()));
- return Status::OK();
-}
-
-Status Service::GetComputationStats(const ComputationStatsRequest* arg,
- ComputationStatsResponse* result) {
- TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
- computation_tracker_.Resolve(arg->computation()));
-
- VersionedComputationHandle versioned_handle =
- user_computation->GetVersionedHandle();
-
- HloModuleConfig config;
- config.set_debug_options(arg->debug_options());
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<HloModule> module,
- computation_tracker_.BuildHloModule(versioned_handle, config));
-
- hlo_graph_dumper::MaybeDumpHloModule(*module,
- "computation statistics subject");
-
- // Run HLO analysis to get the computation statistics.
- HloCostAnalysis analysis(
- execute_backend_->compiler()->ShapeSizeBytesFunction());
-
- TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis));
-
- ComputationStats stats;
- stats.set_flop_count(analysis.flop_count());
- stats.set_transcendental_count(analysis.transcendental_count());
- *result->mutable_stats() = stats;
- return Status::OK();
-}
-
Status Service::GetComputationGraphStats(
const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) {
if (!arg->has_computation()) {
@@ -1793,262 +1174,6 @@ Status Service::GetComputationGraphStats(
return Status::OK();
}
-template <typename RequestT, typename ResponseT>
-Status Service::AddInstruction(
- const RequestT* arg, ResponseT* result,
- const std::function<StatusOr<ComputationDataHandle>(UserComputation*)>&
- adder) {
- TF_ASSIGN_OR_RETURN(UserComputation * computation,
- computation_tracker_.Resolve(arg->computation()));
-
- TF_ASSIGN_OR_RETURN(*result->mutable_output(), adder(computation));
- return Status::OK();
-}
-
-Status Service::Op(const OpRequest* arg, OpResponse* result) {
- TF_ASSIGN_OR_RETURN(UserComputation * computation,
- computation_tracker_.Resolve(arg->computation()));
- StatusOr<ComputationDataHandle> handle_status;
-
- switch (arg->op_case()) {
- case OpRequest::kBatchNormTrainingRequest:
- handle_status = computation->AddBatchNormTrainingInstruction(
- arg->batch_norm_training_request());
- break;
- case OpRequest::kBatchNormInferenceRequest:
- handle_status = computation->AddBatchNormInferenceInstruction(
- arg->batch_norm_inference_request());
- break;
- case OpRequest::kBatchNormGradRequest:
- handle_status = computation->AddBatchNormGradInstruction(
- arg->batch_norm_grad_request());
- break;
- case OpRequest::kBinaryOpRequest:
- handle_status =
- computation->AddBinaryInstruction(arg->binary_op_request());
- break;
- case OpRequest::kBroadcastRequest:
- handle_status =
- computation->AddBroadcastInstruction(arg->broadcast_request());
- break;
- case OpRequest::kCallRequest: {
- TF_ASSIGN_OR_RETURN(
- UserComputation * to_apply,
- computation_tracker_.Resolve(arg->call_request().to_apply()));
- handle_status =
- computation->AddCallInstruction(arg->call_request(), *to_apply);
- break;
- }
- case OpRequest::kConcatenateRequest:
- handle_status =
- computation->AddConcatenateInstruction(arg->concatenate_request());
- break;
- case OpRequest::kConditionalRequest: {
- TF_ASSIGN_OR_RETURN(UserComputation * true_computation,
- computation_tracker_.Resolve(
- arg->conditional_request().true_computation()));
- TF_ASSIGN_OR_RETURN(UserComputation * false_computation,
- computation_tracker_.Resolve(
- arg->conditional_request().false_computation()));
- handle_status = computation->AddConditionalInstruction(
- arg->conditional_request(), *true_computation, *false_computation);
- break;
- }
- case OpRequest::kConstantRequest:
- handle_status =
- computation->AddConstantInstruction(arg->constant_request());
- break;
- case OpRequest::kConvertRequest:
- handle_status =
- computation->AddConvertInstruction(arg->convert_request());
- break;
- case OpRequest::kBitcastConvertRequest:
- handle_status = computation->AddBitcastConvertInstruction(
- arg->bitcast_convert_request());
- break;
- case OpRequest::kConvolveRequest:
- handle_status =
- computation->AddConvolveInstruction(arg->convolve_request());
- break;
- case OpRequest::kCrossReplicaSumRequest:
- handle_status = computation->AddCrossReplicaSumInstruction(
- arg->cross_replica_sum_request());
- break;
- case OpRequest::kCustomCallRequest:
- handle_status =
- computation->AddCustomCallInstruction(arg->custom_call_request());
- break;
- case OpRequest::kDotRequest:
- handle_status = computation->AddDotInstruction(arg->dot_request());
- break;
- case OpRequest::kDynamicSliceRequest:
- handle_status =
- computation->AddDynamicSliceInstruction(arg->dynamic_slice_request());
- break;
- case OpRequest::kDynamicUpdateSliceRequest:
- handle_status = computation->AddDynamicUpdateSliceInstruction(
- arg->dynamic_update_slice_request());
- break;
- case OpRequest::kFftRequest:
- handle_status = computation->AddFftInstruction(arg->fft_request());
- break;
- case OpRequest::kGatherRequest:
- handle_status = computation->AddGatherInstruction(arg->gather_request());
- break;
- case OpRequest::kGetTupleElementRequest:
- handle_status = computation->AddGetTupleElementInstruction(
- arg->get_tuple_element_request());
- break;
- case OpRequest::kInfeedRequest:
- handle_status = computation->AddInfeedInstruction(arg->infeed_request());
- break;
- case OpRequest::kOutfeedRequest:
- handle_status =
- computation->AddOutfeedInstruction(arg->outfeed_request());
- break;
- case OpRequest::kHostComputeRequest:
- handle_status =
- computation->AddHostComputeInstruction(arg->host_compute_request());
- break;
- case OpRequest::kMapRequest: {
- TF_ASSIGN_OR_RETURN(
- UserComputation * to_apply,
- computation_tracker_.Resolve(arg->map_request().to_apply()));
- handle_status =
- computation->AddMapInstruction(arg->map_request(), *to_apply);
- break;
- }
- case OpRequest::kPadRequest:
- handle_status = computation->AddPadInstruction(arg->pad_request());
- break;
- case OpRequest::kParameterRequest:
- handle_status =
- computation->AddParameterInstruction(arg->parameter_request());
- break;
- case OpRequest::kReduceRequest: {
- TF_ASSIGN_OR_RETURN(
- UserComputation * to_apply,
- computation_tracker_.Resolve(arg->reduce_request().to_apply()));
- handle_status =
- computation->AddReduceInstruction(arg->reduce_request(), *to_apply);
- break;
- }
- case OpRequest::kReducePrecisionRequest: {
- handle_status = computation->AddReducePrecisionInstruction(
- arg->reduce_precision_request());
- break;
- }
- case OpRequest::kReduceWindowRequest: {
- TF_ASSIGN_OR_RETURN(UserComputation * to_apply,
- computation_tracker_.Resolve(
- arg->reduce_window_request().to_apply()));
- handle_status = computation->AddReduceWindowInstruction(
- arg->reduce_window_request(), *to_apply);
- break;
- }
- case OpRequest::kReshapeRequest:
- handle_status =
- computation->AddReshapeInstruction(arg->reshape_request());
- break;
- case OpRequest::kReverseRequest:
- handle_status =
- computation->AddReverseInstruction(arg->reverse_request());
- break;
- case OpRequest::kRngRequest:
- handle_status = computation->AddRngInstruction(arg->rng_request());
- break;
- case OpRequest::kSelectAndScatterRequest: {
- TF_ASSIGN_OR_RETURN(UserComputation * select,
- computation_tracker_.Resolve(
- arg->select_and_scatter_request().select()));
- TF_ASSIGN_OR_RETURN(UserComputation * scatter,
- computation_tracker_.Resolve(
- arg->select_and_scatter_request().scatter()));
- handle_status = computation->AddSelectAndScatterInstruction(
- arg->select_and_scatter_request(), *select, *scatter);
- break;
- }
- case OpRequest::kSliceRequest:
- handle_status = computation->AddSliceInstruction(arg->slice_request());
- break;
- case OpRequest::kTernaryOpRequest:
- handle_status =
- computation->AddTernaryInstruction(arg->ternary_op_request());
- break;
- case OpRequest::kTraceRequest:
- return computation->AddTraceInstruction(arg->trace_request());
- case OpRequest::kTransposeRequest:
- handle_status =
- computation->AddTransposeInstruction(arg->transpose_request());
- break;
- case OpRequest::kUnaryOpRequest:
- handle_status = computation->AddUnaryInstruction(arg->unary_op_request());
- break;
- case OpRequest::kVariadicOpRequest:
- handle_status =
- computation->AddVariadicInstruction(arg->variadic_op_request());
- break;
- case OpRequest::kWhileRequest: {
- TF_ASSIGN_OR_RETURN(
- UserComputation * condition,
- computation_tracker_.Resolve(arg->while_request().condition()));
- TF_ASSIGN_OR_RETURN(
- UserComputation * body,
- computation_tracker_.Resolve(arg->while_request().body()));
- handle_status = computation->AddWhileInstruction(arg->while_request(),
- *condition, *body);
- break;
- }
- case OpRequest::kSendRequest: {
- TF_RETURN_IF_ERROR(
- channel_tracker_.RegisterSend(arg->send_request().channel_handle()));
- // Send does not return a value, but we need a handle to be able to
- // set OpMetadata and OpSharding (device assignment).
- handle_status = computation->AddSendInstruction(arg->send_request());
- break;
- }
- case OpRequest::kRecvRequest: {
- TF_RETURN_IF_ERROR(
- channel_tracker_.RegisterRecv(arg->recv_request().channel_handle()));
- handle_status = computation->AddRecvInstruction(arg->recv_request());
- break;
- }
- case OpRequest::OP_NOT_SET:
- return InvalidArgument("XLA service received OpRequest with OP_NOT_SET");
- default:
- return InvalidArgument("Unsupported operation in XLA service");
- }
- TF_ASSIGN_OR_RETURN(*result->mutable_output(), handle_status);
-
- // We set the debug metadata here, because we slice off part of the OpRequest
- // proto in the above switch statement.
- TF_ASSIGN_OR_RETURN(ComputationDataHandle handle, handle_status);
- TF_RETURN_IF_ERROR(computation->SetOpMetadata(handle, arg->metadata()));
- if (arg->has_sharding()) {
- TF_RETURN_IF_ERROR(computation->SetOpSharding(handle, arg->sharding()));
- }
- return Status::OK();
-}
-
-Status Service::SnapshotComputation(const SnapshotComputationRequest* arg,
- SnapshotComputationResponse* result) {
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<SessionModule> module,
- computation_tracker_.SnapshotComputation(arg->computation()));
-
- result->set_allocated_module(module.release());
-
- return Status::OK();
-}
-
-Status Service::LoadComputationSnapshot(
- const LoadComputationSnapshotRequest* arg,
- LoadComputationSnapshotResponse* result) {
- TF_ASSIGN_OR_RETURN(*result->mutable_computation(),
- computation_tracker_.LoadSessionModule(arg->module()));
- return Status::OK();
-}
-
DeviceHandle Service::SingleComputationDeviceHandle() const {
DeviceHandle device_handle;
device_handle.set_handle(0);
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index 81fbd41957..422bb95657 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -27,7 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/channel_tracker.h"
#include "tensorflow/compiler/xla/service/compilation_cache.h"
-#include "tensorflow/compiler/xla/service/computation_tracker.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/execution_tracker.h"
@@ -35,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/session.pb.h"
-#include "tensorflow/compiler/xla/service/user_computation.h"
#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
#include "tensorflow/compiler/xla/service_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -83,11 +81,6 @@ class Service : public ServiceInterface {
static StatusOr<std::unique_ptr<Service>> NewService(
const ServiceOptions& options);
- // Creates a new computation with the given name.
- // A unique ComputationHandle is returned.
- Status Computation(const ComputationRequest* arg,
- ComputationResponse* result) override;
-
// Unregisters a previously-allocated global handle.
//
// If the handle given is not currently allocated, a NOT_FOUND status is
@@ -100,35 +93,15 @@ class Service : public ServiceInterface {
Status DeconstructTuple(const DeconstructTupleRequest* arg,
DeconstructTupleResponse* result) override;
- // Modifies the provided computation so that subsequent executions
- // will compute the provided ComputationDataHandle, rather than the
- // last expression enqueued on that Computation.
- Status SetReturnValue(const SetReturnValueRequest* arg,
- SetReturnValueResponse* results) override;
-
- // Executes a computation with the provided global data passed as
- // immutable arguments. Returns global data output and execution timing.
- Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override;
-
// Executes a computation with the provided global data passed as
// immutable arguments. The request contains the whole computation graph.
// Returns global data output and execution timing.
- //
- // TODO(b/74197823): This is a part of a NOT YET ready refactor.
Status ExecuteGraph(const ExecuteGraphRequest* arg,
ExecuteResponse* result) override;
// Executes one or more computations in parallel with the provided global data
// passed as immutable arguments. Returns global data output for each
// computation.
- Status ExecuteParallel(const ExecuteParallelRequest* arg,
- ExecuteParallelResponse* result) override;
-
- // Executes one or more computations in parallel with the provided global data
- // passed as immutable arguments. Returns global data output for each
- // computation.
- //
- // TODO(b/74197823): This is a part of a NOT YET ready refactor.
Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
ExecuteParallelResponse* result) override;
@@ -143,16 +116,6 @@ class Service : public ServiceInterface {
Status GetDeviceHandles(const GetDeviceHandlesRequest* arg,
GetDeviceHandlesResponse* result) override;
- // Asynchronously executes a computation with provided arguments. Invokes
- // the provided computation with the provided global data passed as
- // immutable arguments. Returns a handle to the execution.
- //
- // (Note: The corresponding function in xla::Client was removed as part of
- // b/64116060, in an attempt to simplify our API. We're keeping this around
- // for now in case we want to expose this to clients in a different way.)
- Status ExecuteAsync(const ExecuteAsyncRequest* arg,
- ExecuteAsyncResponse* result) override;
-
// Waits until the specified execution is complete and returns the result.
// Calling this API multiple times with the same execution handle returns the
// method with an error since the execution handle is destroyed after the
@@ -190,13 +153,6 @@ class Service : public ServiceInterface {
Status ResetDevice(const ResetDeviceRequest* arg,
ResetDeviceResponse* result) override;
- // Tests if an expression is a compile-time constant.
- Status IsConstant(const IsConstantRequest* arg,
- IsConstantResponse* result) override;
-
- // Computes the value of a constant expression.
- Status ComputeConstant(const ComputeConstantRequest* arg,
- ComputeConstantResponse* result) override;
Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
ComputeConstantResponse* result) override;
@@ -205,54 +161,15 @@ class Service : public ServiceInterface {
Status GetShape(const GetShapeRequest* arg,
GetShapeResponse* result) override;
- // Returns the program shape of the computation associated with the given
- // handle.
- Status GetComputationShape(const GetComputationShapeRequest* arg,
- GetComputationShapeResponse* result) override;
-
- /////
- // Computation-oriented methods.
-
- // Enqueues an Op on the computation.
- Status Op(const OpRequest* arg, OpResponse* result) override;
-
- // Retrieves the inferred shape for a value within a computation.
- Status GetLocalShape(const GetLocalShapeRequest* arg,
- GetLocalShapeResponse* result) override;
-
- // Retrieves the statistics of a computation.
- Status GetComputationStats(const ComputationStatsRequest* arg,
- ComputationStatsResponse* result) override;
-
// Retrieves the statistics of a computation.
- //
- // TODO(b/74197823): This is a part of a NOT YET ready refactor.
Status GetComputationGraphStats(const ComputationGraphStatsRequest* arg,
ComputationStatsResponse* result) override;
- // Snapshots the current state of a computation handle into a serializable
- // protocol buffer form, so it can be loaded via
- // LoadComputationSnapshot.
- Status SnapshotComputation(const SnapshotComputationRequest* arg,
- SnapshotComputationResponse* result) override;
-
- // Loads a computation from a serialized protocol buffer created via
- // SnapshotComputation.
- Status LoadComputationSnapshot(
- const LoadComputationSnapshotRequest* arg,
- LoadComputationSnapshotResponse* result) override;
-
// Creates a unique channel handle that can be used for Send/Recv
// instructions.
Status CreateChannelHandle(const CreateChannelHandleRequest* arg,
CreateChannelHandleResponse* result) override;
- // Returns the ComputationTracker of the current service instance.
- // Only used in unit tests to access user computations from client.
- const ComputationTracker& computation_tracker() {
- return computation_tracker_;
- }
-
// Returns the backend used to execute computations.
const Backend& backend() const { return *execute_backend_; }
Backend* mutable_backend() { return execute_backend_.get(); }
@@ -263,8 +180,7 @@ class Service : public ServiceInterface {
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- const ExecutionOptions& execution_options,
- const UserComputation* user_computation = nullptr);
+ const ExecutionOptions& execution_options);
// Picks a parallel response and fills the result.
Status PickParallelResponse(const ExecuteParallelResponse& parallel_result,
@@ -305,8 +221,7 @@ class Service : public ServiceInterface {
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
- const ExecutionOptions* execution_options,
- const UserComputation* user_computation = nullptr);
+ const ExecutionOptions* execution_options);
// Builds an Executable for the given parameters.
//
@@ -314,15 +229,6 @@ class Service : public ServiceInterface {
// buffers, which the compiler is responsible for freeing. The allocator
// given here need not match the allocator used when running the executable.
StatusOr<std::unique_ptr<Executable>> BuildExecutable(
- const VersionedComputationHandle& versioned_handle,
- std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
- se::StreamExecutor* executor,
- DeviceMemoryAllocator* device_allocator = nullptr);
-
- // Builds an Executable for the given HLO module proto.
- //
- // TODO(b/74197823): This is a part of a NOT YET ready refactor.
- StatusOr<std::unique_ptr<Executable>> BuildExecutable(
const HloModuleProto& module_proto,
std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
se::StreamExecutor* executor,
@@ -331,25 +237,11 @@ class Service : public ServiceInterface {
// Same as BuildExecutable() above, but builds a list of Executables for the
// given computations that may interact with each other.
StatusOr<std::vector<std::unique_ptr<Executable>>> BuildExecutables(
- std::vector<VersionedComputationHandle> versioned_handles,
- std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
- Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
- DeviceMemoryAllocator* device_allocator);
- StatusOr<std::vector<std::unique_ptr<Executable>>> BuildExecutables(
const std::vector<const HloModuleProto*>& module_protos,
std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
DeviceMemoryAllocator* device_allocator);
- // Similar to BuildExecutable, but look in the compilation cache for the
- // executable first. If the executable is not in the cache, it is built and
- // inserted into the cache.
- StatusOr<std::shared_ptr<Executable>> BuildAndCacheExecutable(
- const VersionedComputationHandle& versioned_handle,
- std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
- se::StreamExecutor* executor, ExecutionProfile* profile,
- DeviceMemoryAllocator* device_allocator = nullptr);
-
// Runs the given executable with the given arguments and register the result
// in the allocation tracker. The handle of the result from the tracker is
// returned. If the parameter "profile" is not null, it points to an
@@ -372,17 +264,9 @@ class Service : public ServiceInterface {
tensorflow::gtl::ArraySlice<string> result_tags,
ExecutionProfile* profile);
- // Convenience function for adding a function to a user computation.
- template <typename RequestT, typename ResponseT>
- Status AddInstruction(
- const RequestT* arg, ResponseT* result,
- const std::function<StatusOr<ComputationDataHandle>(UserComputation*)>&
- adder);
-
// Executes a single computation which has more than one target device.
// The N devices are expected to all return an empty tuple, but one, which
// will be the result of this computation.
- Status ExecuteOneToN(const ExecuteRequest* arg, ExecuteResponse* result);
Status ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result);
// Convenience function which checks whether the given shape_with_layout
@@ -405,9 +289,6 @@ class Service : public ServiceInterface {
ServiceOptions options_;
- // Tracks computations built via the API.
- ComputationTracker computation_tracker_;
-
// Tracks channels created via the API.
ChannelTracker channel_tracker_;
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
deleted file mode 100644
index 9e62d0acfb..0000000000
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ /dev/null
@@ -1,3557 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/user_computation.h"
-
-#include <algorithm>
-#include <set>
-#include <stack>
-#include <unordered_map>
-#include <utility>
-#include <vector>
-
-#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
-#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/service/hlo_opcode.h"
-#include "tensorflow/compiler/xla/service/shape_inference.h"
-#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/protobuf.h"
-
-namespace xla {
-namespace {
-
-HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) {
- switch (unop) {
- case UNOP_ABS:
- return HloOpcode::kAbs;
- case UNOP_CEIL:
- return HloOpcode::kCeil;
- case UNOP_CLZ:
- return HloOpcode::kClz;
- case UNOP_COS:
- return HloOpcode::kCos;
- case UNOP_EXP:
- return HloOpcode::kExp;
- case UNOP_EXPM1:
- return HloOpcode::kExpm1;
- case UNOP_FLOOR:
- return HloOpcode::kFloor;
- case UNOP_IMAG:
- return HloOpcode::kImag;
- case UNOP_IS_FINITE:
- return HloOpcode::kIsFinite;
- case UNOP_LOG:
- return HloOpcode::kLog;
- case UNOP_LOG1P:
- return HloOpcode::kLog1p;
- case UNOP_NOT:
- return HloOpcode::kNot;
- case UNOP_NEGATE:
- return HloOpcode::kNegate;
- case UNOP_REAL:
- return HloOpcode::kReal;
- case UNOP_ROUND_NEAREST_AFZ:
- return HloOpcode::kRoundNearestAfz;
- case UNOP_SIGN:
- return HloOpcode::kSign;
- case UNOP_SIN:
- return HloOpcode::kSin;
- case UNOP_SORT:
- return HloOpcode::kSort;
- case UNOP_TANH:
- return HloOpcode::kTanh;
- default:
- LOG(FATAL) << "unhandled operation " << unop;
- }
-}
-
-HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) {
- switch (binop) {
- case BINOP_ATAN2:
- return HloOpcode::kAtan2;
- case BINOP_COMPLEX:
- return HloOpcode::kComplex;
- case BINOP_MUL:
- return HloOpcode::kMultiply;
- case BINOP_ADD:
- return HloOpcode::kAdd;
- case BINOP_SUB:
- return HloOpcode::kSubtract;
- case BINOP_DIV:
- return HloOpcode::kDivide;
- case BINOP_EQ:
- return HloOpcode::kEq;
- case BINOP_GE:
- return HloOpcode::kGe;
- case BINOP_GT:
- return HloOpcode::kGt;
- case BINOP_LE:
- return HloOpcode::kLe;
- case BINOP_LT:
- return HloOpcode::kLt;
- case BINOP_NE:
- return HloOpcode::kNe;
- case BINOP_MAX:
- return HloOpcode::kMaximum;
- case BINOP_MIN:
- return HloOpcode::kMinimum;
- case BINOP_POW:
- return HloOpcode::kPower;
- case BINOP_REM:
- return HloOpcode::kRemainder;
- case BINOP_OR:
- return HloOpcode::kOr;
- case BINOP_AND:
- return HloOpcode::kAnd;
- case BINOP_SHIFT_LEFT:
- return HloOpcode::kShiftLeft;
- case BINOP_SHIFT_RIGHT_ARITHMETIC:
- return HloOpcode::kShiftRightArithmetic;
- case BINOP_SHIFT_RIGHT_LOGICAL:
- return HloOpcode::kShiftRightLogical;
- default:
- LOG(FATAL) << "unhandled operation " << binop;
- }
-}
-
-HloOpcode TernaryOperationToHloOpcode(TernaryOperation triop) {
- switch (triop) {
- case TRIOP_CLAMP:
- return HloOpcode::kClamp;
- case TRIOP_SELECT:
- return HloOpcode::kSelect;
- default:
- LOG(FATAL) << "unhandled operation " << triop;
- }
-}
-
-HloOpcode VariadicOperationToHloOpcode(VariadicOperation varop) {
- switch (varop) {
- case VAROP_TUPLE:
- return HloOpcode::kTuple;
- default:
- LOG(FATAL) << "unhandled operation " << varop;
- }
-}
-
-} // namespace
-
-/* static */ StatusOr<std::unique_ptr<UserComputation>>
-UserComputation::MakeWithRemapping(
- const SessionComputation& session_computation,
- const ComputationHandle& handle,
- const std::map<int64, ComputationHandle>& old_to_new) {
- auto user_computation =
- MakeUnique<UserComputation>(session_computation.name(), handle);
- {
- tensorflow::mutex_lock lock(user_computation->mutex_);
- user_computation->session_computation_ = session_computation;
- user_computation->next_handle_value_ =
- std::max_element(session_computation.requests().begin(),
- session_computation.requests().end(),
- [](const std::pair<int64, OperationRequest>& lhs,
- const std::pair<int64, OperationRequest>& rhs) {
- return lhs.first < rhs.first;
- })
- ->first +
- 1;
- TF_RETURN_IF_ERROR(user_computation->RemapEmbeddedComputations(old_to_new));
- }
-
- return std::move(user_computation);
-}
-
-UserComputation::UserComputation(const string& name,
- const ComputationHandle& handle)
- : name_(name), next_handle_value_(1) {
- *session_computation_.mutable_computation_handle() = handle;
- session_computation_.set_name(name);
-
- VLOG(1) << "New UserComputation \"" << name
- << "\", handle: " << handle.handle();
-}
-
-ComputationDataHandle UserComputation::CreateComputationDataHandle() {
- ComputationDataHandle handle;
- handle.set_handle(next_handle_value_);
- // Handles are used as Version values and *must* be assigned consecutively for
- // computation versioning to work.
- next_handle_value_++;
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddParameterInstruction(
- const ParameterRequest& parameter_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- int64 parameter_number = parameter_request.parameter();
- if (parameters_.count(parameter_number) != 0) {
- return InvalidArgument("parameter %lld already registered",
- parameter_number);
- }
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- const Shape& validated_shape = parameter_request.shape();
- TF_RETURN_IF_ERROR(
- ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape));
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = validated_shape;
- *request.mutable_request()->mutable_parameter_request() = parameter_request;
-
- parameters_[parameter_number] = &request;
-
- VLOG(1) << "AddParameterInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << parameter_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddSendInstruction(
- const SendRequest& send_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- // Check if the operand of the instruction is valid.
- TF_RETURN_IF_ERROR(LookUpRequest(send_request.operand()).status());
-
- // No handle is returned, but a handle must be assigned to this instruction
- // for computation versioning.
- ComputationDataHandle handle = CreateComputationDataHandle();
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = ShapeUtil::MakeNil();
- *request.mutable_request()->mutable_send_request() = send_request;
-
- VLOG(1) << "AddSendInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << send_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddRecvInstruction(
- const RecvRequest& recv_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- const Shape& shape = recv_request.shape();
- TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
- ComputationDataHandle handle = CreateComputationDataHandle();
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = shape;
- *request.mutable_request()->mutable_recv_request() = recv_request;
-
- VLOG(1) << "AddRecvInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << recv_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddPadInstruction(
- const PadRequest& pad_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
- LookUpRequest(pad_request.operand()));
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* padding_value,
- LookUpRequest(pad_request.padding_value()));
-
- TF_ASSIGN_OR_RETURN(Shape inferred_shape, ShapeInference::InferPadShape(
- operand->output_shape(),
- padding_value->output_shape(),
- pad_request.padding_config()));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = inferred_shape;
- *request.mutable_request()->mutable_pad_request() = pad_request;
-
- VLOG(1) << "AddPadInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << pad_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddConstantInstruction(
- const ConstantRequest& constant_request) {
- const Shape& validated_shape = constant_request.literal().shape();
- TF_RETURN_IF_ERROR(
- ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape));
-
- tensorflow::mutex_lock lock(mutex_);
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = validated_shape;
- *request.mutable_request()->mutable_constant_request() = constant_request;
-
- VLOG(1) << "AddConstantInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddGatherInstruction(
- const GatherRequest& gather_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* input_request,
- LookUpRequest(gather_request.input()));
- TF_ASSIGN_OR_RETURN(const OperationRequest* gather_indices_request,
- LookUpRequest(gather_request.gather_indices()));
-
- TF_ASSIGN_OR_RETURN(
- Shape shape,
- ShapeInference::InferGatherShape(
- input_request->output_shape(), gather_indices_request->output_shape(),
- gather_request.dimension_numbers(),
- AsInt64Slice(gather_request.window_bounds())));
-
- const ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = shape;
- *request.mutable_request()->mutable_gather_request() = gather_request;
-
- VLOG(1) << "AddGatherInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << gather_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddGetTupleElementInstruction(
- const GetTupleElementRequest& get_tuple_element_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
- LookUpRequest(get_tuple_element_request.operand()));
- if (!ShapeUtil::IsTuple(operand->output_shape())) {
- return InvalidArgument(
- "Operand to GetTupleElement() is not a tuple; got %s",
- ShapeUtil::HumanString(operand->output_shape()).c_str());
- }
- Shape element_shape = ShapeUtil::GetTupleElementShape(
- operand->output_shape(), get_tuple_element_request.index());
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = element_shape;
- *request.mutable_request()->mutable_get_tuple_element_request() =
- get_tuple_element_request;
-
- VLOG(1) << "AddGetTupleElementInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << get_tuple_element_request.ShortDebugString();
- return handle;
-}
-
-Status UserComputation::AddTraceInstruction(const TraceRequest& trace_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- // Verify that the operand index is valid.
- TF_RETURN_IF_ERROR(LookUpRequest(trace_request.operand()).status());
-
- ComputationDataHandle handle = CreateComputationDataHandle();
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = ShapeUtil::MakeNil();
- *request.mutable_request()->mutable_trace_request() = trace_request;
-
- VLOG(1) << "AddTraceInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << trace_request.ShortDebugString();
- return Status::OK();
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddRngInstruction(
- const RngRequest& rng_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- // Check the number of parameters per RNG distribution.
- switch (rng_request.distribution()) {
- case RandomDistribution::RNG_NORMAL:
- case RandomDistribution::RNG_UNIFORM:
- if (rng_request.parameter_size() != 2) {
- return InvalidArgument(
- "RNG distribution (%s) expects 2 parameters, but got %d",
- RandomDistribution_Name(rng_request.distribution()).c_str(),
- rng_request.parameter_size());
- }
- break;
- default:
- LOG(FATAL) << "unhandled distribution " << rng_request.distribution();
- }
-
- // Verify that the parameter indices are valid;
- for (const ComputationDataHandle& param : rng_request.parameter()) {
- TF_RETURN_IF_ERROR(LookUpRequest(param).status());
- }
- const Shape& validated_shape = rng_request.shape();
- TF_RETURN_IF_ERROR(
- ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = validated_shape;
- *request.mutable_request()->mutable_rng_request() = rng_request;
-
- VLOG(1) << "AddRngInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << rng_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddMapInstruction(
- const MapRequest& map_request,
- const UserComputation& to_apply_computation) {
- tensorflow::mutex_lock lock(mutex_);
-
- std::vector<const Shape*> operand_shapes;
- for (const ComputationDataHandle& handle : map_request.operands()) {
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle));
- operand_shapes.push_back(&operand->output_shape());
- }
-
- VersionedComputationHandle::Version to_apply_version =
- to_apply_computation.version();
- TF_ASSIGN_OR_RETURN(
- std::shared_ptr<const ProgramShape> to_apply_program_shape,
- to_apply_computation.ComputeProgramShape(to_apply_version));
- TF_ASSIGN_OR_RETURN(
- Shape inferred_shape,
- ShapeInference::InferMapShape(operand_shapes, *to_apply_program_shape,
- AsInt64Slice(map_request.dimensions())));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = inferred_shape;
- request.add_embedded_computation_versions(to_apply_version);
- *request.mutable_request()->mutable_map_request() = map_request;
-
- VLOG(1) << "AddMapInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << map_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddReduceInstruction(
- const ReduceRequest& reduce_request,
- const UserComputation& to_apply_computation) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
- LookUpRequest(reduce_request.operand()));
- TF_ASSIGN_OR_RETURN(const OperationRequest* init_value,
- LookUpRequest(reduce_request.init_value()));
-
- VersionedComputationHandle::Version to_apply_version =
- to_apply_computation.version();
- TF_ASSIGN_OR_RETURN(
- std::shared_ptr<const ProgramShape> to_apply_program_shape,
- to_apply_computation.ComputeProgramShape(to_apply_version));
-
- TF_ASSIGN_OR_RETURN(
- Shape inferred_shape,
- ShapeInference::InferReduceShape(
- operand->output_shape(), init_value->output_shape(),
- AsInt64Slice(reduce_request.dimensions()), *to_apply_program_shape));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = inferred_shape;
- request.add_embedded_computation_versions(to_apply_version);
- *request.mutable_request()->mutable_reduce_request() = reduce_request;
-
- VLOG(1) << "AddReduceInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << reduce_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle>
-UserComputation::AddBatchNormTrainingInstruction(
- const BatchNormTrainingRequest& batch_norm_training_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
- LookUpRequest(batch_norm_training_request.operand()));
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* scale,
- LookUpRequest(batch_norm_training_request.scale()));
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* offset,
- LookUpRequest(batch_norm_training_request.offset()));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
-
- TF_ASSIGN_OR_RETURN(
- Shape inferred_shape,
- ShapeInference::InferBatchNormTrainingShape(
- operand->output_shape(), scale->output_shape(),
- offset->output_shape(), batch_norm_training_request.feature_index()));
-
- *request.mutable_output_shape() = inferred_shape;
-
- *request.mutable_output_handle() = handle;
-
- *request.mutable_request()->mutable_batch_norm_training_request() =
- batch_norm_training_request;
-
- VLOG(1) << "AddBatchNormTrainingInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << batch_norm_training_request.ShortDebugString();
-
- return handle;
-}
-
-StatusOr<ComputationDataHandle>
-UserComputation::AddBatchNormInferenceInstruction(
- const BatchNormInferenceRequest& batch_norm_inference_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
- LookUpRequest(batch_norm_inference_request.operand()));
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* scale,
- LookUpRequest(batch_norm_inference_request.scale()));
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* offset,
- LookUpRequest(batch_norm_inference_request.offset()));
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* mean,
- LookUpRequest(batch_norm_inference_request.mean()));
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* variance,
- LookUpRequest(batch_norm_inference_request.variance()));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
-
- TF_ASSIGN_OR_RETURN(Shape inferred_shape,
- ShapeInference::InferBatchNormInferenceShape(
- operand->output_shape(), scale->output_shape(),
- offset->output_shape(), mean->output_shape(),
- variance->output_shape(),
- batch_norm_inference_request.feature_index()));
-
- *request.mutable_output_shape() = inferred_shape;
-
- *request.mutable_output_handle() = handle;
-
- *request.mutable_request()->mutable_batch_norm_inference_request() =
- batch_norm_inference_request;
-
- VLOG(1) << "AddBatchNormInferenceInstruction ("
- << GetVersionedHandleInternal() << "), data handle "
- << handle.handle() << ": "
- << batch_norm_inference_request.ShortDebugString();
-
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddBatchNormGradInstruction(
- const BatchNormGradRequest& batch_norm_grad_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
- LookUpRequest(batch_norm_grad_request.operand()));
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* scale,
- LookUpRequest(batch_norm_grad_request.scale()));
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* mean,
- LookUpRequest(batch_norm_grad_request.mean()));
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* variance,
- LookUpRequest(batch_norm_grad_request.variance()));
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* grad_output,
- LookUpRequest(batch_norm_grad_request.grad_output()));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
-
- TF_ASSIGN_OR_RETURN(
- Shape inferred_shape,
- ShapeInference::InferBatchNormGradShape(
- operand->output_shape(), scale->output_shape(), mean->output_shape(),
- variance->output_shape(), grad_output->output_shape(),
- batch_norm_grad_request.feature_index()));
-
- *request.mutable_output_shape() = inferred_shape;
-
- *request.mutable_output_handle() = handle;
-
- *request.mutable_request()->mutable_batch_norm_grad_request() =
- batch_norm_grad_request;
-
- VLOG(1) << "AddBatchNormGradInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << batch_norm_grad_request.ShortDebugString();
-
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddReduceWindowInstruction(
- const ReduceWindowRequest& reduce_window_request,
- const UserComputation& to_apply_computation) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
- LookUpRequest(reduce_window_request.operand()));
- TF_ASSIGN_OR_RETURN(const OperationRequest* init_value,
- LookUpRequest(reduce_window_request.init_value()));
-
- VersionedComputationHandle::Version to_apply_version =
- to_apply_computation.version();
- TF_ASSIGN_OR_RETURN(
- std::shared_ptr<const ProgramShape> to_apply_program_shape,
- to_apply_computation.ComputeProgramShape(to_apply_version));
-
- TF_ASSIGN_OR_RETURN(
- Shape inferred_shape,
- ShapeInference::InferReduceWindowShape(
- operand->output_shape(), init_value->output_shape(),
- reduce_window_request.window(), *to_apply_program_shape));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = inferred_shape;
- request.add_embedded_computation_versions(to_apply_version);
- *request.mutable_request()->mutable_reduce_window_request() =
- reduce_window_request;
-
- VLOG(1) << "AddReduceWindowInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << reduce_window_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddSelectAndScatterInstruction(
- const SelectAndScatterRequest& select_and_scatter_request,
- const UserComputation& select_computation,
- const UserComputation& scatter_computation) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
- LookUpRequest(select_and_scatter_request.operand()));
- TF_ASSIGN_OR_RETURN(const OperationRequest* source,
- LookUpRequest(select_and_scatter_request.source()));
- TF_ASSIGN_OR_RETURN(const OperationRequest* init_value,
- LookUpRequest(select_and_scatter_request.init_value()));
-
- VersionedComputationHandle::Version select_version =
- select_computation.version();
- TF_ASSIGN_OR_RETURN(std::shared_ptr<const ProgramShape> select_program_shape,
- select_computation.ComputeProgramShape(select_version));
- VersionedComputationHandle::Version scatter_version =
- scatter_computation.version();
- TF_ASSIGN_OR_RETURN(std::shared_ptr<const ProgramShape> scatter_program_shape,
- scatter_computation.ComputeProgramShape(scatter_version));
-
- TF_ASSIGN_OR_RETURN(
- Shape inferred_shape,
- ShapeInference::InferSelectAndScatterShape(
- operand->output_shape(), *select_program_shape,
- select_and_scatter_request.window(), source->output_shape(),
- init_value->output_shape(), *scatter_program_shape));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = inferred_shape;
- request.add_embedded_computation_versions(select_version);
- request.add_embedded_computation_versions(scatter_version);
- *request.mutable_request()->mutable_select_and_scatter_request() =
- select_and_scatter_request;
-
- VLOG(1) << "AddSelectAndScatterInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << select_and_scatter_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddReverseInstruction(
- const ReverseRequest& reverse_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
- LookUpRequest(reverse_request.operand()));
- TF_ASSIGN_OR_RETURN(
- Shape inferred_shape,
- ShapeInference::InferReverseShape(
- operand->output_shape(), AsInt64Slice(reverse_request.dimensions())));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = inferred_shape;
- *request.mutable_request()->mutable_reverse_request() = reverse_request;
- VLOG(1) << "AddReverseInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << reverse_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddWhileInstruction(
- const WhileRequest& while_request,
- const UserComputation& condition_computation,
- const UserComputation& body_computation) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* init,
- LookUpRequest(while_request.init()));
-
- VersionedComputationHandle::Version condition_version =
- condition_computation.version();
- TF_ASSIGN_OR_RETURN(
- std::shared_ptr<const ProgramShape> condition_program_shape,
- condition_computation.ComputeProgramShape(condition_version));
-
- VersionedComputationHandle::Version body_version = body_computation.version();
- TF_ASSIGN_OR_RETURN(std::shared_ptr<const ProgramShape> body_program_shape,
- body_computation.ComputeProgramShape(body_version));
-
- TF_ASSIGN_OR_RETURN(
- Shape inferred_shape,
- ShapeInference::InferWhileShape(
- *condition_program_shape, *body_program_shape, init->output_shape()));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = inferred_shape;
- request.add_embedded_computation_versions(condition_version);
- request.add_embedded_computation_versions(body_version);
- *request.mutable_request()->mutable_while_request() = while_request;
-
- VLOG(1) << "AddWhileInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << while_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddConditionalInstruction(
- const ConditionalRequest& conditional_request,
- const UserComputation& true_computation,
- const UserComputation& false_computation) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* pred,
- LookUpRequest(conditional_request.predicate()));
- TF_ASSIGN_OR_RETURN(const OperationRequest* true_operand,
- LookUpRequest(conditional_request.true_operand()));
- TF_ASSIGN_OR_RETURN(const OperationRequest* false_operand,
- LookUpRequest(conditional_request.false_operand()));
-
- VersionedComputationHandle::Version true_computation_version =
- true_computation.version();
- TF_ASSIGN_OR_RETURN(
- std::shared_ptr<const ProgramShape> true_computation_shape,
- true_computation.ComputeProgramShape(true_computation_version));
-
- VersionedComputationHandle::Version false_computation_version =
- false_computation.version();
- TF_ASSIGN_OR_RETURN(
- std::shared_ptr<const ProgramShape> false_computation_shape,
- false_computation.ComputeProgramShape(false_computation_version));
-
- TF_ASSIGN_OR_RETURN(Shape inferred_shape,
- ShapeInference::InferConditionalShape(
- pred->output_shape(), true_operand->output_shape(),
- false_operand->output_shape(),
- *true_computation_shape, *false_computation_shape));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = inferred_shape;
- request.add_embedded_computation_versions(true_computation_version);
- request.add_embedded_computation_versions(false_computation_version);
- *request.mutable_request()->mutable_conditional_request() =
- conditional_request;
-
- VLOG(1) << "AddConditionalInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << conditional_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddBroadcastInstruction(
- const BroadcastRequest& broadcast_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- // Fetches and validates the operand.
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
- LookUpRequest(broadcast_request.operand()));
- TF_ASSIGN_OR_RETURN(Shape inferred_shape,
- ShapeInference::InferBroadcastShape(
- operand->output_shape(),
- AsInt64Slice(broadcast_request.broadcast_sizes())));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = inferred_shape;
- *request.mutable_request()->mutable_broadcast_request() = broadcast_request;
-
- VLOG(1) << "AddBroadcastInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << broadcast_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddReshapeInstruction(
- const ReshapeRequest& reshape_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- // Fetches and validates the operand.
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
- LookUpRequest(reshape_request.operand()));
-
- TF_ASSIGN_OR_RETURN(
- Shape inferred_shape,
- ShapeInference::InferReshapeShape(
- operand->output_shape(), AsInt64Slice(reshape_request.dimensions()),
- AsInt64Slice(reshape_request.new_sizes())));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = inferred_shape;
- *request.mutable_request()->mutable_reshape_request() = reshape_request;
-
- VLOG(1) << "AddReshapeInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << reshape_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddTransposeInstruction(
- const TransposeRequest& transpose_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- // Fetches and validates the operand.
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
- LookUpRequest(transpose_request.operand()));
-
- TF_ASSIGN_OR_RETURN(Shape inferred_shape,
- ShapeInference::InferTransposeShape(
- operand->output_shape(),
- AsInt64Slice(transpose_request.dimensions())));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = inferred_shape;
- *request.mutable_request()->mutable_transpose_request() = transpose_request;
-
- VLOG(1) << "AddTransposeInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << transpose_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddSliceInstruction(
- const SliceRequest& slice_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
- LookUpRequest(slice_request.operand()));
-
- TF_ASSIGN_OR_RETURN(
- Shape new_shape,
- ShapeInference::InferSliceShape(
- operand->output_shape(), AsInt64Slice(slice_request.start_indices()),
- AsInt64Slice(slice_request.limit_indices()),
- AsInt64Slice(slice_request.strides())));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = new_shape;
- *request.mutable_request()->mutable_slice_request() = slice_request;
-
- VLOG(1) << "AddSliceInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << slice_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddDynamicSliceInstruction(
- const DynamicSliceRequest& dynamic_slice_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
- LookUpRequest(dynamic_slice_request.operand()));
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* start_indices,
- LookUpRequest(dynamic_slice_request.start_indices()));
-
- TF_ASSIGN_OR_RETURN(
- Shape new_shape,
- ShapeInference::InferDynamicSliceShape(
- operand->output_shape(), start_indices->output_shape(),
- AsInt64Slice(dynamic_slice_request.slice_sizes())));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = new_shape;
- *request.mutable_request()->mutable_dynamic_slice_request() =
- dynamic_slice_request;
-
- VLOG(1) << "AddDynamicSliceInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << dynamic_slice_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle>
-UserComputation::AddDynamicUpdateSliceInstruction(
- const DynamicUpdateSliceRequest& dynamic_update_slice_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
- LookUpRequest(dynamic_update_slice_request.operand()));
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* update,
- LookUpRequest(dynamic_update_slice_request.update()));
-
- TF_ASSIGN_OR_RETURN(
- const OperationRequest* start_indices,
- LookUpRequest(dynamic_update_slice_request.start_indices()));
-
- TF_ASSIGN_OR_RETURN(Shape new_shape,
- ShapeInference::InferDynamicUpdateSliceShape(
- operand->output_shape(), update->output_shape(),
- start_indices->output_shape()));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = new_shape;
- *request.mutable_request()->mutable_dynamic_update_slice_request() =
- dynamic_update_slice_request;
-
- VLOG(1) << "AddDynamicUpdateSliceInstruction ("
- << GetVersionedHandleInternal() << "), data handle "
- << handle.handle() << ": "
- << dynamic_update_slice_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddConcatenateInstruction(
- const ConcatenateRequest& concatenate_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- std::vector<const Shape*> operand_shapes;
- for (const ComputationDataHandle& handle : concatenate_request.operands()) {
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle));
- operand_shapes.push_back(&operand->output_shape());
- }
-
- TF_ASSIGN_OR_RETURN(Shape new_shape,
- ShapeInference::InferConcatOpShape(
- operand_shapes, concatenate_request.dimension()));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = new_shape;
- *request.mutable_request()->mutable_concatenate_request() =
- concatenate_request;
-
- VLOG(1) << "AddConcatenateInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << concatenate_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddConvertInstruction(
- const ConvertRequest& convert_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
- LookUpRequest(convert_request.operand()));
-
- TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape(
- operand->output_shape(),
- convert_request.new_element_type()));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = new_shape;
- *request.mutable_request()->mutable_convert_request() = convert_request;
-
- VLOG(1) << "AddConvertInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << convert_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddBitcastConvertInstruction(
- const ConvertRequest& convert_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
- LookUpRequest(convert_request.operand()));
-
- TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape(
- operand->output_shape(),
- convert_request.new_element_type()));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = new_shape;
- *request.mutable_request()->mutable_bitcast_convert_request() =
- convert_request;
-
- VLOG(1) << "AddBitcastConvertInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << convert_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddReducePrecisionInstruction(
- const ReducePrecisionRequest& reduce_precision_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
- LookUpRequest(reduce_precision_request.operand()));
-
- TF_ASSIGN_OR_RETURN(
- Shape new_shape,
- ShapeInference::InferReducePrecisionShape(
- operand->output_shape(), reduce_precision_request.exponent_bits(),
- reduce_precision_request.mantissa_bits()));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = new_shape;
- *request.mutable_request()->mutable_reduce_precision_request() =
- reduce_precision_request;
-
- VLOG(1) << "AddReducePrecisionInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << reduce_precision_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddConvolveInstruction(
- const ConvolveRequest& convolve_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* lhs,
- LookUpRequest(convolve_request.lhs()));
- TF_ASSIGN_OR_RETURN(const OperationRequest* rhs,
- LookUpRequest(convolve_request.rhs()));
- TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvolveShape(
- lhs->output_shape(), rhs->output_shape(),
- convolve_request.window(),
- convolve_request.dimension_numbers()));
-
- const ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = shape;
- *request.mutable_request()->mutable_convolve_request() = convolve_request;
-
- VLOG(1) << "AddConvolveInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << convolve_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddFftInstruction(
- const FftRequest& fft_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
- LookUpRequest(fft_request.operand()));
- TF_ASSIGN_OR_RETURN(Shape shape,
- ShapeInference::InferFftShape(
- operand->output_shape(), fft_request.fft_type(),
- AsInt64Slice(fft_request.fft_length())));
-
- const ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = shape;
- *request.mutable_request()->mutable_fft_request() = fft_request;
-
- VLOG(1) << "AddFftInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << fft_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddCrossReplicaSumInstruction(
- const CrossReplicaSumRequest& cross_replica_sum_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
- LookUpRequest(cross_replica_sum_request.operand()));
- TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCrossReplicaSumShape(
- {&operand->output_shape()}));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = shape;
- *request.mutable_request()->mutable_cross_replica_sum_request() =
- cross_replica_sum_request;
-
- VLOG(1) << "AddCrossreplicaSumInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << cross_replica_sum_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddInfeedInstruction(
- const InfeedRequest& infeed_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- const Shape& shape = infeed_request.shape();
- if (!LayoutUtil::HasLayout(shape)) {
- return InvalidArgument("Given shape to Infeed must have a layout");
- }
-
- const ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = shape;
- *request.mutable_request()->mutable_infeed_request() = infeed_request;
-
- VLOG(1) << "AddInfeedInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << infeed_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddOutfeedInstruction(
- const OutfeedRequest& outfeed_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- const Shape& shape = outfeed_request.shape();
- if (!LayoutUtil::HasLayout(shape)) {
- return InvalidArgument("Given shape to Outfeed must have a layout");
- }
-
- // Verify that operand is valid.
- TF_RETURN_IF_ERROR(LookUpRequest(outfeed_request.operand()).status());
-
- ComputationDataHandle handle = CreateComputationDataHandle();
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = shape;
- *request.mutable_request()->mutable_outfeed_request() = outfeed_request;
-
- VLOG(1) << "AddOutfeedInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << outfeed_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddCallInstruction(
- const CallRequest& call_request,
- const UserComputation& to_apply_computation) {
- tensorflow::mutex_lock lock(mutex_);
-
- std::vector<const Shape*> operand_shapes;
- for (const ComputationDataHandle& handle : call_request.operands()) {
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle));
- operand_shapes.push_back(&operand->output_shape());
- }
-
- VersionedComputationHandle::Version to_apply_version =
- to_apply_computation.version();
- TF_ASSIGN_OR_RETURN(
- std::shared_ptr<const ProgramShape> to_apply_program_shape,
- to_apply_computation.ComputeProgramShape(to_apply_version));
- TF_ASSIGN_OR_RETURN(
- Shape inferred_shape,
- ShapeInference::InferCallShape(operand_shapes, *to_apply_program_shape));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = inferred_shape;
- request.add_embedded_computation_versions(to_apply_version);
- *request.mutable_request()->mutable_call_request() = call_request;
-
- VLOG(1) << "AddCallInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << call_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddCustomCallInstruction(
- const CustomCallRequest& custom_call_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- for (const ComputationDataHandle& handle : custom_call_request.operands()) {
- TF_RETURN_IF_ERROR(LookUpRequest(handle).status());
- }
-
- if (tensorflow::str_util::StartsWith(custom_call_request.call_target_name(),
- "$")) {
- return InvalidArgument(
- "Invalid custom_call_target \"%s\": Call targets that start with '$' "
- "are reserved for internal use.",
- custom_call_request.call_target_name().c_str());
- }
-
- const ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = custom_call_request.shape();
- *request.mutable_request()->mutable_custom_call_request() =
- custom_call_request;
-
- VLOG(1) << "AddCustomCallInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << custom_call_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddHostComputeInstruction(
- const HostComputeRequest& host_compute_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- for (const ComputationDataHandle& handle : host_compute_request.operands()) {
- TF_RETURN_IF_ERROR(LookUpRequest(handle).status());
- }
-
- ComputationDataHandle handle = CreateComputationDataHandle();
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = host_compute_request.shape();
- *request.mutable_request()->mutable_host_compute_request() =
- host_compute_request;
-
- VLOG(1) << "AddHostComputeInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << host_compute_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddDotInstruction(
- const DotRequest& dot_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* lhs,
- LookUpRequest(dot_request.lhs()));
- TF_ASSIGN_OR_RETURN(const OperationRequest* rhs,
- LookUpRequest(dot_request.rhs()));
-
- TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDotOpShape(
- lhs->output_shape(), rhs->output_shape(),
- dot_request.dimension_numbers()));
-
- const ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = shape;
- *request.mutable_request()->mutable_dot_request() = dot_request;
-
- VLOG(1) << "AddDotInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << dot_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddUnaryInstruction(
- const UnaryOpRequest& unary_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
- LookUpRequest(unary_request.operand()));
- TF_ASSIGN_OR_RETURN(
- Shape shape, ShapeInference::InferUnaryOpShape(unary_request.unop(),
- operand->output_shape()));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = shape;
- *request.mutable_request()->mutable_unary_op_request() = unary_request;
-
- VLOG(1) << "AddUnaryInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << unary_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddBinaryInstruction(
- const BinaryOpRequest& binary_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* lhs,
- LookUpRequest(binary_request.lhs()));
- TF_ASSIGN_OR_RETURN(const OperationRequest* rhs,
- LookUpRequest(binary_request.rhs()));
- TF_ASSIGN_OR_RETURN(
- Shape shape,
- ShapeInference::InferBinaryOpShape(
- binary_request.binop(), lhs->output_shape(), rhs->output_shape(),
- AsInt64Slice(binary_request.broadcast_dimensions())));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = shape;
- *request.mutable_request()->mutable_binary_op_request() = binary_request;
-
- VLOG(1) << "AddBinaryInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << binary_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddTernaryInstruction(
- const TernaryOpRequest& ternary_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* lhs,
- LookUpRequest(ternary_request.lhs()));
- TF_ASSIGN_OR_RETURN(const OperationRequest* rhs,
- LookUpRequest(ternary_request.rhs()));
- TF_ASSIGN_OR_RETURN(const OperationRequest* ehs,
- LookUpRequest(ternary_request.ehs()));
- TF_ASSIGN_OR_RETURN(Shape shape,
- ShapeInference::InferTernaryOpShape(
- ternary_request.triop(), lhs->output_shape(),
- rhs->output_shape(), ehs->output_shape()));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = shape;
- *request.mutable_request()->mutable_ternary_op_request() = ternary_request;
-
- VLOG(1) << "AddTernaryInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << ternary_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<ComputationDataHandle> UserComputation::AddVariadicInstruction(
- const VariadicOpRequest& variadic_request) {
- tensorflow::mutex_lock lock(mutex_);
-
- std::vector<const Shape*> operand_shapes;
- for (const ComputationDataHandle& handle : variadic_request.operands()) {
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle));
- operand_shapes.push_back(&operand->output_shape());
- }
-
- TF_ASSIGN_OR_RETURN(Shape shape,
- ShapeInference::InferVariadicOpShape(
- variadic_request.varop(), operand_shapes));
-
- ComputationDataHandle handle = CreateComputationDataHandle();
-
- OperationRequest& request =
- (*session_computation_.mutable_requests())[handle.handle()];
- *request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = shape;
- *request.mutable_request()->mutable_variadic_op_request() = variadic_request;
-
- VLOG(1) << "AddVariadicInstruction (" << GetVersionedHandleInternal()
- << "), data handle " << handle.handle() << ": "
- << variadic_request.ShortDebugString();
- return handle;
-}
-
-StatusOr<Shape> UserComputation::GetShape(const ComputationDataHandle& handle) {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle));
- return operand->output_shape();
-}
-
-Status UserComputation::SetOpMetadata(const ComputationDataHandle& handle,
- const OpMetadata& metadata) {
- tensorflow::mutex_lock lock(mutex_);
-
- int64 handle_value = handle.handle();
- if (session_computation_.requests().count(handle_value) == 0) {
- return InvalidArgument("Invalid handle in SetOpMetadata (%lld)",
- handle_value);
- }
- *session_computation_.mutable_requests()
- ->at(handle_value)
- .mutable_request()
- ->mutable_metadata() = metadata;
- return Status::OK();
-}
-
-Status UserComputation::SetOpSharding(const ComputationDataHandle& handle,
- const OpSharding& sharding) {
- tensorflow::mutex_lock lock(mutex_);
-
- int64 handle_value = handle.handle();
- if (session_computation_.requests().count(handle_value) == 0) {
- return InvalidArgument("Invalid handle in SetOpSharding (%lld)",
- handle_value);
- }
- *session_computation_.mutable_requests()
- ->at(handle_value)
- .mutable_request()
- ->mutable_sharding() = sharding;
- return Status::OK();
-}
-
-Status UserComputation::SetReturnValue(const ComputationDataHandle& handle) {
- tensorflow::mutex_lock lock(mutex_);
-
- if (!(handle.handle() > 0 && handle.handle() < next_handle_value_)) {
- return InvalidArgument("Invalid handle in SetReturnValue");
- }
-
- handle_to_return_ = handle;
-
- VLOG(1) << "SetReturnValue of computation \"" << name() << "\" fixed to "
- << GetVersionedHandleInternal();
-
- return Status::OK();
-}
-
-VersionedComputationHandle UserComputation::GetVersionedHandle() const {
- tensorflow::mutex_lock lock(mutex_);
- return GetVersionedHandleInternal();
-}
-
-VersionedComputationHandle UserComputation::GetVersionedHandleInternal() const {
- VersionedComputationHandle versioned_handle;
- versioned_handle.handle = session_computation_.computation_handle();
-
- if (handle_to_return_.handle() > 0) {
- // A specific handle has been requested for the result of the computation.
- versioned_handle.version = handle_to_return_.handle();
- } else {
- // A version value is simply the most recently assigned
- // ComputationDataHandle value, ie the handle value of the root of the
- // computation.
- versioned_handle.version = next_handle_value_ - 1;
- }
-
- return versioned_handle;
-}
-
-VersionedComputationHandle UserComputation::GetVersionedHandleAtOperation(
- const ComputationDataHandle& operation) const {
- tensorflow::mutex_lock lock(mutex_);
-
- // The version at which an operation was added is simply the handle value of
- // the ComputationDataHandle.
- VersionedComputationHandle versioned_handle;
- versioned_handle.handle = session_computation_.computation_handle();
- versioned_handle.version = operation.handle();
- return versioned_handle;
-}
-
-VersionedComputationHandle::Version UserComputation::version() const {
- return GetVersionedHandle().version;
-}
-
-namespace {
-
-// Returns true if the operation type corresponding to the given opcase can be
-// the root of the computation.
-bool CanBeRoot(const OpRequest::OpCase& op_case) {
- switch (op_case) {
- case OpRequest::kTraceRequest:
- case OpRequest::kSendRequest:
- case OpRequest::kOutfeedRequest:
- return false;
- default:
- return true;
- }
-}
-
-// Returns a pointer to the operation with the given data handle value in the
-// given SessionComputation.
-StatusOr<const OperationRequest*> LookUpRequest(
- int64 handle_value, const SessionComputation& session_computation) {
- if (session_computation.requests().count(handle_value) == 0) {
- return InvalidArgument("no ComputationDataHandle value %lld", handle_value);
- }
- return &session_computation.requests().at(handle_value);
-}
-
-// Returns the OperationRequest corresponding to the root (result) of the
-// session computation.
-StatusOr<const OperationRequest*> GetRoot(
- VersionedComputationHandle::Version version,
- const SessionComputation& session_computation) {
- TF_RET_CHECK(version > 0);
- // Not all instructions can be roots. Walk backwards from the operation
- // indicated by this version until a valid root is found.
- const OperationRequest* root_request = nullptr;
- while (version > 0) {
- TF_ASSIGN_OR_RETURN(root_request,
- LookUpRequest(version, session_computation));
- if (CanBeRoot(root_request->request().op_case())) {
- break;
- }
- version--;
- }
- if (version == 0) {
- return InternalError("Computation contains no root operation");
- }
- return root_request;
-}
-
-} // namespace
-
-StatusOr<std::shared_ptr<const ProgramShape>>
-UserComputation::ComputeProgramShape(
- VersionedComputationHandle::Version version) const {
- tensorflow::mutex_lock lock(mutex_);
-
- TF_RET_CHECK(version > 0 && version < next_handle_value_);
-
- if (program_shape_ == nullptr || program_shape_version_ != version) {
- // ProgramShape has not been computed yet, or is for different
- // version. Compute it now.
- TF_RETURN_IF_ERROR(CheckParametersAreContiguous(version));
-
- auto program_shape = MakeUnique<ProgramShape>();
- for (int64 request_num = 1; request_num <= version; ++request_num) {
- const OperationRequest& request =
- session_computation_.requests().at(request_num);
- if (request.request().op_case() == OpRequest::kParameterRequest) {
- const ParameterRequest& parameter_request =
- request.request().parameter_request();
- int64 param_no = parameter_request.parameter();
- // Parameters may be out of order so expand ProgramShape parameters
- // until it is at least large enough to hold the current parameter
- // number.
- while (program_shape->parameters_size() <= param_no) {
- program_shape->add_parameters();
- program_shape->add_parameter_names();
- }
- *program_shape->mutable_parameters(param_no) = request.output_shape();
- *program_shape->mutable_parameter_names(param_no) =
- parameter_request.name();
- }
- }
-
- // The root determines the output shape.
- TF_ASSIGN_OR_RETURN(const OperationRequest* root_request,
- GetRoot(version, session_computation_));
- *program_shape->mutable_result() = root_request->output_shape();
- if (ShapeUtil::IsOpaque(program_shape->result())) {
- return Unimplemented("Computation results cannot be opaque");
- }
-
- program_shape_ = std::move(program_shape);
- program_shape_version_ = version;
- }
-
- return program_shape_;
-}
-
-namespace {
-
-// A visitor which checks whether an operation is pure functional meaning that
-// it doesn't depend on any parameter with an index higher then num_parameters.
-// The visitor walks the computation starting at a given operation and sets
-// is_functional to false iff a parameter or RNG operation is encountered.
-void PureFunctionalVisitor(const SessionComputation& session_computation,
- const ComputationDataHandle& handle,
- int64 num_parameters, std::set<int64>* visited,
- bool* is_functional) {
- if (visited->count(handle.handle()) != 0 || !*is_functional) {
- return;
- }
-
- const OperationRequest& request =
- session_computation.requests().at(handle.handle());
- switch (request.request().op_case()) {
- case OpRequest::kRngRequest:
- *is_functional = false;
- break;
-
- case OpRequest::kConstantRequest:
- break;
-
- case OpRequest::kGetTupleElementRequest: {
- const GetTupleElementRequest& get_tuple_element_request =
- request.request().get_tuple_element_request();
- PureFunctionalVisitor(session_computation,
- get_tuple_element_request.operand(), num_parameters,
- visited, is_functional);
- break;
- }
-
- case OpRequest::kSliceRequest: {
- const SliceRequest& slice_request = request.request().slice_request();
- PureFunctionalVisitor(session_computation, slice_request.operand(),
- num_parameters, visited, is_functional);
- break;
- }
-
- case OpRequest::kDynamicSliceRequest: {
- const DynamicSliceRequest& dynamic_slice_request =
- request.request().dynamic_slice_request();
- PureFunctionalVisitor(session_computation,
- dynamic_slice_request.operand(), num_parameters,
- visited, is_functional);
- PureFunctionalVisitor(session_computation,
- dynamic_slice_request.start_indices(),
- num_parameters, visited, is_functional);
- break;
- }
-
- case OpRequest::kDynamicUpdateSliceRequest: {
- const DynamicUpdateSliceRequest& dynamic_update_slice_request =
- request.request().dynamic_update_slice_request();
- PureFunctionalVisitor(session_computation,
- dynamic_update_slice_request.operand(),
- num_parameters, visited, is_functional);
- PureFunctionalVisitor(session_computation,
- dynamic_update_slice_request.update(),
- num_parameters, visited, is_functional);
- PureFunctionalVisitor(session_computation,
- dynamic_update_slice_request.start_indices(),
- num_parameters, visited, is_functional);
- break;
- }
-
- case OpRequest::kConcatenateRequest: {
- const ConcatenateRequest& concatenate_request =
- request.request().concatenate_request();
- for (const ComputationDataHandle& handle :
- concatenate_request.operands()) {
- PureFunctionalVisitor(session_computation, handle, num_parameters,
- visited, is_functional);
- }
- break;
- }
-
- case OpRequest::kConvolveRequest: {
- const ConvolveRequest& convolve_request =
- request.request().convolve_request();
- PureFunctionalVisitor(session_computation, convolve_request.lhs(),
- num_parameters, visited, is_functional);
- PureFunctionalVisitor(session_computation, convolve_request.rhs(),
- num_parameters, visited, is_functional);
- break;
- }
-
- case OpRequest::kFftRequest: {
- const FftRequest& fft_request = request.request().fft_request();
- PureFunctionalVisitor(session_computation, fft_request.operand(),
- num_parameters, visited, is_functional);
- break;
- }
-
- case OpRequest::kCrossReplicaSumRequest: {
- // TODO(b/33009255): Implmement constant folding for cross replica sum.
- *is_functional = false;
- break;
- }
-
- case OpRequest::kInfeedRequest: {
- *is_functional = false;
- break;
- }
-
- case OpRequest::kOutfeedRequest: {
- *is_functional = false;
- break;
- }
-
- case OpRequest::kHostComputeRequest: {
- *is_functional = false;
- break;
- }
-
- case OpRequest::kCallRequest: {
- const CallRequest& call_request = request.request().call_request();
- for (const ComputationDataHandle& handle : call_request.operands()) {
- PureFunctionalVisitor(session_computation, handle, num_parameters,
- visited, is_functional);
- }
- // TODO(b/32495713): We aren't checking the to_apply computation itself,
- // so we conservatively say that computations containing the Call op
- // cannot be constant. We cannot set is_functional=false in other similar
- // cases since we're already relying on IsConstant to return true.
- *is_functional = false;
- break;
- }
-
- case OpRequest::kCustomCallRequest: {
- *is_functional = false;
- break;
- }
-
- case OpRequest::kDotRequest: {
- const DotRequest& dot_request = request.request().dot_request();
- PureFunctionalVisitor(session_computation, dot_request.lhs(),
- num_parameters, visited, is_functional);
- PureFunctionalVisitor(session_computation, dot_request.rhs(),
- num_parameters, visited, is_functional);
- break;
- }
-
- case OpRequest::kSendRequest: {
- *is_functional = false;
- break;
- }
-
- case OpRequest::kRecvRequest: {
- *is_functional = false;
- break;
- }
-
- case OpRequest::kMapRequest: {
- const MapRequest& map_request = request.request().map_request();
- for (const ComputationDataHandle& handle : map_request.operands()) {
- PureFunctionalVisitor(session_computation, handle, num_parameters,
- visited, is_functional);
- }
- // TODO(b/32495713): We aren't checking the to_apply computation itself.
- break;
- }
-
- case OpRequest::kReduceRequest: {
- const ReduceRequest& reduce_request = request.request().reduce_request();
- PureFunctionalVisitor(session_computation, reduce_request.operand(),
- num_parameters, visited, is_functional);
- PureFunctionalVisitor(session_computation, reduce_request.init_value(),
- num_parameters, visited, is_functional);
- // TODO(b/32495713): We aren't checking the to_apply computation itself.
- break;
- }
-
- case OpRequest::kReduceWindowRequest: {
- const ReduceWindowRequest& reduce_window_request =
- request.request().reduce_window_request();
- PureFunctionalVisitor(session_computation,
- reduce_window_request.operand(), num_parameters,
- visited, is_functional);
- PureFunctionalVisitor(session_computation,
- reduce_window_request.init_value(), num_parameters,
- visited, is_functional);
- // TODO(b/32495713): We aren't checking the to_apply computation itself.
- break;
- }
-
- case OpRequest::kSelectAndScatterRequest: {
- const SelectAndScatterRequest& select_and_scatter_request =
- request.request().select_and_scatter_request();
- PureFunctionalVisitor(session_computation,
- select_and_scatter_request.operand(),
- num_parameters, visited, is_functional);
- PureFunctionalVisitor(session_computation,
- select_and_scatter_request.source(), num_parameters,
- visited, is_functional);
- PureFunctionalVisitor(session_computation,
- select_and_scatter_request.init_value(),
- num_parameters, visited, is_functional);
- // TODO(b/32495713): We aren't checking the select and scatter
- // computations themselves.
- break;
- }
-
- case OpRequest::kBroadcastRequest: {
- const BroadcastRequest& broadcast_request =
- request.request().broadcast_request();
- PureFunctionalVisitor(session_computation, broadcast_request.operand(),
- num_parameters, visited, is_functional);
- break;
- }
-
- case OpRequest::kReshapeRequest: {
- const ReshapeRequest& reshape_request =
- request.request().reshape_request();
- PureFunctionalVisitor(session_computation, reshape_request.operand(),
- num_parameters, visited, is_functional);
- break;
- }
-
- case OpRequest::kReverseRequest: {
- const ReverseRequest& reverse_request =
- request.request().reverse_request();
- PureFunctionalVisitor(session_computation, reverse_request.operand(),
- num_parameters, visited, is_functional);
- break;
- }
-
- case OpRequest::kPadRequest: {
- const PadRequest& pad_request = request.request().pad_request();
- PureFunctionalVisitor(session_computation, pad_request.operand(),
- num_parameters, visited, is_functional);
- PureFunctionalVisitor(session_computation, pad_request.padding_value(),
- num_parameters, visited, is_functional);
- break;
- }
-
- case OpRequest::kParameterRequest: {
- const ParameterRequest& parameter_request =
- request.request().parameter_request();
- if (parameter_request.parameter() >= num_parameters) {
- *is_functional = false;
- }
- break;
- }
-
- case OpRequest::kConvertRequest: {
- const ConvertRequest& convert_request =
- request.request().convert_request();
- PureFunctionalVisitor(session_computation, convert_request.operand(),
- num_parameters, visited, is_functional);
- break;
- }
-
- case OpRequest::kBitcastConvertRequest: {
- const ConvertRequest& convert_request =
- request.request().bitcast_convert_request();
- PureFunctionalVisitor(session_computation, convert_request.operand(),
- num_parameters, visited, is_functional);
- break;
- }
-
- case OpRequest::kWhileRequest: {
- const WhileRequest& while_request = request.request().while_request();
- PureFunctionalVisitor(session_computation, while_request.init(),
- num_parameters, visited, is_functional);
- // TODO(b/32495713): We aren't checking the condition and body
- // computations themselves.
- *is_functional = false;
- break;
- }
-
- case OpRequest::kConditionalRequest: {
- const ConditionalRequest& conditional_request =
- request.request().conditional_request();
- PureFunctionalVisitor(session_computation,
- conditional_request.predicate(), num_parameters,
- visited, is_functional);
- PureFunctionalVisitor(session_computation,
- conditional_request.true_operand(), num_parameters,
- visited, is_functional);
- PureFunctionalVisitor(session_computation,
- conditional_request.false_operand(), num_parameters,
- visited, is_functional);
- // TODO(b/32495713): We aren't checking the true and false computations
- // themselves.
- break;
- }
-
- case OpRequest::kTernaryOpRequest: {
- const TernaryOpRequest& ternary_op_request =
- request.request().ternary_op_request();
- PureFunctionalVisitor(session_computation, ternary_op_request.lhs(),
- num_parameters, visited, is_functional);
- PureFunctionalVisitor(session_computation, ternary_op_request.rhs(),
- num_parameters, visited, is_functional);
- PureFunctionalVisitor(session_computation, ternary_op_request.ehs(),
- num_parameters, visited, is_functional);
- break;
- }
-
- case OpRequest::kTransposeRequest: {
- const TransposeRequest& transpose_request =
- request.request().transpose_request();
- PureFunctionalVisitor(session_computation, transpose_request.operand(),
- num_parameters, visited, is_functional);
- break;
- }
-
- case OpRequest::kVariadicOpRequest: {
- const VariadicOpRequest& variadic_op_request =
- request.request().variadic_op_request();
- for (const ComputationDataHandle& handle :
- variadic_op_request.operands()) {
- PureFunctionalVisitor(session_computation, handle, num_parameters,
- visited, is_functional);
- }
- break;
- }
-
- case OpRequest::kUnaryOpRequest: {
- const UnaryOpRequest& unary_op_request =
- request.request().unary_op_request();
- PureFunctionalVisitor(session_computation, unary_op_request.operand(),
- num_parameters, visited, is_functional);
- break;
- }
-
- case OpRequest::kBatchNormTrainingRequest: {
- const BatchNormTrainingRequest& batch_norm_training_request =
- request.request().batch_norm_training_request();
- PureFunctionalVisitor(session_computation,
- batch_norm_training_request.operand(),
- num_parameters, visited, is_functional);
- PureFunctionalVisitor(session_computation,
- batch_norm_training_request.scale(), num_parameters,
- visited, is_functional);
- PureFunctionalVisitor(session_computation,
- batch_norm_training_request.offset(),
- num_parameters, visited, is_functional);
- break;
- }
-
- case OpRequest::kBatchNormInferenceRequest: {
- const BatchNormInferenceRequest& batch_norm_inference_request =
- request.request().batch_norm_inference_request();
- PureFunctionalVisitor(session_computation,
- batch_norm_inference_request.operand(),
- num_parameters, visited, is_functional);
- PureFunctionalVisitor(session_computation,
- batch_norm_inference_request.scale(),
- num_parameters, visited, is_functional);
- PureFunctionalVisitor(session_computation,
- batch_norm_inference_request.offset(),
- num_parameters, visited, is_functional);
- PureFunctionalVisitor(session_computation,
- batch_norm_inference_request.mean(), num_parameters,
- visited, is_functional);
- PureFunctionalVisitor(session_computation,
- batch_norm_inference_request.variance(),
- num_parameters, visited, is_functional);
- break;
- }
-
- case OpRequest::kBatchNormGradRequest: {
- const BatchNormGradRequest& batch_norm_grad_request =
- request.request().batch_norm_grad_request();
- PureFunctionalVisitor(session_computation,
- batch_norm_grad_request.operand(), num_parameters,
- visited, is_functional);
- PureFunctionalVisitor(session_computation,
- batch_norm_grad_request.scale(), num_parameters,
- visited, is_functional);
- PureFunctionalVisitor(session_computation, batch_norm_grad_request.mean(),
- num_parameters, visited, is_functional);
- PureFunctionalVisitor(session_computation,
- batch_norm_grad_request.variance(), num_parameters,
- visited, is_functional);
- PureFunctionalVisitor(session_computation,
- batch_norm_grad_request.grad_output(),
- num_parameters, visited, is_functional);
- break;
- }
-
- case OpRequest::kBinaryOpRequest: {
- const BinaryOpRequest& binary_op_request =
- request.request().binary_op_request();
- PureFunctionalVisitor(session_computation, binary_op_request.lhs(),
- num_parameters, visited, is_functional);
- PureFunctionalVisitor(session_computation, binary_op_request.rhs(),
- num_parameters, visited, is_functional);
- break;
- }
-
- case OpRequest::kGatherRequest: {
- PureFunctionalVisitor(session_computation,
- request.request().gather_request().input(),
- num_parameters, visited, is_functional);
- PureFunctionalVisitor(session_computation,
- request.request().gather_request().gather_indices(),
- num_parameters, visited, is_functional);
- break;
- }
-
- case OpRequest::OP_NOT_SET:
- LOG(FATAL) << "OperationRequest doesn't contain a request";
-
- default:
- LOG(FATAL) << "Unexpected request type: " << request.request().op_case();
- }
- if (!*is_functional) {
- VLOG(1) << "Non-functional: " << request.request().DebugString();
- }
- visited->insert(handle.handle());
-}
-
-} // namespace
-
-StatusOr<bool> UserComputation::IsConstant(const ComputationDataHandle& handle,
- int64 num_parameters) {
- tensorflow::mutex_lock lock(mutex_);
-
- // Verify that the handle is valid.
- auto operation_status = LookUpRequest(handle);
- if (!operation_status.ok()) {
- return operation_status.status();
- }
-
- bool is_constant = true;
- std::set<int64> visited;
- PureFunctionalVisitor(session_computation_, handle, num_parameters, &visited,
- &is_constant);
-
- return is_constant;
-}
-
-std::vector<VersionedComputationHandle>
-UserComputation::GetEmbeddedComputations(
- VersionedComputationHandle::Version version) const {
- tensorflow::mutex_lock lock(mutex_);
-
- VLOG(1)
- << "GetEmbeddedComputations(" << name() << " "
- << VersionedComputationHandle{session_computation_.computation_handle(),
- version}
- << ")";
- XLA_VLOG_LINES(3, session_computation_.DebugString());
-
- std::vector<VersionedComputationHandle> computations;
- std::vector<int64> sorted_handles;
- for (const auto& handle_request : session_computation_.requests()) {
- sorted_handles.push_back(handle_request.first);
- }
- std::sort(sorted_handles.begin(), sorted_handles.end());
- for (int64 handle : sorted_handles) {
- const auto& handle_request = session_computation_.requests().find(handle);
- CHECK(handle_request != session_computation_.requests().end());
- int64 handle_value = handle_request->first;
- if (handle_value <= version) {
- const OperationRequest& request = handle_request->second;
- switch (request.request().op_case()) {
- case OpRequest::kCallRequest: {
- CHECK_EQ(1, request.embedded_computation_versions_size());
- const CallRequest& call_request = request.request().call_request();
- const VersionedComputationHandle versioned_handle = {
- call_request.to_apply(),
- request.embedded_computation_versions(0)};
- computations.push_back(versioned_handle);
- break;
- }
-
- case OpRequest::kMapRequest: {
- CHECK_EQ(1, request.embedded_computation_versions_size());
- const MapRequest& map_request = request.request().map_request();
- const VersionedComputationHandle versioned_handle = {
- map_request.to_apply(), request.embedded_computation_versions(0)};
- computations.push_back(versioned_handle);
- break;
- }
-
- case OpRequest::kReduceRequest: {
- CHECK_EQ(1, request.embedded_computation_versions_size());
- const ReduceRequest& reduce_request =
- request.request().reduce_request();
- const VersionedComputationHandle versioned_handle = {
- reduce_request.to_apply(),
- request.embedded_computation_versions(0)};
- computations.push_back(versioned_handle);
- break;
- }
-
- case OpRequest::kReduceWindowRequest: {
- CHECK_EQ(1, request.embedded_computation_versions_size());
- const ReduceWindowRequest& reduce_window_request =
- request.request().reduce_window_request();
- const VersionedComputationHandle versioned_handle = {
- reduce_window_request.to_apply(),
- request.embedded_computation_versions(0)};
- computations.push_back(versioned_handle);
- break;
- }
-
- case OpRequest::kSelectAndScatterRequest: {
- CHECK_EQ(2, request.embedded_computation_versions_size());
- const SelectAndScatterRequest& select_and_scatter_request =
- request.request().select_and_scatter_request();
- const VersionedComputationHandle select_versioned_handle = {
- select_and_scatter_request.select(),
- request.embedded_computation_versions(0)};
- computations.push_back(select_versioned_handle);
- const VersionedComputationHandle scatter_versioned_handle = {
- select_and_scatter_request.scatter(),
- request.embedded_computation_versions(1)};
- computations.push_back(scatter_versioned_handle);
- break;
- }
-
- case OpRequest::kWhileRequest: {
- CHECK_EQ(2, request.embedded_computation_versions_size());
- const WhileRequest& while_request = request.request().while_request();
- const VersionedComputationHandle condition_versioned_handle = {
- while_request.condition(),
- request.embedded_computation_versions(0)};
- computations.push_back(condition_versioned_handle);
- const VersionedComputationHandle body_versioned_handle = {
- while_request.body(), request.embedded_computation_versions(1)};
- computations.push_back(body_versioned_handle);
- break;
- }
-
- case OpRequest::kConditionalRequest: {
- CHECK_EQ(2, request.embedded_computation_versions_size());
- const ConditionalRequest& conditional_request =
- request.request().conditional_request();
- const VersionedComputationHandle true_computation_versioned_handle = {
- conditional_request.true_computation(),
- request.embedded_computation_versions(0)};
- computations.push_back(true_computation_versioned_handle);
- const VersionedComputationHandle false_computation_versioned_handle =
- {conditional_request.false_computation(),
- request.embedded_computation_versions(1)};
- computations.push_back(false_computation_versioned_handle);
- break;
- }
-
- default:
- // No embedded computation.
- break;
- }
- }
- }
- VLOG(2) << "Embedded computations: "
- << tensorflow::str_util::Join(
- computations, ", ",
- [](string* out, const VersionedComputationHandle& h) {
- out->append(h.ToString());
- });
- return computations;
-}
-
-StatusOr<const OperationRequest*>
-UserComputation::LookUpRequestForErrorReporting(
- const ComputationDataHandle& handle) const {
- tensorflow::mutex_lock lock(mutex_);
- return LookUpRequest(handle);
-}
-
-tensorflow::gtl::optional<const OpMetadata*> UserComputation::ParameterMetadata(
- int parameter_number) const {
- tensorflow::mutex_lock lock(mutex_);
- auto it = parameters_.find(parameter_number);
- if (it == parameters_.end()) {
- return tensorflow::gtl::nullopt;
- }
- OperationRequest* op = it->second;
- return &op->request().metadata();
-}
-
-Status UserComputation::RemapEmbeddedComputations(
- const std::map<int64, ComputationHandle>& old_to_new) {
- auto update = [&old_to_new](ComputationHandle* to_update) -> Status {
- int64 old = to_update->handle();
- auto it = old_to_new.find(old);
- if (it == old_to_new.end()) {
- string mapping = tensorflow::str_util::Join(
- old_to_new, ", ",
- [](string* out, std::pair<int64, ComputationHandle> element) {
- tensorflow::strings::Appendf(out, "%lld:%lld", element.first,
- element.second.handle());
- });
- return NotFound(
- "could not find referenced (old) computation handle in mapping: "
- "%lld; mapping: {%s}",
- old, mapping.c_str());
- }
- VLOG(2) << "remapping " << old << " to " << it->second.handle();
- *to_update = it->second;
- return Status::OK();
- };
- TF_RETURN_IF_ERROR(update(session_computation_.mutable_computation_handle()));
- for (auto& handle_request : *session_computation_.mutable_requests()) {
- OperationRequest& request = handle_request.second;
- switch (request.request().op_case()) {
- case OpRequest::kCallRequest: {
- TF_RET_CHECK(1 == request.embedded_computation_versions_size());
- CallRequest* call_request =
- request.mutable_request()->mutable_call_request();
- TF_RETURN_IF_ERROR(update(call_request->mutable_to_apply()));
- break;
- }
- case OpRequest::kMapRequest: {
- TF_RET_CHECK(1 == request.embedded_computation_versions_size());
- MapRequest* map_request =
- request.mutable_request()->mutable_map_request();
- TF_RETURN_IF_ERROR(update(map_request->mutable_to_apply()));
- break;
- }
- case OpRequest::kReduceRequest: {
- TF_RET_CHECK(1 == request.embedded_computation_versions_size());
- ReduceRequest* reduce_request =
- request.mutable_request()->mutable_reduce_request();
- TF_RETURN_IF_ERROR(update(reduce_request->mutable_to_apply()));
- break;
- }
- case OpRequest::kReduceWindowRequest: {
- TF_RET_CHECK(1 == request.embedded_computation_versions_size());
- ReduceWindowRequest* reduce_window_request =
- request.mutable_request()->mutable_reduce_window_request();
- TF_RETURN_IF_ERROR(update(reduce_window_request->mutable_to_apply()));
- break;
- }
- case OpRequest::kSelectAndScatterRequest: {
- TF_RET_CHECK(2 == request.embedded_computation_versions_size());
- SelectAndScatterRequest* select_and_scatter_request =
- request.mutable_request()->mutable_select_and_scatter_request();
- TF_RETURN_IF_ERROR(
- update(select_and_scatter_request->mutable_select()));
- TF_RETURN_IF_ERROR(
- update(select_and_scatter_request->mutable_scatter()));
- break;
- }
- case OpRequest::kWhileRequest: {
- TF_RET_CHECK(2 == request.embedded_computation_versions_size());
- WhileRequest* while_request =
- request.mutable_request()->mutable_while_request();
- TF_RETURN_IF_ERROR(update(while_request->mutable_condition()));
- TF_RETURN_IF_ERROR(update(while_request->mutable_body()));
- break;
- }
- case OpRequest::kConditionalRequest: {
- TF_RET_CHECK(2 == request.embedded_computation_versions_size());
- ConditionalRequest* conditional_request =
- request.mutable_request()->mutable_conditional_request();
- TF_RETURN_IF_ERROR(
- update(conditional_request->mutable_true_computation()));
- TF_RETURN_IF_ERROR(
- update(conditional_request->mutable_false_computation()));
- break;
- }
- default:
- // No embedded computation.
- TF_RET_CHECK(0 == request.embedded_computation_versions_size());
- break;
- }
- }
- return Status::OK();
-}
-
-SessionComputation UserComputation::CloneSessionComputation(
- VersionedComputationHandle::Version version) const {
- tensorflow::mutex_lock lock(mutex_);
- SessionComputation result = session_computation_;
- // Erase all the requests that exceed the version specified.
- // There's no lower_bound method on tensorflow::protobuf::Map so we iterate
- // all the elements.
- auto it = result.mutable_requests()->begin();
- while (it != result.mutable_requests()->end()) {
- if (it->first > version) {
- it = result.mutable_requests()->erase(it);
- } else {
- ++it;
- }
- }
- return result;
-}
-
-StatusOr<const OperationRequest*> UserComputation::LookUpRequest(
- const ComputationDataHandle& handle) const {
- int64 handle_value = handle.handle();
- if (session_computation_.requests().count(handle_value) == 0) {
- return InvalidArgument("no ComputationDataHandle value %lld", handle_value);
- }
- return &session_computation_.requests().at(handle_value);
-}
-
-Status UserComputation::CheckParametersAreContiguous(
- VersionedComputationHandle::Version version) const {
- TF_RET_CHECK(version > 0 && version < next_handle_value_);
-
- // Determine number of parameter inputs at the given version.
- std::map<int64, const ParameterRequest*> parameter_requests;
- for (int64 request_num = 1; request_num <= version; ++request_num) {
- const OperationRequest& request =
- session_computation_.requests().at(request_num);
-
- if (request.request().op_case() == OpRequest::kParameterRequest) {
- const ParameterRequest& parameter_request =
- request.request().parameter_request();
- // Duplicate parameters should be checked when parameter requests are
- // added.
- TF_RET_CHECK(0 ==
- parameter_requests.count(parameter_request.parameter()));
- parameter_requests[parameter_request.parameter()] = &parameter_request;
- }
- }
-
- for (int64 i = 0; i < parameter_requests.size(); ++i) {
- auto it = parameter_requests.find(i);
- if (it == parameter_requests.end()) {
- return FailedPrecondition(
- "computation %s does not have all its parameters populated "
- "sequentially, missing parameter %lld",
- name_.c_str(), i);
- }
- }
-
- return Status::OK();
-}
-
-namespace {
-
-// Helper class which builds an HLO computation from a SessionComputation. To
-// construct the HLO computation, the SessionComputation graph is walked in
-// DFS order lowering each OperationRequest to an HLO instruction.
-class ComputationLowerer {
- public:
- static StatusOr<std::unique_ptr<HloComputation>> Lower(
- const string& computation_name,
- const SessionComputation& session_computation,
- VersionedComputationHandle::Version version,
- UserComputation::HloComputationResolver hlo_resolver,
- const DebugOptions& debug_options,
- bool include_unreachable_instructions) {
- ComputationLowerer lowerer(computation_name, session_computation, version,
- std::move(hlo_resolver), debug_options,
- include_unreachable_instructions);
- return lowerer.Lower();
- }
-
- private:
- ComputationLowerer(const string& computation_name,
- const SessionComputation& session_computation,
- VersionedComputationHandle::Version version,
- UserComputation::HloComputationResolver hlo_resolver,
- const DebugOptions& debug_options,
- bool include_unreachable_instructions)
- : hlo_builder_(computation_name),
- session_computation_(session_computation),
- version_(version),
- hlo_resolver_(std::move(hlo_resolver)),
- debug_options_(debug_options),
- include_unreachable_instructions_(include_unreachable_instructions) {}
-
- // Build an HLO computation from the SessionComputation at the given
- // version.
- StatusOr<std::unique_ptr<HloComputation>> Lower();
-
- private:
- // Traverses the computation 'root' using a DFS, calling 'visit' in postorder.
- void TraversePostorder(
- const ComputationDataHandle& root,
- std::unordered_map<int64, HloInstruction*>* visited,
- const std::function<void(const ComputationDataHandle&)>& visit);
-
- // DFS visitor of the UserComputation operations which lowers the operations
- // to HLO instructions.
- void Visit(const ComputationDataHandle& handle,
- std::unordered_map<int64, HloInstruction*>* instructions);
-
- // Resolves a ComputationHandle and Version to a previously lowered
- // HloComputation using the hlo_resolver_ function.
- HloComputation* ResolveComputation(
- const ComputationHandle& handle,
- VersionedComputationHandle::Version version);
-
- // This function takes an input value which is being implicitly broadcast into
- // an output shape and figures out the right kBroadcast instruction(s)
- // necessary to replicate the implicit broadcast semantics explicitly.
- HloInstruction* ImplicitBroadcastToExplicitBroadcast(
- HloInstruction* operand, const Shape& output_shape);
-
- HloComputation::Builder hlo_builder_;
- const SessionComputation& session_computation_;
- const VersionedComputationHandle::Version version_;
- const UserComputation::HloComputationResolver hlo_resolver_;
- const DebugOptions& debug_options_;
- const bool include_unreachable_instructions_;
-};
-
-// Calls 'apply' on each operand of 'request'.
-static void ForEachOperand(
- const OperationRequest& request,
- const std::function<void(const ComputationDataHandle& param)>& apply) {
- switch (request.request().op_case()) {
- case OpRequest::kRngRequest: {
- const RngRequest& rng_request = request.request().rng_request();
- for (const ComputationDataHandle& param : rng_request.parameter()) {
- apply(param);
- }
- break;
- }
-
- case OpRequest::kConstantRequest:
- break;
- case OpRequest::kGetTupleElementRequest: {
- const GetTupleElementRequest& get_tuple_element_request =
- request.request().get_tuple_element_request();
- apply(get_tuple_element_request.operand());
- break;
- }
-
- case OpRequest::kSliceRequest: {
- const SliceRequest& slice_request = request.request().slice_request();
- apply(slice_request.operand());
- break;
- }
-
- case OpRequest::kDynamicSliceRequest: {
- const DynamicSliceRequest& dynamic_slice_request =
- request.request().dynamic_slice_request();
- apply(dynamic_slice_request.operand());
- apply(dynamic_slice_request.start_indices());
- break;
- }
-
- case OpRequest::kDynamicUpdateSliceRequest: {
- const DynamicUpdateSliceRequest& dynamic_update_slice_request =
- request.request().dynamic_update_slice_request();
- apply(dynamic_update_slice_request.operand());
- apply(dynamic_update_slice_request.update());
- apply(dynamic_update_slice_request.start_indices());
- break;
- }
-
- case OpRequest::kConcatenateRequest: {
- const ConcatenateRequest& concatenate_request =
- request.request().concatenate_request();
- for (const ComputationDataHandle& handle :
- concatenate_request.operands()) {
- apply(handle);
- }
- break;
- }
-
- case OpRequest::kConvolveRequest: {
- const ConvolveRequest& convolve_request =
- request.request().convolve_request();
- apply(convolve_request.lhs());
- apply(convolve_request.rhs());
- break;
- }
-
- case OpRequest::kFftRequest: {
- const FftRequest& fft_request = request.request().fft_request();
- apply(fft_request.operand());
- break;
- }
-
- case OpRequest::kBatchNormTrainingRequest: {
- const BatchNormTrainingRequest& batch_norm_training_request =
- request.request().batch_norm_training_request();
-
- apply(batch_norm_training_request.operand());
- apply(batch_norm_training_request.scale());
- apply(batch_norm_training_request.offset());
- break;
- }
-
- case OpRequest::kBatchNormInferenceRequest: {
- const BatchNormInferenceRequest& batch_norm_inference_request =
- request.request().batch_norm_inference_request();
-
- apply(batch_norm_inference_request.operand());
- apply(batch_norm_inference_request.scale());
- apply(batch_norm_inference_request.offset());
- apply(batch_norm_inference_request.mean());
- apply(batch_norm_inference_request.variance());
- break;
- }
-
- case OpRequest::kBatchNormGradRequest: {
- const BatchNormGradRequest& batch_norm_grad_request =
- request.request().batch_norm_grad_request();
-
- apply(batch_norm_grad_request.operand());
- apply(batch_norm_grad_request.scale());
- apply(batch_norm_grad_request.mean());
- apply(batch_norm_grad_request.variance());
- apply(batch_norm_grad_request.grad_output());
- break;
- }
-
- case OpRequest::kCrossReplicaSumRequest: {
- const CrossReplicaSumRequest& cross_replica_sum_request =
- request.request().cross_replica_sum_request();
- apply(cross_replica_sum_request.operand());
- break;
- }
-
- case OpRequest::kInfeedRequest:
- break;
-
- case OpRequest::kOutfeedRequest: {
- const OutfeedRequest& outfeed_request =
- request.request().outfeed_request();
- apply(outfeed_request.operand());
- break;
- }
-
- case OpRequest::kMapRequest: {
- const MapRequest& map_request = request.request().map_request();
- for (const ComputationDataHandle& handle : map_request.operands()) {
- apply(handle);
- }
- break;
- }
-
- case OpRequest::kReduceRequest: {
- const ReduceRequest& reduce_request = request.request().reduce_request();
- apply(reduce_request.operand());
- apply(reduce_request.init_value());
- break;
- }
-
- case OpRequest::kReduceWindowRequest: {
- const ReduceWindowRequest& reduce_window_request =
- request.request().reduce_window_request();
- apply(reduce_window_request.operand());
- apply(reduce_window_request.init_value());
- break;
- }
-
- case OpRequest::kSelectAndScatterRequest: {
- const SelectAndScatterRequest& select_and_scatter_request =
- request.request().select_and_scatter_request();
- apply(select_and_scatter_request.operand());
- apply(select_and_scatter_request.source());
- apply(select_and_scatter_request.init_value());
-
- break;
- }
-
- case OpRequest::kBroadcastRequest: {
- const BroadcastRequest& broadcast_request =
- request.request().broadcast_request();
- apply(broadcast_request.operand());
- break;
- }
-
- case OpRequest::kReshapeRequest: {
- const ReshapeRequest& reshape_request =
- request.request().reshape_request();
- apply(reshape_request.operand());
- break;
- }
-
- case OpRequest::kTransposeRequest: {
- const TransposeRequest& transpose_request =
- request.request().transpose_request();
- apply(transpose_request.operand());
- break;
- }
-
- case OpRequest::kReverseRequest: {
- const ReverseRequest& reverse_request =
- request.request().reverse_request();
- apply(reverse_request.operand());
- break;
- }
-
- case OpRequest::kPadRequest: {
- const PadRequest& pad_request = request.request().pad_request();
- apply(pad_request.operand());
- apply(pad_request.padding_value());
- break;
- }
-
- case OpRequest::kRecvRequest:
- case OpRequest::kParameterRequest:
- break;
-
- case OpRequest::kConvertRequest: {
- const ConvertRequest& convert_request =
- request.request().convert_request();
- apply(convert_request.operand());
- break;
- }
-
- case OpRequest::kBitcastConvertRequest: {
- const ConvertRequest& convert_request =
- request.request().bitcast_convert_request();
- apply(convert_request.operand());
- break;
- }
-
- case OpRequest::kWhileRequest: {
- const WhileRequest& while_request = request.request().while_request();
- apply(while_request.init());
- break;
- }
-
- case OpRequest::kConditionalRequest: {
- const ConditionalRequest& conditional_request =
- request.request().conditional_request();
- apply(conditional_request.predicate());
- apply(conditional_request.true_operand());
- apply(conditional_request.false_operand());
- break;
- }
-
- case OpRequest::kTernaryOpRequest: {
- const TernaryOpRequest& ternary_op_request =
- request.request().ternary_op_request();
- apply(ternary_op_request.lhs());
- apply(ternary_op_request.rhs());
- apply(ternary_op_request.ehs());
- break;
- }
-
- case OpRequest::kVariadicOpRequest: {
- const VariadicOpRequest& variadic_op_request =
- request.request().variadic_op_request();
- for (const ComputationDataHandle& handle :
- variadic_op_request.operands()) {
- apply(handle);
- }
- break;
- }
-
- case OpRequest::kCallRequest: {
- const CallRequest& call_request = request.request().call_request();
- for (const ComputationDataHandle& handle : call_request.operands()) {
- apply(handle);
- }
- break;
- }
-
- case OpRequest::kCustomCallRequest: {
- const CustomCallRequest& cc_request =
- request.request().custom_call_request();
- for (const ComputationDataHandle& operand : cc_request.operands()) {
- apply(operand);
- }
- break;
- }
-
- case OpRequest::kHostComputeRequest: {
- const HostComputeRequest& hc_request =
- request.request().host_compute_request();
- for (const ComputationDataHandle& operand : hc_request.operands()) {
- apply(operand);
- }
- break;
- }
-
- case OpRequest::kDotRequest: {
- const DotRequest& dot_request = request.request().dot_request();
- apply(dot_request.rhs());
- apply(dot_request.lhs());
- break;
- }
-
- case OpRequest::kUnaryOpRequest: {
- const UnaryOpRequest& unary_op_request =
- request.request().unary_op_request();
- apply(unary_op_request.operand());
- break;
- }
-
- case OpRequest::kBinaryOpRequest: {
- const BinaryOpRequest& binary_op_request =
- request.request().binary_op_request();
- apply(binary_op_request.rhs());
- apply(binary_op_request.lhs());
- break;
- }
-
- case OpRequest::kReducePrecisionRequest: {
- const ReducePrecisionRequest& reduce_precision_request =
- request.request().reduce_precision_request();
- apply(reduce_precision_request.operand());
- break;
- }
-
- case OpRequest::kTraceRequest: {
- const TraceRequest& trace_request = request.request().trace_request();
- apply(trace_request.operand());
- break;
- }
-
- case OpRequest::kSendRequest: {
- const SendRequest& send_request = request.request().send_request();
- apply(send_request.operand());
- break;
- }
-
- case OpRequest::kGatherRequest: {
- const GatherRequest& gather_request = request.request().gather_request();
- apply(gather_request.input());
- apply(gather_request.gather_indices());
- break;
- }
-
- case OpRequest::OP_NOT_SET:
- LOG(FATAL) << "OperationRequest doesn't contain a request";
-
- default:
- LOG(FATAL) << "Unexpected request type: " << request.request().op_case();
- }
-}
-
-void ComputationLowerer::TraversePostorder(
- const ComputationDataHandle& root,
- std::unordered_map<int64, HloInstruction*>* visited,
- const std::function<void(const ComputationDataHandle&)>& visit) {
- // Stack containing {handle, enter} pairs. The 'enter' value describes whether
- // we are entering or leaving 'handle'.
- std::stack<std::pair<ComputationDataHandle, bool>> work;
- work.push({root, true});
- while (!work.empty()) {
- ComputationDataHandle handle;
- bool enter;
- std::tie(handle, enter) = work.top();
- work.pop();
-
- if (enter) {
- // We are entering 'handle'. The first time we enter 'handle', we add it
- // to 'visited' with a nullptr value. If 'handle' is already in 'visited',
- // we do not visit it again. This algorithm only uses the presence of
- // a handle in 'visited', but we use a map so we can use the same data
- // structure to store the HloInstruction outputs.
- if (visited->emplace(handle.handle(), nullptr).second) {
- const OperationRequest& request =
- session_computation_.requests().at(handle.handle());
- // Push the corresponding 'leave' action onto the stack, followed by
- // the operands.
- work.push({handle, false});
- ForEachOperand(request, [&work](const ComputationDataHandle& child) {
- work.push({child, true});
- });
- }
- } else {
- // We are leaving 'handle'. We have visited the operands of 'handle', and
- // now can visit the 'handle' itself.
- visit(handle);
- }
- }
-}
-
-StatusOr<std::unique_ptr<HloComputation>> ComputationLowerer::Lower() {
- // Map from ComputationDataHandle to HLO instruction. Serves as a record of
- // which operations have been visited as well as a cache for looking up
- // ComputationDataHandles as HloInstructions.
- std::unordered_map<int64, HloInstruction*> instructions;
-
- TF_ASSIGN_OR_RETURN(const OperationRequest* root_request,
- GetRoot(version_, session_computation_));
-
- auto visit = [&](const ComputationDataHandle& handle) {
- Visit(handle, &instructions);
- };
- TraversePostorder(root_request->output_handle(), &instructions, visit);
- HloInstruction* hlo_root =
- instructions.at(root_request->output_handle().handle());
-
- if (include_unreachable_instructions_) {
- // Iterate through all computation data handles, and visit any unvisited
- // operations.
- for (int64 request_num = 1; request_num <= version_; ++request_num) {
- TF_ASSIGN_OR_RETURN(const OperationRequest* request,
- LookUpRequest(request_num, session_computation_));
- TraversePostorder(request->output_handle(), &instructions, visit);
- }
- }
-
- return hlo_builder_.Build(hlo_root);
-}
-
-HloComputation* ComputationLowerer::ResolveComputation(
- const ComputationHandle& handle,
- VersionedComputationHandle::Version version) {
- const VersionedComputationHandle checked_handle = {handle, version};
- return hlo_resolver_(checked_handle);
-}
-
-HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast(
- HloInstruction* operand, const Shape& output_shape) {
- auto fadd = [this](std::unique_ptr<HloInstruction> x) {
- return hlo_builder_.AddInstruction(std::move(x));
- };
- return fadd(
- HloInstruction::CreateBroadcastSequence(output_shape, operand, fadd));
-}
-
-void ComputationLowerer::Visit(
- const ComputationDataHandle& handle,
- std::unordered_map<int64, HloInstruction*>* instructions) {
- CHECK_LE(handle.handle(), version_);
- CHECK(instructions->at(handle.handle()) == nullptr);
- const OperationRequest& request =
- session_computation_.requests().at(handle.handle());
- auto add_instruction = [&](std::unique_ptr<HloInstruction> instruction) {
- HloInstruction* hlo_instruction =
- hlo_builder_.AddInstruction(std::move(instruction));
- hlo_instruction->set_metadata(request.request().metadata());
- if (request.request().has_sharding()) {
- OpSharding op_sharding = request.request().sharding();
- hlo_instruction->set_sharding(
- HloSharding::FromProto(op_sharding).ValueOrDie());
- }
- return hlo_instruction;
- };
- auto lookup_instruction = [&](const ComputationDataHandle& handle) {
- return instructions->at(handle.handle());
- };
- HloInstruction* hlo_instruction;
- switch (request.request().op_case()) {
- case OpRequest::kRngRequest: {
- const RngRequest& rng_request = request.request().rng_request();
- std::vector<HloInstruction*> parameters;
- for (const ComputationDataHandle& param : rng_request.parameter()) {
- parameters.push_back(lookup_instruction(param));
- }
- hlo_instruction = add_instruction(HloInstruction::CreateRng(
- request.output_shape(), rng_request.distribution(), parameters));
- break;
- }
-
- case OpRequest::kConstantRequest: {
- const ConstantRequest& constant_request =
- request.request().constant_request();
- hlo_instruction = add_instruction(HloInstruction::CreateConstant(
- Literal::CreateFromProto(constant_request.literal())
- .ConsumeValueOrDie()));
- break;
- }
-
- case OpRequest::kGetTupleElementRequest: {
- const GetTupleElementRequest& get_tuple_element_request =
- request.request().get_tuple_element_request();
- HloInstruction* operand =
- lookup_instruction(get_tuple_element_request.operand());
- hlo_instruction = add_instruction(HloInstruction::CreateGetTupleElement(
- request.output_shape(), operand, get_tuple_element_request.index()));
- break;
- }
-
- case OpRequest::kSliceRequest: {
- const SliceRequest& slice_request = request.request().slice_request();
- HloInstruction* operand = lookup_instruction(slice_request.operand());
- hlo_instruction = add_instruction(HloInstruction::CreateSlice(
- request.output_shape(), operand,
- AsInt64Slice(slice_request.start_indices()),
- AsInt64Slice(slice_request.limit_indices()),
- AsInt64Slice(slice_request.strides())));
- break;
- }
-
- case OpRequest::kDynamicSliceRequest: {
- const DynamicSliceRequest& dynamic_slice_request =
- request.request().dynamic_slice_request();
- HloInstruction* operand =
- lookup_instruction(dynamic_slice_request.operand());
- HloInstruction* start_indices =
- lookup_instruction(dynamic_slice_request.start_indices());
-
- hlo_instruction = add_instruction(HloInstruction::CreateDynamicSlice(
- request.output_shape(), operand, start_indices,
- AsInt64Slice(dynamic_slice_request.slice_sizes())));
- break;
- }
-
- case OpRequest::kDynamicUpdateSliceRequest: {
- const DynamicUpdateSliceRequest& dynamic_update_slice_request =
- request.request().dynamic_update_slice_request();
- HloInstruction* operand =
- lookup_instruction(dynamic_update_slice_request.operand());
- HloInstruction* update =
- lookup_instruction(dynamic_update_slice_request.update());
- HloInstruction* start_indices =
- lookup_instruction(dynamic_update_slice_request.start_indices());
- hlo_instruction =
- add_instruction(HloInstruction::CreateDynamicUpdateSlice(
- request.output_shape(), operand, update, start_indices));
- break;
- }
-
- case OpRequest::kConcatenateRequest: {
- const ConcatenateRequest& concatenate_request =
- request.request().concatenate_request();
- std::vector<HloInstruction*> operands;
- for (const ComputationDataHandle& handle :
- concatenate_request.operands()) {
- HloInstruction* operand = lookup_instruction(handle);
- operands.push_back(operand);
- }
- hlo_instruction = add_instruction(HloInstruction::CreateConcatenate(
- request.output_shape(), operands, concatenate_request.dimension()));
- break;
- }
-
- case OpRequest::kConvolveRequest: {
- const ConvolveRequest& convolve_request =
- request.request().convolve_request();
- HloInstruction* lhs = lookup_instruction(convolve_request.lhs());
- HloInstruction* rhs = lookup_instruction(convolve_request.rhs());
- hlo_instruction = add_instruction(HloInstruction::CreateConvolve(
- request.output_shape(), lhs, rhs, convolve_request.window(),
- convolve_request.dimension_numbers()));
- break;
- }
-
- case OpRequest::kFftRequest: {
- const FftRequest& fft_request = request.request().fft_request();
- HloInstruction* operand = lookup_instruction(fft_request.operand());
- hlo_instruction = add_instruction(HloInstruction::CreateFft(
- request.output_shape(), operand, fft_request.fft_type(),
- AsInt64Slice(fft_request.fft_length())));
- break;
- }
-
- case OpRequest::kDotRequest: {
- const DotRequest& dot_request = request.request().dot_request();
- HloInstruction* lhs = lookup_instruction(dot_request.lhs());
- HloInstruction* rhs = lookup_instruction(dot_request.rhs());
- hlo_instruction = add_instruction(HloInstruction::CreateDot(
- request.output_shape(), lhs, rhs, dot_request.dimension_numbers()));
- break;
- }
-
- case OpRequest::kCrossReplicaSumRequest: {
- const CrossReplicaSumRequest& cross_replica_sum_request =
- request.request().cross_replica_sum_request();
- HloInstruction* operand =
- lookup_instruction(cross_replica_sum_request.operand());
- hlo_instruction = add_instruction(HloInstruction::CreateCrossReplicaSum(
- request.output_shape(), {operand}));
- break;
- }
-
- case OpRequest::kInfeedRequest: {
- const InfeedRequest& infeed_request = request.request().infeed_request();
- hlo_instruction = add_instruction(HloInstruction::CreateInfeed(
- request.output_shape(), infeed_request.config()));
- break;
- }
-
- case OpRequest::kOutfeedRequest: {
- const OutfeedRequest& outfeed_request =
- request.request().outfeed_request();
- HloInstruction* operand = lookup_instruction(outfeed_request.operand());
- hlo_instruction = add_instruction(HloInstruction::CreateOutfeed(
- outfeed_request.shape(), operand, outfeed_request.outfeed_config()));
- break;
- }
-
- case OpRequest::kMapRequest: {
- const MapRequest& map_request = request.request().map_request();
- std::vector<HloInstruction*> operands;
- for (const ComputationDataHandle& handle : map_request.operands()) {
- HloInstruction* operand = lookup_instruction(handle);
- operands.push_back(operand);
- }
- CHECK_EQ(1, request.embedded_computation_versions_size());
- VersionedComputationHandle::Version map_version =
- request.embedded_computation_versions(0);
- HloComputation* map_computation =
- ResolveComputation(map_request.to_apply(), map_version);
- hlo_instruction = add_instruction(HloInstruction::CreateMap(
- request.output_shape(), operands, map_computation));
- break;
- }
-
- case OpRequest::kReduceRequest: {
- const ReduceRequest& reduce_request = request.request().reduce_request();
- HloInstruction* operand = lookup_instruction(reduce_request.operand());
- HloInstruction* init_value =
- lookup_instruction(reduce_request.init_value());
- CHECK_EQ(1, request.embedded_computation_versions_size());
- VersionedComputationHandle::Version reduce_version =
- request.embedded_computation_versions(0);
- HloComputation* reduce_computation =
- ResolveComputation(reduce_request.to_apply(), reduce_version);
- hlo_instruction = add_instruction(HloInstruction::CreateReduce(
- request.output_shape(), operand, init_value,
- AsInt64Slice(reduce_request.dimensions()), reduce_computation));
- break;
- }
-
- case OpRequest::kReduceWindowRequest: {
- const ReduceWindowRequest& reduce_window_request =
- request.request().reduce_window_request();
- HloInstruction* operand =
- lookup_instruction(reduce_window_request.operand());
- HloInstruction* init_value =
- lookup_instruction(reduce_window_request.init_value());
- CHECK_EQ(1, request.embedded_computation_versions_size());
- VersionedComputationHandle::Version reduce_window_version =
- request.embedded_computation_versions(0);
- HloComputation* reduce_window_computation = ResolveComputation(
- reduce_window_request.to_apply(), reduce_window_version);
- hlo_instruction = add_instruction(HloInstruction::CreateReduceWindow(
- request.output_shape(), operand, init_value,
- reduce_window_request.window(), reduce_window_computation));
- break;
- }
-
- case OpRequest::kSelectAndScatterRequest: {
- const SelectAndScatterRequest& select_and_scatter_request =
- request.request().select_and_scatter_request();
- HloInstruction* operand =
- lookup_instruction(select_and_scatter_request.operand());
- HloInstruction* source =
- lookup_instruction(select_and_scatter_request.source());
- HloInstruction* init_value =
- lookup_instruction(select_and_scatter_request.init_value());
- CHECK_EQ(2, request.embedded_computation_versions_size());
- VersionedComputationHandle::Version select_version =
- request.embedded_computation_versions(0);
- VersionedComputationHandle::Version scatter_version =
- request.embedded_computation_versions(1);
- HloComputation* select_computation = ResolveComputation(
- select_and_scatter_request.select(), select_version);
- HloComputation* scatter_computation = ResolveComputation(
- select_and_scatter_request.scatter(), scatter_version);
- hlo_instruction = add_instruction(HloInstruction::CreateSelectAndScatter(
- request.output_shape(), operand, select_computation,
- select_and_scatter_request.window(), source, init_value,
- scatter_computation));
- break;
- }
-
- case OpRequest::kBatchNormTrainingRequest: {
- const BatchNormTrainingRequest& batch_norm_training_request =
- request.request().batch_norm_training_request();
- HloInstruction* operand =
- lookup_instruction(batch_norm_training_request.operand());
- HloInstruction* scale =
- lookup_instruction(batch_norm_training_request.scale());
- HloInstruction* offset =
- lookup_instruction(batch_norm_training_request.offset());
-
- hlo_instruction = add_instruction(HloInstruction::CreateBatchNormTraining(
- request.output_shape(), operand, scale, offset,
- batch_norm_training_request.epsilon(),
- batch_norm_training_request.feature_index()));
- break;
- }
-
- case OpRequest::kBatchNormInferenceRequest: {
- const BatchNormInferenceRequest& batch_norm_inference_request =
- request.request().batch_norm_inference_request();
- HloInstruction* operand =
- lookup_instruction(batch_norm_inference_request.operand());
- HloInstruction* scale =
- lookup_instruction(batch_norm_inference_request.scale());
- HloInstruction* offset =
- lookup_instruction(batch_norm_inference_request.offset());
- HloInstruction* mean =
- lookup_instruction(batch_norm_inference_request.mean());
- HloInstruction* variance =
- lookup_instruction(batch_norm_inference_request.variance());
-
- hlo_instruction =
- add_instruction(HloInstruction::CreateBatchNormInference(
- request.output_shape(), operand, scale, offset, mean, variance,
- batch_norm_inference_request.epsilon(),
- batch_norm_inference_request.feature_index()));
- break;
- }
-
- case OpRequest::kBatchNormGradRequest: {
- const BatchNormGradRequest& batch_norm_grad_request =
- request.request().batch_norm_grad_request();
-
- HloInstruction* operand =
- lookup_instruction(batch_norm_grad_request.operand());
- HloInstruction* scale =
- lookup_instruction(batch_norm_grad_request.scale());
- HloInstruction* mean = lookup_instruction(batch_norm_grad_request.mean());
- HloInstruction* variance =
- lookup_instruction(batch_norm_grad_request.variance());
- HloInstruction* grad_output =
- lookup_instruction(batch_norm_grad_request.grad_output());
-
- hlo_instruction = add_instruction(HloInstruction::CreateBatchNormGrad(
- request.output_shape(), operand, scale, mean, variance, grad_output,
- batch_norm_grad_request.epsilon(),
- batch_norm_grad_request.feature_index()));
- break;
- }
-
- case OpRequest::kBroadcastRequest: {
- const BroadcastRequest& broadcast_request =
- request.request().broadcast_request();
- HloInstruction* operand = lookup_instruction(broadcast_request.operand());
- std::vector<int64> broadcast_dimensions;
- // The client-level broadcast instruction just appends dimensions on the
- // left (adds lowest numbered dimensions). The HLO broadcast op is more
- // flexible and can add new dimensions anywhere. The broadcast_dimensions
- // maps operand dimensions to dimensions in the broadcast output, so
- // to append dimensions on the left the broadcast_dimensions should just
- // be the n highest dimension numbers of the output shape where n is
- // the number of input dimensions.
- broadcast_dimensions.reserve(ShapeUtil::Rank(operand->shape()));
- for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) {
- broadcast_dimensions.push_back(i +
- ShapeUtil::Rank(request.output_shape()) -
- ShapeUtil::Rank(operand->shape()));
- }
- hlo_instruction = add_instruction(HloInstruction::CreateBroadcast(
- request.output_shape(), operand, broadcast_dimensions));
- break;
- }
-
- case OpRequest::kReshapeRequest: {
- const ReshapeRequest& reshape_request =
- request.request().reshape_request();
- HloInstruction* operand = lookup_instruction(reshape_request.operand());
- HloInstruction* transposed;
- if (IsIdentityPermutation(AsInt64Slice(reshape_request.dimensions()))) {
- transposed = operand;
- } else {
- transposed = add_instruction(HloInstruction::CreateTranspose(
- ShapeUtil::PermuteDimensions(
- InversePermutation(AsInt64Slice(reshape_request.dimensions())),
- operand->shape()),
- operand, AsInt64Slice(reshape_request.dimensions())));
- }
- hlo_instruction = add_instruction(
- HloInstruction::CreateReshape(request.output_shape(), transposed));
- break;
- }
-
- case OpRequest::kTransposeRequest: {
- const TransposeRequest& transpose_request =
- request.request().transpose_request();
- HloInstruction* operand = lookup_instruction(transpose_request.operand());
- hlo_instruction = add_instruction(HloInstruction::CreateTranspose(
- ShapeUtil::PermuteDimensions(
- InversePermutation(AsInt64Slice(transpose_request.dimensions())),
- operand->shape()),
- operand, AsInt64Slice(transpose_request.dimensions())));
- break;
- }
-
- case OpRequest::kReverseRequest: {
- const ReverseRequest& reverse_request =
- request.request().reverse_request();
- HloInstruction* operand = lookup_instruction(reverse_request.operand());
- hlo_instruction = add_instruction(HloInstruction::CreateReverse(
- request.output_shape(), operand,
- AsInt64Slice(reverse_request.dimensions())));
- break;
- }
-
- case OpRequest::kPadRequest: {
- const PadRequest& pad_request = request.request().pad_request();
- HloInstruction* operand = lookup_instruction(pad_request.operand());
- HloInstruction* padding_value =
- lookup_instruction(pad_request.padding_value());
- hlo_instruction = add_instruction(HloInstruction::CreatePad(
- request.output_shape(), operand, padding_value,
- pad_request.padding_config()));
- break;
- }
-
- case OpRequest::kRecvRequest: {
- const RecvRequest& recv_request = request.request().recv_request();
- HloInstruction* recv = add_instruction(HloInstruction::CreateRecv(
- request.output_shape(), recv_request.channel_handle().handle()));
- hlo_instruction = add_instruction(HloInstruction::CreateRecvDone(recv));
- break;
- }
-
- case OpRequest::kParameterRequest: {
- const ParameterRequest& parameter_request =
- request.request().parameter_request();
- hlo_instruction = add_instruction(HloInstruction::CreateParameter(
- parameter_request.parameter(), request.output_shape(),
- parameter_request.name()));
- break;
- }
-
- case OpRequest::kConvertRequest: {
- const ConvertRequest& convert_request =
- request.request().convert_request();
- HloInstruction* operand = lookup_instruction(convert_request.operand());
- hlo_instruction = add_instruction(
- HloInstruction::CreateConvert(request.output_shape(), operand));
- break;
- }
-
- case OpRequest::kBitcastConvertRequest: {
- const ConvertRequest& convert_request =
- request.request().bitcast_convert_request();
- HloInstruction* operand = lookup_instruction(convert_request.operand());
- hlo_instruction = add_instruction(HloInstruction::CreateBitcastConvert(
- request.output_shape(), operand));
- break;
- }
-
- case OpRequest::kWhileRequest: {
- const WhileRequest& while_request = request.request().while_request();
- CHECK_EQ(2, request.embedded_computation_versions_size());
- VersionedComputationHandle::Version condition_version =
- request.embedded_computation_versions(0);
- HloComputation* condition =
- ResolveComputation(while_request.condition(), condition_version);
- VersionedComputationHandle::Version body_version =
- request.embedded_computation_versions(1);
- HloComputation* body =
- ResolveComputation(while_request.body(), body_version);
- HloInstruction* init = lookup_instruction(while_request.init());
- hlo_instruction = add_instruction(HloInstruction::CreateWhile(
- request.output_shape(), condition, body, init));
- break;
- }
-
- case OpRequest::kConditionalRequest: {
- const ConditionalRequest& conditional_request =
- request.request().conditional_request();
- CHECK_EQ(2, request.embedded_computation_versions_size());
- VersionedComputationHandle::Version true_computation_version =
- request.embedded_computation_versions(0);
- HloComputation* true_computation = ResolveComputation(
- conditional_request.true_computation(), true_computation_version);
- VersionedComputationHandle::Version false_computation_version =
- request.embedded_computation_versions(1);
- HloComputation* false_computation = ResolveComputation(
- conditional_request.false_computation(), false_computation_version);
- HloInstruction* predicate =
- lookup_instruction(conditional_request.predicate());
- HloInstruction* true_operand =
- lookup_instruction(conditional_request.true_operand());
- HloInstruction* false_operand =
- lookup_instruction(conditional_request.false_operand());
- hlo_instruction = add_instruction(HloInstruction::CreateConditional(
- request.output_shape(), predicate, true_operand, true_computation,
- false_operand, false_computation));
- break;
- }
-
- case OpRequest::kTernaryOpRequest: {
- const TernaryOpRequest& ternary_op_request =
- request.request().ternary_op_request();
- HloInstruction* lhs = lookup_instruction(ternary_op_request.lhs());
- HloInstruction* rhs = lookup_instruction(ternary_op_request.rhs());
- HloInstruction* ehs = lookup_instruction(ternary_op_request.ehs());
- auto hlo_opcode = TernaryOperationToHloOpcode(ternary_op_request.triop());
- if (debug_options_.xla_eliminate_hlo_implicit_broadcast() &&
- !ShapeUtil::IsTuple(request.output_shape())) {
- if (!ShapeUtil::IsTuple(lhs->shape()) &&
- !ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) {
- // lhs side is being implicitly broadcast. Change to explicit.
- lhs =
- ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape());
- }
-
- if (!ShapeUtil::IsTuple(rhs->shape()) &&
- !ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) {
- rhs =
- ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape());
- }
-
- if (!ShapeUtil::IsTuple(ehs->shape()) &&
- !ShapeUtil::SameDimensions(request.output_shape(), ehs->shape())) {
- ehs =
- ImplicitBroadcastToExplicitBroadcast(ehs, request.output_shape());
- }
- }
-
- hlo_instruction = add_instruction(HloInstruction::CreateTernary(
- request.output_shape(), hlo_opcode, lhs, rhs, ehs));
- break;
- }
-
- case OpRequest::kVariadicOpRequest: {
- const VariadicOpRequest& variadic_op_request =
- request.request().variadic_op_request();
- std::vector<HloInstruction*> operands;
- for (const ComputationDataHandle& handle :
- variadic_op_request.operands()) {
- HloInstruction* operand = lookup_instruction(handle);
- operands.push_back(operand);
- }
- auto hlo_opcode =
- VariadicOperationToHloOpcode(variadic_op_request.varop());
- hlo_instruction = add_instruction(HloInstruction::CreateVariadic(
- request.output_shape(), hlo_opcode, operands));
- break;
- }
-
- case OpRequest::kCallRequest: {
- const CallRequest& call_request = request.request().call_request();
- std::vector<HloInstruction*> operands;
- for (const ComputationDataHandle& handle : call_request.operands()) {
- operands.push_back(lookup_instruction(handle));
- }
- CHECK_EQ(1, request.embedded_computation_versions_size());
- VersionedComputationHandle::Version call_version =
- request.embedded_computation_versions(0);
- HloComputation* call_computation =
- ResolveComputation(call_request.to_apply(), call_version);
- hlo_instruction = add_instruction(HloInstruction::CreateCall(
- request.output_shape(), operands, call_computation));
- break;
- }
-
- case OpRequest::kCustomCallRequest: {
- const CustomCallRequest& cc_request =
- request.request().custom_call_request();
- std::vector<HloInstruction*> operands;
- for (const ComputationDataHandle& operand : cc_request.operands()) {
- operands.push_back(lookup_instruction(operand));
- }
- hlo_instruction = add_instruction(HloInstruction::CreateCustomCall(
- cc_request.shape(), operands, cc_request.call_target_name()));
- break;
- }
-
- case OpRequest::kHostComputeRequest: {
- const HostComputeRequest& host_compute_request =
- request.request().host_compute_request();
- std::vector<HloInstruction*> operands;
- for (const ComputationDataHandle& operand :
- host_compute_request.operands()) {
- operands.push_back(lookup_instruction(operand));
- }
- auto output_shape = host_compute_request.shape();
- auto channel_name = host_compute_request.channel_name();
- auto cost_estimate_ns = host_compute_request.cost_estimate_ns();
- hlo_instruction = add_instruction(HloInstruction::CreateHostCompute(
- output_shape, operands, channel_name, cost_estimate_ns));
- break;
- }
-
- case OpRequest::kUnaryOpRequest: {
- const UnaryOpRequest& unary_op_request =
- request.request().unary_op_request();
- HloInstruction* operand = lookup_instruction(unary_op_request.operand());
- auto hlo_opcode = UnaryOperationToHloOpcode(unary_op_request.unop());
- hlo_instruction = add_instruction(HloInstruction::CreateUnary(
- request.output_shape(), hlo_opcode, operand));
- break;
- }
-
- case OpRequest::kBinaryOpRequest: {
- const BinaryOpRequest& binary_op_request =
- request.request().binary_op_request();
- HloInstruction* lhs = lookup_instruction(binary_op_request.lhs());
- HloInstruction* rhs = lookup_instruction(binary_op_request.rhs());
- auto hlo_opcode = BinaryOperationToHloOpcode(binary_op_request.binop());
- if (binary_op_request.broadcast_dimensions_size() > 0 &&
- ShapeUtil::Rank(lhs->shape()) != ShapeUtil::Rank(rhs->shape())) {
- // Emit a broadcast instruction to perform the "broadcast in dimension"
- // operation.
- HloInstruction* operand_to_broadcast =
- ShapeUtil::Rank(lhs->shape()) < ShapeUtil::Rank(rhs->shape()) ? lhs
- : rhs;
- CHECK_EQ(ShapeUtil::Rank(operand_to_broadcast->shape()),
- binary_op_request.broadcast_dimensions().size());
-
- // Construct the bounds of the shape of the kBroadcast instruction
- // responsible for the in-dimension broadcast.
- std::vector<int64> output_dimensions;
- for (int64 size : request.output_shape().dimensions()) {
- output_dimensions.push_back(size);
- }
- for (int64 operand_dim = 0;
- operand_dim < ShapeUtil::Rank(operand_to_broadcast->shape());
- ++operand_dim) {
- int64 output_dim =
- binary_op_request.broadcast_dimensions()[operand_dim];
- output_dimensions[output_dim] =
- operand_to_broadcast->shape().dimensions(operand_dim);
- }
-
- Shape broadcast_shape = ShapeUtil::MakeShape(
- operand_to_broadcast->shape().element_type(), output_dimensions);
-
- // The broadcast semantics of a client-level binary op broadcast is
- // identical to the HLO broadcast semantics so the broadcast_dimensions
- // field can just be passed to the instruction builder.
- HloInstruction* broadcasted_operand =
- add_instruction(HloInstruction::CreateBroadcast(
- broadcast_shape, operand_to_broadcast,
- AsInt64Slice(binary_op_request.broadcast_dimensions())));
-
- lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs;
- rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs;
- }
- if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) {
- if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) {
- // lhs side is being implicitly broadcast. Change to explicit.
- lhs =
- ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape());
- }
-
- if (!ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) {
- rhs =
- ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape());
- }
- }
- hlo_instruction = add_instruction(HloInstruction::CreateBinary(
- request.output_shape(), hlo_opcode, lhs, rhs));
- break;
- }
-
- case OpRequest::kReducePrecisionRequest: {
- const ReducePrecisionRequest& reduce_precision_request =
- request.request().reduce_precision_request();
- HloInstruction* operand =
- lookup_instruction(reduce_precision_request.operand());
- auto exponent_bits = reduce_precision_request.exponent_bits();
- auto mantissa_bits = reduce_precision_request.mantissa_bits();
- hlo_instruction = add_instruction(HloInstruction::CreateReducePrecision(
- request.output_shape(), operand, exponent_bits, mantissa_bits));
- break;
- }
-
- case OpRequest::kTraceRequest: {
- const TraceRequest& trace_request = request.request().trace_request();
- HloInstruction* operand = lookup_instruction(trace_request.operand());
- hlo_instruction = add_instruction(
- HloInstruction::CreateTrace(trace_request.tag(), operand));
- break;
- }
-
- case OpRequest::kSendRequest: {
- const SendRequest& send_request = request.request().send_request();
- HloInstruction* operand = lookup_instruction(send_request.operand());
- HloInstruction* send = add_instruction(HloInstruction::CreateSend(
- operand, send_request.channel_handle().handle()));
- hlo_instruction = add_instruction(HloInstruction::CreateSendDone(send));
- break;
- }
-
- case OpRequest::kGatherRequest: {
- const GatherRequest& gather_request = request.request().gather_request();
- HloInstruction* input_operand =
- lookup_instruction(gather_request.input());
- HloInstruction* gather_indices_operand =
- lookup_instruction(gather_request.gather_indices());
- std::vector<int64> window_bounds;
- c_copy(gather_request.window_bounds(), std::back_inserter(window_bounds));
- hlo_instruction = add_instruction(HloInstruction::CreateGather(
- request.output_shape(), input_operand, gather_indices_operand,
- gather_request.dimension_numbers(), window_bounds));
- break;
- }
-
- case OpRequest::OP_NOT_SET:
- LOG(FATAL) << "OperationRequest doesn't contain a request";
-
- default:
- LOG(FATAL) << "Unexpected request type: " << request.request().op_case();
- }
- (*instructions)[handle.handle()] = hlo_instruction;
-} // NOLINT(readability/fn_size)
-
-} // namespace
-
-StatusOr<std::unique_ptr<HloComputation>> UserComputation::BuildHloComputation(
- VersionedComputationHandle::Version version,
- HloComputationResolver hlo_resolver, const DebugOptions& debug_options,
- bool include_unreachable_instructions) const {
- tensorflow::mutex_lock lock(mutex_);
-
- VLOG(2) << "Building HloComputation from UserComputation " << name_
- << " at version " << version;
- XLA_VLOG_LINES(3, session_computation_.DebugString());
-
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<HloComputation> hlo_computation,
- ComputationLowerer::Lower(
- tensorflow::strings::StrCat(name(), ".v", version),
- session_computation_, version, std::move(hlo_resolver), debug_options,
- include_unreachable_instructions));
-
- return std::move(hlo_computation);
-}
-
-} // namespace xla
diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h
deleted file mode 100644
index 5544c868fe..0000000000
--- a/tensorflow/compiler/xla/service/user_computation.h
+++ /dev/null
@@ -1,413 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_
-
-#include <functional>
-#include <map>
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/compiler/xla/service/session.pb.h"
-#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
-#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla.pb.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/thread_annotations.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace xla {
-
-// A UserComputation is the built-up computation that users create via the
-// XLA Service interface.
-//
-// The XLA service adds instructions to a user computation via this
-// interface. The state of the computation is stored as a SessionComputation
-// proto which holds a record of all operation-building requests received by the
-// XLA service.
-//
-// UserComputations are lowered to HloComputations which are passed to the high
-// level compiler interface.
-class UserComputation {
- public:
- // Factory used when restoring a computation from serialized session
- // computation (computation snapshot) data. Remaps any references to
- // computation handle via the old_to_new mapping.
- //
- // An error will occur if the old_to_new mapping cannot resolve a reference to
- // a computation that is present in session_computation.
- static StatusOr<std::unique_ptr<UserComputation>> MakeWithRemapping(
- const SessionComputation& session_computation,
- const ComputationHandle& handle,
- const std::map<int64, ComputationHandle>& old_to_new);
-
- // Creates an empty computation with the given name and computation handle.
- explicit UserComputation(const string& name, const ComputationHandle& handle);
-
- // Enqueues a parameter-retrieving instruction onto this user computation.
- // Returns an error status if the parameter number is already registered with
- // different values.
- StatusOr<ComputationDataHandle> AddParameterInstruction(
- const ParameterRequest& parameter_request);
-
- // Enqueues a pad instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddPadInstruction(
- const PadRequest& pad_request);
-
- // Enqueues a tracing instruction onto this user computation.
- // Returns an error status if the operand cannot be resolved.
- Status AddTraceInstruction(const TraceRequest& trace_request);
-
- // Enqueues a random number generation instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddRngInstruction(
- const RngRequest& rng_request);
-
- // Enqueues a unary instruction onto this user computation.
- // Returns an error status if the operand index is out of bounds.
- StatusOr<ComputationDataHandle> AddUnaryInstruction(
- const UnaryOpRequest& unary_request);
-
- // Enqueues a batch norm training instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddBatchNormTrainingInstruction(
- const BatchNormTrainingRequest& batch_norm_training_request);
-
- // Enqueues a batch norm inference instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddBatchNormInferenceInstruction(
- const BatchNormInferenceRequest& batch_norm_inference_request);
-
- // Enqueues a batch norm grad instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddBatchNormGradInstruction(
- const BatchNormGradRequest& batch_norm_grad_request);
-
- // Enqueues a binary instruction onto this user computation.
- // Returns an error status if the operand indices are out of bounds.
- StatusOr<ComputationDataHandle> AddBinaryInstruction(
- const BinaryOpRequest& binary_request);
-
- // Enqueues a ternary instruction onto this user computation.
- // Returns an error status if the operand indices are out of bounds.
- StatusOr<ComputationDataHandle> AddTernaryInstruction(
- const TernaryOpRequest& ternary_request);
-
- // Enqueues a variadic instruction onto this user computation.
- // Returns an error status if the operand indices are out of bounds.
- StatusOr<ComputationDataHandle> AddVariadicInstruction(
- const VariadicOpRequest& variadic_request);
-
- // Enqueues a constant instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddConstantInstruction(
- const ConstantRequest& constant_request);
-
- // Enqueues a get tuple element instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddGetTupleElementInstruction(
- const GetTupleElementRequest& get_tuple_element_request);
-
- // Enqueues a map instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddMapInstruction(
- const MapRequest& map_request,
- const UserComputation& to_apply_computation);
-
- // Enqueues a reduce-precision instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddReducePrecisionInstruction(
- const ReducePrecisionRequest& reduce_precision_request);
-
- // Enqueues a convolution instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddConvolveInstruction(
- const ConvolveRequest& convolve_request);
-
- // Enqueues an FFT instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddFftInstruction(
- const FftRequest& fft_request);
-
- // Enqueues a cross replica sum instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddCrossReplicaSumInstruction(
- const CrossReplicaSumRequest& cross_replica_sum_request);
-
- // Enqueues an infeed instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddInfeedInstruction(
- const InfeedRequest& infeed_request);
-
- // Enqueues an outfeed instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddOutfeedInstruction(
- const OutfeedRequest& outfeed_request);
-
- // Enqueues a host compute instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddHostComputeInstruction(
- const HostComputeRequest& host_compute_request);
-
- // Enqueues a call instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddCallInstruction(
- const CallRequest& call_request,
- const UserComputation& to_apply_computation);
-
- // Enqueues a custom call instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddCustomCallInstruction(
- const CustomCallRequest& custom_call_request);
-
- // Enqueues a dot instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddDotInstruction(
- const DotRequest& dot_request);
-
- // Enqueues a broadcast instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddBroadcastInstruction(
- const BroadcastRequest& broadcast_request);
-
- // Enqueues a reshape instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddReshapeInstruction(
- const ReshapeRequest& reshape_request);
-
- // Enqueues a transpose instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddTransposeInstruction(
- const TransposeRequest& transpose_request);
-
- // Enqueues a slice instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddSliceInstruction(
- const SliceRequest& slice_request);
-
- // Enqueues a dynamic slice instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddDynamicSliceInstruction(
- const DynamicSliceRequest& dynamic_slice_request);
-
- // Enqueues a dynamic update slice instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddDynamicUpdateSliceInstruction(
- const DynamicUpdateSliceRequest& dynamic_update_slice_request);
-
- // Enqueues a concatenate instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddConcatenateInstruction(
- const ConcatenateRequest& concatenate_request);
-
- // Enqueues a convert instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddConvertInstruction(
- const ConvertRequest& convert_request);
-
- // Enqueues a bitcast element instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddBitcastConvertInstruction(
- const ConvertRequest& convert_request);
-
- // Enqueues a reduce instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddReduceInstruction(
- const ReduceRequest& reduce_request,
- const UserComputation& to_apply_computation);
-
- // Enqueues a windowed reduce instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddReduceWindowInstruction(
- const ReduceWindowRequest& reduce_window_request,
- const UserComputation& to_apply_computation);
-
- // Enqueues a select-and-scatter instruction onto this user
- // computation.
- StatusOr<ComputationDataHandle> AddSelectAndScatterInstruction(
- const SelectAndScatterRequest& select_and_scatter_request,
- const UserComputation& select_computation,
- const UserComputation& scatter_computation);
-
- // Enqueues a reverse instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddReverseInstruction(
- const ReverseRequest& reverse_request);
-
- // Enqueues a while instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddWhileInstruction(
- const WhileRequest& while_request,
- const UserComputation& condition_computation,
- const UserComputation& body_computation);
-
- // Enqueues a conditional instruction on this user computation.
- StatusOr<ComputationDataHandle> AddConditionalInstruction(
- const ConditionalRequest& conditional_request,
- const UserComputation& true_computation,
- const UserComputation& false_computation);
-
- // Enqueues a Send instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddSendInstruction(
- const SendRequest& send_request);
-
- // Enqueues a Recv instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddRecvInstruction(
- const RecvRequest& recv_request);
-
- // Enqueues a Gather instruction onto this user computation.
- StatusOr<ComputationDataHandle> AddGatherInstruction(
- const GatherRequest& gather_request);
-
- // Returns the user-provided name of this user computation, which is provided
- // via the XLA computation-building API.
- const string& name() const { return name_; }
-
- // Subsequent executions of this computation will compute the value
- // represented by handle, rather than the last expression enqueued
- // on the computation.
- Status SetReturnValue(const ComputationDataHandle& handle);
-
- // Return a versioned handle for this computation.
- VersionedComputationHandle GetVersionedHandle() const;
-
- // Return a versioned handle for this computation with a version equal to the
- // point at which given operation was added to the computation.
- VersionedComputationHandle GetVersionedHandleAtOperation(
- const ComputationDataHandle& operation) const;
-
- // Return a version value representing the current state of the
- // computation.
- VersionedComputationHandle::Version version() const;
-
- // Computes and returns the program shape for the user computation -- gathers
- // parameters and result type into a single proto. A shared_ptr is used
- // because the returned pointer refers to an internally cached value which may
- // be discarded by the UserComputation object. This avoid unnecessary copies.
- //
- // If the parameter space is not dense (i.e. there are holes in the parameter
- // numbers provided) then an error status is returned.
- StatusOr<std::shared_ptr<const ProgramShape>> ComputeProgramShape(
- VersionedComputationHandle::Version version) const;
-
- // Returns true if the given data handle does not depend on any parameter with
- // index higher then num_parameters. That is, the value can be computed at
- // compile time if we know the first num_parameters arguments.
- StatusOr<bool> IsConstant(const ComputationDataHandle& handle,
- int64 num_parameters);
-
- // Returns the output shape of the operation indicated by the given handle.
- StatusOr<Shape> GetShape(const ComputationDataHandle& handle);
-
- // Sets metadata on the Hlo instruction referenced by the given handle.
- Status SetOpMetadata(const ComputationDataHandle& handle,
- const OpMetadata& metadata);
-
- // Sets the device assignment on the Hlo instruction referenced by 'handle'.
- Status SetOpSharding(const ComputationDataHandle& handle,
- const OpSharding& sharding);
-
- // Builds a HLO computation from the UserComputation. The parameter "resolver"
- // is a function which returns a pointer to the HloComputation corresponding
- // to the given ComputationHandle at the given version. The resolver is used
- // for operations, such as map, which call other computations and need a
- // pointer to the called HloComputation to construct the respective HLO
- // instructions. If include_unreachable_instructions is true, then
- // instructions which are not reachable from the root are lowered into
- // HloInstructions.
- using HloComputationResolver =
- std::function<HloComputation*(const VersionedComputationHandle& handle)>;
- StatusOr<std::unique_ptr<HloComputation>> BuildHloComputation(
- VersionedComputationHandle::Version version,
- HloComputationResolver hlo_resolver, const DebugOptions& debug_options,
- bool include_unreachable_instructions = true) const;
-
- // Return a vector containing the embedded computations used by this
- // UserComputation. Only embedded computations which are called directly by
- // this UserComputation are included. That is, the transitive closure of
- // embedded computations is not included.
- std::vector<VersionedComputationHandle> GetEmbeddedComputations(
- VersionedComputationHandle::Version version) const;
-
- // Returns the number of OperationRequest objects in this UserComputation.
- // The 'version' of a computation is identical to the number of
- // OperationRequests in the UserComputation.
- int64 request_count(VersionedComputationHandle::Version version) const {
- return version;
- }
-
- // Returns a copy of the internal session state for this computation -- this
- // is useful for serializing the guts of a user computation, though references
- // to other handles (e.g. referred-to computations) must be handled with care
- // in the serialization / de-serialization process.
- SessionComputation CloneSessionComputation(
- VersionedComputationHandle::Version version) const;
-
- // Warning: typically we don't want to look up computation data handles until
- // the computation is finished being built, for consistency purposes. We
- // expose this routine for error reporting purposes so that we can provide
- // more meaningful error messages from the XLA service layer.
- //
- // Returns the operation request that the handle comes from.
- StatusOr<const OperationRequest*> LookUpRequestForErrorReporting(
- const ComputationDataHandle& handle) const;
-
- // Retrieves the parameter metadata for the given parameter number.
- //
- // If the parameter number is invalid for this computation, nullopt is
- // returned. When the return value has_value(), nullptr will never be
- // the held value.
- tensorflow::gtl::optional<const OpMetadata*> ParameterMetadata(
- int parameter_number) const;
-
- private:
- // Warning: dangerous mutating operation that doesn't respect versioning.
- // This is only used at initialization time when constructing from a
- // SessionComputation a la MakeWithRemapping.
- //
- // Remaps references to old computations (with handle values in the keys of
- // old_to_new) to the computation handle given in the values. This is useful
- // when loading computations from snapshots, to finish initialization, before
- // the user computation is released into the wild.
- Status RemapEmbeddedComputations(
- const std::map<int64, ComputationHandle>& old_to_new)
- EXCLUSIVE_LOCKS_REQUIRED(mutex_);
-
- // Returns the OperationRequest corresponding to the given handle.
- StatusOr<const OperationRequest*> LookUpRequest(
- const ComputationDataHandle& handle) const
- EXCLUSIVE_LOCKS_REQUIRED(mutex_);
-
- // Creates a new ComputationDataHandle with the next available handle value.
- ComputationDataHandle CreateComputationDataHandle()
- EXCLUSIVE_LOCKS_REQUIRED(mutex_);
-
- // Checks whether the parameter numbers of the parameter operations are
- // contiguous starting from zero. Returns appropriate error status if not.
- Status CheckParametersAreContiguous(
- VersionedComputationHandle::Version version) const
- EXCLUSIVE_LOCKS_REQUIRED(mutex_);
-
- VersionedComputationHandle GetVersionedHandleInternal() const
- EXCLUSIVE_LOCKS_REQUIRED(mutex_);
-
- // Name of the computation.
- string name_;
-
- mutable tensorflow::mutex mutex_;
-
- // State of the computation as a record of all operation-building requests.
- SessionComputation session_computation_ GUARDED_BY(mutex_);
-
- // Mapping from parameter number to operation request containing the
- // respective ParameterRequest.
- std::map<int64, OperationRequest*> parameters_ GUARDED_BY(mutex_);
-
- // The next ComputationDataHandle value to assign. Handle values are assigned
- // sequentially.
- int64 next_handle_value_ GUARDED_BY(mutex_);
-
- // If handle_to_return_.has_handle() then an Execution of this Computation
- // will compute the value represented by handle_to_return_, otherwise it will
- // compute the value of (next_handle_value_ - 1).
- ComputationDataHandle handle_to_return_ GUARDED_BY(mutex_);
-
- // Memoized ProgramShape and its version. A shared_ptr is used because
- // references to this object are returned by ComputeProgramShape.
- mutable int64 program_shape_version_ GUARDED_BY(mutex_) = 0;
- mutable std::shared_ptr<const ProgramShape> program_shape_ GUARDED_BY(mutex_);
-
- TF_DISALLOW_COPY_AND_ASSIGN(UserComputation);
-};
-
-} // namespace xla
-
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_
diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc
deleted file mode 100644
index 2fa163953f..0000000000
--- a/tensorflow/compiler/xla/service/user_computation_test.cc
+++ /dev/null
@@ -1,340 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/user_computation.h"
-
-#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/compiler/xla/service/hlo_matchers.h"
-#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-
-namespace op = xla::testing::opcode_matchers;
-
-namespace xla {
-namespace {
-
-using UserComputationTest = ::testing::Test;
-
-TEST_F(UserComputationTest, SimpleComputation) {
- const Shape kScalarShape = ShapeUtil::MakeShape(F32, {});
- const Shape kVectorShape = ShapeUtil::MakeShape(F32, {2});
-
- // Build a simple three operation computatation:
- //
- // %constant = Constant({123, 42})
- // %param = Param(0)
- // %outfeed = Outfeed(%constant)
- //
- // Build the computation at two different versions and check invariants.
- ComputationHandle handle;
- handle.set_handle(123);
- UserComputation computation("TheComputation", handle);
-
- ConstantRequest constant_request;
- *constant_request.mutable_literal() =
- Literal::CreateR1<float>({123.0f, 42.0f})->ToProto();
- TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle constant_handle,
- computation.AddConstantInstruction(constant_request));
-
- ParameterRequest param_request;
- *param_request.mutable_shape() = kScalarShape;
- param_request.set_parameter(0);
- param_request.set_name("param0");
- TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle param_handle,
- computation.AddParameterInstruction(param_request));
- OpMetadata metadata;
- metadata.set_op_name("meta");
- TF_ASSERT_OK(computation.SetOpMetadata(param_handle, metadata));
-
- OutfeedRequest outfeed_request;
- *outfeed_request.mutable_operand() = constant_handle;
- *outfeed_request.mutable_shape() = kVectorShape;
- outfeed_request.set_outfeed_config("abc");
- TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle outfeed_handle,
- computation.AddOutfeedInstruction(outfeed_request));
-
- auto hlo_resolver = [](const VersionedComputationHandle& handle) {
- return nullptr;
- };
- {
- // Test the computation at the latest version. In this case, the most
- // recently added operation is an outfeed. However, the outfeed is not the
- // root because outfeeds cannot be the root of a computation.
- VersionedComputationHandle latest_version =
- computation.GetVersionedHandle();
-
- // Program shape should have a single scalar parameter and scalar
- // result. The outfeed instruction should not affect the program shape.
- TF_ASSERT_OK_AND_ASSIGN(
- std::shared_ptr<const ProgramShape> program_shape,
- computation.ComputeProgramShape(latest_version.version));
- ASSERT_EQ(1, program_shape->parameters_size());
- EXPECT_TRUE(
- ShapeUtil::Compatible(kScalarShape, program_shape->parameters(0)));
- EXPECT_TRUE(ShapeUtil::Compatible(kScalarShape, program_shape->result()));
-
- // Build the HLO computation.
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<HloComputation> hlo_computation,
- computation.BuildHloComputation(latest_version.version, hlo_resolver,
- DebugOptions()));
- // There should be one HloInstruction per UserComputation operation.
- EXPECT_EQ(3, hlo_computation->instruction_count());
- // The root of the instruction should be the parameter instruction (not the
- // outfeed).
- EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter());
- }
-
- {
- // Test the computation at the version right after the parameter instruction
- // is added.
- VersionedComputationHandle version_at_param =
- computation.GetVersionedHandleAtOperation(param_handle);
-
- // Program shape should have a single scalar parameter, and scalar result.
- TF_ASSERT_OK_AND_ASSIGN(
- std::shared_ptr<const ProgramShape> program_shape,
- computation.ComputeProgramShape(version_at_param.version));
- ASSERT_EQ(1, program_shape->parameters_size());
- EXPECT_TRUE(
- ShapeUtil::Compatible(kScalarShape, program_shape->parameters(0)));
- EXPECT_TRUE(ShapeUtil::Compatible(kScalarShape, program_shape->result()));
-
- // There should be two instructions, one for the constant and one for the
- // parameter. The outfeed instruction should not be included.
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<HloComputation> hlo_computation,
- computation.BuildHloComputation(version_at_param.version, hlo_resolver,
- DebugOptions()));
- EXPECT_EQ(2, hlo_computation->instruction_count());
- EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter());
- }
- {
- // Test the computation at the latest version, but lowered with
- // include_unreachable_instructions set to false.
- VersionedComputationHandle latest_version =
- computation.GetVersionedHandle();
-
- // Build the HLO computation.
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<HloComputation> hlo_computation,
- computation.BuildHloComputation(
- latest_version.version, hlo_resolver, DebugOptions(),
- /*include_unreachable_instructions=*/false));
- // There is only one reachable instruction, the parameter.
- EXPECT_EQ(1, hlo_computation->instruction_count());
- // The root of the instruction should be the parameter instruction (not the
- // outfeed).
- EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter());
- EXPECT_EQ(hlo_computation->root_instruction()->metadata().op_name(),
- "meta");
- }
-}
-
-TEST_F(UserComputationTest, EliminateScalarBroadcast) {
- auto debug_options = DebugOptions();
- debug_options.set_xla_eliminate_hlo_implicit_broadcast(true);
-
- // Build a binary computation with scalar broadcast.
- //
- // %a = Constant({123, 42})
- // %b = Constant(1)
- // %add = Add(%a, %b)
- ComputationHandle handle;
- handle.set_handle(123);
- UserComputation computation("TheComputation", handle);
-
- ConstantRequest a_request;
- *a_request.mutable_literal() =
- Literal::CreateR1<float>({123.0f, 42.0f})->ToProto();
- TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle,
- computation.AddConstantInstruction(a_request));
-
- ConstantRequest b_request;
- *b_request.mutable_literal() = Literal::CreateR0<float>(1.0f)->ToProto();
- TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle,
- computation.AddConstantInstruction(b_request));
-
- BinaryOpRequest add;
- add.set_binop(BINOP_ADD);
- *add.mutable_lhs() = a_handle;
- *add.mutable_rhs() = b_handle;
- TF_ASSERT_OK(computation.AddBinaryInstruction(add).status());
-
- auto hlo_resolver = [](const VersionedComputationHandle& handle) {
- return nullptr;
- };
- VersionedComputationHandle latest_version = computation.GetVersionedHandle();
-
- // Build the HLO computation.
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<HloComputation> hlo_computation,
- computation.BuildHloComputation(latest_version.version, hlo_resolver,
- debug_options));
- // The binary operation has implicit scalar broadcast, should be converted
- // to an explicit broadcast intruction and a binary instruction.
- EXPECT_EQ(4, hlo_computation->instruction_count());
- EXPECT_THAT(hlo_computation->root_instruction(), op::Add());
- LOG(INFO) << hlo_computation->root_instruction()->ToString();
- const auto& operands = hlo_computation->root_instruction()->operands();
- ASSERT_EQ(2, operands.size());
- EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast ||
- operands[1]->opcode() == HloOpcode::kBroadcast);
-}
-
-TEST_F(UserComputationTest, CheckImplicitBroadcastToExplicitBroadcast) {
- auto debug_options = DebugOptions();
- debug_options.set_xla_eliminate_hlo_implicit_broadcast(true);
-
- // Build a binary computation with degenerate broadcast.
- //
- // %a = Param({1, 2, 3});
- // %b = Param({1, 2, 1});
- // %add = Add(%a, %b, {});
- ComputationHandle handle;
- handle.set_handle(123);
- UserComputation computation("TheComputation", handle);
-
- ParameterRequest a_request;
- *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 3});
- a_request.set_name("a");
- a_request.set_parameter(0);
- TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle,
- computation.AddParameterInstruction(a_request));
-
- ParameterRequest b_request;
- *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 1});
- b_request.set_name("b");
- b_request.set_parameter(1);
- TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle,
- computation.AddParameterInstruction(b_request));
-
- const int64 kDevice = 7;
- OpSharding sharding;
- sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
- sharding.add_tile_assignment_dimensions(1);
- sharding.add_tile_assignment_devices(kDevice);
-
- TF_EXPECT_OK(computation.SetOpSharding(b_handle, sharding));
-
- BinaryOpRequest add;
- add.set_binop(BINOP_ADD);
- *add.mutable_lhs() = a_handle;
- *add.mutable_rhs() = b_handle;
- TF_ASSERT_OK(computation.AddBinaryInstruction(add).status());
-
- auto hlo_resolver = [](const VersionedComputationHandle& handle) {
- return nullptr;
- };
- VersionedComputationHandle latest_version = computation.GetVersionedHandle();
-
- // Build the HLO computation.
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<HloComputation> hlo_computation,
- computation.BuildHloComputation(latest_version.version, hlo_resolver,
- debug_options));
-
- // b a
- // | |
- // reshape |
- // | |
- // broadcast |
- // \ /
- // add
- EXPECT_EQ(5, hlo_computation->instruction_count());
- ASSERT_THAT(
- hlo_computation->root_instruction(),
- op::Add(op::Parameter(), op::Broadcast(op::Reshape(op::Parameter()))));
-
- const HloInstruction* broadcast =
- hlo_computation->root_instruction()->operand(1);
- EXPECT_TRUE(broadcast->has_sharding());
-
- const HloInstruction* reshape = broadcast->operand(0);
- EXPECT_TRUE(reshape->has_sharding());
-}
-
-TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) {
- auto debug_options = DebugOptions();
- debug_options.set_xla_eliminate_hlo_implicit_broadcast(true);
-
- // Build a binary computation with in-dim broadcast and degenerate broadcast.
- //
- // %a = Param({2, 3});
- // %b = Param({2, 1, 4});
- // %add = Add(%a, %b, {0, 1});
- ComputationHandle handle;
- handle.set_handle(123);
- UserComputation computation("TheComputation", handle);
-
- ParameterRequest a_request;
- *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 3});
- a_request.set_name("a");
- a_request.set_parameter(0);
- TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle,
- computation.AddParameterInstruction(a_request));
-
- ParameterRequest b_request;
- *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 1, 4});
- b_request.set_name("b");
- b_request.set_parameter(1);
- TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle,
- computation.AddParameterInstruction(b_request));
-
- BinaryOpRequest add;
- add.set_binop(BINOP_ADD);
- *add.mutable_lhs() = a_handle;
- *add.mutable_rhs() = b_handle;
- add.add_broadcast_dimensions(0);
- add.add_broadcast_dimensions(1);
- TF_ASSERT_OK(computation.AddBinaryInstruction(add).status());
-
- auto hlo_resolver = [](const VersionedComputationHandle& handle) {
- return nullptr;
- };
- VersionedComputationHandle latest_version = computation.GetVersionedHandle();
-
- // Build the HLO computation.
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<HloComputation> hlo_computation,
- computation.BuildHloComputation(latest_version.version, hlo_resolver,
- debug_options));
-
- // The binary operation has in-dim broadcast and degenerate broadcast, should
- // first do the in-dim broadcast then convert the degnerate broadcast into a
- // reshape and a broadcast.
- //
- // b a
- // | |
- // broadcast reshape
- // | |
- // | broadcast
- // \ /
- // add
- EXPECT_EQ(6, hlo_computation->instruction_count());
- EXPECT_THAT(hlo_computation->root_instruction(), op::Add());
- const auto& operands = hlo_computation->root_instruction()->operands();
- ASSERT_EQ(2, operands.size());
- EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast &&
- operands[1]->opcode() == HloOpcode::kBroadcast);
-}
-
-} // namespace
-} // namespace xla
diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc
index ed20b36292..473eab2ea8 100644
--- a/tensorflow/compiler/xla/service/while_util.cc
+++ b/tensorflow/compiler/xla/service/while_util.cc
@@ -117,9 +117,13 @@ WhileUtil::MakeInstructionsLiveIn(
HloInstruction* new_while = containing_computation->AddInstruction(
HloInstruction::CreateWhile(new_while_shape, new_while_condition,
new_while_body, new_while_init));
- TF_RETURN_IF_ERROR(containing_computation->ReplaceInstruction(
- while_instr, TupleUtil::ExtractPrefix(
- new_while, while_instr->shape().tuple_shapes_size())));
+
+ // We want to get rid of the old while instruction even if it has side
+ // effecting operations so we do a manual HloComputation::RemoveInstruction
+ // instead of relying on HloComputation::ReplaceInstruction.
+ TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(TupleUtil::ExtractPrefix(
+ new_while, while_instr->shape().tuple_shapes_size())));
+ TF_RETURN_IF_ERROR(containing_computation->RemoveInstruction(while_instr));
HloInstruction* while_body_param = new_while_body->parameter_instruction(0);
std::vector<HloInstruction*> live_in_instructions;
diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h
index 322d27b88c..e67636d80f 100644
--- a/tensorflow/compiler/xla/service/while_util.h
+++ b/tensorflow/compiler/xla/service/while_util.h
@@ -38,17 +38,21 @@ class WhileUtil {
};
// Replaces `while_instr` with a new while instruction that is equivalent to
- // `while_instr`, except that it has all of the HLO instructions in
+ // `while_instr` except that it has all of the HLO instructions in
// `instructions` as live-in, loop invariant values. These new live in values
// are represented as new elements appended to the parameter of the while
// loop, which must be of tuple shape. GetTupleElement instructions computing
// each new live in value is returned in the `while_body_live_in_values`
// vector.
//
- // Precondition: `while_instr` must have a tuple shaped state.
+ // Deletes `while_instr` after replacing it.
//
- // Every instruction in `instructions` must be contained in the computation
- // that contains `while_instr`.
+ // Preconditions:
+ //
+ // `while_instr` must have a tuple shaped state.
+ //
+ // Every instruction in `instructions` must be contained in the computation
+ // that contains `while_instr`.
static StatusOr<MakeInstructionsLiveInResult> MakeInstructionsLiveIn(
HloInstruction* while_instr,
tensorflow::gtl::ArraySlice<HloInstruction*> instructions);
diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc
index 974bc542a3..bcc545c61d 100644
--- a/tensorflow/compiler/xla/service/while_util_test.cc
+++ b/tensorflow/compiler/xla/service/while_util_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+#include "tensorflow/compiler/xla/util.h"
namespace xla {
namespace {
@@ -163,5 +164,47 @@ ENTRY main {
ASSERT_EQ(gte_list.size(), 1);
EXPECT_EQ((*gte_list.begin())->name(), "gte.0");
}
+
+TEST(WhileUtilTest, AlwaysRemovePreviousWhileBody) {
+ const char* const hlo_string = R"(
+HloModule WhileWithSideEffects
+
+body {
+ param.b = (s32[], s32[]) parameter(0)
+ gte.0 = s32[] get-tuple-element(param.b), index=0
+ gte.1 = s32[] get-tuple-element(param.b), index=1
+ add = s32[] add(gte.0, gte.1)
+ ROOT tuple = (s32[], s32[]) tuple(gte.0, add)
+}
+
+cond {
+ param.c = (s32[], s32[]) parameter(0)
+ ROOT condition = pred[] infeed()
+}
+
+ENTRY main {
+ init = (s32[], s32[]) parameter(0)
+ to_make_live_in = f32[100] parameter(1)
+ ROOT while = (s32[], s32[]) while(init), condition=cond, body=body
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ tools::Parse(hlo_string));
+
+ HloComputation* main = module->GetComputationWithName("main");
+ HloInstruction* while_instr = main->root_instruction();
+ HloInstruction* to_make_live_in = main->parameter_instruction(1);
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ WhileUtil::MakeInstructionsLiveInResult make_live_in_result,
+ WhileUtil::MakeInstructionsLiveIn(while_instr,
+ /*instructions=*/{to_make_live_in}));
+
+ auto is_while = [](const HloInstruction* instr) {
+ return instr->opcode() == HloOpcode::kWhile;
+ };
+ EXPECT_EQ(c_count_if(main->instructions(), is_while), 1);
+}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h
index 141347a792..14c35e7b84 100644
--- a/tensorflow/compiler/xla/service_interface.h
+++ b/tensorflow/compiler/xla/service_interface.h
@@ -47,41 +47,22 @@ class ServiceInterface {
virtual Status ResetDevice(const ResetDeviceRequest* arg,
ResetDeviceResponse* result) = 0;
- virtual Status LoadComputationSnapshot(
- const LoadComputationSnapshotRequest* request,
- LoadComputationSnapshotResponse* result) = 0;
-
- virtual Status Execute(const ExecuteRequest* arg,
- ExecuteResponse* result) = 0;
-
virtual Status ExecuteGraph(const ExecuteGraphRequest* arg,
ExecuteResponse* result) = 0;
- virtual Status ExecuteParallel(const ExecuteParallelRequest* arg,
- ExecuteParallelResponse* result) = 0;
-
virtual Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
ExecuteParallelResponse* result) = 0;
- virtual Status ExecuteAsync(const ExecuteAsyncRequest* arg,
- ExecuteAsyncResponse* result) = 0;
-
virtual Status WaitForExecution(const WaitForExecutionRequest* arg,
WaitForExecutionResponse* result) = 0;
virtual Status DeconstructTuple(const DeconstructTupleRequest* arg,
DeconstructTupleResponse* result) = 0;
- virtual Status GetComputationStats(const ComputationStatsRequest* arg,
- ComputationStatsResponse* result) = 0;
-
virtual Status GetComputationGraphStats(
const ComputationGraphStatsRequest* arg,
ComputationStatsResponse* result) = 0;
- virtual Status GetComputationShape(const GetComputationShapeRequest* arg,
- GetComputationShapeResponse* result) = 0;
-
virtual Status GetShape(const GetShapeRequest* arg,
GetShapeResponse* result) = 0;
@@ -91,31 +72,9 @@ class ServiceInterface {
virtual Status GetDeviceHandles(const GetDeviceHandlesRequest* arg,
GetDeviceHandlesResponse* result) = 0;
- // Methods used by ComputationBuilder.
- virtual Status Computation(const ComputationRequest* arg,
- ComputationResponse* result) = 0;
-
- virtual Status Op(const OpRequest* arg, OpResponse* result) = 0;
-
- virtual Status GetLocalShape(const GetLocalShapeRequest* arg,
- GetLocalShapeResponse* result) = 0;
-
- virtual Status SetReturnValue(const SetReturnValueRequest* arg,
- SetReturnValueResponse* results) = 0;
-
- virtual Status IsConstant(const IsConstantRequest* arg,
- IsConstantResponse* result) = 0;
-
- virtual Status ComputeConstant(const ComputeConstantRequest* arg,
- ComputeConstantResponse* result) = 0;
-
virtual Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
ComputeConstantResponse* result) = 0;
- // Methods used by Computation.
- virtual Status SnapshotComputation(const SnapshotComputationRequest* ag,
- SnapshotComputationResponse* result) = 0;
-
// Methods used by GlobalData.
virtual Status Unregister(const UnregisterRequest* arg,
UnregisterResponse* result) = 0;
diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD
index 15b9cd4265..d73bcdaf82 100644
--- a/tensorflow/compiler/xla/tools/BUILD
+++ b/tensorflow/compiler/xla/tools/BUILD
@@ -164,7 +164,6 @@ tf_cc_binary(
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service",
- "//tensorflow/compiler/xla/service:computation_tracker",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:interpreter_plugin",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc
index b815bbf854..5dd5150be3 100644
--- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc
+++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc
@@ -20,7 +20,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/service/computation_tracker.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/service.h"
#include "tensorflow/compiler/xla/statusor.h"
diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD
index 0fa4b98d0a..76f35afd53 100644
--- a/tensorflow/compiler/xla/tools/parser/BUILD
+++ b/tensorflow/compiler/xla/tools/parser/BUILD
@@ -65,6 +65,7 @@ tf_cc_test(
srcs = ["hlo_parser_test.cc"],
deps = [
":hlo_parser",
+ "//tensorflow/compiler/xla:window_util",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index 134978d21f..ef10ca4bff 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -56,10 +56,10 @@ class HloParser {
// Returns the error information.
string GetError() const { return Join(error_, "\n"); }
- // Stand alone parsing for sharding. The parser string is supposed to
- // contain the body of the sharding, i.e. just the rhs of the "sharding={...}"
- // attribute string.
+ // Stand alone parsing utils for various aggregate data types.
StatusOr<HloSharding> ParseShardingOnly();
+ StatusOr<Window> ParseWindowOnly();
+ StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly();
private:
// ParseXXX returns false if an error occurred.
@@ -169,7 +169,9 @@ class HloParser {
bool ParseComputationName(HloComputation** value);
// Parses a list of names and finds the corresponding hlo instructions.
bool ParseInstructionNames(std::vector<HloInstruction*>* instructions);
- bool ParseWindow(Window* window);
+ // Pass expect_outer_curlies == true when parsing a Window in the context of a
+ // larger computation. Pass false when parsing a stand-alone Window string.
+ bool ParseWindow(Window* window, bool expect_outer_curlies);
bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums);
bool ParsePaddingConfig(PaddingConfig* padding);
bool ParseMetadata(OpMetadata* metadata);
@@ -1125,7 +1127,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
instruction->set_metadata(*metadata);
}
if (backend_config) {
- instruction->set_backend_config(std::move(*backend_config));
+ instruction->set_raw_backend_config_string(std::move(*backend_config));
}
return AddInstruction(name, instruction, name_loc);
} // NOLINT(readability/fn_size)
@@ -1933,7 +1935,7 @@ bool HloParser::ParseAttributeHelper(
}
case AttrTy::kWindow: {
Window result;
- if (!ParseWindow(&result)) {
+ if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) {
return false;
}
static_cast<optional<Window>*>(attr_out_ptr)->emplace(result);
@@ -2051,9 +2053,10 @@ bool HloParser::ParseComputationName(HloComputation** value) {
// ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}'
// The subattributes can appear in any order. 'size=' is required, others are
// optional.
-bool HloParser::ParseWindow(Window* window) {
+bool HloParser::ParseWindow(Window* window, bool expect_outer_curlies) {
LocTy loc = lexer_.GetLoc();
- if (!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) {
+ if (expect_outer_curlies &&
+ !ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) {
return false;
}
@@ -2063,7 +2066,9 @@ bool HloParser::ParseWindow(Window* window) {
std::vector<int64> lhs_dilate;
std::vector<int64> rhs_dilate;
std::vector<int64> rhs_reversal;
- while (lexer_.GetKind() != TokKind::kRbrace) {
+ const auto end_token =
+ expect_outer_curlies ? TokKind::kRbrace : TokKind::kEof;
+ while (lexer_.GetKind() != end_token) {
LocTy attr_loc = lexer_.GetLoc();
string field_name;
if (!ParseAttributeName(&field_name)) {
@@ -2127,7 +2132,8 @@ bool HloParser::ParseWindow(Window* window) {
window->mutable_dimensions(i)->set_window_reversal(
rhs_reversal.empty() ? false : (rhs_reversal[i] == 1));
}
- return ParseToken(TokKind::kRbrace, "expected '}' to end window attribute");
+ return !expect_outer_curlies ||
+ ParseToken(TokKind::kRbrace, "expected '}' to end window attribute");
}
// This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString.
@@ -2692,6 +2698,32 @@ StatusOr<HloSharding> HloParser::ParseShardingOnly() {
return HloSharding::FromProto(op_sharding);
}
+StatusOr<Window> HloParser::ParseWindowOnly() {
+ lexer_.Lex();
+ Window window;
+ if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) {
+ return InvalidArgument("Syntax error:\n%s", GetError().c_str());
+ }
+ if (lexer_.GetKind() != TokKind::kEof) {
+ return InvalidArgument("Syntax error:\nExtra content after window");
+ }
+ return window;
+}
+
+StatusOr<ConvolutionDimensionNumbers>
+HloParser::ParseConvolutionDimensionNumbersOnly() {
+ lexer_.Lex();
+ ConvolutionDimensionNumbers dnums;
+ if (!ParseConvolutionDimensionNumbers(&dnums)) {
+ return InvalidArgument("Syntax error:\n%s", GetError().c_str());
+ }
+ if (lexer_.GetKind() != TokKind::kEof) {
+ return InvalidArgument(
+ "Syntax error:\nExtra content after convolution dnums");
+ }
+ return dnums;
+}
+
} // namespace
StatusOr<std::unique_ptr<HloModule>> Parse(StringPiece str,
@@ -2714,5 +2746,18 @@ StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str) {
return parser.ParseShardingOnly();
}
+StatusOr<Window> ParseWindow(tensorflow::StringPiece str) {
+ HloModuleConfig config;
+ HloParser parser(str, config);
+ return parser.ParseWindowOnly();
+}
+
+StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
+ tensorflow::StringPiece str) {
+ HloModuleConfig config;
+ HloParser parser(str, config);
+ return parser.ParseConvolutionDimensionNumbersOnly();
+}
+
} // namespace tools
} // namespace xla
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.h b/tensorflow/compiler/xla/tools/parser/hlo_parser.h
index f7854f403e..902c45cebc 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.h
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.h
@@ -36,10 +36,17 @@ StatusOr<std::unique_ptr<HloModule>> Parse(tensorflow::StringPiece str,
// format, parses the string and creates a HloModule with default config.
StatusOr<std::unique_ptr<HloModule>> Parse(tensorflow::StringPiece str);
-// Parse sharding from str. str is supposed to contain the body of the
-// sharding, i.e. just the rhs of the "sharding={...}" attribute string.
+// Parses the result of HloSharding::ToString(), e.g. "{replicated}".
StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str);
+// Parses the result of window_util::ToString(const Window&).
+StatusOr<Window> ParseWindow(tensorflow::StringPiece str);
+
+// Parses the result of ConvolutionDimensionNumbersToString(), e.g.
+// "b0f_0io->b0f".
+StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
+ tensorflow::StringPiece str);
+
} // namespace tools
} // namespace xla
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
index 183b1121cd..3c5957b96a 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include <string>
+#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -1024,7 +1025,7 @@ ENTRY %configuration_test() -> s32[] {
EXPECT_EQ("foo bar", result.ValueOrDie()
->entry_computation()
->root_instruction()
- ->backend_config());
+ ->raw_backend_config_string());
}
TEST_F(HloParserTest, LiteralDimensionsMismatch_1) {
@@ -1349,6 +1350,26 @@ ENTRY entry {
"was parsing 8:39: error: instruction does not exist: aparam");
}
+TEST_F(HloParserTest, ParseSharding) {
+ const string original = "{maximal device=42}";
+ TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original));
+ EXPECT_EQ(sharding.ToString(), original);
+}
+
+TEST_F(HloParserTest, ParseWindow) {
+ Window original = window_util::MakeWindow({1, 2, 3});
+ TF_ASSERT_OK_AND_ASSIGN(Window parsed,
+ ParseWindow(window_util::ToString(original)))
+ EXPECT_EQ(window_util::ToString(original), window_util::ToString(parsed));
+}
+
+TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) {
+ const string original = "b0f_0io->b0f";
+ TF_ASSERT_OK_AND_ASSIGN(ConvolutionDimensionNumbers dnums,
+ ParseConvolutionDimensionNumbers(original));
+ EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums));
+}
+
} // namespace
} // namespace tools
} // namespace xla
diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc
index 2349fa919e..be094b7890 100644
--- a/tensorflow/compiler/xla/tools/replay_computation.cc
+++ b/tensorflow/compiler/xla/tools/replay_computation.cc
@@ -68,7 +68,6 @@ struct Options {
bool use_fake_data = false;
bool print_result = true;
int num_runs = 1;
- bool xla_hlo_profile_last_run = false;
};
// Invokes the given computation passing arbitrary data for every (unbound)
@@ -80,21 +79,35 @@ struct Options {
//
// If neither generate_fake_infeed is true nor a fake_infeed_shape is provided,
// no infeed is performed.
-StatusOr<std::unique_ptr<Literal>> ReplayComputation(const HloSnapshot& module,
- Client* client,
- const Options& opts) {
- TF_ASSIGN_OR_RETURN(auto computation, client->LoadSnapshot(module));
+StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
+ LocalClient* client, const Options& opts) {
+ XlaComputation computation(module.hlo().hlo_module());
- std::vector<std::unique_ptr<GlobalData>> arguments;
+ // Build the `argument_ptrs` vector, which contains ShapedBuffer*s to our
+ // arguments. This is a bit involved, because we may have to convert from
+ // GlobalData to ShapedBuffer*, and we have to manage the lifetime of all our
+ // objects.
+ std::vector<ScopedShapedBuffer> scoped_shaped_buffer_arguments;
+ std::vector<std::unique_ptr<GlobalData>> global_data_arguments;
+ std::vector<const ShapedBuffer*> argument_ptrs;
if (opts.use_fake_data) {
- arguments = MakeFakeArgumentsOrDie(computation, client);
+ global_data_arguments = MakeFakeArgumentsOrDie(computation, client);
+ for (const auto& data : global_data_arguments) {
+ argument_ptrs.push_back(
+ client->GlobalDataToShapedBuffer(data->handle(), /*device_ordinal=*/0)
+ .ValueOrDie());
+ }
} else { // use recorded data if available
for (const auto& proto : module.arguments()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal,
Literal::CreateFromProto(proto));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<GlobalData> data,
- client->TransferToServer(*literal));
- arguments.push_back(std::move(data));
+ TF_ASSIGN_OR_RETURN(
+ ScopedShapedBuffer data,
+ client->LiteralToShapedBuffer(*literal, /*device_ordinal=*/0));
+ scoped_shaped_buffer_arguments.push_back(std::move(data));
+ }
+ for (const auto& argument : scoped_shaped_buffer_arguments) {
+ argument_ptrs.push_back(&argument);
}
}
@@ -149,55 +162,59 @@ StatusOr<std::unique_ptr<Literal>> ReplayComputation(const HloSnapshot& module,
});
}
- std::vector<GlobalData*> execute_arguments;
- execute_arguments.reserve(arguments.size());
- for (auto& argument : arguments) {
- execute_arguments.push_back(argument.get());
+ std::vector<const Shape*> argument_layouts;
+ for (const auto& param : computation.proto().program_shape().parameters()) {
+ argument_layouts.push_back(&param);
}
+ std::unique_ptr<LocalExecutable> executable =
+ client->Compile(computation, argument_layouts, ExecutableBuildOptions())
+ .ValueOrDie();
// Run the computation num_runs times, and return the result from the last
// execution.
- std::unique_ptr<Literal> result;
+ StreamExecutorMemoryAllocator allocator(
+ client->platform(),
+ {client->platform()->ExecutorForDevice(0).ValueOrDie()});
+ tensorflow::gtl::optional<ScopedShapedBuffer> result;
for (int i = 0; i < opts.num_runs; ++i) {
ExecutionProfile profile;
- ExecutionOptions execution_options = CreateDefaultExecutionOptions();
- if (opts.xla_hlo_profile_last_run && i == opts.num_runs - 1) {
- execution_options.mutable_debug_options()->set_xla_hlo_profile(true);
- }
+ ExecutableRunOptions run_options;
+ run_options.set_execution_profile(&profile);
+ run_options.set_allocator(&allocator);
- if (opts.print_result) {
- TF_ASSIGN_OR_RETURN(
- result, client->ExecuteAndTransfer(computation, execute_arguments,
- &execution_options, &profile));
- } else {
- // If we're not printing the result, execute the computation but don't
- // bother retrieving the result. This can be a significant speedup.
- TF_RETURN_IF_ERROR(client
- ->Execute(computation, execute_arguments,
- &execution_options, &profile)
- .status());
- }
+ TF_ASSIGN_OR_RETURN(result, executable->Run(argument_ptrs, run_options));
LOG(INFO) << "Execution took "
<< static_cast<double>(profile.compute_time_ns()) / 1e9 << "s";
}
- return std::move(result);
+ // Check that --num_runs > 0, otherwise *result below will fail with an
+ // unhelpful error (because the loop didn't run any iterations).
+ CHECK_GT(opts.num_runs, 0) << "--num_runs must be > 0";
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result_literal,
+ client->ShapedBufferToLiteral(*result));
+ return std::move(*result_literal);
}
int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
- Client* client = ClientLibrary::LocalClientOrDie();
+ LocalClient* client = ClientLibrary::LocalClientOrDie();
tensorflow::Env* env = tensorflow::Env::Default();
int exit_status = EXIT_SUCCESS;
for (char* arg : args) {
HloSnapshot snapshot;
auto status = tensorflow::ReadBinaryProto(env, arg, &snapshot);
if (!status.ok()) {
- fprintf(stderr, "%s: is not HloSnapshot: %s.\n", arg,
- status.ToString().c_str());
- continue;
+ fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n", arg);
+ status = tensorflow::ReadBinaryProto(env, arg, snapshot.mutable_hlo());
+ if (!status.ok()) {
+ fprintf(stderr, "%s: is not HloSnapshot or HloProto: %s.\n", arg,
+ status.ToString().c_str());
+ continue;
+ }
+ CHECK(opts.use_fake_data)
+ << "HloProto input must be handled with --use_fake_data";
}
- StatusOr<std::unique_ptr<Literal>> result_status =
- ReplayComputation(snapshot, client, opts);
+
+ StatusOr<Literal> result_status = ReplayComputation(snapshot, client, opts);
if (!result_status.ok()) {
fprintf(stderr, "%s: error: %s\n", arg,
result_status.status().ToString().c_str());
@@ -205,12 +222,12 @@ int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
continue;
}
- std::unique_ptr<Literal> result = result_status.ConsumeValueOrDie();
- if (result != nullptr) {
+ if (opts.print_result) {
+ Literal result = std::move(result_status).ValueOrDie();
fprintf(stdout, "%s: %s :: %s:%s\n", arg,
snapshot.hlo().hlo_module().name().c_str(),
- ShapeUtil::HumanString(result->shape()).c_str(),
- result->ToString().c_str());
+ ShapeUtil::HumanString(result.shape()).c_str(),
+ result.ToString().c_str());
if (snapshot.has_result()) {
std::unique_ptr<Literal> literal =
Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie();
@@ -243,9 +260,6 @@ int main(int argc, char** argv) {
tensorflow::Flag("generate_fake_infeed", &opts.generate_fake_infeed,
"Whether a fake infeed shape should be generated "
"derived from the computation"),
- tensorflow::Flag(
- "xla_hlo_profile_last_run", &opts.xla_hlo_profile_last_run,
- "Pass --xla_hlo_profile the last time we run the computation."),
};
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h
index 7303640726..b4f45cc972 100644
--- a/tensorflow/compiler/xla/util.h
+++ b/tensorflow/compiler/xla/util.h
@@ -526,6 +526,13 @@ typename std::decay<T>::type c_accumulate(const Sequence& sequence, T&& init,
std::forward<BinaryOp>(binary_op));
}
+template <typename C, typename Pred>
+typename std::iterator_traits<
+ decltype(std::begin(std::declval<C>()))>::difference_type
+c_count_if(const C& c, Pred&& pred) {
+ return std::count_if(std::begin(c), std::end(c), std::forward<Pred>(pred));
+}
+
template <typename C, typename Value>
int64 FindIndex(const C& c, Value&& value) {
auto it = c_find(c, std::forward<Value>(value));
diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py
index 5b7508c9a5..775d92c1d9 100644
--- a/tensorflow/contrib/autograph/converters/break_statements.py
+++ b/tensorflow/contrib/autograph/converters/break_statements.py
@@ -32,14 +32,6 @@ CONTROL_VAR_NAME = 'control_var_name'
class BreakStatementTransformer(transformer.Base):
"""Canonicalizes break statements into additional conditionals."""
- def _track_body(self, nodes, break_var):
- self.enter_local_scope()
- self.set_local(CONTROL_VAR_NAME, break_var)
- nodes = self.visit_block(nodes)
- break_used = self.get_local(BREAK_USED, False)
- self.exit_local_scope()
- return nodes, break_used
-
def visit_Break(self, node):
self.set_local(BREAK_USED, True)
var_name = self.get_local(CONTROL_VAR_NAME)
@@ -65,6 +57,14 @@ class BreakStatementTransformer(transformer.Base):
block=block)
return node
+ def _track_body(self, nodes, break_var):
+ self.enter_local_scope()
+ self.set_local(CONTROL_VAR_NAME, break_var)
+ nodes = self.visit_block(nodes)
+ break_used = self.get_local(BREAK_USED, False)
+ self.exit_local_scope()
+ return nodes, break_used
+
def visit_While(self, node):
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
break_var = self.context.namer.new_symbol('break_', scope.referenced)
diff --git a/tensorflow/contrib/autograph/converters/continue_statements.py b/tensorflow/contrib/autograph/converters/continue_statements.py
index 4299a8a9d5..0417817a77 100644
--- a/tensorflow/contrib/autograph/converters/continue_statements.py
+++ b/tensorflow/contrib/autograph/converters/continue_statements.py
@@ -24,103 +24,115 @@ from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
-class ContinueCanonicalizationTransformer(transformer.Base):
- """Canonicalizes continue statements into additional conditionals."""
+# Tags for local state.
+CONTROL_VAR_NAME = 'control_var_name'
+CONTINUE_USED = 'continue_used'
+GUARD_CREATED = 'guard_created'
+CREATE_GUARD_NEXT = 'create_guard_next'
- def __init__(self, context):
- super(ContinueCanonicalizationTransformer, self).__init__(context)
- # This is a stack structure, to correctly process nested loops.
- self.continuation_uses = []
- def _create_continuation_check(self):
- template = """
- if not var_name:
- pass
- """
- cond, = templates.replace(template, var_name=self.continuation_uses[-1][1])
- cond.body = []
- return cond
+class ContinueCanonicalizationTransformer(transformer.Base):
+ """Canonicalizes continue statements into additional conditionals."""
- def _create_continuation_trigger(self):
+ def visit_Continue(self, node):
+ self.set_local(CONTINUE_USED, True)
template = """
var_name = True
"""
- assign, = templates.replace(
- template, var_name=self.continuation_uses[-1][1])
- return assign
-
- def _create_continuation_init(self):
- template = """
- var_name = False
- """
- assign, = templates.replace(
- template, var_name=self.continuation_uses[-1][1])
- return assign
-
- def _visit_and_reindent_if_necessary(self, nodes):
- reorganized_nodes = []
- current_dest = reorganized_nodes
- continue_used_in_block = False
- for i, n in enumerate(nodes):
- # TODO(mdan): This could be optimized if control structures are simple.
- self.continuation_uses[-1][0] = False
- n = self.visit(n)
- current_dest.append(n)
- if self.continuation_uses[-1][0]:
- continue_used_in_block = True
- if i < len(nodes) - 1: # Last statement in block needs no protection.
- cond = self._create_continuation_check()
- current_dest.append(cond)
- current_dest = cond.body
- self.continuation_uses[-1][0] = continue_used_in_block
- return reorganized_nodes
-
- def _process_loop_block(self, block, scope):
- cont_var = self.context.namer.new_symbol('cont_requested', scope.referenced)
- self.continuation_uses.append([False, cont_var])
- block = self._visit_and_reindent_if_necessary(block)
- if self.continuation_uses[-1][0]:
- block.insert(0, self._create_continuation_init())
- self.continuation_uses.pop()
- return block
+ return templates.replace(
+ template, var_name=self.get_local(CONTROL_VAR_NAME))
+
+ def _postprocess_statement(self, node):
+ # Example of how the state machine below works:
+ #
+ # 1| stmt # State: CONTINUE_USED = False
+ # | # Action: none
+ # 2| if cond:
+ # 3| continue # State: CONTINUE_USED = True,
+ # | # GUARD_CREATED = False,
+ # | # CREATE_GUARD_NEXT = False
+ # | # Action: set CREATE_GUARD_NEXT = True
+ # 4| stmt # State: CONTINUE_USED = True,
+ # | # GUARD_CREATED = False,
+ # | # CREATE_GUARD_NEXT = True
+ # | # Action: create `if not continue_used`,
+ # | # set GUARD_CREATED = True
+ # 5| stmt # State: CONTINUE_USED = True, GUARD_CREATED = True
+ # | # Action: none (will be wrapped under previously
+ # | # created if node)
+
+ if self.get_local(CONTINUE_USED, False):
+ if self.get_local(GUARD_CREATED, False):
+ return node, None
+
+ elif not self.get_local(CREATE_GUARD_NEXT, False):
+ self.set_local(CREATE_GUARD_NEXT, True)
+ return node, None
+
+ else:
+ self.set_local(GUARD_CREATED, True)
+ template = """
+ if not var_name:
+ original_node
+ """
+ cond, = templates.replace(
+ template,
+ var_name=self.get_local(CONTROL_VAR_NAME),
+ original_node=node)
+ return cond, cond.body
+ return node, None
+
+ def _visit_loop_body(self, node, nodes):
+ self.enter_local_scope()
+ scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
+ continue_var = self.context.namer.new_symbol('continue_', scope.referenced)
+ self.set_local(CONTROL_VAR_NAME, continue_var)
+
+ nodes = self.visit_block(nodes, after_visit=self._postprocess_statement)
+
+ if self.get_local(CONTINUE_USED, False):
+ template = """
+ var_name = False
+ """
+ control_var_init = templates.replace(template, var_name=continue_var)
+ nodes = control_var_init + nodes
+
+ self.exit_local_scope()
+ return nodes
+
+ def _visit_non_loop_body(self, nodes):
+ self.enter_local_scope(inherit=(CONTROL_VAR_NAME,))
+ nodes = self.visit_block(nodes, after_visit=self._postprocess_statement)
+ continue_used = self.get_local(CONTINUE_USED, False)
+ self.exit_local_scope(keep=(CONTINUE_USED,))
+ return nodes, continue_used
def visit_While(self, node):
- self.generic_visit(node.test)
- node.body = self._process_loop_block(node.body,
- anno.getanno(node,
- NodeAnno.BODY_SCOPE))
- for n in node.orelse:
- self.generic_visit(n)
+ node.test = self.visit(node.test)
+ node.body = self._visit_loop_body(node, node.body)
+ # A continue in the else clause applies to the containing scope.
+ node.orelse, _ = self._visit_non_loop_body(node.orelse)
return node
def visit_For(self, node):
- self.generic_visit(node.target)
- self.generic_visit(node.iter)
- node.body = self._process_loop_block(node.body,
- anno.getanno(node,
- NodeAnno.BODY_SCOPE))
- for n in node.orelse:
- self.generic_visit(n)
+ node.target = self.generic_visit(node.target)
+ node.iter = self.generic_visit(node.iter)
+ node.body = self._visit_loop_body(node, node.body)
+ # A continue in the else clause applies to the containing scope.
+ node.orelse, _ = self._visit_non_loop_body(node.orelse)
return node
def visit_If(self, node):
- if self.continuation_uses:
- self.generic_visit(node.test)
- node.body = self._visit_and_reindent_if_necessary(node.body)
- continue_used_in_body = self.continuation_uses[-1][0]
- node.orelse = self._visit_and_reindent_if_necessary(node.orelse)
- self.continuation_uses[-1][0] = (
- continue_used_in_body or self.continuation_uses[-1][0])
- else:
- node = self.generic_visit(node)
+ node.test = self.generic_visit(node.test)
+ node.body, continue_used_body = self._visit_non_loop_body(node.body)
+ node.orelse, continue_used_orelse = self._visit_non_loop_body(node.orelse)
+ self.set_local(CONTINUE_USED, continue_used_body or continue_used_orelse)
return node
- def visit_Continue(self, node):
- self.continuation_uses[-1][0] = True
- return self._create_continuation_trigger()
-
- def visit_Break(self, node):
- assert False, 'break statement should be desugared at this point'
+ def visit_With(self, node):
+ node.items = self.visit_block(node.items)
+ node.body, _ = self._visit_non_loop_body(node.body)
+ return node
def transform(node, namer):
diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD
index 18bfec5d9c..0c6ab65505 100644
--- a/tensorflow/contrib/autograph/operators/BUILD
+++ b/tensorflow/contrib/autograph/operators/BUILD
@@ -22,7 +22,7 @@ py_library(
"__init__.py",
"control_flow.py",
"data_structures.py",
- "dispatch_context.py",
+ "slices.py",
],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
@@ -52,3 +52,13 @@ py_test(
"//tensorflow/python:client_testlib",
],
)
+
+py_test(
+ name = "slices_test",
+ srcs = ["slices_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":operators",
+ "//tensorflow/python:client_testlib",
+ ],
+)
diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py
index 38b761d97d..c900fd6af2 100644
--- a/tensorflow/contrib/autograph/operators/__init__.py
+++ b/tensorflow/contrib/autograph/operators/__init__.py
@@ -28,6 +28,10 @@ closures for the body.
# - the names used in the Python docs, if the operator is a function (e.g.
# list_ and x for append, see
# https://docs.python.org/3.7/tutorial/datastructures.html)
+#
+# All operators may accept a final argument named "opts", of a type that
+# subclasses namedtuple and contains any arguments that are only required
+# for some specializations of the operator.
from __future__ import absolute_import
from __future__ import division
@@ -35,3 +39,12 @@ from __future__ import print_function
from tensorflow.contrib.autograph.operators.control_flow import for_stmt
from tensorflow.contrib.autograph.operators.control_flow import while_stmt
+from tensorflow.contrib.autograph.operators.data_structures import list_append
+from tensorflow.contrib.autograph.operators.data_structures import list_pop
+from tensorflow.contrib.autograph.operators.data_structures import list_stack
+from tensorflow.contrib.autograph.operators.data_structures import ListPopOpts
+from tensorflow.contrib.autograph.operators.data_structures import ListStackOpts
+from tensorflow.contrib.autograph.operators.data_structures import new_list
+from tensorflow.contrib.autograph.operators.slices import get_item
+from tensorflow.contrib.autograph.operators.slices import GetItemOpts
+from tensorflow.contrib.autograph.operators.slices import set_item
diff --git a/tensorflow/contrib/autograph/operators/data_structures.py b/tensorflow/contrib/autograph/operators/data_structures.py
index c862306baa..06d8727b0f 100644
--- a/tensorflow/contrib/autograph/operators/data_structures.py
+++ b/tensorflow/contrib/autograph/operators/data_structures.py
@@ -18,39 +18,250 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import list_ops
from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.ops import variables
+
+
+# TODO(mdan): Once control flow supports objects, repackage as a class.
+
+
+def new_list(iterable=None):
+ """The list constructor.
+
+ Args:
+ iterable: Optional elements to fill the list with.
+
+ Returns:
+ A list-like object. The exact return value depends on the initial elements.
+ """
+ if iterable:
+ elements = tuple(iterable)
+ else:
+ elements = ()
+
+ # TODO(mdan): Extend these criteria.
+ if any(isinstance(el, variables.Variable) for el in elements):
+ return _py_list_new(elements)
+ return _tf_tensor_list_new(elements)
-# TODO(mdan): Add support for TensorList once functional.
-# TODO(mdan): Add primitives for empty list, list with elements.
+def _tf_tensor_list_new(elements):
+ """Overload of new_list that stages a Tensor list creation."""
+ elements = tuple(ops.convert_to_tensor(el) for el in elements)
+ all_dtypes = set(el.dtype for el in elements)
+ if len(all_dtypes) == 1:
+ element_dtype = tuple(all_dtypes)[0]
+ else:
+ # Heterogeneous lists are ok.
+ element_dtype = dtypes.variant
+
+ # TODO(mdan): This may fail for elements of variable shapes.
+ all_shapes = set(tuple(el.shape.as_list()) for el in elements)
+ if len(all_shapes) == 1:
+ element_shape = array_ops.shape(elements[0])
+ else:
+ # Heterogeneous lists are ok.
+ element_shape = constant_op.constant(-1) # unknown shape, by convention
+
+ l = list_ops.empty_tensor_list(
+ element_shape=element_shape, element_dtype=element_dtype)
+ for el in elements:
+ l = list_ops.tensor_list_push_back(l, el)
+ return l
-def append(target, element):
+
+def _py_list_new(elements):
+ """Overload of new_list that creates a Python list."""
+ return list(elements)
+
+
+def list_append(list_, x):
"""The list append function.
- Note: it is unspecified where target will be mutated or not. If target is
- a TensorFlow entity, it will not be typically mutated. If target is a plain
- list, it will be. In general, if the target is mutated then the return value
+ Note: it is unspecified where list_ will be mutated or not. If list_ is
+ a TensorFlow entity, it will not be typically mutated. If list_ is a plain
+ list, it will be. In general, if the list is mutated then the return value
should point to the original entity.
Args:
- target: An entity that supports append semantics.
- element: The element to append.
+ list_: An entity that supports append semantics.
+ x: The element to append.
Returns:
- Same as target, after the append was performed.
+ Same as list_, after the append was performed.
+
+ Raises:
+ ValueError: if list_ is not of a known list-like type.
"""
- if isinstance(target, tensor_array_ops.TensorArray):
- return _tf_tensorarray_append(target, element)
+ if isinstance(list_, tensor_array_ops.TensorArray):
+ return _tf_tensorarray_append(list_, x)
+ elif tensor_util.is_tensor(list_):
+ if list_.dtype == dtypes.variant:
+ return _tf_tensor_list_append(list_, x)
+ else:
+ raise ValueError(
+ 'tensor lists are expected to be Tensors with dtype=tf.variant,'
+ ' instead found %s' % list_)
else:
- return _py_append(target, element)
+ return _py_list_append(list_, x)
+
+
+def _tf_tensor_list_append(list_, x):
+ """Overload of list_append that stages a Tensor list write."""
+ def empty_list_of_elements_like_x():
+ tensor_x = ops.convert_to_tensor(x)
+ return list_ops.empty_tensor_list(
+ element_shape=array_ops.shape(tensor_x),
+ element_dtype=tensor_x.dtype)
+
+ list_ = control_flow_ops.cond(
+ list_ops.tensor_list_length(list_) > 0,
+ lambda: list_,
+ empty_list_of_elements_like_x,
+ )
+ return list_ops.tensor_list_push_back(list_, x)
+
+
+def _tf_tensorarray_append(list_, x):
+ """Overload of list_append that stages a TensorArray write."""
+ return list_.write(list_.size(), x)
+
+
+def _py_list_append(list_, x):
+ """Overload of list_append that executes a Python list append."""
+ # Revert to the original call.
+ list_.append(x)
+ return list_
+
+
+class ListPopOpts(
+ collections.namedtuple('ListPopOpts', ('element_dtype', 'element_shape'))):
+ pass
+
+
+def list_pop(list_, i, opts):
+ """The list pop function.
+
+ Note: it is unspecified where list_ will be mutated or not. If list_ is
+ a TensorFlow entity, it will not be typically mutated. If list_ is a plain
+ list, it will be. In general, if the list is mutated then the return value
+ should point to the original entity.
+
+ Args:
+ list_: An entity that supports pop semantics.
+ i: Optional index to pop from. May be None.
+ opts: A ListPopOpts.
+
+ Returns:
+ Tuple (x, out_list_):
+ out_list_: same as list_, after the removal was performed.
+ x: the removed element value.
+
+ Raises:
+ ValueError: if list_ is not of a known list-like type or the operation is
+ not supported for that type.
+ """
+ assert isinstance(opts, ListPopOpts)
+
+ if isinstance(list_, tensor_array_ops.TensorArray):
+ raise ValueError('TensorArray does not support item removal')
+ elif tensor_util.is_tensor(list_):
+ if list_.dtype == dtypes.variant:
+ return _tf_tensor_list_pop(list_, i, opts)
+ else:
+ raise ValueError(
+ 'tensor lists are expected to be Tensors with dtype=tf.variant,'
+ ' instead found %s' % list_)
+ else:
+ return _py_list_pop(list_, i)
+
+
+def _tf_tensor_list_pop(list_, i, opts):
+ """Overload of list_pop that stages a Tensor list pop."""
+ if i is not None:
+ raise NotImplementedError('tensor lists only support removing from the end')
+
+ if opts.element_dtype is None:
+ raise ValueError('cannot pop from a list without knowing its element '
+ 'type; use set_element_type to annotate it')
+ if opts.element_shape is None:
+ raise ValueError('cannot pop from a list without knowing its element '
+ 'shape; use set_element_type to annotate it')
+ list_out, x = list_ops.tensor_list_pop_back(
+ list_, element_dtype=opts.element_dtype)
+ x.set_shape(opts.element_shape)
+ return list_out, x
+
+
+def _py_list_pop(list_, i):
+ """Overload of list_pop that executes a Python list append."""
+ if i is None:
+ x = list_.pop()
+ else:
+ x = list_.pop(i)
+ return list_, x
+
+
+# TODO(mdan): Look into reducing duplication between all these containers.
+class ListStackOpts(
+ collections.namedtuple('ListStackOpts',
+ ('element_dtype', 'original_call'))):
+ pass
+
+
+def list_stack(list_, opts):
+ """The list stack function.
+
+ This does not have a direct correspondent in Python. The closest idiom to
+ this is tf.append or np.stack. It's different from those in the sense that it
+ accepts a Tensor list, rather than a list of tensors. It can also accept
+ TensorArray. When the target is anything else, the dispatcher will rely on
+ ctx.original_call for fallback.
+
+ Args:
+ list_: An entity that supports append semantics.
+ opts: A ListStackOpts object.
+
+ Returns:
+ The output of the stack operation, typically a Tensor.
+ """
+ assert isinstance(opts, ListStackOpts)
+
+ if isinstance(list_, tensor_array_ops.TensorArray):
+ return _tf_tensorarray_stack(list_)
+ elif tensor_util.is_tensor(list_):
+ if list_.dtype == dtypes.variant:
+ return _tf_tensor_list_stack(list_, opts)
+ else:
+ # No-op for primitive Tensor arguments.
+ return list_
+ else:
+ return _py_list_stack(list_, opts)
+
+
+def _tf_tensorarray_stack(list_):
+ """Overload of list_stack that stages a TensorArray stack."""
+ return list_.stack()
-def _tf_tensorarray_append(target, element):
- """Overload of append that stages a TensorArray write at the last position."""
- return target.write(target.size(), element)
+def _tf_tensor_list_stack(list_, opts):
+ """Overload of list_stack that stages a Tensor list write."""
+ if opts.element_dtype is None:
+ raise ValueError('cannot stack a list without knowing its element type;'
+ ' use set_element_type to annotate it')
+ return list_ops.tensor_list_stack(list_, element_dtype=opts.element_dtype)
-def _py_append(target, element):
- """Overload of append that executes a Python list append."""
- target.append(element)
- return target
+def _py_list_stack(list_, opts):
+ """Overload of list_stack that executes a Python list append."""
+ # Revert to the original call.
+ return opts.original_call(list_)
diff --git a/tensorflow/contrib/autograph/operators/data_structures_test.py b/tensorflow/contrib/autograph/operators/data_structures_test.py
index 577d28c34d..8bbb52d6c1 100644
--- a/tensorflow/contrib/autograph/operators/data_structures_test.py
+++ b/tensorflow/contrib/autograph/operators/data_structures_test.py
@@ -19,25 +19,98 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.autograph.operators import data_structures
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import list_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.platform import test
-class AppendTest(test.TestCase):
+class ListTest(test.TestCase):
- def test_tf_tensorarray(self):
+ def test_new_list_empty(self):
+ l = data_structures.new_list()
+ # Can't evaluate an empty list.
+ # TODO(mdan): sess.run should allow tf.variant maybe?
+ self.assertTrue(isinstance(l, ops.Tensor))
+
+ def test_new_list_tensor(self):
+ l = data_structures.new_list([3, 4, 5])
+ t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
+ with self.test_session() as sess:
+ self.assertAllEqual(sess.run(t), [3, 4, 5])
+
+ def test_append_tensor_list(self):
+ l = data_structures.new_list()
+ x = constant_op.constant([1, 2, 3])
+ l = data_structures.list_append(l, x)
+
+ t = list_ops.tensor_list_stack(l, element_dtype=x.dtype)
+ with self.test_session() as sess:
+ self.assertAllEqual(sess.run(t), [[1, 2, 3]])
+
+ def test_append_tensorarray(self):
l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
- l1 = data_structures.append(l, 1)
- l2 = data_structures.append(l1, 2)
+ l1 = data_structures.list_append(l, 1)
+ l2 = data_structures.list_append(l1, 2)
with self.test_session() as sess:
self.assertAllEqual(sess.run(l1.stack()), [1])
self.assertAllEqual(sess.run(l2.stack()), [1, 2])
- def test_python(self):
+ def test_append_python(self):
l = []
- self.assertAllEqual(data_structures.append(l, 1), [1])
- self.assertAllEqual(data_structures.append(l, 2), [1, 2])
+ self.assertAllEqual(data_structures.list_append(l, 1), [1])
+ self.assertAllEqual(data_structures.list_append(l, 2), [1, 2])
+
+ def test_pop_tensor_list(self):
+ initial_list = constant_op.constant([[1, 2], [3, 4]])
+ elem_shape = constant_op.constant([2])
+ l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape)
+
+ opts = data_structures.ListPopOpts(
+ element_dtype=initial_list.dtype,
+ element_shape=(2,))
+
+ with self.assertRaises(NotImplementedError):
+ data_structures.list_pop(l, 0, opts)
+
+ with self.test_session() as sess:
+ l, x = data_structures.list_pop(l, None, opts)
+ self.assertAllEqual(sess.run(x), [3, 4])
+
+ t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype)
+ self.assertAllEqual(sess.run(t), [[1, 2]])
+
+ def test_pop_python(self):
+ l = [1, 2, 3]
+ opts = data_structures.ListPopOpts(element_dtype=None, element_shape=())
+ self.assertAllEqual(data_structures.list_pop(l, None, opts), ([1, 2], 3))
+ self.assertAllEqual(data_structures.list_pop(l, None, opts), ([1], 2))
+
+ def test_stack_tensor_list(self):
+ initial_list = constant_op.constant([[1, 2], [3, 4]])
+ elem_shape = constant_op.constant([2])
+ l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape)
+
+ opts = data_structures.ListStackOpts(
+ element_dtype=initial_list.dtype, original_call=None)
+
+ with self.test_session() as sess:
+ t = data_structures.list_stack(l, opts)
+ self.assertAllEqual(sess.run(t), sess.run(initial_list))
+
+ def test_stack_fallback(self):
+
+ def dummy_function(l):
+ # Lazy person's mock: just transform the argument in a way in which we
+ # can check that this function was indeed called.
+ return [x * 2 for x in l]
+
+ opts = data_structures.ListStackOpts(
+ element_dtype=None, original_call=dummy_function)
+
+ self.assertAllEqual(data_structures.list_stack([1, 2], opts), [2, 4])
if __name__ == '__main__':
diff --git a/tensorflow/contrib/autograph/operators/slices.py b/tensorflow/contrib/autograph/operators/slices.py
new file mode 100644
index 0000000000..04fbeb2f6e
--- /dev/null
+++ b/tensorflow/contrib/autograph/operators/slices.py
@@ -0,0 +1,133 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Operators specific to slicing operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import tensor_array_ops
+
+
+# TODO(mdan): Support extended slices.
+
+
+class GetItemOpts(collections.namedtuple('GetItemOpts', ('element_dtype',))):
+ pass
+
+
+def get_item(target, i, opts):
+ """The slice read operator (i.e. __getitem__).
+
+ Note: it is unspecified whether target will be mutated or not. In general,
+ if target is mutable (like Python lists), it will be mutated.
+
+ Args:
+ target: An entity that supports getitem semantics.
+ i: Index to read from.
+ opts: A GetItemOpts object.
+
+ Returns:
+ The read element.
+
+ Raises:
+ ValueError: if target is not of a supported type.
+ """
+ assert isinstance(opts, GetItemOpts)
+
+ if isinstance(target, tensor_array_ops.TensorArray):
+ return _tf_tensorarray_get_item(target, i)
+ elif tensor_util.is_tensor(target):
+ if target.dtype == dtypes.variant:
+ return _tf_tensor_list_get_item(target, i, opts)
+ else:
+ return _tf_tensor_get_item(target, i)
+ else:
+ return _py_get_item(target, i)
+
+
+def _tf_tensorarray_get_item(target, i):
+ """Overload of get_item that stages a TensorArray read."""
+ return target.read(i)
+
+
+def _tf_tensor_list_get_item(target, i, opts):
+ """Overload of get_item that stages a Tensor list read."""
+ if opts.element_dtype is None:
+ raise ValueError('cannot retrieve from a list without knowing its '
+ 'element type; use set_element_type to annotate it')
+ x = list_ops.tensor_list_get_item(target, i, element_dtype=opts.element_dtype)
+ return x
+
+
+def _tf_tensor_get_item(target, i):
+ """Overload of get_item that stages a Tensor (not Tensor list) read."""
+ return target[i]
+
+
+def _py_get_item(target, i):
+ """Overload of get_item that executes a Python list modification."""
+ return target[i]
+
+
+def set_item(target, i, x):
+ """The slice write operator (i.e. __setitem__).
+
+ Note: it is unspecified whether target will be mutated or not. In general,
+ if target is mutable (like Python lists), it will be mutated.
+
+ Args:
+ target: An entity that supports setitem semantics.
+ i: Index to modify.
+ x: The new element value.
+
+ Returns:
+ Same as target, after the update was performed.
+
+ Raises:
+ ValueError: if target is not of a supported type.
+ """
+ if isinstance(target, tensor_array_ops.TensorArray):
+ return _tf_tensorarray_set_item(target, i, x)
+ elif tensor_util.is_tensor(target):
+ if target.dtype == dtypes.variant:
+ return _tf_tensor_list_set_item(target, i, x)
+ else:
+ raise ValueError(
+ 'tensor lists are expected to be Tensors with dtype=tf.variant,'
+ ' instead found %s' % target)
+ else:
+ return _py_set_item(target, i, x)
+
+
+def _tf_tensorarray_set_item(target, i, x):
+ """Overload of set_item that stages a TensorArray write."""
+ return target.write(i, x)
+
+
+def _tf_tensor_list_set_item(target, i, x):
+ """Overload of set_item that stages a Tensor list update."""
+ return list_ops.tensor_list_set_item(target, i, x)
+
+
+def _py_set_item(target, i, x):
+ """Overload of set_item that executes a Python list modification."""
+ target[i] = x
+ return target
diff --git a/tensorflow/contrib/autograph/operators/slices_test.py b/tensorflow/contrib/autograph/operators/slices_test.py
new file mode 100644
index 0000000000..d4aacb9d20
--- /dev/null
+++ b/tensorflow/contrib/autograph/operators/slices_test.py
@@ -0,0 +1,51 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for slices module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph.operators import slices
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import list_ops
+from tensorflow.python.platform import test
+
+
+class SlicesTest(test.TestCase):
+
+ def test_set_item_tensor_list(self):
+ initial_list = constant_op.constant([[1, 2], [3, 4]])
+ elem_shape = constant_op.constant([2])
+ l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape)
+ l = slices.set_item(l, 0, [5, 6])
+
+ with self.test_session() as sess:
+ t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype)
+ self.assertAllEqual(sess.run(t), [[5, 6], [3, 4]])
+
+ def test_get_item_tensor_list(self):
+ initial_list = constant_op.constant([[1, 2], [3, 4]])
+ elem_shape = constant_op.constant([2])
+ l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape)
+ t = slices.get_item(
+ l, 1, slices.GetItemOpts(element_dtype=initial_list.dtype))
+
+ with self.test_session() as sess:
+ self.assertAllEqual(sess.run(t), [3, 4])
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/autograph/pyct/transformer.py b/tensorflow/contrib/autograph/pyct/transformer.py
index 4c65edb6de..60bca8b38d 100644
--- a/tensorflow/contrib/autograph/pyct/transformer.py
+++ b/tensorflow/contrib/autograph/pyct/transformer.py
@@ -70,14 +70,40 @@ class Base(gast.NodeTransformer):
return tuple(self._enclosing_entities)
@property
- def locel_scope_level(self):
+ def local_scope_level(self):
return len(self._local_scope_state)
- def enter_local_scope(self):
- self._local_scope_state.append({})
+ def enter_local_scope(self, inherit=None):
+ """Marks entry into a new local scope.
- def exit_local_scope(self):
- return self._local_scope_state.pop()
+ Args:
+ inherit: Optional enumerable of variable names to copy from the
+ parent scope.
+ """
+ scope_entered = {}
+ if inherit:
+ this_scope = self._local_scope_state[-1]
+ for name in inherit:
+ if name in this_scope:
+ scope_entered[name] = this_scope[name]
+ self._local_scope_state.append(scope_entered)
+
+ def exit_local_scope(self, keep=None):
+ """Marks exit from the current local scope.
+
+ Args:
+ keep: Optional enumerable of variable names to copy into the
+ parent scope.
+ Returns:
+ A dict containing the scope that has just been exited.
+ """
+ scope_left = self._local_scope_state.pop()
+ if keep:
+ this_scope = self._local_scope_state[-1]
+ for name in keep:
+ if name in scope_left:
+ this_scope[name] = scope_left[name]
+ return scope_left
def set_local(self, name, value):
self._local_scope_state[-1][name] = value
@@ -91,16 +117,76 @@ class Base(gast.NodeTransformer):
print(pretty_printer.fmt(node))
return node
- def visit_block(self, nodes):
- """Helper equivalent to generic_visit, but for node lists."""
+ def visit_block(self, nodes, before_visit=None, after_visit=None):
+ """A more powerful version of generic_visit for statement blocks.
+
+ An example of a block is the body of an if statement.
+
+ This function allows specifying a postprocessing callback (the
+ after_visit argument) argument which can be used to move nodes to a new
+ destination. This is done by after_visit by returning a non-null
+ second return value, e.g. return new_node, new_destination.
+
+ For example, a transformer could perform the following move:
+
+ foo()
+ bar()
+ baz()
+
+ foo()
+ if cond:
+ bar()
+ baz()
+
+ The above could be done with a postprocessor of this kind:
+
+ def after_visit(node):
+ if node_is_function_call(bar):
+ new_container_node = build_cond()
+ new_container_node.body.append(node)
+ return new_container_node, new_container_node.body
+ else:
+ # Once we set a new destination, all subsequent items will be
+ # moved to it, so we don't need to explicitly handle baz.
+ return node, None
+
+ Args:
+ nodes: enumerable of AST node objects
+ before_visit: optional callable that is called before visiting each item
+ in nodes
+ after_visit: optional callable that takes in an AST node and
+ returns a tuple (new_node, new_destination). It is called after
+ visiting each item in nodes. Is used in the same was as the
+ visit_* methods: new_node will replace the node; if not None,
+ new_destination must be a list, and subsequent nodes will be placed
+ in this list instead of the list returned by visit_block.
+ Returns:
+ A list of AST node objects containing the transformed items fron nodes,
+ except those nodes that have been relocated using after_visit.
+ """
results = []
+ node_destination = results
for node in nodes:
+ if before_visit:
+ # TODO(mdan): We can modify node here too, if ever needed.
+ before_visit()
+
replacement = self.visit(node)
+
+ if after_visit and replacement:
+ replacement, new_destination = after_visit(replacement)
+ else:
+ new_destination = None
+
if replacement:
if isinstance(replacement, (list, tuple)):
- results.extend(replacement)
+ node_destination.extend(replacement)
else:
- results.append(replacement)
+ node_destination.append(replacement)
+
+ # Allow the postprocessor to reroute the remaining nodes to a new list.
+ if new_destination is not None:
+ node_destination = new_destination
return results
# TODO(mdan): Once we have error tracing, we may be able to just go to SSA.
@@ -155,22 +241,39 @@ class Base(gast.NodeTransformer):
source_code = self.context.source_code
source_file = self.context.source_file
did_enter_function = False
- local_scope_state_size = len(self._local_scope_state)
+ local_scope_size_at_entry = len(self._local_scope_state)
try:
if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)):
- self._enclosing_entities.append(node)
did_enter_function = True
+ if did_enter_function:
+ self._enclosing_entities.append(node)
+
if source_code and hasattr(node, 'lineno'):
self._lineno = node.lineno
self._col_offset = node.col_offset
- if anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
- return node
- return super(Base, self).visit(node)
- except (ValueError, AttributeError, KeyError, NotImplementedError,
- AssertionError) as e:
+ if not anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
+ result = super(Base, self).visit(node)
+
+ # On exception, the local scope integrity is not guaranteed.
+ if did_enter_function:
+ self._enclosing_entities.pop()
+
+ if local_scope_size_at_entry != len(self._local_scope_state):
+ raise AssertionError(
+ 'Inconsistent local scope stack. Before entering node %s, the'
+ ' stack had length %d, after exit it has length %d. This'
+ ' indicates enter_local_scope and exit_local_scope are not'
+ ' well paired.' % (
+ node,
+ local_scope_size_at_entry,
+ len(self._local_scope_state)
+ ))
+ return result
+
+ except (ValueError, AttributeError, KeyError, NotImplementedError) as e:
msg = '%s: %s\nOffending source:\n%s\n\nOccurred at node:\n%s' % (
e.__class__.__name__, str(e), try_ast_to_source(node),
pretty_printer.fmt(node, color=False))
@@ -178,18 +281,11 @@ class Base(gast.NodeTransformer):
line = source_code.splitlines()[self._lineno - 1]
else:
line = '<no source available>'
+ # TODO(mdan): Avoid the printing of the original exception.
+ # In other words, we need to find how to suppress the "During handling
+ # of the above exception, another exception occurred" message.
six.reraise(AutographParseError,
AutographParseError(
msg,
(source_file, self._lineno, self._col_offset + 1, line)),
sys.exc_info()[2])
- finally:
- if did_enter_function:
- self._enclosing_entities.pop()
-
- if local_scope_state_size != len(self._local_scope_state):
- raise AssertionError(
- 'Inconsistent local scope stack. Before entering node %s, the'
- ' stack had length %d, after exit it has length %d. This'
- ' indicates enter_local_scope and exit_local_scope are not'
- ' well paired.')
diff --git a/tensorflow/contrib/autograph/pyct/transformer_test.py b/tensorflow/contrib/autograph/pyct/transformer_test.py
index 1f1adf4fbd..f110e79605 100644
--- a/tensorflow/contrib/autograph/pyct/transformer_test.py
+++ b/tensorflow/contrib/autograph/pyct/transformer_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import gast
+
from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import context
from tensorflow.contrib.autograph.pyct import parser
@@ -27,7 +29,7 @@ from tensorflow.python.platform import test
class TransformerTest(test.TestCase):
- def _context_for_nodetesting(self):
+ def _context_for_testing(self):
return context.EntityContext(
namer=None,
source_code=None,
@@ -53,7 +55,7 @@ class TransformerTest(test.TestCase):
anno.setanno(node, 'enclosing_entities', self.enclosing_entities)
return self.generic_visit(node)
- tr = TestTransformer(self._context_for_nodetesting())
+ tr = TestTransformer(self._context_for_testing())
def test_function():
a = 0
@@ -116,7 +118,7 @@ class TransformerTest(test.TestCase):
def visit_For(self, node):
return self._annotate_result(node)
- tr = TestTransformer(self._context_for_nodetesting())
+ tr = TestTransformer(self._context_for_testing())
def test_function(a):
"""Docstring."""
@@ -155,7 +157,7 @@ class TransformerTest(test.TestCase):
self.exit_local_scope()
return node
- tr = TestTransformer(self._context_for_nodetesting())
+ tr = TestTransformer(self._context_for_testing())
def no_exit(a):
if a > 0:
@@ -174,6 +176,38 @@ class TransformerTest(test.TestCase):
with self.assertRaises(AssertionError):
tr.visit(node)
+ def test_visit_block_postprocessing(self):
+
+ class TestTransformer(transformer.Base):
+
+ def _process_body_item(self, node):
+ if isinstance(node, gast.Assign) and (node.value.id == 'y'):
+ if_node = gast.If(gast.Name('x', gast.Load(), None), [node], [])
+ return if_node, if_node.body
+ return node, None
+
+ def visit_FunctionDef(self, node):
+ node.body = self.visit_block(
+ node.body, after_visit=self._process_body_item)
+ return node
+
+ def test_function(x, y):
+ z = x
+ z = y
+ return z
+
+ tr = TestTransformer(self._context_for_testing())
+
+ node, _ = parser.parse_entity(test_function)
+ node = tr.visit(node)
+ node = node.body[0]
+
+ self.assertEqual(len(node.body), 2)
+ self.assertTrue(isinstance(node.body[0], gast.Assign))
+ self.assertTrue(isinstance(node.body[1], gast.If))
+ self.assertTrue(isinstance(node.body[1].body[0], gast.Assign))
+ self.assertTrue(isinstance(node.body[1].body[1], gast.Return))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py
index 9ab124ae72..8c8c5acb31 100644
--- a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py
+++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py
@@ -53,6 +53,12 @@ class BlockCacheParams(object):
class ConfigureGcsHook(training.SessionRunHook):
"""ConfigureGcsHook configures GCS when used with Estimator/TPUEstimator.
+ Warning: GCS `credentials` may be transmitted over the network unencrypted.
+ Please ensure that the network is trusted before using this function. For
+ users running code entirely within Google Cloud, your data is protected by
+ encryption in between data centers. For more information, please take a look
+ at https://cloud.google.com/security/encryption-in-transit/.
+
Example:
```
@@ -135,6 +141,12 @@ class ConfigureGcsHook(training.SessionRunHook):
def configure_gcs(session, credentials=None, block_cache=None, device=None):
"""Configures the GCS file system for a given a session.
+ Warning: GCS `credentials` may be transmitted over the network unencrypted.
+ Please ensure that the network is trusted before using this function. For
+ users running code entirely within Google Cloud, your data is protected by
+ encryption in between data centers. For more information, please take a look
+ at https://cloud.google.com/security/encryption-in-transit/.
+
Args:
session: A `tf.Session` session that should be used to configure the GCS
file system.
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index 03a937cd7f..1959ad028a 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -725,7 +725,7 @@ endif()
########################################################
# Parse tensorflow/tools/api/generator/BUILD to get list of generated files.
-FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/BUILD api_generator_BUILD_text)
+FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/api_gen.bzl api_generator_BUILD_text)
STRING(REGEX MATCH "# BEGIN GENERATED FILES.*# END GENERATED FILES" api_init_files_text ${api_generator_BUILD_text})
string(REPLACE "# BEGIN GENERATED FILES" "" api_init_files_text ${api_init_files_text})
string(REPLACE "# END GENERATED FILES" "" api_init_files_text ${api_init_files_text})
@@ -736,7 +736,7 @@ foreach(api_init_file ${api_init_files_list})
string(STRIP "${api_init_file}" api_init_file)
if(api_init_file)
string(REPLACE "\"" "" api_init_file "${api_init_file}") # Remove quotes
- list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/${api_init_file}")
+ list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/${api_init_file}")
endif()
endforeach(api_init_file)
set(api_init_list_file "${tensorflow_source_dir}/api_init_files_list.txt")
@@ -749,18 +749,14 @@ add_custom_command(
# tensorflow/__init__.py depends on files generated in this step. So, remove it while
# this step is running since the files aren't there yet.
- COMMAND ${CMAKE_COMMAND} -E rename ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py
- ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/final.__init__.py
- COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py
+ COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py
# Run create_python_api.py to generate API init files.
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE}
- "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" "${api_init_list_file}"
-
- # Re-add tensorflow/__init__.py back.
- COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py
- COMMAND ${CMAKE_COMMAND} -E rename ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/final.__init__.py
- ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py
+ "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py"
+ "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py"
+ "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow"
+ "${api_init_list_file}"
COMMENT "Generating __init__.py files for Python API."
WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python"
diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake
index 5942ff3363..eb9482dc25 100644
--- a/tensorflow/contrib/cmake/tf_tests.cmake
+++ b/tensorflow/contrib/cmake/tf_tests.cmake
@@ -212,6 +212,10 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/gmm_test.py"
# Disable following manual tag in BUILD.
"${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py"
+ # These tests depend on a .so file
+ ${tensorflow_source_dir}/tensorflow/python/kernel_tests/duplicate_op_test.py
+ ${tensorflow_source_dir}/tensorflow/python/kernel_tests/invalid_op_test.py
+ ${tensorflow_source_dir}/tensorflow/python/kernel_tests/ackermann_test.py
)
if (WIN32)
diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
index 76e54a284e..97cc0bc6c9 100644
--- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
@@ -133,7 +133,7 @@ class CSVDatasetOp : public DatasetOpKernel {
delim_(delim),
na_value_(std::move(na_value)) {}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::CSV")}));
@@ -145,7 +145,7 @@ class CSVDatasetOp : public DatasetOpKernel {
return output_shapes_;
}
- string DebugString() override { return "CSVDatasetOp::Dataset"; }
+ string DebugString() const override { return "CSVDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc
index 48d3734162..6a12ca06f4 100644
--- a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc
@@ -91,7 +91,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
}
}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
{this, strings::StrCat(prefix, "::DirectedInterleave")}));
@@ -105,7 +105,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
return output_shapes_;
}
- string DebugString() override {
+ string DebugString() const override {
return strings::StrCat("DirectedInterleaveDatasetOp::Dataset");
}
@@ -130,15 +130,21 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
- selector_input_impl_(params.dataset->selector_input_->MakeIterator(
- params.prefix + ".selector")),
- num_active_inputs_(params.dataset->data_inputs_.size()) {
- data_input_impls_.reserve(params.dataset->data_inputs_.size());
- for (size_t i = 0; i < params.dataset->data_inputs_.size(); ++i) {
- const DatasetBase* data_input = params.dataset->data_inputs_[i];
- data_input_impls_.push_back(data_input->MakeIterator(
- strings::StrCat(params.prefix, "[", i, "]")));
+ num_active_inputs_(params.dataset->data_inputs_.size()) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator(
+ ctx, strings::StrCat(prefix(), ".selector"),
+ &selector_input_impl_));
+ data_input_impls_.resize(dataset()->data_inputs_.size());
+ for (size_t i = 0; i < data_input_impls_.size(); ++i) {
+ const DatasetBase* data_input = dataset()->data_inputs_[i];
+ TF_RETURN_IF_ERROR(data_input->MakeIterator(
+ ctx, strings::StrCat(prefix(), "[", i, "]"),
+ &data_input_impls_[i]));
}
+ return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
diff --git a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc
index bb29df60e8..bbec50681c 100644
--- a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc
@@ -44,7 +44,7 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::IgnoreErrors")}));
@@ -57,7 +57,9 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
return input_->output_shapes();
}
- string DebugString() override { return "IgnoreErrorsDatasetOp::Dataset"; }
+ string DebugString() const override {
+ return "IgnoreErrorsDatasetOp::Dataset";
+ }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
@@ -72,8 +74,11 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
index 63e19ae3f8..3dfc3741c2 100644
--- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
@@ -127,7 +127,7 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
threadpool_->Unref();
}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::ThreadPool")}));
@@ -140,7 +140,9 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
return input_->output_shapes();
}
- string DebugString() override { return "ThreadPoolDatasetOp::Dataset"; }
+ string DebugString() const override {
+ return "ThreadPoolDatasetOp::Dataset";
+ }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
@@ -154,8 +156,11 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
diff --git a/tensorflow/contrib/data/kernels/unique_dataset_op.cc b/tensorflow/contrib/data/kernels/unique_dataset_op.cc
index 69fbb0fcdc..67c237799c 100644
--- a/tensorflow/contrib/data/kernels/unique_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/unique_dataset_op.cc
@@ -56,7 +56,7 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel {
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Unique")}));
@@ -70,7 +70,7 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel {
return input_->output_shapes();
}
- string DebugString() override {
+ string DebugString() const override {
return strings::StrCat("UniqueDatasetOp::Dataset");
}
@@ -87,8 +87,11 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const typename Iterator::Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 285c77dea9..c483a43769 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -8,7 +8,7 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test", "py_test", "tf_py_test")
py_test(
name = "batch_dataset_op_test",
- size = "large",
+ size = "medium",
srcs = ["batch_dataset_op_test.py"],
srcs_version = "PY2AND3",
tags = [
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index e309d611e1..b5fbc45ad3 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -553,14 +553,14 @@ class BatchDatasetTest(test.TestCase):
sess.run(next_element)
def testMapAndBatchParallelGetNext(self):
- iterator = (dataset_ops.Dataset.range(500000)
+ iterator = (dataset_ops.Dataset.range(50000)
.apply(batching.map_and_batch(lambda x: x, batch_size=100))
.make_one_shot_iterator())
elements = []
for _ in range(100):
elements.append(iterator.get_next())
with self.test_session() as sess:
- for i in range(50):
+ for i in range(5):
got = sess.run(elements)
got.sort(key=lambda x: x[0])
expected = []
@@ -572,7 +572,7 @@ class BatchDatasetTest(test.TestCase):
def testMapAndBatchParallelGetNextDropRemainder(self):
iterator = (
- dataset_ops.Dataset.range(499999).apply(
+ dataset_ops.Dataset.range(49999).apply(
batching.map_and_batch(
lambda x: x, batch_size=100, drop_remainder=True))
.make_one_shot_iterator())
@@ -580,7 +580,7 @@ class BatchDatasetTest(test.TestCase):
for _ in range(100):
elements.append(iterator.get_next())
with self.test_session() as sess:
- for i in range(49):
+ for i in range(4):
got = sess.run(elements)
got.sort(key=lambda x: x[0])
expected = []
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb
index d9a9bffbb4..84f1d031d4 100644
--- a/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb
+++ b/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb
@@ -54,11 +54,41 @@
"source": [
"## Variables\n",
"\n",
- "Neural networks are characterized by a set of parameters (sometimes called \"weights\", sometimes called \"variables\") with fixed shapes and types, where the actual values are computed and adjusted during the training process. The `tfe.Variable` object encapsulates such parameters.\n",
- "\n",
- "Recall that `Tensor` objects are immutable, i.e., the underlying value of the `Tensor` cannot be changed. `Variable` objects act like `Tensor`s but are mutable via calls to `assign`, `assign_add` etc.\n",
+ "Tensors in TensorFlow are immutable stateless objects. Machine learning models, however, need to have changing state: as your model trains, the same code to compute predictions should behave differently over time (hopefully with a lower loss!). To represent this state which needs to change over the course of your computation, you can choose to rely on the fact that Python is a stateful programming language:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "VkJwtLS_Jbn8"
+ },
+ "outputs": [],
+ "source": [
+ "# Using python state\n",
+ "x = tf.zeros([10, 10])\n",
+ "x += 2 # This is equivalent to x = x + 2, which does not mutate the original\n",
+ " # value of x\n",
+ "print(x)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "wfneTXy7JcUz"
+ },
+ "source": [
+ "TensorFlow, however, has stateful operations built in, and these are often more pleasant to use than low-level Python representations of your state. To represent weights in a model, for example, it's often convenient and efficient to use TensorFlow variables.\n",
"\n",
- "For example:"
+ "A Variable is an object which stores a value and, when used in a TensorFlow computation, will implicitly read from this stored value. There are operations (`tf.assign_sub`, `tf.scatter_update`, etc) which manipulate the value stored in a TensorFlow variable."
]
},
{
@@ -92,6 +122,18 @@
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
+ "id": "-paSaeq1JzwC"
+ },
+ "source": [
+ "Computations using Variables are automatically traced when computing gradients. For Variables representing embeddings TensorFlow will do sparse updates by default, which are more computation and memory efficient.\n",
+ "\n",
+ "Using Variables is also a way to quickly let a reader of your code know that this piece of state is mutable."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
"id": "BMiFcDzE7Qu3"
},
"source": [
@@ -228,7 +270,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
@@ -331,7 +373,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb
new file mode 100644
index 0000000000..4fe3a0e3f3
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb
@@ -0,0 +1,551 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "pwX7Fii1rwsJ"
+ },
+ "outputs": [],
+ "source": [
+ "import tensorflow as tf\n",
+ "tf.enable_eager_execution()\n",
+ "tfe = tf.contrib.eager\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "UEu3q4jmpKVT"
+ },
+ "source": [
+ "# High level API\n",
+ "\n",
+ "We recommend using `tf.keras` as a high-level API for building neural networks. That said, most TensorFlow APIs are usable with eager execution.\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "zSFfVVjkrrsI"
+ },
+ "source": [
+ "## Layers: common sets of useful operations\n",
+ "\n",
+ "Most of the time when writing code for machine learning models you want to operate at a higher level of abstraction than individual operations and manipulation of individual variables.\n",
+ "\n",
+ "Many machine learning models are expressible as the composition and stacking of relatively simple layers, and TensorFlow provides both a set of many common layers as a well as easy ways for you to write your own application-specific layers either from scratch or as the composition of existing layers.\n",
+ "\n",
+ "TensorFlow includes the full [Keras](https://keras.io) API in the tf.keras package, and the Keras layers are very useful when building your own models.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "8PyXlPl-4TzQ"
+ },
+ "outputs": [],
+ "source": [
+ "# In the tf.keras.layers package, layers are objects. To construct a layer,\n",
+ "# simply construct the object. Most layers take as a first argument the number\n",
+ "# of output dimensions / channels.\n",
+ "layer = tf.keras.layers.Dense(100)\n",
+ "# The number of input dimensionss is often unnecessary, as it can be inferred\n",
+ "# the first time the layer is used, but it can be provided if you want to \n",
+ "# specify it manually, which is useful in some complex models.\n",
+ "layer = tf.keras.layers.Dense(10, input_shape=(None, 5))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "Fn69xxPO5Psr"
+ },
+ "source": [
+ "The full list of pre-existing layers can be seen in [the documentation](https://www.tensorflow.org/api_docs/python/tf/keras/layers). It includes Dense (a fully-connected layer),\n",
+ "Conv2D, LSTM, BatchNormalization, Dropout, and many others."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ },
+ "height": 204
+ },
+ "colab_type": "code",
+ "executionInfo": {
+ "elapsed": 244,
+ "status": "ok",
+ "timestamp": 1527783641557,
+ "user": {
+ "displayName": "",
+ "photoUrl": "",
+ "userId": ""
+ },
+ "user_tz": 420
+ },
+ "id": "E3XKNknP5Mhb",
+ "outputId": "c5d52434-d980-4488-efa7-5660819d0207"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "\u003ctf.Tensor: id=30, shape=(10, 10), dtype=float32, numpy=\n",
+ "array([[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
+ " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
+ " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
+ " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
+ " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
+ " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
+ " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
+ " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
+ " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
+ " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)\u003e"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# To use a layer, simply call it.\n",
+ "layer(tf.zeros([10, 5]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ },
+ "height": 221
+ },
+ "colab_type": "code",
+ "executionInfo": {
+ "elapsed": 320,
+ "status": "ok",
+ "timestamp": 1527783642457,
+ "user": {
+ "displayName": "",
+ "photoUrl": "",
+ "userId": ""
+ },
+ "user_tz": 420
+ },
+ "id": "Wt_Nsv-L5t2s",
+ "outputId": "f0d96dce-0128-4080-bfe2-0ee6fbc0ad90"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[\u003ctf.Variable 'dense_1/kernel:0' shape=(5, 10) dtype=float32, numpy=\n",
+ " array([[ 0.43788117, -0.62099844, -0.30525017, -0.59352523, 0.1783089 ,\n",
+ " 0.47078604, -0.23620895, -0.30482283, 0.01366901, -0.1288507 ],\n",
+ " [ 0.18407935, -0.56550485, 0.54180616, -0.42254075, 0.3702994 ,\n",
+ " 0.36705834, -0.29678228, 0.36660975, 0.36717761, 0.46269661],\n",
+ " [ 0.1709305 , -0.11529458, 0.32710236, 0.46300393, -0.62802851,\n",
+ " 0.51641601, 0.39624029, 0.26918125, -0.25196898, 0.21353298],\n",
+ " [ 0.35752094, 0.44161648, 0.61500639, -0.12653333, 0.41629118,\n",
+ " 0.36193585, 0.066082 , -0.59253877, 0.47318751, 0.17115968],\n",
+ " [-0.22554061, -0.17727301, 0.5525015 , 0.3678053 , -0.00454676,\n",
+ " 0.24066836, -0.53640735, 0.13792562, -0.10727292, 0.59708995]], dtype=float32)\u003e,\n",
+ " \u003ctf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32, numpy=array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)\u003e]"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Layers have many useful methods. For example, you can inspect all variables\n",
+ "# in a layer by calling layer.variables. In this case a fully-connected layer\n",
+ "# will have variables for weights and biases.\n",
+ "layer.variables"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ },
+ "height": 221
+ },
+ "colab_type": "code",
+ "executionInfo": {
+ "elapsed": 226,
+ "status": "ok",
+ "timestamp": 1527783643252,
+ "user": {
+ "displayName": "",
+ "photoUrl": "",
+ "userId": ""
+ },
+ "user_tz": 420
+ },
+ "id": "6ilvKjz8_4MQ",
+ "outputId": "f647fced-c2d7-41a3-c237-242036784665"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(\u003ctf.Variable 'dense_1/kernel:0' shape=(5, 10) dtype=float32, numpy=\n",
+ " array([[ 0.43788117, -0.62099844, -0.30525017, -0.59352523, 0.1783089 ,\n",
+ " 0.47078604, -0.23620895, -0.30482283, 0.01366901, -0.1288507 ],\n",
+ " [ 0.18407935, -0.56550485, 0.54180616, -0.42254075, 0.3702994 ,\n",
+ " 0.36705834, -0.29678228, 0.36660975, 0.36717761, 0.46269661],\n",
+ " [ 0.1709305 , -0.11529458, 0.32710236, 0.46300393, -0.62802851,\n",
+ " 0.51641601, 0.39624029, 0.26918125, -0.25196898, 0.21353298],\n",
+ " [ 0.35752094, 0.44161648, 0.61500639, -0.12653333, 0.41629118,\n",
+ " 0.36193585, 0.066082 , -0.59253877, 0.47318751, 0.17115968],\n",
+ " [-0.22554061, -0.17727301, 0.5525015 , 0.3678053 , -0.00454676,\n",
+ " 0.24066836, -0.53640735, 0.13792562, -0.10727292, 0.59708995]], dtype=float32)\u003e,\n",
+ " \u003ctf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32, numpy=array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)\u003e)"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# The variables are also accessible through nice accessors\n",
+ "layer.kernel, layer.bias"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "O0kDbE54-5VS"
+ },
+ "source": [
+ "## Implementing custom layers\n",
+ "The best way to implement your own layer is extending the tf.keras.Layer class and implementing:\n",
+ " * `__init__` , where you can do all input-independent initialization\n",
+ " * `build`, where you know the shapes of the input tensors and can do the rest of the initialization\n",
+ " * `call`, where you do the forward computation\n",
+ "\n",
+ "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`. However, the advantage of creating them in `build` is that it enables late variable creation based on the shape of the inputs the layer will operate on. On the other hand, creating variables in `__init__` would mean that shapes requires to create the variables will need to be explicitly specified."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ },
+ "height": 391
+ },
+ "colab_type": "code",
+ "executionInfo": {
+ "elapsed": 251,
+ "status": "ok",
+ "timestamp": 1527783661512,
+ "user": {
+ "displayName": "",
+ "photoUrl": "",
+ "userId": ""
+ },
+ "user_tz": 420
+ },
+ "id": "5Byl3n1k5kIy",
+ "outputId": "6e7f9285-649a-4132-82ce-73ea92f15862"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "tf.Tensor(\n",
+ "[[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
+ " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
+ " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
+ " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
+ " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
+ " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
+ " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
+ " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
+ " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
+ " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(10, 10), dtype=float32)\n",
+ "[\u003ctf.Variable 'my_dense_layer_1/kernel:0' shape=(5, 10) dtype=float32, numpy=\n",
+ "array([[-0.4011991 , 0.22458655, -0.33237562, -0.25117266, 0.33528614,\n",
+ " -0.01392961, 0.58580834, -0.16346583, 0.28465688, -0.47191954],\n",
+ " [-0.52922136, 0.22416979, -0.58209574, -0.60914612, 0.05226624,\n",
+ " -0.18325993, 0.5591442 , -0.24718609, 0.37148207, 0.40475875],\n",
+ " [ 0.16912812, -0.47618777, -0.38989353, 0.30105609, -0.08085585,\n",
+ " 0.44758242, 0.545829 , 0.51421839, 0.11063248, 0.20159996],\n",
+ " [ 0.34073615, -0.59835428, 0.06498981, -0.44489855, -0.34302285,\n",
+ " 0.20969599, 0.35527444, -0.03173476, -0.22227573, 0.09303057],\n",
+ " [ 0.41764337, -0.06435019, -0.52509922, -0.39957345, 0.56811184,\n",
+ " 0.23481232, -0.61666459, 0.31144124, -0.11532354, -0.42421889]], dtype=float32)\u003e]\n"
+ ]
+ }
+ ],
+ "source": [
+ "class MyDenseLayer(tf.keras.layers.Layer):\n",
+ " def __init__(self, num_outputs):\n",
+ " super(MyDenseLayer, self).__init__()\n",
+ " self.num_outputs = num_outputs\n",
+ " \n",
+ " def build(self, input_shape):\n",
+ " self.kernel = self.add_variable(\"kernel\", \n",
+ " shape=[input_shape[-1].value, \n",
+ " self.num_outputs])\n",
+ " \n",
+ " def call(self, input):\n",
+ " return tf.matmul(input, self.kernel)\n",
+ " \n",
+ "layer = MyDenseLayer(10)\n",
+ "print(layer(tf.zeros([10, 5])))\n",
+ "print(layer.variables)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "tk8E2vY0-z4Z"
+ },
+ "source": [
+ "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`.\n",
+ "\n",
+ "Overall code is easier to read and maintain if it uses standard layers whenever possible, as other readers will be familiar with the behavior of standard layers. If you want to use a layer which is not present in tf.keras.layers or tf.contrib.layers, consider filing a [github issue](http://github.com/tensorflow/tensorflow/issues/new) or, even better, sending us a pull request!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "Qhg4KlbKrs3G"
+ },
+ "source": [
+ "## Models: composing layers\n",
+ "\n",
+ "Many interesting layer-like things in machine learning models are implemented by composing existing layers. For example, each residual block in a resnet is a composition of convolutions, batch normalizations, and a shortcut.\n",
+ "\n",
+ "The main class used when creating a layer-like thing which contains other layers is tf.keras.Model. Implementing one is done by inheriting from tf.keras.Model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ },
+ "height": 190
+ },
+ "colab_type": "code",
+ "executionInfo": {
+ "elapsed": 420,
+ "status": "ok",
+ "timestamp": 1527783698512,
+ "user": {
+ "displayName": "",
+ "photoUrl": "",
+ "userId": ""
+ },
+ "user_tz": 420
+ },
+ "id": "N30DTXiRASlb",
+ "outputId": "a8b23a8e-5cf9-4bbf-f93b-6c763d74e2b3"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "tf.Tensor(\n",
+ "[[[[ 0. 0. 0.]\n",
+ " [ 0. 0. 0.]\n",
+ " [ 0. 0. 0.]]\n",
+ "\n",
+ " [[ 0. 0. 0.]\n",
+ " [ 0. 0. 0.]\n",
+ " [ 0. 0. 0.]]]], shape=(1, 2, 3, 3), dtype=float32)\n",
+ "['resnet_identity_block_1/conv2d_3/kernel:0', 'resnet_identity_block_1/conv2d_3/bias:0', 'resnet_identity_block_1/batch_normalization_3/gamma:0', 'resnet_identity_block_1/batch_normalization_3/beta:0', 'resnet_identity_block_1/conv2d_4/kernel:0', 'resnet_identity_block_1/conv2d_4/bias:0', 'resnet_identity_block_1/batch_normalization_4/gamma:0', 'resnet_identity_block_1/batch_normalization_4/beta:0', 'resnet_identity_block_1/conv2d_5/kernel:0', 'resnet_identity_block_1/conv2d_5/bias:0', 'resnet_identity_block_1/batch_normalization_5/gamma:0', 'resnet_identity_block_1/batch_normalization_5/beta:0', 'resnet_identity_block_1/batch_normalization_3/moving_mean:0', 'resnet_identity_block_1/batch_normalization_3/moving_variance:0', 'resnet_identity_block_1/batch_normalization_4/moving_mean:0', 'resnet_identity_block_1/batch_normalization_4/moving_variance:0', 'resnet_identity_block_1/batch_normalization_5/moving_mean:0', 'resnet_identity_block_1/batch_normalization_5/moving_variance:0']\n"
+ ]
+ }
+ ],
+ "source": [
+ "class ResnetIdentityBlock(tf.keras.Model):\n",
+ " def __init__(self, kernel_size, filters):\n",
+ " super(ResnetIdentityBlock, self).__init__(name='')\n",
+ " filters1, filters2, filters3 = filters\n",
+ "\n",
+ " self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))\n",
+ " self.bn2a = tf.keras.layers.BatchNormalization()\n",
+ "\n",
+ " self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')\n",
+ " self.bn2b = tf.keras.layers.BatchNormalization()\n",
+ "\n",
+ " self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))\n",
+ " self.bn2c = tf.keras.layers.BatchNormalization()\n",
+ "\n",
+ " def call(self, input_tensor, training=False):\n",
+ " x = self.conv2a(input_tensor)\n",
+ " x = self.bn2a(x, training=training)\n",
+ " x = tf.nn.relu(x)\n",
+ "\n",
+ " x = self.conv2b(x)\n",
+ " x = self.bn2b(x, training=training)\n",
+ " x = tf.nn.relu(x)\n",
+ "\n",
+ " x = self.conv2c(x)\n",
+ " x = self.bn2c(x, training=training)\n",
+ "\n",
+ " x += input_tensor\n",
+ " return tf.nn.relu(x)\n",
+ "\n",
+ " \n",
+ "block = ResnetIdentityBlock(1, [1, 2, 3])\n",
+ "print(block(tf.zeros([1, 2, 3, 3])))\n",
+ "print([x.name for x in block.variables])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "wYfucVw65PMj"
+ },
+ "source": [
+ "Much of the time, however, models which compose many layers simply call one layer after the other. This can be done in very little code using tf.keras.Sequential"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ },
+ "base_uri": "https://localhost:8080/",
+ "height": 153
+ },
+ "colab_type": "code",
+ "executionInfo": {
+ "elapsed": 361,
+ "status": "ok",
+ "timestamp": 1526674830777,
+ "user": {
+ "displayName": "Alexandre Passos",
+ "photoUrl": "//lh4.googleusercontent.com/-kmTTWXEgAPw/AAAAAAAAAAI/AAAAAAAAAC0/q_DoOzKGwds/s50-c-k-no/photo.jpg",
+ "userId": "108023195365833072773"
+ },
+ "user_tz": 420
+ },
+ "id": "L9frk7Ur4uvJ",
+ "outputId": "882e9076-b6d9-4380-bb1e-7c6b57d54c39"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "\u003ctf.Tensor: id=1423, shape=(1, 2, 3, 3), dtype=float32, numpy=\n",
+ "array([[[[0., 0., 0.],\n",
+ " [0., 0., 0.],\n",
+ " [0., 0., 0.]],\n",
+ "\n",
+ " [[0., 0., 0.],\n",
+ " [0., 0., 0.],\n",
+ " [0., 0., 0.]]]], dtype=float32)\u003e"
+ ]
+ },
+ "execution_count": 26,
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ " my_seq = tf.keras.Sequential([tf.keras.layers.Conv2D(1, (1, 1)),\n",
+ " tf.keras.layers.BatchNormalization(),\n",
+ " tf.keras.layers.Conv2D(2, 1, \n",
+ " padding='same'),\n",
+ " tf.keras.layers.BatchNormalization(),\n",
+ " tf.keras.layers.Conv2D(3, (1, 1)),\n",
+ " tf.keras.layers.BatchNormalization()])\n",
+ "my_seq(tf.zeros([1, 2, 3, 3]))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "c5YwYcnuK-wc"
+ },
+ "source": [
+ "# Next steps\n",
+ "\n",
+ "Now you can go back to the previous notebook and adapt the linear regression example to use layers and models to be better structured."
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [],
+ "default_view": {},
+ "name": "4 - High level API - TensorFlow Eager.ipynb",
+ "provenance": [],
+ "version": "0.3.2",
+ "views": {}
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py
index 09745e2de5..8f73274c2a 100644
--- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py
+++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py
@@ -197,7 +197,8 @@ class WALSModel(object):
row_weights=1,
col_weights=1,
use_factors_weights_cache=True,
- use_gramian_cache=True):
+ use_gramian_cache=True,
+ use_scoped_vars=False):
"""Creates model for WALS matrix factorization.
Args:
@@ -239,6 +240,8 @@ class WALSModel(object):
weights cache to take effect.
use_gramian_cache: When True, the Gramians will be cached on the workers
before the updates start. Defaults to True.
+ use_scoped_vars: When True, the factor and weight vars will also be nested
+ in a tf.name_scope.
"""
self._input_rows = input_rows
self._input_cols = input_cols
@@ -251,18 +254,36 @@ class WALSModel(object):
regularization * linalg_ops.eye(self._n_components)
if regularization is not None else None)
assert (row_weights is None) == (col_weights is None)
- self._row_weights = WALSModel._create_weights(
- row_weights, self._input_rows, self._num_row_shards, "row_weights")
- self._col_weights = WALSModel._create_weights(
- col_weights, self._input_cols, self._num_col_shards, "col_weights")
self._use_factors_weights_cache = use_factors_weights_cache
self._use_gramian_cache = use_gramian_cache
- self._row_factors = self._create_factors(
- self._input_rows, self._n_components, self._num_row_shards, row_init,
- "row_factors")
- self._col_factors = self._create_factors(
- self._input_cols, self._n_components, self._num_col_shards, col_init,
- "col_factors")
+
+ if use_scoped_vars:
+ with ops.name_scope("row_weights"):
+ self._row_weights = WALSModel._create_weights(
+ row_weights, self._input_rows, self._num_row_shards, "row_weights")
+ with ops.name_scope("col_weights"):
+ self._col_weights = WALSModel._create_weights(
+ col_weights, self._input_cols, self._num_col_shards, "col_weights")
+ with ops.name_scope("row_factors"):
+ self._row_factors = self._create_factors(
+ self._input_rows, self._n_components, self._num_row_shards,
+ row_init, "row_factors")
+ with ops.name_scope("col_factors"):
+ self._col_factors = self._create_factors(
+ self._input_cols, self._n_components, self._num_col_shards,
+ col_init, "col_factors")
+ else:
+ self._row_weights = WALSModel._create_weights(
+ row_weights, self._input_rows, self._num_row_shards, "row_weights")
+ self._col_weights = WALSModel._create_weights(
+ col_weights, self._input_cols, self._num_col_shards, "col_weights")
+ self._row_factors = self._create_factors(
+ self._input_rows, self._n_components, self._num_row_shards, row_init,
+ "row_factors")
+ self._col_factors = self._create_factors(
+ self._input_cols, self._n_components, self._num_col_shards, col_init,
+ "col_factors")
+
self._row_gramian = self._create_gramian(self._n_components, "row_gramian")
self._col_gramian = self._create_gramian(self._n_components, "col_gramian")
with ops.name_scope("row_prepare_gramian"):
@@ -313,37 +334,36 @@ class WALSModel(object):
@classmethod
def _create_factors(cls, rows, cols, num_shards, init, name):
"""Helper function to create row and column factors."""
- with ops.name_scope(name):
- if callable(init):
- init = init()
- if isinstance(init, list):
- assert len(init) == num_shards
- elif isinstance(init, str) and init == "random":
- pass
- elif num_shards == 1:
- init = [init]
- sharded_matrix = []
- sizes = cls._shard_sizes(rows, num_shards)
- assert len(sizes) == num_shards
-
- def make_initializer(i, size):
-
- def initializer():
- if init == "random":
- return random_ops.random_normal([size, cols])
- else:
- return init[i]
+ if callable(init):
+ init = init()
+ if isinstance(init, list):
+ assert len(init) == num_shards
+ elif isinstance(init, str) and init == "random":
+ pass
+ elif num_shards == 1:
+ init = [init]
+ sharded_matrix = []
+ sizes = cls._shard_sizes(rows, num_shards)
+ assert len(sizes) == num_shards
+
+ def make_initializer(i, size):
- return initializer
+ def initializer():
+ if init == "random":
+ return random_ops.random_normal([size, cols])
+ else:
+ return init[i]
- for i, size in enumerate(sizes):
- var_name = "%s_shard_%d" % (name, i)
- var_init = make_initializer(i, size)
- sharded_matrix.append(
- variable_scope.variable(
- var_init, dtype=dtypes.float32, name=var_name))
+ return initializer
- return sharded_matrix
+ for i, size in enumerate(sizes):
+ var_name = "%s_shard_%d" % (name, i)
+ var_init = make_initializer(i, size)
+ sharded_matrix.append(
+ variable_scope.variable(
+ var_init, dtype=dtypes.float32, name=var_name))
+
+ return sharded_matrix
@classmethod
def _create_weights(cls, wt_init, num_wts, num_shards, name):
@@ -384,26 +404,25 @@ class WALSModel(object):
sizes = cls._shard_sizes(num_wts, num_shards)
assert len(sizes) == num_shards
- with ops.name_scope(name):
- def make_wt_initializer(i, size):
+ def make_wt_initializer(i, size):
- def initializer():
- if init_mode == "scalar":
- return wt_init * array_ops.ones([size])
- else:
- return wt_init[i]
+ def initializer():
+ if init_mode == "scalar":
+ return wt_init * array_ops.ones([size])
+ else:
+ return wt_init[i]
- return initializer
+ return initializer
- sharded_weight = []
- for i, size in enumerate(sizes):
- var_name = "%s_shard_%d" % (name, i)
- var_init = make_wt_initializer(i, size)
- sharded_weight.append(
- variable_scope.variable(
- var_init, dtype=dtypes.float32, name=var_name))
+ sharded_weight = []
+ for i, size in enumerate(sizes):
+ var_name = "%s_shard_%d" % (name, i)
+ var_init = make_wt_initializer(i, size)
+ sharded_weight.append(
+ variable_scope.variable(
+ var_init, dtype=dtypes.float32, name=var_name))
- return sharded_weight
+ return sharded_weight
@staticmethod
def _create_gramian(n_components, name):
diff --git a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc
index a4cd4a2cc4..2638b25ec4 100644
--- a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc
+++ b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc
@@ -64,7 +64,7 @@ class KafkaDatasetOp : public DatasetOpKernel {
eof_(eof),
timeout_(timeout) {}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Kafka")}));
@@ -81,7 +81,7 @@ class KafkaDatasetOp : public DatasetOpKernel {
return *shapes;
}
- string DebugString() override { return "KafkaDatasetOp::Dataset"; }
+ string DebugString() const override { return "KafkaDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index 0fdbe8f630..b56a88659b 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -284,6 +284,7 @@ py_test(
tags = [
"manual",
"noasan", # times out
+ "optonly", # test is flaky without optimization.
],
deps = [
":learn",
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index c8820ab29b..b9e40cc50c 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -239,6 +239,7 @@ def generated_test_models():
"softmax",
"space_to_batch_nd",
"space_to_depth",
+ "sparse_to_dense",
"split",
"squeeze",
"strided_slice",
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
index 8660c653ae..52ab9ee640 100644
--- a/tensorflow/contrib/lite/builtin_op_data.h
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -236,6 +236,10 @@ typedef struct {
int stride_height;
} TfLiteTransposeConvParams;
+typedef struct {
+ bool validate_indices;
+} TfLiteSparseToDenseParams;
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index 24a9b0f6b8..c797e3589a 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -93,6 +93,7 @@ typedef enum {
kTfLiteBuiltinSlice = 65,
kTfLiteBuiltinSin = 66,
kTfLiteBuiltinTransposeConv = 67,
+ kTfLiteBuiltinSparseToDense = 68,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
index d8c46e6331..b2f6444e9e 100644
--- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
+++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
@@ -607,6 +607,21 @@ Outputs {
}
```
+**SPARSE_TO_DENSE**
+
+```
+Inputs {
+ 0: 0D or 1D or 2D tensor
+ 1: 1D tensor
+ 2: 0D or 1D tensor
+ 3: 0D tensor
+ 4: a boolean value
+}
+Outputs {
+ 0: Dense Tensor of shape output_shape. Has the same type as sparse_values.
+}
+```
+
**SPLIT**
```
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index b7291dd379..0af659b5ca 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -170,6 +170,7 @@ cc_library(
"slice.cc",
"space_to_batch_nd.cc",
"space_to_depth.cc",
+ "sparse_to_dense.cc",
"split.cc",
"squeeze.cc",
"strided_slice.cc",
@@ -934,6 +935,19 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "sparse_to_dense_test",
+ size = "small",
+ srcs = ["sparse_to_dense_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc
index 7dc0c5656d..c09b15b3d2 100644
--- a/tensorflow/contrib/lite/kernels/basic_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc
@@ -36,7 +36,7 @@ constexpr int kOutputTensor = 1;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* scratch_tensor_index = new int;
- context->AddTensors(context, /*tensors_to_add=*/2, scratch_tensor_index);
+ context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index);
return scratch_tensor_index;
}
@@ -91,7 +91,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (input->type == kTfLiteFloat32 && input_weights->type == kTfLiteUInt8) {
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
TfLiteIntArrayFree(node->temporaries);
- node->temporaries = TfLiteIntArrayCreate(2);
+ node->temporaries = TfLiteIntArrayCreate(3);
node->temporaries->data[0] = *scratch_tensor_index;
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
input_quantized->type = kTfLiteUInt8;
@@ -114,6 +114,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context->ResizeTensor(context, hidden_state_quantized,
hidden_state_quantized_size));
}
+ node->temporaries->data[2] = *scratch_tensor_index + 2;
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = batch_size;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
}
return kTfLiteOk;
@@ -145,14 +155,14 @@ TfLiteStatus EvalFloat(const TfLiteTensor* input,
return kTfLiteOk;
}
-TfLiteStatus EvalQuantized(const TfLiteTensor* input,
- const TfLiteTensor* input_weights,
- const TfLiteTensor* recurrent_weights,
- const TfLiteTensor* bias,
- const TfLiteRNNParams* params,
- TfLiteTensor* input_scratch,
- TfLiteTensor* hidden_state_scratch,
- TfLiteTensor* hidden_state, TfLiteTensor* output) {
+TfLiteStatus EvalHybrid(const TfLiteTensor* input,
+ const TfLiteTensor* input_weights,
+ const TfLiteTensor* recurrent_weights,
+ const TfLiteTensor* bias, const TfLiteRNNParams* params,
+ TfLiteTensor* input_scratch,
+ TfLiteTensor* hidden_state_scratch,
+ TfLiteTensor* scaling_factors,
+ TfLiteTensor* hidden_state, TfLiteTensor* output) {
const int batch_size = input->dims->data[0];
const int num_units = input_weights->dims->data[0];
const int input_size = input->dims->data[1];
@@ -176,12 +186,14 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input,
reinterpret_cast<int8_t*>(input_scratch->data.uint8);
int8_t* quantized_hidden_state_ptr =
reinterpret_cast<int8_t*>(hidden_state_scratch->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
kernel_utils::RnnBatchStep(
input_ptr_batch, input_weights_ptr, input_weights_scale,
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size,
num_units, batch_size, params->activation, quantized_input_ptr,
- quantized_hidden_state_ptr, hidden_state_ptr_batch, output_ptr_batch);
+ quantized_hidden_state_ptr, scaling_factors_ptr, hidden_state_ptr_batch,
+ output_ptr_batch);
return kTfLiteOk;
}
@@ -205,9 +217,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// TODO(mirkov): implement eval with quantized inputs as well.
TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1);
- return EvalQuantized(input, input_weights, recurrent_weights, bias,
- params, input_quantized, hidden_state_quantized,
- hidden_state, output);
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, 2);
+ return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
+ input_quantized, hidden_state_quantized,
+ scaling_factors, hidden_state, output);
}
default:
context->ReportError(context, "Type %d not currently supported.",
diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h
index ede95dfee0..b86ca49c11 100644
--- a/tensorflow/contrib/lite/kernels/internal/common.h
+++ b/tensorflow/contrib/lite/kernels/internal/common.h
@@ -87,12 +87,12 @@ float ActivationFunction(float x) {
output_activation_max);
}
-inline int32 MultiplyByQuantizedMultiplierSmallerThanOne(
- int32 x, int32 quantized_multiplier, int right_shift) {
+inline int32 MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ int32 x, int32 quantized_multiplier, int left_shift) {
using gemmlowp::RoundingDivideByPOT;
using gemmlowp::SaturatingRoundingDoublingHighMul;
return RoundingDivideByPOT(
- SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift);
+ SaturatingRoundingDoublingHighMul(x, quantized_multiplier), -left_shift);
}
inline int32 MultiplyByQuantizedMultiplierGreaterThanOne(
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
index 3bbaaa6a9d..67e3810479 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
@@ -52,7 +52,8 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
TfLiteFusedActivation activation,
int8_t* quantized_input_ptr_batch,
int8_t* quantized_hidden_state_ptr_batch,
- float* hidden_state_ptr_batch, float* output_ptr_batch) {
+ float* scaling_factors, float* hidden_state_ptr_batch,
+ float* output_ptr_batch) {
// Output = bias
tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size,
output_ptr_batch);
@@ -62,7 +63,6 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
// Quantize input from float to uint8 + quantization params (scaling
// factor).
float unused_min, unused_max;
- float* scaling_factors = new float[batch_size];
for (int b = 0; b < batch_size; ++b) {
const int offset = b * input_size;
tensor_utils::SymmetricQuantizeFloats(
@@ -76,7 +76,6 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_weights_ptr, num_units, input_size, quantized_input_ptr_batch,
scaling_factors, batch_size, output_ptr_batch, /*result_stride=*/1);
- delete[] scaling_factors;
}
// Save quantization and matmul computation for all zero input.
@@ -84,7 +83,6 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
batch_size * num_units)) {
// Quantize hidden_state
float unused_min, unused_max;
- float* scaling_factors = new float[batch_size];
for (int b = 0; b < batch_size; ++b) {
const int offset = b * num_units;
tensor_utils::SymmetricQuantizeFloats(
@@ -99,7 +97,6 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
recurrent_weights_ptr, num_units, num_units,
quantized_hidden_state_ptr_batch, scaling_factors, batch_size,
output_ptr_batch, /*result_stride=*/1);
- delete[] scaling_factors;
}
// Output = activation(Output) and update hidden_state
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
index cbfbcbeefc..f3f42f0840 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
@@ -41,6 +41,9 @@ void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
// values of hidden_state_ptr_batch and input_ptr_batch, respectively.
// These temporary storages are expected to be preallocated to the same size as
// the respective pointers.
+// An additional preallocated temporary storage 'scaling_factors' (of size
+// batch_size) is used to store the scaling factors of the quantization (used
+// for recovery).
// {input,recurrent}_weights_scale params are used for dequantization/recovery.
void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
float input_weights_scale,
@@ -50,7 +53,8 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
TfLiteFusedActivation activation,
int8_t* quantized_input_ptr_batch,
int8_t* quantized_hidden_state_ptr_batch,
- float* hidden_state_ptr_batch, float* output_ptr_batch);
+ float* scaling_factors, float* hidden_state_ptr_batch,
+ float* output_ptr_batch);
// Performs an LSTM batch inference step for input specified by input_ptr_batch.
// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index d48178d608..f7011b28fd 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -51,6 +51,13 @@ using reference_ops::LessEqual;
using reference_ops::RankOneSelect;
using reference_ops::Select;
+// TODO(b/80247582) Remove this constant.
+// This will be phased out as the shifts are revised with more thought. Use of a
+// constant enables us to track progress on this work.
+//
+// Used mainly to convert from old-style shifts (right) to new-style (left).
+static constexpr int kReverseShift = -1;
+
// Make a local VectorMap typedef allowing to map a float array
// as a Eigen vector expression. The std::conditional here is to
// construct the suitable Eigen type for the constness of the
@@ -2417,8 +2424,8 @@ inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
for (int c = 0; c < depth; c++) {
int32 diff = *input_data - input_zero_point;
- int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne(
- 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
+ int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ 128 * diff, inv_l2norm_multiplier, kReverseShift * inv_l2norm_shift);
int32 unclamped_output_val = 128 + rescaled_diff;
int32 output_val = std::min(255, std::max(0, unclamped_output_val));
*output_data = static_cast<uint8>(output_val);
@@ -2560,14 +2567,19 @@ inline void AddElementwise(int size, int left_shift, const uint8* input1_data,
const int32 input2_val = input2_offset + input2_data[i];
const int32 shifted_input1_val = input1_val * (1 << left_shift);
const int32 shifted_input2_val = input2_val * (1 << left_shift);
- const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input1_val, input1_multiplier, input1_shift);
- const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input2_val, input2_multiplier, input2_shift);
+ const int32 scaled_input1_val =
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input1_val, input1_multiplier,
+ kReverseShift * input1_shift);
+ const int32 scaled_input2_val =
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input2_val, input2_multiplier,
+ kReverseShift * input2_shift);
const int32 raw_sum = scaled_input1_val + scaled_input2_val;
- const int32 raw_output = MultiplyByQuantizedMultiplierSmallerThanOne(
- raw_sum, output_multiplier, output_shift) +
- output_offset;
+ const int32 raw_output =
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ raw_sum, output_multiplier, kReverseShift * output_shift) +
+ output_offset;
const int32 clamped_output = std::min(
output_activation_max, std::max(output_activation_min, raw_output));
output_data[i] = static_cast<uint8>(clamped_output);
@@ -2786,15 +2798,17 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data,
const int32 shifted_input1_val = input1_val * (1 << left_shift);
const int32 shifted_input2_val = input2_val * (1 << left_shift);
const int32 scaled_input1_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input1_val, input1_multiplier, input1_shift);
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input1_val, input1_multiplier,
+ kReverseShift * input1_shift);
const int32 scaled_input2_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input2_val, input2_multiplier, input2_shift);
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input2_val, input2_multiplier,
+ kReverseShift * input2_shift);
const int32 raw_sum = scaled_input1_val + scaled_input2_val;
const int32 raw_output =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- raw_sum, output_multiplier, output_shift) +
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ raw_sum, output_multiplier, kReverseShift * output_shift) +
output_offset;
const int32 clamped_output =
std::min(output_activation_max,
@@ -3135,9 +3149,9 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
const int32 input2_val =
input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
const int32 unclamped_result =
- output_offset +
- MultiplyByQuantizedMultiplierSmallerThanOne(
- input1_val * input2_val, output_multiplier, output_shift);
+ output_offset + MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ input1_val * input2_val, output_multiplier,
+ kReverseShift * output_shift);
const int32 clamped_output =
std::min(output_activation_max,
std::max(output_activation_min, unclamped_result));
@@ -3319,15 +3333,17 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data,
const int32 shifted_input1_val = input1_val * (1 << left_shift);
const int32 shifted_input2_val = input2_val * (1 << left_shift);
const int32 scaled_input1_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input1_val, input1_multiplier, input1_shift);
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input1_val, input1_multiplier,
+ kReverseShift * input1_shift);
const int32 scaled_input2_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input2_val, input2_multiplier, input2_shift);
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input2_val, input2_multiplier,
+ kReverseShift * input2_shift);
const int32 raw_sub = scaled_input1_val - scaled_input2_val;
const int32 raw_output =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- raw_sub, output_multiplier, output_shift) +
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ raw_sub, output_multiplier, kReverseShift * output_shift) +
output_offset;
const int32 clamped_output =
std::min(output_activation_max,
@@ -4782,9 +4798,9 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
fixed_log_sum_of_exps + std::numeric_limits<int32>::lowest();
const int adjusted_diff_min =
std::max(diff_min - 1, // Note use of > below instead of >= above.
- MultiplyByQuantizedMultiplierSmallerThanOne(
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
rescaled_diff_min, reverse_scaling_divisor,
- reverse_scaling_right_shift));
+ kReverseShift * reverse_scaling_right_shift));
for (int c = 0; c < depth; ++c) {
int32 input_diff = static_cast<int32>(block_input_data[c]) - max_in_row;
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 48a96f7db0..bebc97309e 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -98,20 +98,12 @@ gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingSub(
namespace reference_ops {
-inline int32 MultiplyByQuantizedMultiplierSmallerThanOne(
- int32 x, int32 quantized_multiplier, int right_shift) {
- using gemmlowp::RoundingDivideByPOT;
- using gemmlowp::SaturatingRoundingDoublingHighMul;
- return RoundingDivideByPOT(
- SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift);
-}
-
-inline int32 MultiplyByQuantizedMultiplierGreaterThanOne(
- int32 x, int32 quantized_multiplier, int left_shift) {
- using gemmlowp::SaturatingRoundingDoublingHighMul;
- return SaturatingRoundingDoublingHighMul(x * (1 << left_shift),
- quantized_multiplier);
-}
+// TODO(b/80247582) Remove this constant.
+// This will be phased out as the shifts are revised with more thought. Use of a
+// constant enables us to track progress on this work.
+//
+// Used mainly to convert from old-style shifts (right) to new-style (left).
+static constexpr int kReverseShift = -1;
template <typename T>
int CountLeadingZeros(T integer_input) {
@@ -422,8 +414,8 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
if (bias_data) {
acc += bias_data[Offset(bias_dims, out_channel, 0, 0, 0)];
}
- acc = MultiplyByQuantizedMultiplierSmallerThanOne(
- acc, output_multiplier, output_shift);
+ acc = MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ acc, output_multiplier, kReverseShift * output_shift);
acc += output_offset;
acc = std::max(acc, output_activation_min);
acc = std::min(acc, output_activation_max);
@@ -646,8 +638,8 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
if (bias_data) {
acc += bias_data[Offset(bias_dims, out_c, 0, 0, 0)];
}
- acc = MultiplyByQuantizedMultiplierSmallerThanOne(acc, output_multiplier,
- output_shift);
+ acc = MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ acc, output_multiplier, kReverseShift * output_shift);
acc += output_offset;
acc = std::max(acc, output_activation_min);
acc = std::min(acc, output_activation_max);
@@ -1041,8 +1033,8 @@ inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
for (int c = 0; c < depth; c++) {
int32 diff =
input_data[Offset(input_dims, c, i, 0, 0)] - input_zero_point;
- int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne(
- 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
+ int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ 128 * diff, inv_l2norm_multiplier, kReverseShift * inv_l2norm_shift);
int32 unclamped_output_val = 128 + rescaled_diff;
int32 output_val = std::min(255, std::max(0, unclamped_output_val));
output_data[Offset(output_dims, c, i, 0, 0)] =
@@ -1113,15 +1105,17 @@ inline void Add(int left_shift, const uint8* input1_data,
const int32 shifted_input1_val = input1_val * (1 << left_shift);
const int32 shifted_input2_val = input2_val * (1 << left_shift);
const int32 scaled_input1_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input1_val, input1_multiplier, input1_shift);
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input1_val, input1_multiplier,
+ kReverseShift * input1_shift);
const int32 scaled_input2_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input2_val, input2_multiplier, input2_shift);
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input2_val, input2_multiplier,
+ kReverseShift * input2_shift);
const int32 raw_sum = scaled_input1_val + scaled_input2_val;
const int32 raw_output =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- raw_sum, output_multiplier, output_shift) +
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ raw_sum, output_multiplier, kReverseShift * output_shift) +
output_offset;
const int32 clamped_output =
std::min(output_activation_max,
@@ -1267,15 +1261,17 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data,
const int32 shifted_input1_val = input1_val * (1 << left_shift);
const int32 shifted_input2_val = input2_val * (1 << left_shift);
const int32 scaled_input1_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input1_val, input1_multiplier, input1_shift);
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input1_val, input1_multiplier,
+ kReverseShift * input1_shift);
const int32 scaled_input2_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input2_val, input2_multiplier, input2_shift);
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input2_val, input2_multiplier,
+ kReverseShift * input2_shift);
const int32 raw_sum = scaled_input1_val + scaled_input2_val;
const int32 raw_output =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- raw_sum, output_multiplier, output_shift) +
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ raw_sum, output_multiplier, kReverseShift * output_shift) +
output_offset;
const int32 clamped_output =
std::min(output_activation_max,
@@ -1320,15 +1316,17 @@ inline void BroadcastAddFivefold(
const int32 shifted_input1_val = input1_val * (1 << left_shift);
const int32 shifted_input2_val = input2_val * (1 << left_shift);
const int32 scaled_input1_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input1_val, input1_multiplier, input1_shift);
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input1_val, input1_multiplier,
+ kReverseShift * input1_shift);
const int32 scaled_input2_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input2_val, input2_multiplier, input2_shift);
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input2_val, input2_multiplier,
+ kReverseShift * input2_shift);
const int32 raw_sum = scaled_input1_val + scaled_input2_val;
const int32 raw_output =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- raw_sum, output_multiplier, output_shift) +
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ raw_sum, output_multiplier, kReverseShift * output_shift) +
output_offset;
const int32 clamped_output =
std::min(output_activation_max,
@@ -1508,9 +1506,9 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
const int32 input2_val =
input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
const int32 unclamped_result =
- output_offset +
- MultiplyByQuantizedMultiplierSmallerThanOne(
- input1_val * input2_val, output_multiplier, output_shift);
+ output_offset + MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ input1_val * input2_val, output_multiplier,
+ kReverseShift * output_shift);
const int32 clamped_output =
std::min(output_activation_max,
std::max(output_activation_min, unclamped_result));
@@ -1724,15 +1722,17 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data,
const int32 shifted_input1_val = input1_val * (1 << left_shift);
const int32 shifted_input2_val = input2_val * (1 << left_shift);
const int32 scaled_input1_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input1_val, input1_multiplier, input1_shift);
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input1_val, input1_multiplier,
+ kReverseShift * input1_shift);
const int32 scaled_input2_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input2_val, input2_multiplier, input2_shift);
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input2_val, input2_multiplier,
+ kReverseShift * input2_shift);
const int32 raw_sub = scaled_input1_val - scaled_input2_val;
const int32 raw_output =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- raw_sub, output_multiplier, output_shift) +
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ raw_sub, output_multiplier, kReverseShift * output_shift) +
output_offset;
const int32 clamped_output =
std::min(output_activation_max,
@@ -2944,9 +2944,9 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
fixed_log_sum_of_exps + std::numeric_limits<int32>::lowest();
const int adjusted_diff_min =
std::max(diff_min - 1, // Note use of > below instead of >= above.
- MultiplyByQuantizedMultiplierSmallerThanOne(
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
rescaled_diff_min, reverse_scaling_divisor,
- reverse_scaling_right_shift));
+ kReverseShift * reverse_scaling_right_shift));
for (int c = 0; c < depth; ++c) {
int32 input_diff =
@@ -3850,10 +3850,14 @@ inline void Comparison(int left_shift, const T* input1_data,
const int32 input2_val = input2_offset + input2_data[i];
const int32 shifted_input1_val = input1_val * (1 << left_shift);
const int32 shifted_input2_val = input2_val * (1 << left_shift);
- const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input1_val, input1_multiplier, input1_shift);
- const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input2_val, input2_multiplier, input2_shift);
+ const int32 scaled_input1_val =
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input1_val, input1_multiplier,
+ kReverseShift * input1_shift);
+ const int32 scaled_input2_val =
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input2_val, input2_multiplier,
+ kReverseShift * input2_shift);
output_data[i] = F(scaled_input1_val, scaled_input2_val);
}
}
@@ -3902,11 +3906,13 @@ inline void BroadcastComparison(int left_shift, const T* input1_data,
const int32 shifted_input1_val = input1_val * (1 << left_shift);
const int32 shifted_input2_val = input2_val * (1 << left_shift);
const int32 scaled_input1_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input1_val, input1_multiplier, input1_shift);
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input1_val, input1_multiplier,
+ kReverseShift * input1_shift);
const int32 scaled_input2_val =
- MultiplyByQuantizedMultiplierSmallerThanOne(
- shifted_input2_val, input2_multiplier, input2_shift);
+ MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ shifted_input2_val, input2_multiplier,
+ kReverseShift * input2_shift);
output_data[Offset(output_dims, c, x, y, b)] =
F(scaled_input1_val, scaled_input2_val);
}
@@ -4000,6 +4006,42 @@ inline void RankOneSelect(const D* input_condition_data,
}
}
+// For easy implementation, the indices is always a vector of size-4 vectors.
+template <typename T, typename I>
+inline void SparseToDense(const std::vector<std::vector<I>>& indices,
+ const T* values, T default_value, T* output_data,
+ const Dims<4>& output_dims, bool value_is_scalar) {
+ const int value_count = indices.size();
+
+ // First fill the output_data with default value.
+ const int num_elements = FlatSize(output_dims);
+ for (int i = 0; i < num_elements; ++i) {
+ output_data[i] = default_value;
+ }
+
+ // Special handle for value is scalar case to avoid checking the boolean
+ // condition within the loop every time.
+ if (value_is_scalar) {
+ for (int i = 0; i < value_count; ++i) {
+ const std::vector<I>& index = indices[i];
+ TFLITE_DCHECK_EQ(index.size(), 4);
+ const T value = *values; // just use the first value.
+ output_data[Offset(output_dims, index[3], index[2], index[1], index[0])] =
+ value;
+ }
+ return;
+ }
+
+ // Go through the values and indices to fill the sparse values.
+ for (int i = 0; i < value_count; ++i) {
+ const std::vector<I>& index = indices[i];
+ TFLITE_DCHECK_EQ(index.size(), 4);
+ const T value = values[i];
+ output_data[Offset(output_dims, index[3], index[2], index[1], index[0])] =
+ value;
+ }
+}
+
} // namespace reference_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index d5293edd56..fc8ed753c5 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -15,6 +15,9 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
+#include <cstring>
+#include <iterator>
+
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
namespace tflite {
@@ -44,6 +47,101 @@ struct Dims {
int strides[N];
};
+class RuntimeShape {
+ public:
+ // Shapes with dimensions up to 4 are stored directly in the structure, while
+ // larger shapes are separately allocated.
+ static constexpr int kMaxSmallSize = 4;
+
+ RuntimeShape() : size_(0) {}
+
+ explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {
+ if (dimensions_count > kMaxSmallSize) {
+ dims_pointer_ = new int32[dimensions_count];
+ }
+ }
+
+ RuntimeShape(int dimensions_count, const int32* dims_data) : size_(0) {
+ ReplaceWith(dimensions_count, dims_data);
+ }
+
+ ~RuntimeShape() {
+ if (size_ > kMaxSmallSize) {
+ delete[] dims_pointer_;
+ }
+ }
+
+ inline int32 DimensionsCount() const { return size_; }
+ inline int32 Dims(int i) const {
+ TFLITE_DCHECK_GE(i, 0);
+ TFLITE_DCHECK_LT(i, size_);
+ return size_ > kMaxSmallSize ? dims_pointer_[i] : dims_[i];
+ }
+ inline void SetDim(int i, int32 val) {
+ TFLITE_DCHECK_GE(i, 0);
+ TFLITE_DCHECK_LT(i, size_);
+ if (size_ > kMaxSmallSize) {
+ dims_pointer_[i] = val;
+ } else {
+ dims_[i] = val;
+ }
+ }
+ inline int32* DimsData() {
+ return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
+ }
+ inline const int32* DimsData() const {
+ return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
+ }
+
+ inline void Resize(int dimensions_count) {
+ if (size_ > kMaxSmallSize) {
+ delete[] dims_pointer_;
+ }
+ size_ = dimensions_count;
+ if (dimensions_count > kMaxSmallSize) {
+ dims_pointer_ = new int32[dimensions_count];
+ }
+ }
+
+ inline void ReplaceWith(int dimensions_count, const int32* dims_data) {
+ Resize(dimensions_count);
+ int32* dst_dims = DimsData();
+ std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32));
+ }
+
+ template <typename T>
+ inline void BuildFrom(const T& src_iterable) {
+ const int dimensions_count =
+ std::distance(src_iterable.begin(), src_iterable.end());
+ Resize(dimensions_count);
+ int32* data = DimsData();
+ for (auto it : src_iterable) {
+ *data = it;
+ ++data;
+ }
+ }
+
+ // Returns the total count of elements, that is the size when flattened into a
+ // vector.
+ inline int FlatSize() const {
+ int buffer_size = 1;
+ const int* dims_data = DimsData();
+ for (int i = 0; i < size_; i++) {
+ const int dim = dims_data[i];
+ TFLITE_DCHECK_GE(dim, 1);
+ buffer_size *= dim;
+ }
+ return buffer_size;
+ }
+
+ private:
+ int32 size_;
+ union {
+ int32 dims_[kMaxSmallSize];
+ int32* dims_pointer_;
+ };
+};
+
// Gets next index to iterate through a multidimensional array.
inline bool NextIndex(const int num_dims, const int* dims, int* current) {
TFLITE_DCHECK_GT(num_dims, 0);
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 21cc185e9f..4eea9921b2 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -90,6 +90,7 @@ TfLiteRegistration* Register_SELECT();
TfLiteRegistration* Register_SLICE();
TfLiteRegistration* Register_SIN();
TfLiteRegistration* Register_TRANSPOSE_CONV();
+TfLiteRegistration* Register_SPARSE_TO_DENSE();
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -161,6 +162,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE());
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV());
+ AddBuiltin(BuiltinOperator_SPARSE_TO_DENSE, Register_SPARSE_TO_DENSE());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
new file mode 100644
index 0000000000..404c32ad9c
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
@@ -0,0 +1,275 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <unistd.h>
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstdlib>
+#include <iostream>
+#include <limits>
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+#include "tensorflow/contrib/lite/kernels/padding.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace sparse_to_dense {
+
+constexpr int kIndicesTensor = 0;
+constexpr int kOutputShapeTensor = 1;
+constexpr int kValueInputTensor = 2;
+constexpr int kDefaultValueTensor = 3;
+constexpr int kOutputTensor = 0;
+
+constexpr int kMaxDimensions = 4;
+
+template <typename T>
+TfLiteStatus Resize(TfLiteContext* context, const TfLiteTensor* output_shape,
+ TfLiteTensor* output) {
+ const int output_dimensions = NumElements(output_shape);
+ TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(output_dimensions);
+ for (int i = 0; i < output_dimensions; ++i) {
+ output_shape_array->data[i] = GetTensorData<T>(output_shape)[i];
+ }
+
+ return context->ResizeTensor(context, output, output_shape_array);
+}
+
+TfLiteStatus CheckDimensionsMatch(TfLiteContext* context,
+ const TfLiteTensor* indices,
+ const TfLiteTensor* output_shape,
+ const TfLiteTensor* values) {
+ switch (NumDimensions(indices)) {
+ case 0:
+ case 1: {
+ if (NumDimensions(values) == 0) {
+ TF_LITE_ENSURE_EQ(context, NumElements(indices), NumElements(values));
+ }
+ TF_LITE_ENSURE_EQ(context, NumElements(output_shape), 1);
+ break;
+ }
+ case 2: {
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 1),
+ NumElements(output_shape));
+ if (NumDimensions(values) == 0)
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0),
+ NumElements(values));
+ break;
+ }
+ default:
+ context->ReportError(
+ context, "Wrong indices dimensions %d, should be less than 3.",
+ NumDimensions(indices));
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+// Convert indices into a vector of 4-d vectors.
+// TODO(renjieliu): Revisit here to improve the performance, since multiple
+// allocations of std::vectors will be quite slow on phones.
+template <typename T>
+TfLiteStatus GetIndicesVector(TfLiteContext* context,
+ const TfLiteTensor* indices,
+ const int num_indices,
+ std::vector<std::vector<T>>* indices_vector) {
+ // Note because TfLite will reverse the dimensions, so pad zeros upfront.
+ switch (NumDimensions(indices)) {
+ case 0:
+ case 1: {
+ const auto indices_data = GetTensorData<T>(indices);
+ for (int i = 0; i < num_indices; ++i) {
+ std::vector<T> index({0, 0, 0, indices_data[i]});
+ indices_vector->push_back(index);
+ }
+ break;
+ }
+ case 2: {
+ const int true_dimensions = SizeOfDimension(indices, 1);
+ TF_LITE_ENSURE(context, true_dimensions <= kMaxDimensions);
+ for (int i = 0; i < num_indices; ++i) {
+ std::vector<T> index;
+ index.reserve(kMaxDimensions);
+ // Fill the index with 1 up to kMaxDimensions - true_dimensions to
+ // satisfy the needs for 4-dimension index.
+ for (int j = 0; j < kMaxDimensions - true_dimensions; ++j) {
+ index.push_back(0);
+ }
+ for (int j = 0; j < true_dimensions; ++j) {
+ index.push_back(GetTensorData<T>(indices)[i * true_dimensions + j]);
+ }
+
+ indices_vector->push_back(index);
+ }
+ break;
+ }
+ default:
+ context->ReportError(context,
+ "Indices dimensions problem, got %d dimensions",
+ NumDimensions(indices));
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus ResizeOutputShape(TfLiteContext* context,
+ const TfLiteTensor* output_shape,
+ TfLiteTensor* output) {
+ if (output_shape->type == kTfLiteInt32) {
+ return Resize<int32_t>(context, output_shape, output);
+ } else if (output_shape->type == kTfLiteInt64) {
+ return Resize<int64_t>(context, output_shape, output);
+ } else {
+ context->ReportError(context, "Dense shape type %d not supported.",
+ output_shape->type);
+ return kTfLiteError;
+ }
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor);
+ const TfLiteTensor* output_shape =
+ GetInput(context, node, kOutputShapeTensor);
+ const TfLiteTensor* values = GetInput(context, node, kValueInputTensor);
+ const TfLiteTensor* default_value =
+ GetInput(context, node, kDefaultValueTensor);
+
+ // TODO(renjieliu): Handle validate_indices.
+
+ // Indices can be 0-D, 1-D or 2-D.
+ TF_LITE_ASSERT(NumDimensions(indices) >= 0);
+ TF_LITE_ENSURE(context, NumDimensions(indices) < 3);
+ TF_LITE_ASSERT(NumDimensions(output_shape) >= 0);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1);
+ // Values can be 0-D or 1-D.
+ TF_LITE_ASSERT(NumDimensions(values) >= 0);
+ TF_LITE_ENSURE(context, NumDimensions(values) < 2);
+
+ TF_LITE_ENSURE_EQ(context, NumElements(default_value), 1);
+
+ TF_LITE_ENSURE(
+ context, indices->type == kTfLiteInt32 || indices->type == kTfLiteInt64);
+ TF_LITE_ENSURE(context, output_shape->type == kTfLiteInt32 ||
+ output_shape->type == kTfLiteInt64);
+ TF_LITE_ENSURE_EQ(context, values->type, default_value->type);
+
+ // Ensure dimensions match.
+ TF_LITE_ENSURE_OK(
+ context, CheckDimensionsMatch(context, indices, output_shape, values));
+
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1);
+
+ if (!IsConstantTensor(output_shape)) {
+ SetTensorToDynamic(output);
+ return kTfLiteOk;
+ }
+ return ResizeOutputShape(context, output_shape, output);
+}
+
+template <typename T, typename I>
+TfLiteStatus SparseToDenseImpl(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor);
+ const TfLiteTensor* output_shape =
+ GetInput(context, node, kOutputShapeTensor);
+ const TfLiteTensor* values = GetInput(context, node, kValueInputTensor);
+ const TfLiteTensor* default_value =
+ GetInput(context, node, kDefaultValueTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ if (IsDynamicTensor(output)) {
+ TF_LITE_ENSURE_OK(context,
+ ResizeOutputShape(context, output_shape, output));
+ }
+
+ const int num_indices = SizeOfDimension(indices, 0);
+ const bool value_is_scalar = NumDimensions(values) == 0;
+ std::vector<std::vector<I>> indices_vector;
+ indices_vector.reserve(num_indices);
+ TF_LITE_ENSURE_OK(context, GetIndicesVector<I>(context, indices, num_indices,
+ &indices_vector));
+ reference_ops::SparseToDense(indices_vector, GetTensorData<T>(values),
+ *GetTensorData<T>(default_value),
+ GetTensorData<T>(output), GetTensorDims(output),
+ value_is_scalar);
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor);
+ const TfLiteTensor* values = GetInput(context, node, kValueInputTensor);
+
+ // Currently only supports float32 and int32.
+ switch (values->type) {
+ case kTfLiteFloat32: {
+ switch (indices->type) {
+ case kTfLiteInt32: {
+ return SparseToDenseImpl<float, int32_t>(context, node);
+ }
+ case kTfLiteInt64: {
+ return SparseToDenseImpl<float, int64_t>(context, node);
+ }
+ default:
+ context->ReportError(
+ context, "Type %d is currently not supported by sparse to dense.",
+ indices->type);
+ return kTfLiteError;
+ }
+ break;
+ }
+ case kTfLiteInt32: {
+ switch (indices->type) {
+ case kTfLiteInt32: {
+ return SparseToDenseImpl<int32_t, int32_t>(context, node);
+ }
+ case kTfLiteInt64: {
+ return SparseToDenseImpl<int32_t, int64_t>(context, node);
+ }
+ default:
+ context->ReportError(
+ context, "Type %d is currently not supported by sparse to dense.",
+ indices->type);
+ return kTfLiteError;
+ }
+ break;
+ }
+ default:
+ context->ReportError(
+ context, "Type %d is currently not supported by sparse to dense.",
+ values->type);
+ return kTfLiteError;
+ }
+}
+
+} // namespace sparse_to_dense
+
+TfLiteRegistration* Register_SPARSE_TO_DENSE() {
+ static TfLiteRegistration r = {nullptr, nullptr, sparse_to_dense::Prepare,
+ sparse_to_dense::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc
new file mode 100644
index 0000000000..a51ec17afc
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc
@@ -0,0 +1,155 @@
+
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdarg>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+template <typename T>
+class SparseToDenseOpModel : public SingleOpModel {
+ public:
+ SparseToDenseOpModel(std::initializer_list<int> indices_shape,
+ std::initializer_list<int> output_shape_shape,
+ std::initializer_list<int> values_shape, T default_value,
+ TensorType tensor_index_type,
+ TensorType tensor_input_type) {
+ indices_ = AddInput(tensor_index_type);
+ output_shape_ = AddInput(TensorType_INT32);
+ values_ = AddInput(tensor_input_type);
+ default_value_ = AddInput(tensor_input_type);
+ output_ = AddOutput(tensor_input_type);
+
+ SetBuiltinOp(BuiltinOperator_SPARSE_TO_DENSE,
+ BuiltinOptions_SparseToDenseOptions,
+ CreateSparseToDenseOptions(builder_, false).Union());
+ BuildInterpreter({indices_shape, output_shape_shape, values_shape, {1}});
+
+ PopulateTensor<T>(default_value_, {default_value});
+ }
+
+ int indices() { return indices_; }
+ int output_shape() { return output_shape_; }
+ int values() { return values_; }
+
+ std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int indices_;
+ int output_shape_;
+ int values_;
+ int default_value_;
+ int output_;
+};
+
+TEST(SparseToDenseOpModelTest, ZeroDimensionTest) {
+ SparseToDenseOpModel<float> m({1}, {1}, {1}, 0, TensorType_INT32,
+ TensorType_FLOAT32);
+ m.PopulateTensor<int32_t>(m.indices(), {3});
+ m.PopulateTensor<int32_t>(m.output_shape(), {5});
+ m.PopulateTensor<float>(m.values(), {7});
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 7, 0}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({5}));
+}
+
+TEST(SparseToDenseOpModelTest, OneDimensionTest) {
+ SparseToDenseOpModel<float> m({3}, {1}, {3}, 0, TensorType_INT32,
+ TensorType_FLOAT32);
+ m.PopulateTensor<int32_t>(m.indices(), {1, 3, 5});
+ m.PopulateTensor<int32_t>(m.output_shape(), {7});
+ m.PopulateTensor<float>(m.values(), {2, 4, 6});
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 0, 4, 0, 6, 0}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({7}));
+}
+
+TEST(SparseToDenseOpModelTest, TwoDimensionsTest) {
+ SparseToDenseOpModel<float> m({3, 3}, {3}, {3}, 0, TensorType_INT32,
+ TensorType_FLOAT32);
+ m.PopulateTensor<int32_t>(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1});
+ m.PopulateTensor<int32_t>(m.output_shape(), {3, 3, 3});
+ m.PopulateTensor<float>(m.values(), {2, 4, 6});
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 4, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3}));
+}
+
+TEST(SparseToDenseOpModelTest, DefaultValueTest) {
+ SparseToDenseOpModel<float> m({3, 3}, {3}, {3}, -1, TensorType_INT32,
+ TensorType_FLOAT32);
+ m.PopulateTensor<int32_t>(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1});
+ m.PopulateTensor<int32_t>(m.output_shape(), {3, 3, 3});
+ m.PopulateTensor<float>(m.values(), {2, 4, 6});
+ m.Invoke();
+
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray({2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
+ -1, -1, 4, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3}));
+}
+
+TEST(SparseToDenseOpModelTest, IntegerValueTest) {
+ SparseToDenseOpModel<int32_t> m({3, 3}, {3}, {3}, -1, TensorType_INT32,
+ TensorType_INT32);
+ m.PopulateTensor<int32_t>(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1});
+ m.PopulateTensor<int32_t>(m.output_shape(), {3, 3, 3});
+ m.PopulateTensor<int32_t>(m.values(), {2, 4, 6});
+ m.Invoke();
+
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray({2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
+ -1, -1, 4, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3}));
+}
+
+TEST(SparseToDenseOpModelTest, Int64IndexTest) {
+ SparseToDenseOpModel<float> m({3, 3}, {3}, {3}, -1, TensorType_INT64,
+ TensorType_FLOAT32);
+ m.PopulateTensor<int64_t>(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1});
+ m.PopulateTensor<int32_t>(m.output_shape(), {3, 3, 3});
+ m.PopulateTensor<float>(m.values(), {2, 4, 6});
+ m.Invoke();
+
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray({2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
+ -1, -1, 4, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
index 8429dba54b..164a0cbd08 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
@@ -41,7 +41,7 @@ constexpr int kOutputTensor = 1;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* scratch_tensor_index = new int;
- context->AddTensors(context, /*tensors_to_add=*/2, scratch_tensor_index);
+ context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index);
return scratch_tensor_index;
}
@@ -102,7 +102,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (input->type == kTfLiteFloat32 && input_weights->type == kTfLiteUInt8) {
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
TfLiteIntArrayFree(node->temporaries);
- node->temporaries = TfLiteIntArrayCreate(2);
+ node->temporaries = TfLiteIntArrayCreate(3);
node->temporaries->data[0] = *scratch_tensor_index;
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0);
input_quantized->type = kTfLiteUInt8;
@@ -125,6 +125,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context->ResizeTensor(context, hidden_state_quantized,
hidden_state_quantized_size));
}
+ node->temporaries->data[2] = *scratch_tensor_index + 2;
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = batch_size;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
}
return kTfLiteOk;
}
@@ -187,14 +197,12 @@ TfLiteStatus EvalFloat(const TfLiteTensor* input,
return kTfLiteOk;
}
-TfLiteStatus EvalQuantized(const TfLiteTensor* input,
- const TfLiteTensor* input_weights,
- const TfLiteTensor* recurrent_weights,
- const TfLiteTensor* bias,
- const TfLiteSequenceRNNParams* params,
- TfLiteTensor* input_scratch,
- TfLiteTensor* hidden_state_scratch,
- TfLiteTensor* hidden_state, TfLiteTensor* output) {
+TfLiteStatus EvalHybrid(
+ const TfLiteTensor* input, const TfLiteTensor* input_weights,
+ const TfLiteTensor* recurrent_weights, const TfLiteTensor* bias,
+ const TfLiteSequenceRNNParams* params, TfLiteTensor* input_scratch,
+ TfLiteTensor* hidden_state_scratch, TfLiteTensor* scaling_factors,
+ TfLiteTensor* hidden_state, TfLiteTensor* output) {
const bool time_major = params->time_major;
const int batch_size =
(time_major) ? input->dims->data[1] : input->dims->data[0];
@@ -218,6 +226,7 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input,
reinterpret_cast<int8_t*>(input_scratch->data.uint8);
int8_t* quantized_hidden_state_ptr =
reinterpret_cast<int8_t*>(hidden_state_scratch->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
if (time_major) {
// Initialize the pointer to hidden state.
@@ -233,7 +242,8 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input,
input_ptr_batch, input_weights_ptr, input_weights_scale,
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size,
num_units, batch_size, params->activation, quantized_input_ptr,
- quantized_hidden_state_ptr, hidden_state_ptr_batch, output_ptr_batch);
+ quantized_hidden_state_ptr, scaling_factors_ptr,
+ hidden_state_ptr_batch, output_ptr_batch);
}
} else {
// For each batch
@@ -252,7 +262,7 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input,
recurrent_weights_ptr, recurrent_weights_scale, bias_ptr,
input_size, num_units, /*batch_size=*/1, params->activation,
quantized_input_ptr, quantized_hidden_state_ptr,
- hidden_state_ptr_batch, output_ptr_batch);
+ scaling_factors_ptr, hidden_state_ptr_batch, output_ptr_batch);
}
}
}
@@ -278,9 +288,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// TODO(mirkov): implement eval with quantized inputs as well.
TfLiteTensor* input_quantized = GetTemporary(context, node, 0);
TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1);
- return EvalQuantized(input, input_weights, recurrent_weights, bias,
- params, input_quantized, hidden_state_quantized,
- hidden_state, output);
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, 2);
+ return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
+ input_quantized, hidden_state_quantized,
+ scaling_factors, hidden_state, output);
}
default:
context->ReportError(context, "Type %d not currently supported.",
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 80fcb28bc7..6ac41a94bd 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -699,6 +699,16 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
+ case BuiltinOperator_SPARSE_TO_DENSE: {
+ TfLiteSparseToDenseParams* params =
+ MallocPOD<TfLiteSparseToDenseParams>();
+ if (auto* sparse_to_dense_params =
+ op->builtin_options_as_SparseToDenseOptions()) {
+ params->validate_indices = sparse_to_dense_params->validate_indices();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
case BuiltinOperator_DELEGATE: {
// TODO(ycling): Revisit when supporting saving delegated models.
error_reporter->Report("DELEGATE op shouldn't exist in model.");
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index eed57d412b..fad08bbfe6 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -491,6 +491,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_SLICE:
case tflite::BuiltinOperator_SIN:
case tflite::BuiltinOperator_TRANSPOSE_CONV:
+ case tflite::BuiltinOperator_SPARSE_TO_DENSE:
FATAL("Op code %d is currently not delegated to NNAPI", builtin);
nn_op_type = -1; // set to invalid
break;
diff --git a/tensorflow/contrib/lite/op_resolver.h b/tensorflow/contrib/lite/op_resolver.h
index 38a2706942..9d7e3f2085 100644
--- a/tensorflow/contrib/lite/op_resolver.h
+++ b/tensorflow/contrib/lite/op_resolver.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/contrib/lite/util.h"
namespace tflite {
@@ -55,8 +56,7 @@ struct OperatorKeyHasher {
size_t operator()(const T& x) const {
size_t a = ValueHasher<typename T::first_type>()(x.first);
size_t b = ValueHasher<typename T::second_type>()(x.second);
- // Hash combinator used by TensorFlow core.
- return a ^ (b + 0x9e3779b97f4a7800ULL + (a << 10) + (a >> 4));
+ return CombineHashes({a, b});
}
};
} // namespace op_resolver_hasher
diff --git a/tensorflow/contrib/lite/profiling/BUILD b/tensorflow/contrib/lite/profiling/BUILD
index c86be65ca7..c31189f2b1 100644
--- a/tensorflow/contrib/lite/profiling/BUILD
+++ b/tensorflow/contrib/lite/profiling/BUILD
@@ -29,6 +29,13 @@ cc_library(
name = "profile_buffer",
hdrs = ["profile_buffer.h"],
copts = common_copts,
+ deps = [":time"],
+)
+
+cc_library(
+ name = "time",
+ srcs = ["time.cc"],
+ hdrs = ["time.h"],
)
cc_library(
diff --git a/tensorflow/contrib/lite/profiling/profile_buffer.h b/tensorflow/contrib/lite/profiling/profile_buffer.h
index 299b2a9cad..65d86dce47 100644
--- a/tensorflow/contrib/lite/profiling/profile_buffer.h
+++ b/tensorflow/contrib/lite/profiling/profile_buffer.h
@@ -18,6 +18,8 @@ limitations under the License.
#include <cstddef>
#include <cstdint>
+#include "tensorflow/contrib/lite/profiling/time.h"
+
namespace tflite {
namespace profiling {
@@ -74,7 +76,7 @@ class ProfileBuffer {
if (!enabled_) {
return kInvalidEventHandle;
}
- uint64_t timestamp = NowMicros();
+ uint64_t timestamp = time::NowMicros();
int index = current_index_ % event_buffer_.size();
event_buffer_[index].tag = tag;
event_buffer_[index].event_type = event_type;
@@ -103,7 +105,7 @@ class ProfileBuffer {
}
int event_index = event_handle % max_size;
- event_buffer_[event_index].end_timestamp_us = NowMicros();
+ event_buffer_[event_index].end_timestamp_us = time::NowMicros();
}
// Returns the size of the buffer.
@@ -134,12 +136,6 @@ class ProfileBuffer {
}
private:
- static uint64_t NowMicros() {
- // TODO(shashishekhar): Refactor this to a separate file.
- struct timeval tv;
- gettimeofday(&tv, nullptr);
- return static_cast<uint64_t>(tv.tv_sec) * 1000000 + tv.tv_usec;
- }
bool enabled_;
uint32_t current_index_;
std::vector<ProfileEvent> event_buffer_;
diff --git a/tensorflow/contrib/lite/profiling/time.cc b/tensorflow/contrib/lite/profiling/time.cc
new file mode 100644
index 0000000000..446660bb74
--- /dev/null
+++ b/tensorflow/contrib/lite/profiling/time.cc
@@ -0,0 +1,29 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/profiling/time.h"
+
+#include <sys/time.h>
+
+namespace tflite {
+namespace profiling {
+namespace time {
+uint64_t NowMicros() {
+ struct timeval tv;
+ gettimeofday(&tv, nullptr);
+ return static_cast<uint64_t>(tv.tv_sec) * 1000000 + tv.tv_usec;
+}
+} // namespace time
+} // namespace profiling
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/profiling/time.h b/tensorflow/contrib/lite/profiling/time.h
new file mode 100644
index 0000000000..cc2ec319b8
--- /dev/null
+++ b/tensorflow/contrib/lite/profiling/time.h
@@ -0,0 +1,27 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_PROFILING_TIME_H_
+#define TENSORFLOW_CONTRIB_LITE_PROFILING_TIME_H_
+
+#include <cstdint>
+
+namespace tflite {
+namespace profiling {
+namespace time {
+uint64_t NowMicros();
+} // namespace time
+} // namespace profiling
+} // namespace tflite
+#endif // TENSORFLOW_CONTRIB_LITE_PROFILING_TIME_H_
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
index a40e512045..7e6ff6c0a8 100644
--- a/tensorflow/contrib/lite/python/BUILD
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -36,6 +36,16 @@ py_test(
],
)
+py_binary(
+ name = "tflite_convert",
+ srcs = ["tflite_convert.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":lite",
+ ],
+)
+
py_library(
name = "lite",
srcs = ["lite.py"],
@@ -125,6 +135,7 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
+ ":convert",
"//tensorflow/contrib/saved_model:saved_model_py",
"//tensorflow/python:graph_util",
"//tensorflow/python:platform",
@@ -164,11 +175,3 @@ py_test(
"//tensorflow/python/saved_model",
],
)
-
-# Transitive dependencies of this target will be included in the pip package.
-py_library(
- name = "tf_lite_py_pip",
- deps = [
- ":convert_saved_model",
- ],
-)
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
index c0926d2f33..0819475240 100644
--- a/tensorflow/contrib/lite/python/convert.py
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -115,11 +115,15 @@ def toco_convert(input_data,
input_tensors,
output_tensors,
inference_type=lite_constants.FLOAT,
+ inference_input_type=None,
input_format=lite_constants.TENSORFLOW_GRAPHDEF,
output_format=lite_constants.TFLITE,
quantized_input_stats=None,
+ default_ranges_stats=None,
drop_control_dependency=True,
- allow_custom_ops=False):
+ reorder_across_fake_quant=False,
+ allow_custom_ops=False,
+ change_concat_input_ranges=False):
"""Convert a model using TOCO from `input_format` to `output_format`.
Typically this is to convert from TensorFlow GraphDef to TFLite, in which
@@ -130,18 +134,41 @@ def toco_convert(input_data,
input_tensors: List of input tensors. Type and shape are computed using
`foo.get_shape()` and `foo.dtype`.
output_tensors: List of output tensors (only .name is used from this).
- inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`.
- input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF).
- output_format: Type of data to write (currently must be TFLITE or
- GRAPHVIZ_DOT)
- quantized_input_stats: For each member of input_tensors the mean and
- std deviation of training data. Only needed if `inference_type` is
- `QUANTIZED_UINT8`.
- drop_control_dependency: Drops control dependencies silently. This is due
- to tf lite not supporting control dependencies.
+ inference_type: Target data type of arrays in the output file. Currently
+ must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT)
+ inference_input_type: Target data type of input arrays. Allows for a
+ different type for input arrays in the case of quantization. Currently
+ must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`)
+ input_format: Type of data to read Currently must be
+ `{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF)
+ output_format: Output file format. Currently must be `{TFLITE,
+ GRAPHVIZ_DOT}`. (default TFLITE)
+ quantized_input_stats: Dict of strings representing input tensor names
+ mapped to tuple of integers representing the mean and standard deviation
+ of the training data (e.g., {"foo" : (0., 1.)}). Only need if
+ `inference_type` is `QUANTIZED_UINT8`. (default None)
+ default_ranges_stats: Tuple of integers representing (min, max) range values
+ for all arrays without a specified range. Intended for experimenting with
+ quantization via "dummy quantization". (default None)
+ drop_control_dependency: Boolean indicating whether to drop control
+ dependencies silently. This is due to TFLite not supporting control
+ dependencies. (default True)
+ reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant
+ nodes in unexpected locations. Used when the location of the FakeQuant
+ nodes is preventing graph transformations necessary to convert the graph.
+ Results in a graph that differs from the quantized training graph,
+ potentially causing differing arithmetic behavior. (default False)
+ change_concat_input_ranges: Boolean to change behavior of min/max ranges for
+ inputs and outputs of the concat operator for quantized models. Changes
+ the ranges of concat operator overlap when true. (default False)
+ allow_custom_ops: Boolean indicating whether to allow custom operations.
+ When false any unknown operation is an error. When true, custom ops are
+ created for any op that is unknown. The developer will need to provide
+ these to the TensorFlow Lite runtime with a custom resolver.
+ (default False)
Returns:
- The converted data. For example if tflite was the destination, then
+ The converted data. For example if TFLite was the destination, then
this will be a tflite flatbuffer in a bytes array.
Raises:
@@ -152,10 +179,18 @@ def toco_convert(input_data,
toco = _toco_flags_pb2.TocoFlags()
toco.input_format = input_format
toco.output_format = output_format
- toco.drop_control_dependency = drop_control_dependency
- model = _model_flags_pb2.ModelFlags()
toco.inference_type = inference_type
+ if inference_input_type:
+ toco.inference_input_type = inference_input_type
+ toco.drop_control_dependency = drop_control_dependency
+ toco.reorder_across_fake_quant = reorder_across_fake_quant
toco.allow_custom_ops = allow_custom_ops
+ if default_ranges_stats:
+ toco.default_ranges_min = default_ranges_stats[0]
+ toco.default_ranges_max = default_ranges_stats[1]
+
+ model = _model_flags_pb2.ModelFlags()
+ model.change_concat_input_ranges = change_concat_input_ranges
for idx, input_tensor in enumerate(input_tensors):
if input_tensor.dtype == _dtypes.float32:
tflite_input_type = lite_constants.FLOAT
@@ -163,6 +198,8 @@ def toco_convert(input_data,
tflite_input_type = lite_constants.INT32
elif input_tensor.dtype == _dtypes.int64:
tflite_input_type = lite_constants.INT64
+ elif input_tensor.dtype == _dtypes.uint8:
+ tflite_input_type = lite_constants.QUANTIZED_UINT8
# TODO(aselle): Insert strings when they are available
else:
raise ValueError("Tensors %s not known type %r" % (input_tensor.name,
diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py
index 54fec9d61f..b952a72aab 100644
--- a/tensorflow/contrib/lite/python/convert_saved_model.py
+++ b/tensorflow/contrib/lite/python/convert_saved_model.py
@@ -18,31 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.lite.python.convert import tensor_name
from tensorflow.contrib.saved_model.python.saved_model import reader
from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
from tensorflow.core.framework import types_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util as tf_graph_util
from tensorflow.python.framework import ops
-from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import loader
-from tensorflow.python.saved_model import signature_constants
-from tensorflow.python.saved_model import tag_constants
-
-
-def _write_and_flush_file(file_path, data_str):
- """Writes data to file path.
-
- Args:
- file_path: Full path of the file to store data in.
- data_str: Data represented as a string.
-
- Returns: None.
- """
- with gfile.Open(file_path, "wb") as data_file:
- data_file.write(data_str)
- data_file.flush()
def _log_tensor_details(tensor_info):
@@ -167,29 +151,10 @@ def _get_tensors(graph, signature_def_tensor_names=None,
"""
tensors = []
if user_tensor_names:
- # Get the list of all of the tensors with and without the tensor index.
- all_tensor_names = [
- tensor.name for op in graph.get_operations() for tensor in op.outputs
- ]
- all_tensor_names_only = [name.split(":")[0] for name in all_tensor_names]
-
# Sort the tensor names.
user_tensor_names = sorted(user_tensor_names)
- # Get the tensors associated with the tensor names.
- tensors = []
- invalid_tensors = []
- for name in user_tensor_names:
- if name not in all_tensor_names_only:
- invalid_tensors.append(name)
- else:
- idx = all_tensor_names_only.index(name)
- tensors.append(graph.get_tensor_by_name(all_tensor_names[idx]))
-
- # Throw ValueError if any user input names are not valid tensors.
- if invalid_tensors:
- raise ValueError("Invalid tensors '{}' were found.".format(
- ",".join(invalid_tensors)))
+ tensors = get_tensors_from_tensor_names(graph, user_tensor_names)
elif signature_def_tensor_names:
tensors = [
graph.get_tensor_by_name(name)
@@ -204,6 +169,58 @@ def _get_tensors(graph, signature_def_tensor_names=None,
return tensors
+def get_tensors_from_tensor_names(graph, tensor_names):
+ """Gets the Tensors associated with the `tensor_names` in the provided graph.
+
+ Args:
+ graph: TensorFlow Graph.
+ tensor_names: List of strings that represent names of tensors in the graph.
+
+ Returns:
+ A list of Tensor objects in the same order the names are provided.
+
+ Raises:
+ ValueError:
+ tensor_names contains an invalid tensor name.
+ """
+ # Get the list of all of the tensors.
+ tensor_name_to_tensor = {
+ tensor_name(tensor): tensor for op in graph.get_operations()
+ for tensor in op.values()
+ }
+
+ # Get the tensors associated with tensor_names.
+ tensors = []
+ invalid_tensors = []
+ for name in tensor_names:
+ tensor = tensor_name_to_tensor.get(name)
+ if tensor is None:
+ invalid_tensors.append(name)
+ else:
+ tensors.append(tensor)
+
+ # Throw ValueError if any user input names are not valid tensors.
+ if invalid_tensors:
+ raise ValueError("Invalid tensors '{}' were found.".format(
+ ",".join(invalid_tensors)))
+ return tensors
+
+
+def set_tensor_shapes(tensors, shapes):
+ """Sets Tensor shape for each tensor if the shape is defined.
+
+ Args:
+ tensors: TensorFlow ops.Tensor.
+ shapes: Dict of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
+ """
+ if shapes:
+ for tensor in tensors:
+ shape = shapes.get(tensor.name)
+ if shape is not None:
+ tensor.set_shape(shapes[tensor.name])
+
+
def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
output_arrays, tag_set, signature_key):
"""Converts a SavedModel to a frozen graph.
@@ -211,15 +228,14 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
Args:
saved_model_dir: SavedModel directory to convert.
input_arrays: List of input tensors to freeze graph with. Uses input arrays
- from SignatureDef when none are provided. (default None)
- input_shapes: Map of strings representing input tensor names to list of
+ from SignatureDef when none are provided.
+ input_shapes: Dict of strings representing input tensor names to list of
integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
Automatically determined when input shapes is None (e.g., {"foo" : None}).
- (default None)
output_arrays: List of output tensors to freeze graph with. Uses output
- arrays from SignatureDef when none are provided. (default None)
+ arrays from SignatureDef when none are provided.
tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
- analyze. All tags in the tag set must be present. (default "serve")
+ analyze. All tags in the tag set must be present.
signature_key: Key identifying SignatureDef containing inputs and outputs.
Returns:
@@ -233,14 +249,7 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
signature_key is not in the MetaGraphDef.
input_shapes does not match the length of input_arrays.
input_arrays or output_arrays are not valid.
- Unable to load Session.
"""
- # Set default values for inputs if they are set to None.
- if signature_key is None:
- signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
- if tag_set is None:
- tag_set = set([tag_constants.SERVING])
-
# Read SignatureDef.
meta_graph = _get_meta_graph_def(saved_model_dir, tag_set)
signature_def = _get_signature_def(meta_graph, signature_key)
@@ -255,19 +264,10 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
# TODO(zhixianyan): Use TFLite supported Op list to filter outputs.
in_tensors = _get_tensors(graph, inputs, input_arrays)
out_tensors = _get_tensors(graph, outputs, output_arrays)
-
- # Gets fully defined tensor shape.
- for tensor in in_tensors:
- if (input_shapes and tensor.name in input_shapes and
- input_shapes[tensor.name] is not None):
- shape = input_shapes[tensor.name]
- else:
- shape = tensor.get_shape().as_list()
- tensor.set_shape(shape)
+ set_tensor_shapes(in_tensors, input_shapes)
output_names = [node.split(":")[0] for node in outputs]
frozen_graph_def = tf_graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), output_names)
return frozen_graph_def, in_tensors, out_tensors
- raise ValueError("Unable to load Session.")
diff --git a/tensorflow/contrib/lite/python/convert_saved_model_test.py b/tensorflow/contrib/lite/python/convert_saved_model_test.py
index f69381d0e6..80e5dc6e46 100644
--- a/tensorflow/contrib/lite/python/convert_saved_model_test.py
+++ b/tensorflow/contrib/lite/python/convert_saved_model_test.py
@@ -41,9 +41,58 @@ from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import test
from tensorflow.python.saved_model import saved_model
from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import training as train
+class TensorFunctionsTest(test_util.TensorFlowTestCase):
+
+ def testGetTensorsValid(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + in_tensor
+ sess = session.Session()
+
+ tensors = convert_saved_model.get_tensors_from_tensor_names(
+ sess.graph, ["Placeholder"])
+ self.assertEqual("Placeholder:0", tensors[0].name)
+
+ def testGetTensorsInvalid(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + in_tensor
+ sess = session.Session()
+
+ with self.assertRaises(ValueError) as error:
+ convert_saved_model.get_tensors_from_tensor_names(sess.graph,
+ ["invalid-input"])
+ self.assertEqual("Invalid tensors 'invalid-input' were found.",
+ str(error.exception))
+
+ def testSetTensorShapeValid(self):
+ tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
+ self.assertEqual([None, 3, 5], tensor.shape.as_list())
+
+ convert_saved_model.set_tensor_shapes([tensor],
+ {"Placeholder:0": [5, 3, 5]})
+ self.assertEqual([5, 3, 5], tensor.shape.as_list())
+
+ def testSetTensorShapeInvalid(self):
+ tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
+ self.assertEqual([None, 3, 5], tensor.shape.as_list())
+
+ convert_saved_model.set_tensor_shapes([tensor],
+ {"invalid-input": [5, 3, 5]})
+ self.assertEqual([None, 3, 5], tensor.shape.as_list())
+
+ def testSetTensorShapeEmpty(self):
+ tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
+ self.assertEqual([None, 3, 5], tensor.shape.as_list())
+
+ convert_saved_model.set_tensor_shapes([tensor], {})
+ self.assertEqual([None, 3, 5], tensor.shape.as_list())
+
+
class FreezeSavedModelTest(test_util.TensorFlowTestCase):
def _createSimpleSavedModel(self, shape):
@@ -93,6 +142,10 @@ class FreezeSavedModelTest(test_util.TensorFlowTestCase):
output_arrays=None,
tag_set=None,
signature_key=None):
+ if tag_set is None:
+ tag_set = set([tag_constants.SERVING])
+ if signature_key is None:
+ signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
graph_def, in_tensors, out_tensors = convert_saved_model.freeze_saved_model(
saved_model_dir=saved_model_dir,
input_arrays=input_arrays,
@@ -390,7 +443,7 @@ class FreezeSavedModelTestTrainGraph(test_util.TensorFlowTestCase):
input_arrays=None,
input_shapes=None,
output_arrays=["Softmax"],
- tag_set=None,
+ tag_set=set([tag_constants.SERVING]),
signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
self.assertTrue(result)
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index f7f2d40a02..d595415b63 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -33,15 +33,24 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from six import PY3
+
+from google.protobuf import text_format as _text_format
+from google.protobuf.message import DecodeError
from tensorflow.contrib.lite.python import lite_constants as constants
from tensorflow.contrib.lite.python.convert import tensor_name
from tensorflow.contrib.lite.python.convert import toco_convert
from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import
from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model
+from tensorflow.contrib.lite.python.convert_saved_model import get_tensors_from_tensor_names
+from tensorflow.contrib.lite.python.convert_saved_model import set_tensor_shapes
from tensorflow.contrib.lite.python.interpreter import Interpreter # pylint: disable=unused-import
from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import
from tensorflow.contrib.lite.python.op_hint import OpHint # pylint: disable=unused-import
+from tensorflow.core.framework import graph_pb2 as _graph_pb2
+from tensorflow.python.client import session as _session
from tensorflow.python.framework import graph_util as tf_graph_util
+from tensorflow.python.framework.importer import import_graph_def
from tensorflow.python.ops.variables import global_variables_initializer
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
@@ -55,26 +64,50 @@ class TocoConverter(object):
Attributes:
- inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`.
- (default FLOAT)
- output_format: Type of data to write (currently must be TFLITE or
- GRAPHVIZ_DOT). (default TFLITE)
- quantized_input_stats: The mean and std deviation of training data for each
- input tensor. Only needed if `inference_type` is `QUANTIZED_UINT8`.
- (default None)
+ inference_type: Target data type of arrays in the output file. Currently
+ must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT)
+ inference_input_type: Target data type of input arrays. Allows for a
+ different type for input arrays in the case of quantization. Currently
+ must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`)
+ output_format: Output file format. Currently must be `{TFLITE,
+ GRAPHVIZ_DOT}`. (default TFLITE)
+ quantized_input_stats: Dict of strings representing input tensor names
+ mapped to tuple of integers representing the mean and standard deviation
+ of the training data (e.g., {"foo" : (0., 1.)}). Only need if
+ `inference_type` is `QUANTIZED_UINT8`. (default {})
+ default_ranges_stats: Tuple of integers representing (min, max) range values
+ for all arrays without a specified range. Intended for experimenting with
+ quantization via "dummy quantization". (default None)
drop_control_dependency: Boolean indicating whether to drop control
dependencies silently. This is due to TFLite not supporting control
dependencies. (default True)
+ reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant
+ nodes in unexpected locations. Used when the location of the FakeQuant
+ nodes is preventing graph transformations necessary to convert the graph.
+ Results in a graph that differs from the quantized training graph,
+ potentially causing differing arithmetic behavior. (default False)
+ change_concat_input_ranges: Boolean to change behavior of min/max ranges for
+ inputs and outputs of the concat operator for quantized models. Changes
+ the ranges of concat operator overlap when true. (default False)
allow_custom_ops: Boolean indicating whether to allow custom operations.
+ When false any unknown operation is an error. When true, custom ops are
+ created for any op that is unknown. The developer will need to provide
+ these to the TensorFlow Lite runtime with a custom resolver.
(default False)
Example usage:
- # Converting a frozen graph.
+ # Converting a GraphDef from session.
converter = lite.TocoConverter.from_session(sess, in_tensors, out_tensors)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
+ # Converting a GraphDef from file.
+ converter = lite.TocoConverter.from_frozen_graph(
+ graph_def_file, input_arrays, output_arrays)
+ tflite_model = converter.convert()
+ open("converted_model.tflite", "wb").write(tflite_model)
+
# Converting a SavedModel.
converter = lite.TocoConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
@@ -94,17 +127,17 @@ class TocoConverter(object):
self._input_tensors = input_tensors
self._output_tensors = output_tensors
self.inference_type = constants.FLOAT
+ self.inference_input_type = None
self.output_format = constants.TFLITE
- self.quantized_input_stats = None
+ self.quantized_input_stats = {}
+ self.default_ranges_stats = None
self.drop_control_dependency = True
+ self.reorder_across_fake_quant = False
+ self.change_concat_input_ranges = False
self.allow_custom_ops = False
@classmethod
- def from_session(cls,
- sess,
- input_tensors,
- output_tensors,
- freeze_variables=False):
+ def from_session(cls, sess, input_tensors, output_tensors):
"""Creates a TocoConverter class from a TensorFlow Session.
Args:
@@ -112,56 +145,108 @@ class TocoConverter(object):
input_tensors: List of input tensors. Type and shape are computed using
`foo.get_shape()` and `foo.dtype`.
output_tensors: List of output tensors (only .name is used from this).
- freeze_variables: Boolean indicating whether the variables need to be
- converted into constants via the freeze_graph.py script.
- (default False)
Returns:
TocoConverter class.
"""
+ graph_def = _freeze_graph(sess, output_tensors)
+ return cls(graph_def, input_tensors, output_tensors)
- # Get GraphDef.
- if freeze_variables:
+ @classmethod
+ def from_frozen_graph(cls,
+ graph_def_file,
+ input_arrays,
+ output_arrays,
+ input_shapes=None):
+ """Creates a TocoConverter class from a file containing a frozen GraphDef.
+
+ Args:
+ graph_def_file: Full filepath of file containing TensorFlow GraphDef.
+ input_arrays: List of input tensors to freeze graph with.
+ output_arrays: List of output tensors to freeze graph with.
+ input_shapes: Dict of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
+ Automatically determined when input shapes is None (e.g., {"foo" :
+ None}). (default None)
+
+ Returns:
+ TocoConverter class.
+
+ Raises:
+ ValueError:
+ Unable to parse input file.
+ The graph is not frozen.
+ input_arrays or output_arrays contains an invalid tensor name.
+ """
+ with _session.Session() as sess:
sess.run(global_variables_initializer())
- output_arrays = [tensor_name(tensor) for tensor in output_tensors]
- graph_def = tf_graph_util.convert_variables_to_constants(
- sess, sess.graph_def, output_arrays)
- else:
- graph_def = sess.graph_def
- # Create TocoConverter class.
- return cls(graph_def, input_tensors, output_tensors)
+ # Read GraphDef from file.
+ graph_def = _graph_pb2.GraphDef()
+ with open(graph_def_file, "rb") as f:
+ file_content = f.read()
+ try:
+ graph_def.ParseFromString(file_content)
+ except (_text_format.ParseError, DecodeError):
+ try:
+ print("Ignore 'tcmalloc: large alloc' warnings.")
+
+ if not isinstance(file_content, str):
+ if PY3:
+ file_content = file_content.decode('utf-8')
+ else:
+ file_content = file_content.encode('utf-8')
+ _text_format.Merge(file_content, graph_def)
+ except (_text_format.ParseError, DecodeError):
+ raise ValueError(
+ "Unable to parse input file '{}'.".format(graph_def_file))
+ sess.graph.as_default()
+ import_graph_def(graph_def, name="")
+
+ # Get input and output tensors.
+ input_tensors = get_tensors_from_tensor_names(sess.graph, input_arrays)
+ output_tensors = get_tensors_from_tensor_names(sess.graph, output_arrays)
+ set_tensor_shapes(input_tensors, input_shapes)
+
+ # Check if graph is frozen.
+ if not _is_frozen_graph(sess):
+ raise ValueError("Please freeze the graph using freeze_graph.py")
+
+ # Create TocoConverter class.
+ return cls(sess.graph_def, input_tensors, output_tensors)
@classmethod
- def from_saved_model(
- cls,
- saved_model_dir,
- input_arrays=None,
- input_shapes=None,
- output_arrays=None,
- tag_set=None,
- signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY):
+ def from_saved_model(cls,
+ saved_model_dir,
+ input_arrays=None,
+ input_shapes=None,
+ output_arrays=None,
+ tag_set=None,
+ signature_key=None):
"""Creates a TocoConverter class from a SavedModel.
Args:
saved_model_dir: SavedModel directory to convert.
input_arrays: List of input tensors to freeze graph with. Uses input
arrays from SignatureDef when none are provided. (default None)
- input_shapes: Map of strings representing input tensor names to list of
- integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
+ input_shapes: Dict of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
Automatically determined when input shapes is None (e.g., {"foo" :
None}). (default None)
output_arrays: List of output tensors to freeze graph with. Uses output
arrays from SignatureDef when none are provided. (default None)
tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
- analyze. All tags in the tag set must be present. (default "serve")
+ analyze. All tags in the tag set must be present. (default set("serve"))
signature_key: Key identifying SignatureDef containing inputs and outputs.
+ (default DEFAULT_SERVING_SIGNATURE_DEF_KEY)
Returns:
TocoConverter class.
"""
if tag_set is None:
tag_set = set([tag_constants.SERVING])
+ if signature_key is None:
+ signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
result = freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
output_arrays, tag_set, signature_key)
@@ -189,16 +274,39 @@ class TocoConverter(object):
elif shape[0] is None:
self._set_batch_size(batch_size=1)
+ # Get quantization stats. Ensures there is one stat per name if the stats
+ # are specified.
+ if self.quantized_input_stats:
+ quantized_stats = []
+ invalid_stats = []
+ for tensor in self._input_tensors:
+ name = tensor_name(tensor)
+ if name in self.quantized_input_stats:
+ quantized_stats.append(self.quantized_input_stats[name])
+ else:
+ invalid_stats.append(name)
+
+ if invalid_stats:
+ raise ValueError("Quantization input stats are not available for input "
+ "tensors '{0}'.".format(",".join(invalid_stats)))
+ else:
+ quantized_stats = None
+
# Converts model.
result = toco_convert(
input_data=self._graph_def,
input_tensors=self._input_tensors,
output_tensors=self._output_tensors,
inference_type=self.inference_type,
+ inference_input_type=self.inference_input_type,
input_format=constants.TENSORFLOW_GRAPHDEF,
output_format=self.output_format,
- quantized_input_stats=self.quantized_input_stats,
- drop_control_dependency=self.drop_control_dependency)
+ quantized_input_stats=quantized_stats,
+ default_ranges_stats=self.default_ranges_stats,
+ drop_control_dependency=self.drop_control_dependency,
+ reorder_across_fake_quant=self.reorder_across_fake_quant,
+ change_concat_input_ranges=self.change_concat_input_ranges,
+ allow_custom_ops=self.allow_custom_ops)
return result
def _set_batch_size(self, batch_size):
@@ -212,3 +320,43 @@ class TocoConverter(object):
shape = tensor.get_shape().as_list()
shape[0] = batch_size
tensor.set_shape(shape)
+
+
+def _is_frozen_graph(sess):
+ """Determines if the graph is frozen.
+
+ Determines if a graph has previously been frozen by checking for any
+ operations of type Variable*. If variables are found, the graph is not frozen.
+
+ Args:
+ sess: TensorFlow Session.
+
+ Returns:
+ Bool.
+ """
+ for op in sess.graph.get_operations():
+ if op.type.startswith("Variable"):
+ return False
+ return True
+
+
+def _freeze_graph(sess, output_tensors):
+ """Returns a frozen GraphDef.
+
+ Freezes a graph with Variables in it. Otherwise the existing GraphDef is
+ returned.
+
+ Args:
+ sess: TensorFlow Session.
+ output_tensors: List of output tensors (only .name is used from this).
+
+ Returns:
+ Frozen GraphDef.
+ """
+ if not _is_frozen_graph(sess):
+ sess.run(global_variables_initializer())
+ output_arrays = [tensor_name(tensor) for tensor in output_tensors]
+ return tf_graph_util.convert_variables_to_constants(sess, sess.graph_def,
+ output_arrays)
+ else:
+ return sess.graph_def
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index 2f3105f3e6..53d1878293 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -29,8 +29,10 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.saved_model import saved_model
+from tensorflow.python.training.training_util import write_graph
class FromSessionTest(test_util.TensorFlowTestCase):
@@ -65,16 +67,22 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertEqual((0., 0.), output_details[0]['quantization'])
def testQuantization(self):
- in_tensor = array_ops.placeholder(
- shape=[1, 16, 16, 3], dtype=dtypes.float32, name='input')
+ in_tensor_1 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
+ in_tensor_2 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
out_tensor = array_ops.fake_quant_with_min_max_args(
- in_tensor + in_tensor, min=0., max=1., name='output')
+ in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TocoConverter.from_session(
+ sess, [in_tensor_1, in_tensor_2], [out_tensor])
converter.inference_type = lite_constants.QUANTIZED_UINT8
- converter.quantized_input_stats = [(0., 1.)] # mean, std_dev
+ converter.quantized_input_stats = {
+ 'inputA': (0., 1.),
+ 'inputB': (0., 1.)
+ } # mean, std_dev
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -83,13 +91,19 @@ class FromSessionTest(test_util.TensorFlowTestCase):
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
- self.assertEqual(1, len(input_details))
- self.assertEqual('input', input_details[0]['name'])
+ self.assertEqual(2, len(input_details))
+ self.assertEqual('inputA', input_details[0]['name'])
self.assertEqual(np.uint8, input_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
self.assertEqual((1., 0.),
input_details[0]['quantization']) # scale, zero_point
+ self.assertEqual('inputB', input_details[1]['name'])
+ self.assertEqual(np.uint8, input_details[1]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
+ self.assertEqual((1., 0.),
+ input_details[1]['quantization']) # scale, zero_point
+
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('output', output_details[0]['name'])
@@ -97,6 +111,26 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertTrue(output_details[0]['quantization'][0] > 0) # scale
+ def testQuantizationInvalid(self):
+ in_tensor_1 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
+ in_tensor_2 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
+ out_tensor = array_ops.fake_quant_with_min_max_args(
+ in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
+ sess = session.Session()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(
+ sess, [in_tensor_1, in_tensor_2], [out_tensor])
+ converter.inference_type = lite_constants.QUANTIZED_UINT8
+ converter.quantized_input_stats = {'inputA': (0., 1.)} # mean, std_dev
+ with self.assertRaises(ValueError) as error:
+ converter.convert()
+ self.assertEqual(
+ 'Quantization input stats are not available for input tensors '
+ '\'inputB\'.', str(error.exception))
+
def testBatchSizeInvalid(self):
in_tensor = array_ops.placeholder(
shape=[None, 16, 16, 3], dtype=dtypes.float32)
@@ -152,8 +186,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(
- sess, [in_tensor], [out_tensor], freeze_variables=True)
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -187,6 +220,196 @@ class FromSessionTest(test_util.TensorFlowTestCase):
graphviz_output = converter.convert()
self.assertTrue(graphviz_output)
+ def testInferenceInputType(self):
+ in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.uint8)
+ out_tensor = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter.inference_input_type = lite_constants.QUANTIZED_UINT8
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('Placeholder', input_details[0]['name'])
+ self.assertEqual(np.uint8, input_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('add', output_details[0]['name'])
+ self.assertEqual(np.uint8, output_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ def testDefaultRangesStats(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ out_tensor = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter.inference_type = lite_constants.QUANTIZED_UINT8
+ converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
+ converter.default_ranges_stats = (0, 6) # min, max
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('Placeholder', input_details[0]['name'])
+ self.assertEqual(np.uint8, input_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+ self.assertEqual((1., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('add', output_details[0]['name'])
+ self.assertEqual(np.uint8, output_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
+ self.assertTrue(output_details[0]['quantization'][0] > 0) # scale
+
+
+class FromFlatbufferFile(test_util.TensorFlowTestCase):
+
+ def testFloat(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Write graph to file.
+ graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
+ write_graph(sess.graph_def, '', graph_def_file, False)
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
+ ['Placeholder'], ['add'])
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('Placeholder', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('add', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ def testFloatWithShapesArray(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Write graph to file.
+ graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
+ write_graph(sess.graph_def, '', graph_def_file, False)
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_frozen_graph(
+ graph_def_file, ['Placeholder'], ['add'],
+ input_shapes={'Placeholder': [1, 16, 16, 3]})
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+
+ def testFreezeGraph(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ var = variable_scope.get_variable(
+ 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + var
+ sess = session.Session()
+
+ # Write graph to file.
+ graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
+ write_graph(sess.graph_def, '', graph_def_file, False)
+
+ # Ensure the graph with variables cannot be converted.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
+ ['add'])
+ self.assertEqual('Please freeze the graph using freeze_graph.py',
+ str(error.exception))
+
+ def testPbtxt(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Write graph to file.
+ graph_def_file = os.path.join(self.get_temp_dir(), 'model.pbtxt')
+ write_graph(sess.graph_def, '', graph_def_file, True)
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
+ ['Placeholder'], ['add'])
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('Placeholder', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('add', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ def testInvalidFile(self):
+ graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file')
+ with gfile.Open(graph_def_file, 'wb') as temp_file:
+ temp_file.write('bad data')
+ temp_file.flush()
+
+ # Attempts to convert the invalid model.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
+ ['add'])
+ self.assertEqual(
+ 'Unable to parse input file \'{}\'.'.format(graph_def_file),
+ str(error.exception))
+
class FromSavedModelTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
new file mode 100644
index 0000000000..337f05785e
--- /dev/null
+++ b/tensorflow/contrib/lite/python/tflite_convert.py
@@ -0,0 +1,324 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python command line interface for running TOCO."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import os
+import sys
+
+from tensorflow.contrib.lite.python import lite
+from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2
+from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2
+from tensorflow.python.platform import app
+
+
+def _parse_array(values):
+ if values:
+ return values.split(",")
+
+
+def _parse_int_array(values):
+ if values:
+ return [int(val) for val in values.split(",")]
+
+
+def _parse_set(values):
+ if values:
+ return set(values.split(","))
+
+
+def _get_toco_converter(flags):
+ """Makes a TocoConverter object based on the flags provided.
+
+ Args:
+ flags: argparse.Namespace object containing TFLite flags.
+
+ Returns:
+ TocoConverter object.
+ """
+ # Parse input and output arrays.
+ input_arrays = _parse_array(flags.input_arrays)
+ input_shapes = None
+ if flags.input_shapes:
+ input_shapes_list = [
+ _parse_int_array(shape) for shape in flags.input_shapes.split(":")
+ ]
+ input_shapes = dict(zip(input_arrays, input_shapes_list))
+ output_arrays = _parse_array(flags.output_arrays)
+
+ converter_kwargs = {
+ "input_arrays": input_arrays,
+ "input_shapes": input_shapes,
+ "output_arrays": output_arrays
+ }
+
+ # Create TocoConverter.
+ if flags.graph_def_file:
+ converter_fn = lite.TocoConverter.from_frozen_graph
+ converter_kwargs["graph_def_file"] = flags.graph_def_file
+ elif flags.saved_model_dir:
+ converter_fn = lite.TocoConverter.from_saved_model
+ converter_kwargs["saved_model_dir"] = flags.saved_model_dir
+ converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set)
+ converter_kwargs["signature_key"] = flags.saved_model_signature_key
+
+ return converter_fn(**converter_kwargs)
+
+
+def _convert_model(flags):
+ """Calls function to convert the TensorFlow model into a TFLite model.
+
+ Args:
+ flags: argparse.Namespace object.
+ """
+ # Create converter.
+ converter = _get_toco_converter(flags)
+ if flags.inference_type:
+ converter.inference_type = _types_pb2.IODataType.Value(flags.inference_type)
+ if flags.inference_input_type:
+ converter.inference_input_type = _types_pb2.IODataType.Value(
+ flags.inference_input_type)
+ if flags.output_format:
+ converter.output_format = _toco_flags_pb2.FileFormat.Value(
+ flags.output_format)
+
+ if flags.mean_values and flags.std_dev_values:
+ input_arrays = _parse_array(flags.input_arrays)
+ std_dev_values = _parse_int_array(flags.std_dev_values)
+ mean_values = _parse_int_array(flags.mean_values)
+ quant_stats = zip(mean_values, std_dev_values)
+ converter.quantized_input_stats = dict(zip(input_arrays, quant_stats))
+ if flags.default_ranges_min and flags.default_ranges_max:
+ converter.default_ranges_stats = (flags.default_ranges_min,
+ flags.default_ranges_max)
+
+ if flags.drop_control_dependency:
+ converter.drop_control_dependency = flags.drop_control_dependency
+ if flags.reorder_across_fake_quant:
+ converter.reorder_across_fake_quant = flags.reorder_across_fake_quant
+ if flags.change_concat_input_ranges:
+ converter.change_concat_input_ranges = flags.change_concat_input_ranges
+ if flags.allow_custom_ops:
+ converter.allow_custom_ops = flags.allow_custom_ops
+
+ # Convert model.
+ output_data = converter.convert()
+ with open(flags.output_file, "wb") as f:
+ f.write(output_data)
+
+
+def _check_flags(flags, unparsed):
+ """Checks the parsed and unparsed flags to ensure they are valid.
+
+ Raises an error if previously support unparsed flags are found. Raises an
+ error for parsed flags that don't meet the required conditions.
+
+ Args:
+ flags: argparse.Namespace object containing TFLite flags.
+ unparsed: List of unparsed flags.
+
+ Raises:
+ ValueError: Invalid flags.
+ """
+
+ # Check unparsed flags for common mistakes based on previous TOCO.
+ def _get_message_unparsed(flag, orig_flag, new_flag):
+ if flag.startswith(orig_flag):
+ return "\n Use {0} instead of {1}".format(new_flag, orig_flag)
+ return ""
+
+ if unparsed:
+ output = ""
+ for flag in unparsed:
+ output += _get_message_unparsed(flag, "--input_file", "--graph_def_file")
+ output += _get_message_unparsed(flag, "--std_value", "--std_dev_values")
+ output += _get_message_unparsed(flag, "--batch_size", "--input_shapes")
+ raise ValueError(output)
+
+ # Check that flags are valid.
+ if flags.graph_def_file and (not flags.input_arrays or
+ not flags.output_arrays):
+ raise ValueError("--input_arrays and --output_arrays are required with "
+ "--graph_def_file")
+
+ if flags.input_shapes:
+ if not flags.input_arrays:
+ raise ValueError("--input_shapes must be used with --input_arrays")
+ if flags.input_shapes.count(":") != flags.input_arrays.count(","):
+ raise ValueError("--input_shapes and --input_arrays must have the same "
+ "number of items")
+
+ if flags.std_dev_values or flags.mean_values:
+ if bool(flags.std_dev_values) != bool(flags.mean_values):
+ raise ValueError("--std_dev_values and --mean_values must be used "
+ "together")
+ if not flags.input_arrays:
+ raise ValueError("--std_dev_values and --mean_values must be used with "
+ "--input_arrays")
+ if (flags.std_dev_values.count(",") != flags.mean_values.count(",") or
+ flags.std_dev_values.count(",") != flags.input_arrays.count(",")):
+ raise ValueError("--std_dev_values, --mean_values, and --input_arrays "
+ "must have the same number of items")
+
+ if bool(flags.default_ranges_min) != bool(flags.default_ranges_max):
+ raise ValueError("--default_ranges_min and --default_ranges_max must be "
+ "used together")
+
+
+def run_main(_):
+ """Main in toco_convert.py."""
+ parser = argparse.ArgumentParser(
+ description=("Command line tool to run TensorFlow Lite Optimizing "
+ "Converter (TOCO)."))
+
+ # Output file flag.
+ parser.add_argument(
+ "--output_file",
+ type=str,
+ help="Full filepath of the output file.",
+ required=True)
+
+ # Input file flags.
+ input_file_group = parser.add_mutually_exclusive_group(required=True)
+ input_file_group.add_argument(
+ "--graph_def_file",
+ type=str,
+ help="Full filepath of file containing TensorFlow GraphDef.")
+ input_file_group.add_argument(
+ "--saved_model_dir",
+ type=str,
+ help="Full filepath of directory containing the SavedModel.")
+
+ # Model format flags.
+ parser.add_argument(
+ "--output_format",
+ type=str,
+ choices=["TFLITE", "GRAPHVIZ_DOT"],
+ help="Output file format.")
+ parser.add_argument(
+ "--inference_type",
+ type=str,
+ choices=["FLOAT", "QUANTIZED_UINT8"],
+ help="Target data type of arrays in the output file.")
+ parser.add_argument(
+ "--inference_input_type",
+ type=str,
+ choices=["FLOAT", "QUANTIZED_UINT8"],
+ help=("Target data type of input arrays. Allows for a different type for "
+ "input arrays in the case of quantization."))
+
+ # Input and output arrays flags.
+ parser.add_argument(
+ "--input_arrays",
+ type=str,
+ help="Names of the output arrays, comma-separated.")
+ parser.add_argument(
+ "--input_shapes",
+ type=str,
+ help="Shapes corresponding to --input_arrays, colon-separated.")
+ parser.add_argument(
+ "--output_arrays",
+ type=str,
+ help="Names of the output arrays, comma-separated.")
+
+ # SavedModel related flags.
+ parser.add_argument(
+ "--saved_model_tag_set",
+ type=str,
+ help=("Comma-separated set of tags identifying the MetaGraphDef within "
+ "the SavedModel to analyze. All tags must be present. "
+ "(default \"serve\")"))
+ parser.add_argument(
+ "--saved_model_signature_key",
+ type=str,
+ help=("Key identifying the SignatureDef containing inputs and outputs. "
+ "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)"))
+
+ # Quantization flags.
+ parser.add_argument(
+ "--std_dev_values",
+ type=str,
+ help=("Standard deviation of training data for each input tensor, "
+ "comma-separated. Used for quantization. (default None)"))
+ parser.add_argument(
+ "--mean_values",
+ type=str,
+ help=("Mean of training data for each input tensor, comma-separated. "
+ "Used for quantization. (default None)"))
+ parser.add_argument(
+ "--default_ranges_min",
+ type=int,
+ help=("Default value for min bound of min/max range values used for all "
+ "arrays without a specified range, Intended for experimenting with "
+ "quantization via \"dummy quantization\". (default None)"))
+ parser.add_argument(
+ "--default_ranges_max",
+ type=int,
+ help=("Default value for max bound of min/max range values used for all "
+ "arrays without a specified range, Intended for experimenting with "
+ "quantization via \"dummy quantization\". (default None)"))
+
+ # Graph manipulation flags.
+ parser.add_argument(
+ "--drop_control_dependency",
+ type=bool,
+ help=("Boolean indicating whether to drop control dependencies silently. "
+ "This is due to TensorFlow not supporting control dependencies. "
+ "(default True)"))
+ parser.add_argument(
+ "--reorder_across_fake_quant",
+ type=bool,
+ help=("Boolean indicating whether to reorder FakeQuant nodes in "
+ "unexpected locations. Used when the location of the FakeQuant "
+ "nodes is preventing graph transformations necessary to convert "
+ "the graph. Results in a graph that differs from the quantized "
+ "training graph, potentially causing differing arithmetic "
+ "behavior. (default False)"))
+ parser.add_argument(
+ "--change_concat_input_ranges",
+ type=bool,
+ help=("Boolean to change behavior of min/max ranges for inputs and "
+ "outputs of the concat operator for quantized models. Changes the "
+ "ranges of concat operator overlap when true. (default False)"))
+ parser.add_argument(
+ "--allow_custom_ops",
+ type=bool,
+ help=("Boolean indicating whether to allow custom operations. When false "
+ "any unknown operation is an error. When true, custom ops are "
+ "created for any op that is unknown. The developer will need to "
+ "provide these to the TensorFlow Lite runtime with a custom "
+ "resolver. (default False)"))
+
+ tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:])
+ try:
+ _check_flags(tflite_flags, unparsed)
+ except ValueError as e:
+ parser.print_usage()
+ file_name = os.path.basename(sys.argv[0])
+ sys.stderr.write("{0}: error: {1}\n".format(file_name, str(e)))
+ sys.exit(1)
+ _convert_model(tflite_flags)
+
+
+def main():
+ app.run(main=run_main, argv=sys.argv[:1])
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 8bdeb035f5..522eac25b3 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -145,6 +145,7 @@ enum BuiltinOperator : byte {
SLICE = 65,
SIN = 66,
TRANSPOSE_CONV = 67,
+ SPARSE_TO_DENSE = 68,
}
// Options for the builtin operators.
@@ -198,6 +199,7 @@ union BuiltinOptions {
SelectOptions,
SliceOptions,
TransposeConvOptions,
+ SparseToDenseOptions,
}
enum Padding : byte { SAME, VALID }
@@ -450,6 +452,10 @@ table TransposeConvOptions {
stride_h:int;
}
+table SparseToDenseOptions {
+ validate_indices:bool;
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 35c34f53a6..746dd26796 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -178,6 +178,9 @@ struct SliceOptionsT;
struct TransposeConvOptions;
struct TransposeConvOptionsT;
+struct SparseToDenseOptions;
+struct SparseToDenseOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -305,11 +308,12 @@ enum BuiltinOperator {
BuiltinOperator_SLICE = 65,
BuiltinOperator_SIN = 66,
BuiltinOperator_TRANSPOSE_CONV = 67,
+ BuiltinOperator_SPARSE_TO_DENSE = 68,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_TRANSPOSE_CONV
+ BuiltinOperator_MAX = BuiltinOperator_SPARSE_TO_DENSE
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[67] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[68] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -377,7 +381,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[67] {
BuiltinOperator_SELECT,
BuiltinOperator_SLICE,
BuiltinOperator_SIN,
- BuiltinOperator_TRANSPOSE_CONV
+ BuiltinOperator_TRANSPOSE_CONV,
+ BuiltinOperator_SPARSE_TO_DENSE
};
return values;
}
@@ -452,6 +457,7 @@ inline const char **EnumNamesBuiltinOperator() {
"SLICE",
"SIN",
"TRANSPOSE_CONV",
+ "SPARSE_TO_DENSE",
nullptr
};
return names;
@@ -513,11 +519,12 @@ enum BuiltinOptions {
BuiltinOptions_SelectOptions = 47,
BuiltinOptions_SliceOptions = 48,
BuiltinOptions_TransposeConvOptions = 49,
+ BuiltinOptions_SparseToDenseOptions = 50,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_TransposeConvOptions
+ BuiltinOptions_MAX = BuiltinOptions_SparseToDenseOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[50] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[51] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -568,7 +575,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[50] {
BuiltinOptions_LessEqualOptions,
BuiltinOptions_SelectOptions,
BuiltinOptions_SliceOptions,
- BuiltinOptions_TransposeConvOptions
+ BuiltinOptions_TransposeConvOptions,
+ BuiltinOptions_SparseToDenseOptions
};
return values;
}
@@ -625,6 +633,7 @@ inline const char **EnumNamesBuiltinOptions() {
"SelectOptions",
"SliceOptions",
"TransposeConvOptions",
+ "SparseToDenseOptions",
nullptr
};
return names;
@@ -835,6 +844,10 @@ template<> struct BuiltinOptionsTraits<TransposeConvOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_TransposeConvOptions;
};
+template<> struct BuiltinOptionsTraits<SparseToDenseOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_SparseToDenseOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1258,6 +1271,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_TransposeConvOptions ?
reinterpret_cast<const TransposeConvOptionsT *>(value) : nullptr;
}
+ SparseToDenseOptionsT *AsSparseToDenseOptions() {
+ return type == BuiltinOptions_SparseToDenseOptions ?
+ reinterpret_cast<SparseToDenseOptionsT *>(value) : nullptr;
+ }
+ const SparseToDenseOptionsT *AsSparseToDenseOptions() const {
+ return type == BuiltinOptions_SparseToDenseOptions ?
+ reinterpret_cast<const SparseToDenseOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -4543,6 +4564,60 @@ inline flatbuffers::Offset<TransposeConvOptions> CreateTransposeConvOptions(
flatbuffers::Offset<TransposeConvOptions> CreateTransposeConvOptions(flatbuffers::FlatBufferBuilder &_fbb, const TransposeConvOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct SparseToDenseOptionsT : public flatbuffers::NativeTable {
+ typedef SparseToDenseOptions TableType;
+ bool validate_indices;
+ SparseToDenseOptionsT()
+ : validate_indices(false) {
+ }
+};
+
+struct SparseToDenseOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef SparseToDenseOptionsT NativeTableType;
+ enum {
+ VT_VALIDATE_INDICES = 4
+ };
+ bool validate_indices() const {
+ return GetField<uint8_t>(VT_VALIDATE_INDICES, 0) != 0;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint8_t>(verifier, VT_VALIDATE_INDICES) &&
+ verifier.EndTable();
+ }
+ SparseToDenseOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(SparseToDenseOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<SparseToDenseOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct SparseToDenseOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_validate_indices(bool validate_indices) {
+ fbb_.AddElement<uint8_t>(SparseToDenseOptions::VT_VALIDATE_INDICES, static_cast<uint8_t>(validate_indices), 0);
+ }
+ explicit SparseToDenseOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ SparseToDenseOptionsBuilder &operator=(const SparseToDenseOptionsBuilder &);
+ flatbuffers::Offset<SparseToDenseOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<SparseToDenseOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<SparseToDenseOptions> CreateSparseToDenseOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ bool validate_indices = false) {
+ SparseToDenseOptionsBuilder builder_(_fbb);
+ builder_.add_validate_indices(validate_indices);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<SparseToDenseOptions> CreateSparseToDenseOptions(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -4821,6 +4896,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const TransposeConvOptions *builtin_options_as_TransposeConvOptions() const {
return builtin_options_type() == BuiltinOptions_TransposeConvOptions ? static_cast<const TransposeConvOptions *>(builtin_options()) : nullptr;
}
+ const SparseToDenseOptions *builtin_options_as_SparseToDenseOptions() const {
+ return builtin_options_type() == BuiltinOptions_SparseToDenseOptions ? static_cast<const SparseToDenseOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -5043,6 +5121,10 @@ template<> inline const TransposeConvOptions *Operator::builtin_options_as<Trans
return builtin_options_as_TransposeConvOptions();
}
+template<> inline const SparseToDenseOptions *Operator::builtin_options_as<SparseToDenseOptions>() const {
+ return builtin_options_as_SparseToDenseOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -6862,6 +6944,32 @@ inline flatbuffers::Offset<TransposeConvOptions> CreateTransposeConvOptions(flat
_stride_h);
}
+inline SparseToDenseOptionsT *SparseToDenseOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new SparseToDenseOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void SparseToDenseOptions::UnPackTo(SparseToDenseOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = validate_indices(); _o->validate_indices = _e; };
+}
+
+inline flatbuffers::Offset<SparseToDenseOptions> SparseToDenseOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateSparseToDenseOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<SparseToDenseOptions> CreateSparseToDenseOptions(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SparseToDenseOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _validate_indices = _o->validate_indices;
+ return tflite::CreateSparseToDenseOptions(
+ _fbb,
+ _validate_indices);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@@ -7244,6 +7352,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const TransposeConvOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_SparseToDenseOptions: {
+ auto ptr = reinterpret_cast<const SparseToDenseOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -7458,6 +7570,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const TransposeConvOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_SparseToDenseOptions: {
+ auto ptr = reinterpret_cast<const SparseToDenseOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -7660,6 +7776,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const TransposeConvOptionsT *>(value);
return CreateTransposeConvOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_SparseToDenseOptions: {
+ auto ptr = reinterpret_cast<const SparseToDenseOptionsT *>(value);
+ return CreateSparseToDenseOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -7862,6 +7982,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new TransposeConvOptionsT(*reinterpret_cast<TransposeConvOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_SparseToDenseOptions: {
+ value = new SparseToDenseOptionsT(*reinterpret_cast<SparseToDenseOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -8114,6 +8238,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_SparseToDenseOptions: {
+ auto ptr = reinterpret_cast<SparseToDenseOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 13fafebd1d..ae66bd858b 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -146,8 +146,9 @@ def toco_options(data_types,
" --inference_type=%s" % inference_type +
" --input_format=TENSORFLOW_GRAPHDEF" + " --output_format=TFLITE" +
" --input_arrays=%s" % ",".join(input_arrays) +
- " --input_shapes=%s" % shape_str +
" --output_arrays=%s" % ",".join(output_arrays))
+ if shape_str:
+ s += (" --input_shapes=%s" % shape_str)
if extra_toco_options.drop_control_dependency:
s += " --drop_control_dependency"
if extra_toco_options.allow_custom_ops:
@@ -238,6 +239,19 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100):
return value.astype(dtype)
+def create_scalar_data(dtype, min_value=-100, max_value=100):
+ """Build scalar tensor data range from min_value to max_value exclusively."""
+
+ if dtype in _TF_TYPE_INFO:
+ dtype = _TF_TYPE_INFO[dtype][0]
+
+ if dtype in (tf.float32, tf.float16):
+ value = (max_value - min_value) * np.random.random() + min_value
+ elif dtype in (tf.int32, tf.uint8, tf.int64):
+ value = np.random.randint(min_value, max_value + 1)
+ return np.array(value, dtype=dtype)
+
+
def freeze_graph(session, outputs):
"""Freeze the current graph.
@@ -2485,6 +2499,67 @@ def make_transpose_conv_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_sparse_to_dense_tests(zip_path):
+ """Make a set of tests to do sparse to dense."""
+
+ test_parameters = [{
+ "value_dtype": [tf.float32, tf.int32],
+ "index_dtype": [tf.int32, tf.int64],
+ "value_count": [1, 3, 6, 8],
+ "dense_shape": [[15], [3, 10], [4, 4, 4, 4], [7, 10, 9]],
+ "default_value": [0, -1],
+ "value_is_scalar": [True, False],
+ }]
+
+ # Return a single value for 1-D dense shape, but a tuple for other shapes.
+ def generate_index(dense_shape):
+ if len(dense_shape) == 1:
+ return np.random.randint(dense_shape[0])
+ else:
+ index = []
+ for shape in dense_shape:
+ index.append(np.random.randint(shape))
+ return tuple(index)
+
+ def build_graph(parameters):
+ """Build the sparse_to_dense op testing graph."""
+ dense_shape = parameters["dense_shape"]
+
+ # Special handle for value_is_scalar case.
+ # value_count must be 1.
+ if parameters["value_is_scalar"] and parameters["value_count"] == 1:
+ value = tf.placeholder(
+ name="value", dtype=parameters["value_dtype"], shape=())
+ else:
+ value = tf.placeholder(
+ name="value",
+ dtype=parameters["value_dtype"],
+ shape=[parameters["value_count"]])
+ indices = set()
+ while len(indices) < parameters["value_count"]:
+ indices.add(generate_index(dense_shape))
+ indices = tf.constant(tuple(indices), dtype=parameters["index_dtype"])
+ # TODO(renjieliu): Add test for validate_indices case.
+ out = tf.sparse_to_dense(
+ indices,
+ dense_shape,
+ value,
+ parameters["default_value"],
+ validate_indices=False)
+
+ return [value], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ if parameters["value_is_scalar"] and parameters["value_count"] == 1:
+ input_value = create_scalar_data(parameters["value_dtype"])
+ else:
+ input_value = create_tensor_data(parameters["value_dtype"],
+ [parameters["value_count"]])
+ return [input_value], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
# Toco binary path provided by the generate rule.
bin_path = None
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index f5157149af..99f0c81a1b 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -1728,6 +1728,25 @@ void ConvertComparisonOperator(const Model& model, const Operator& src_op,
(*comparison_op->mutable_attr())["T"].set_type(data_type);
}
+void ConvertSparseToDenseOperator(const Model& model,
+ const SparseToDenseOperator& src_op,
+ const char* op_name,
+ GraphDef* tensorflow_graph) {
+ auto* sparse_to_dense_op = tensorflow_graph->add_node();
+ sparse_to_dense_op->set_op(op_name);
+ sparse_to_dense_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 4);
+ for (int i = 0; i < 4; ++i) {
+ *sparse_to_dense_op->add_input() = src_op.inputs[i];
+ }
+ const auto data_type = GetTensorFlowDataType(model, src_op.inputs[3]);
+ (*sparse_to_dense_op->mutable_attr())["T"].set_type(data_type);
+ const auto index_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*sparse_to_dense_op->mutable_attr())["Tindices"].set_type(index_type);
+ (*sparse_to_dense_op->mutable_attr())["Tindices"].set_b(
+ src_op.validate_indices);
+}
+
void ConvertOperator(const Model& model, const Operator& src_op,
GraphDef* tensorflow_graph) {
if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md
index 29a83bd26f..5071361bfd 100644
--- a/tensorflow/contrib/lite/toco/g3doc/python_api.md
+++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md
@@ -12,8 +12,8 @@ Table of contents:
* [High-level overview](#high-level-overview)
* [API](#api)
* [Basic examples](#basic)
- * [Exporting a GraphDef with constants](#basic-graphdef-const)
- * [Exporting a GraphDef with variables](#basic-graphdef-var)
+ * [Exporting a GraphDef from tf.Session](#basic-graphdef-sess)
+ * [Exporting a GraphDef from file](#basic-graphdef-file)
* [Exporting a SavedModel](#basic-savedmodel)
* [Complex examples](#complex)
* [Exporting a quantized GraphDef](#complex-quant)
@@ -50,17 +50,17 @@ possible.
The following section shows examples of how to convert a basic float-point model
from each of the supported data formats into a TensorFlow Lite FlatBuffers.
-### Exporting a GraphDef with constants <a name="basic-graphdef-const"></a>
+### Exporting a GraphDef from tf.Session <a name="basic-graphdef-sess"></a>
-The following example shows how to convert a TensorFlow GraphDef with constants
-into a TensorFlow Lite FlatBuffer.
+The following example shows how to convert a TensorFlow GraphDef into a
+TensorFlow Lite FlatBuffer from a `tf.Session` object.
```python
import tensorflow as tf
img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
-const = tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
-val = img + const
+var = tf.get_variable("weights", dtype=tf.float32, shape=(1, 64, 64, 3))
+val = img + var
out = tf.identity(val, name="out")
with tf.Session() as sess:
@@ -69,25 +69,28 @@ with tf.Session() as sess:
open("converted_model.tflite", "wb").write(tflite_model)
```
-### Exporting a GraphDef with variables <a name="basic-graphdef-var"></a>
+### Exporting a GraphDef from file <a name="basic-graphdef-file"></a>
-If a model has variables, they need to be turned into constants through a
-process known as freezing. It can be accomplished by setting `freeze_variables`
-to `True` as shown in the example below.
+The following example shows how to convert a TensorFlow GraphDef into a
+TensorFlow Lite FlatBuffer when the GraphDef is stored in a file. Both `.pb` and
+`.pbtxt` files are accepted.
+
+The example uses
+[Mobilenet_1.0_224](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz).
+The function only supports GraphDefs frozen via
+[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py).
```python
import tensorflow as tf
-img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
-var = tf.get_variable("weights", dtype=tf.float32, shape=(1, 64, 64, 3))
-val = img + var
-out = tf.identity(val, name="out")
+graph_def_file = "/path/to/Downloads/mobilenet_v1_1.0_224/frozen_graph.pb"
+input_arrays = ["input"]
+output_arrays = ["MobilenetV1/Predictions/Softmax"]
-with tf.Session() as sess:
- converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out],
- freeze_variables=True)
- tflite_model = converter.convert()
- open("converted_model.tflite", "wb").write(tflite_model)
+converter = tf.contrib.lite.TocoConverter.from_frozen_graph(
+ graph_def_file, input_arrays, output_arrays)
+tflite_model = converter.convert()
+open("converted_model.tflite", "wb").write(tflite_model)
```
### Exporting a SavedModel <a name="basic-savedmodel"></a>
@@ -111,8 +114,8 @@ available by running `help(tf.contrib.lite.TocoConverter)`.
## Complex examples <a name="complex"></a>
For models where the default value of the attributes is not sufficient, the
-variables values should be set before calling `convert()`. In order to call any
-constants use `tf.contrib.lite.constants.<CONSTANT_NAME>` as seen below with
+attribute's values should be set before calling `convert()`. In order to call
+any constants use `tf.contrib.lite.constants.<CONSTANT_NAME>` as seen below with
`QUANTIZED_UINT8`. Run `help(tf.contrib.lite.TocoConverter)` in the Python
terminal for detailed documentation on the attributes.
@@ -135,7 +138,7 @@ out = tf.fake_quant_with_min_max_args(val, min=0., max=1., name="output")
with tf.Session() as sess:
converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out])
converter.inference_type = tf.contrib.lite.constants.QUANTIZED_UINT8
- converter.quantized_input_stats = [(0., 1.)] # mean, std_dev
+ converter.quantized_input_stats = {"img" : (0., 1.)} # mean, std_dev
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
```
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 6342cf3e8a..64096fb069 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -163,6 +163,16 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
SetDataTypeForAllOutputs(model, op, data_type_x);
break;
}
+ case OperatorType::kSparseToDense: {
+ // Select produces outputs with the same type as their 3rd input
+ CHECK_EQ(op->inputs.size(), 4);
+ const ArrayDataType data_type = model->GetArray(op->inputs[2]).data_type;
+ const ArrayDataType data_type_default =
+ model->GetArray(op->inputs[3]).data_type;
+ CHECK(data_type == data_type_default);
+ SetDataTypeForAllOutputs(model, op, data_type);
+ break;
+ }
default: {
// These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 9d1d27f3ef..adb241da32 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1477,6 +1477,34 @@ void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) {
*output_array.mutable_shape()->mutable_dims() = output_dims;
}
+void ProcessSparseToDenseOperator(Model* model, SparseToDenseOperator* op) {
+ CHECK_EQ(op->inputs.size(), 4);
+
+ const Array& output_shape_array = model->GetArray(op->inputs[1]);
+ if (!output_shape_array.has_shape()) return;
+ CHECK_EQ(output_shape_array.shape().dimensions_count(), 1);
+
+ // Output should not go over four dimensions.
+ CHECK_LE(output_shape_array.shape().dims(0), 4);
+
+ const string& output_name = op->outputs[0];
+ Array& output_array = model->GetArray(output_name);
+ if (output_array.has_shape()) return;
+
+ CHECK(output_shape_array.data_type == ArrayDataType::kInt32 ||
+ output_shape_array.data_type == ArrayDataType::kInt64);
+ if (output_shape_array.data_type == ArrayDataType::kInt32) {
+ *output_array.mutable_shape()->mutable_dims() =
+ output_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
+ } else {
+ const std::vector<int64>& output_shape_data =
+ output_shape_array.GetBuffer<ArrayDataType::kInt64>().data;
+ std::copy(
+ output_shape_data.begin(), output_shape_data.end(),
+ std::back_inserter(*output_array.mutable_shape()->mutable_dims()));
+ }
+}
+
} // namespace
bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
@@ -1700,6 +1728,10 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
CHECK_EQ(op->inputs.size(), 1);
ProcessOpWithShapeInput(model, op);
break;
+ case OperatorType::kSparseToDense:
+ ProcessSparseToDenseOperator(model,
+ static_cast<SparseToDenseOperator*>(op));
+ break;
default:
// Unimplemented, another graph transformation should drop it.
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index ea051bb84a..83dce66df1 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -2133,6 +2133,24 @@ void ConvertDynamicStitchOperator(const NodeDef& node,
model->operators.emplace_back(op.release());
}
+void ConvertSparseToDenseOperator(const NodeDef& node,
+ const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), "SparseToDense");
+ CheckInputsCount(node, tf_import_flags, 4);
+
+ auto* op = new SparseToDenseOperator;
+ for (const string& input : node.input()) {
+ op->inputs.push_back(input);
+ }
+ op->outputs.push_back(node.name());
+
+ op->validate_indices = HasAttr(node, "validate_indices")
+ ? GetBoolAttr(node, "validate_indices")
+ : true;
+ model->operators.emplace_back(op);
+}
+
} // namespace
namespace internal {
@@ -2314,6 +2332,8 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node,
ConvertSinOperator(node, tf_import_flags, model);
} else if (node.op() == "Select") {
ConvertSelectOperator(node, tf_import_flags, model);
+ } else if (node.op() == "SparseToDense") {
+ ConvertSparseToDenseOperator(node, tf_import_flags, model);
} else {
ConvertUnsupportedOperator(node, tf_import_flags, model);
}
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index d878ac54e4..9062c03c73 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -135,6 +135,7 @@ enum class OperatorType {
// special nodes in the graph to shuffle axes.
kReorderAxes,
kSelect,
+ kSparseToDense,
};
// Helper to deal with TensorFlow arrays using a different ordering of
@@ -1598,6 +1599,19 @@ struct DynamicStitchOperator : Operator {
int num_partitions;
};
+// SparseToDense operator:
+//
+// Inputs:
+// Inputs[0]: required: sparse_indices.
+// Inputs[1]: required: output_shape.
+// Inputs[2]: required: sparse_values.
+//
+// TensorFlow equivalent: SparseToDense.
+struct SparseToDenseOperator : Operator {
+ SparseToDenseOperator() : Operator(OperatorType::kSparseToDense) {}
+ bool validate_indices;
+};
+
// Alloc's are used for transient arrays only. An Alloc specifies which interval
// of the "transient_data" workspace buffer passed to inference functions, is to
// be used for the transient array at hand. The 'start' and 'end' values are
diff --git a/tensorflow/contrib/lite/toco/python/BUILD b/tensorflow/contrib/lite/toco/python/BUILD
index 8cac568bd7..a954f1d6ba 100644
--- a/tensorflow/contrib/lite/toco/python/BUILD
+++ b/tensorflow/contrib/lite/toco/python/BUILD
@@ -41,12 +41,6 @@ py_binary(
],
)
-py_binary(
- name = "toco_wrapper",
- srcs = ["toco_wrapper.py"],
- srcs_version = "PY2AND3",
-)
-
tf_py_test(
name = "toco_from_protos_test",
srcs = ["toco_from_protos_test.py"],
diff --git a/tensorflow/contrib/lite/toco/python/toco_wrapper.py b/tensorflow/contrib/lite/toco/python/toco_wrapper.py
deleted file mode 100644
index 6d6b500d7e..0000000000
--- a/tensorflow/contrib/lite/toco/python/toco_wrapper.py
+++ /dev/null
@@ -1,40 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Wrapper for runninmg toco binary embedded in pip site-package.
-
-NOTE: this mainly exists since PIP setup.py cannot install binaries to bin/.
-It can only install Python "console-scripts." This will work as a console
-script. See tools/pip_package/setup.py (search for CONSOLE_SCRIPTS).
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import sys
-
-
-def main():
- # Pip installs the binary in aux-bin off of main site-package install.
- # Just find it and exec, passing all arguments in the process.
- # TODO(aselle): it is unfortunate to use all of tensorflow to lookup binary.
- print("""TOCO from pip install is currently not working on command line.
-Please use the python TOCO API or use
-bazel run tensorflow/contrib/lite:toco -- <args> from a TensorFlow source dir.
-""")
- sys.exit(1)
- # TODO(aselle): Replace this when we find a way to run toco without
- # blowing up executable size.
- # binary = os.path.join(tf.__path__[0], 'aux-bin/toco')
- # os.execvp(binary, sys.argv)
diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h
index 90abfb94d8..098d2163e6 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.h
+++ b/tensorflow/contrib/lite/toco/tflite/export.h
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/tflite/operator.h"
+#include "tensorflow/contrib/lite/util.h"
namespace toco {
@@ -72,22 +73,10 @@ struct OperatorKey {
struct Hash {
size_t operator()(const OperatorKey& key) const {
- return CombineHashes({std::hash<size_t>()(static_cast<size_t>(key.type)),
- std::hash<std::string>()(key.custom_code),
- std::hash<int>()(key.version)});
- }
-
- private:
- // TODO(ycling): Refactoring and extract this function into a common
- // utility module.
- static size_t CombineHashes(std::initializer_list<size_t> hashes) {
- size_t result = 0;
- // Hash combiner used by TensorFlow core.
- for (size_t hash : hashes) {
- result = result ^ (hash + 0x9e3779b97f4a7800ULL + (result << 10) +
- (result >> 4));
- }
- return result;
+ return ::tflite::CombineHashes(
+ {std::hash<size_t>()(static_cast<size_t>(key.type)),
+ std::hash<std::string>()(key.custom_code),
+ std::hash<int>()(key.version)});
}
};
};
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 6922e5055a..8f0f2e24db 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -794,6 +794,27 @@ class TransposeConv
int GetVersion(const Operator& op) const override { return 1; }
};
+class SparseToDense
+ : public BuiltinOperator<SparseToDenseOperator,
+ ::tflite::SparseToDenseOptions,
+ ::tflite::BuiltinOptions_SparseToDenseOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateSparseToDenseOptions(*builder, op.validate_indices);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->validate_indices = options.validate_indices();
+ }
+
+ int GetVersion(const Operator& op) const override { return 1; }
+};
+
class TensorFlowUnsupported : public BaseOperator {
public:
using BaseOperator::BaseOperator;
@@ -978,6 +999,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
new ArgMax(::tflite::BuiltinOperator_ARG_MAX, OperatorType::kArgMax));
ops.emplace_back(new TransposeConv(::tflite::BuiltinOperator_TRANSPOSE_CONV,
OperatorType::kTransposeConv));
+ ops.emplace_back(new SparseToDense(::tflite::BuiltinOperator_SPARSE_TO_DENSE,
+ OperatorType::kSparseToDense));
// Custom Operators.
ops.emplace_back(
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index fe594c6da9..d63c99a5f9 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -420,6 +420,15 @@ TEST_F(OperatorTest, BuiltinTransposeConv) {
EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
}
+TEST_F(OperatorTest, BuiltinSparseToDense) {
+ SparseToDenseOperator op;
+ op.validate_indices = false;
+ std::unique_ptr<toco::SparseToDenseOperator> output_toco_op =
+ SerializeAndDeserialize(
+ GetOperator("SPARSE_TO_DENSE", OperatorType::kSparseToDense), op);
+ EXPECT_EQ(op.validate_indices, output_toco_op->validate_indices);
+}
+
TEST_F(OperatorTest, TensorFlowUnsupported) {
TensorFlowUnsupportedOperator op;
op.tensorflow_op = "MyCustomUnsupportedOp";
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 1e6314f2dc..fe7bed885d 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -393,6 +393,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(DynamicPartition)
HANDLE_OPERATORTYPENAME_CASE(DynamicStitch)
HANDLE_OPERATORTYPENAME_CASE(Select)
+ HANDLE_OPERATORTYPENAME_CASE(SparseToDense)
default:
LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE
diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD
index 824a164651..7fb7517600 100644
--- a/tensorflow/contrib/lite/tools/BUILD
+++ b/tensorflow/contrib/lite/tools/BUILD
@@ -7,6 +7,8 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
+common_copts = ["-Wall"]
+
py_binary(
name = "visualize",
srcs = ["visualize.py"],
@@ -30,7 +32,11 @@ tf_cc_binary(
tf_cc_binary(
name = "benchmark_model",
- srcs = ["benchmark_model.cc"],
+ srcs = [
+ "benchmark_main.cc",
+ "logging.h",
+ ],
+ copts = common_copts,
linkopts = select({
"//tensorflow:android": [
"-pie",
@@ -42,18 +48,67 @@ tf_cc_binary(
"//conditions:default": [],
}),
deps = [
+ ":benchmark_tflite_model_lib",
+ "//tensorflow/core:stats_calculator_portable",
+ ],
+)
+
+cc_library(
+ name = "command_line_flags",
+ srcs = ["command_line_flags.cc"],
+ hdrs = ["command_line_flags.h"],
+ copts = common_copts,
+ visibility = ["//visibility:private"],
+)
+
+cc_test(
+ name = "command_line_flags_test",
+ srcs = ["command_line_flags_test.cc"],
+ copts = common_copts,
+ visibility = ["//visibility:private"],
+ deps = [
+ ":command_line_flags",
+ "//tensorflow/contrib/lite/testing:util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "benchmark_tflite_model_lib",
+ srcs = [
+ "benchmark_tflite_model.cc",
+ "logging.h",
+ ],
+ hdrs = ["benchmark_tflite_model.h"],
+ copts = common_copts,
+ deps = [
+ ":benchmark_model_lib",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string_util",
"//tensorflow/contrib/lite/kernels:builtin_ops",
- ] + select({
- "//tensorflow:android": [
- "//tensorflow/core:android_tensorflow_lib",
- ],
- "//conditions:default": [
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
- ],
- }),
+ "//tensorflow/contrib/lite/profiling:profile_summarizer",
+ "//tensorflow/contrib/lite/profiling:profiler",
+ ],
+)
+
+cc_library(
+ name = "benchmark_model_lib",
+ srcs = [
+ "benchmark_model.cc",
+ "logging.h",
+ ],
+ hdrs = ["benchmark_model.h"],
+ copts = common_copts,
+ deps = [
+ ":command_line_flags",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:string_util",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/profiling:profile_summarizer",
+ "//tensorflow/contrib/lite/profiling:profiler",
+ "//tensorflow/contrib/lite/profiling:time",
+ "//tensorflow/core:stats_calculator_portable",
+ ],
)
cc_library(
diff --git a/tensorflow/contrib/lite/tools/benchmark_main.cc b/tensorflow/contrib/lite/tools/benchmark_main.cc
new file mode 100644
index 0000000000..1325385e32
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/benchmark_main.cc
@@ -0,0 +1,37 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/benchmark_tflite_model.h"
+#include "tensorflow/contrib/lite/tools/logging.h"
+
+namespace tflite {
+namespace benchmark {
+
+int Main(int argc, char** argv) {
+#ifdef TFLITE_CUSTOM_OPS_HEADER
+ TFLITE_LOG(INFO) << "STARTING with custom ops!";
+#else
+ TFLITE_LOG(INFO) << "STARTING!";
+#endif
+ BenchmarkTfLiteModel benchmark;
+ BenchmarkLoggingListener listener;
+ benchmark.AddListener(&listener);
+ benchmark.Run(argc, argv);
+ return 0;
+}
+} // namespace benchmark
+} // namespace tflite
+
+int main(int argc, char** argv) { return tflite::benchmark::Main(argc, argv); }
diff --git a/tensorflow/contrib/lite/tools/benchmark_model.cc b/tensorflow/contrib/lite/tools/benchmark_model.cc
index 869c531b3e..550994c662 100644
--- a/tensorflow/contrib/lite/tools/benchmark_model.cc
+++ b/tensorflow/contrib/lite/tools/benchmark_model.cc
@@ -13,463 +13,127 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <cstdarg>
-#include <cstdlib>
-#include <iostream>
-#include <memory>
-#include <string>
-#include <unordered_set>
-#include <vector>
-
-#include "tensorflow/contrib/lite/kernels/register.h"
-#include "tensorflow/contrib/lite/model.h"
-#include "tensorflow/contrib/lite/op_resolver.h"
-#include "tensorflow/contrib/lite/string_util.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/init_main.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/util/command_line_flags.h"
-
-#ifdef TFLITE_CUSTOM_OPS_HEADER
-void RegisterSelectedOps(::tflite::MutableOpResolver* resolver);
-#endif
-
-namespace tflite {
-
-using ::tensorflow::Env;
-using ::tensorflow::str_util::Split;
-using ::tensorflow::str_util::SplitAndParseAsFloats;
-using ::tensorflow::str_util::SplitAndParseAsInts;
-
-struct InputLayerInfo {
- string name;
- TfLiteType data_type;
- std::vector<int> shape;
- // Note that initialization_values is currently unused.
- std::vector<float> initialization_values;
-};
-
-template <typename T>
-void FillRandomValue(T* ptr, const std::vector<int>& sizes,
- const std::function<T()>& random_func) {
- int num_elements = 1;
- for (int dim : sizes) {
- num_elements *= dim;
- }
- for (int i = 0; i < num_elements; ++i) {
- *ptr++ = random_func();
- }
-}
-
-void FillRandomString(tflite::DynamicBuffer* buffer,
- const std::vector<int>& sizes,
- const std::function<string()>& random_func) {
- int num_elements = 1;
- for (int dim : sizes) {
- num_elements *= dim;
- }
- for (int i = 0; i < num_elements; ++i) {
- auto str = random_func();
- buffer->AddString(str.data(), str.length());
- }
-}
-
-TfLiteType TfLiteTypeFromString(const string& input_layer_type) {
- if (input_layer_type == "string")
- return kTfLiteString;
- else if (input_layer_type == "float")
- return kTfLiteFloat32;
- else if (input_layer_type == "uint8")
- return kTfLiteUInt8;
- else if (input_layer_type == "int32")
- return kTfLiteInt32;
- else if (input_layer_type == "int64")
- return kTfLiteInt64;
- else
- return kTfLiteNoType;
-}
-
-std::vector<int> ShapeFromTfLiteTensor(TfLiteTensor* t) {
- std::vector<int> result;
- result.reserve(t->dims->size);
- for (int i = 0; i < t->dims->size; ++i) {
- result.push_back(t->dims->data[i]);
- }
- CHECK(!result.empty()) << "Found no shapes in model";
- return result;
-}
-
-bool CreateInterpreter(const string& graph,
- std::unique_ptr<FlatBufferModel>* model,
- std::unique_ptr<Interpreter>* interpreter) {
- *model = tflite::FlatBufferModel::BuildFromFile(graph.c_str());
- if (!model) {
- std::cerr << "Failed to load model " << graph << std::endl;
- return false;
- }
-
-#ifdef TFLITE_CUSTOM_OPS_HEADER
- tflite::MutableOpResolver resolver;
- RegisterSelectedOps(&resolver);
-#else
- tflite::ops::builtin::BuiltinOpResolver resolver;
-#endif
-
- tflite::InterpreterBuilder(*(model->get()), resolver)(interpreter);
- if (!(*interpreter)) {
- std::cerr << "Failed to construct interpreter" << std::endl;
- return false;
- }
-
- return true;
-}
-
-bool PrepareInterpreter(const std::vector<InputLayerInfo> inputs,
- int num_threads, bool use_nnapi,
- Interpreter* interpreter) {
- if (num_threads != -1) {
- interpreter->SetNumThreads(num_threads);
- }
-
- interpreter->UseNNAPI(use_nnapi);
-
- // Check that all names and types match
- for (const InputLayerInfo& input : inputs) {
- for (int i : interpreter->inputs()) {
- TfLiteTensor* t = interpreter->tensor(i);
- CHECK_EQ(t->name, input.name)
- << "Tensor # " << i << " is named " << t->name
- << " but flags call it " << input.name;
- CHECK_EQ(t->type, input.data_type)
- << "Could not match the type of input tensor " << t->name;
- }
- }
-
- // Resize all non-string tensors.
- for (const InputLayerInfo& input : inputs) {
- for (int i : interpreter->inputs()) {
- TfLiteTensor* t = interpreter->tensor(i);
- if (t->type != kTfLiteString) {
- interpreter->ResizeInputTensor(i, input.shape);
- }
- }
- }
-
- if (interpreter->AllocateTensors() != kTfLiteOk) {
- std::cerr << "Failed to allocate tensors!" << std::endl;
- return false;
- }
-
- // Set the values of the input tensors.
- for (int i : interpreter->inputs()) {
- TfLiteTensor* t = interpreter->tensor(i);
- std::vector<int> sizes = ShapeFromTfLiteTensor(t);
-
- // TODO(ahentz): below we ignore the O-th dimension (number of batches).
- if (t->type == kTfLiteFloat32) {
- FillRandomValue<float>(
- interpreter->typed_tensor<float>(i),
- std::vector<int>(sizes.begin() + 1, sizes.end()),
- []() { return static_cast<float>(rand()) / RAND_MAX - 0.5f; });
- } else if (t->type == kTfLiteUInt8) {
- FillRandomValue<uint8_t>(
- interpreter->typed_tensor<uint8_t>(i),
- std::vector<int>(sizes.begin() + 1, sizes.end()),
- []() { return static_cast<uint8_t>(rand()) % 255; });
- } else if (t->type == kTfLiteString) {
- tflite::DynamicBuffer buffer;
- FillRandomString(&buffer, sizes, []() {
- return "we're have some friends over saturday to hang out in the yard";
- });
- buffer.WriteToTensor(interpreter->tensor(i));
- } else {
- std::cerr << "Don't know how to populate tensor " << t->name
- << " of type " << t->type << std::endl;
- return false;
- }
- }
- return true;
-}
-
-bool PopulateInputLayerInfo(const string& names_string,
- const string& shapes_string,
- const string& types_string,
- const string& values_string,
- std::vector<InputLayerInfo>* info) {
- std::vector<string> names = Split(names_string, ',');
- std::vector<string> shapes = Split(shapes_string, ':');
- std::vector<string> types = Split(types_string, ',');
- std::vector<string> values = Split(values_string, ':');
-
- if (names.size() != shapes.size()) {
- LOG(ERROR) << "The number of items in"
- << " --input_layer_shape (" << shapes_string << ", with "
- << shapes.size() << " items)"
- << " must match the number of items in"
- << " --input_layer (" << names_string << ", with "
- << names.size() << " items)."
- << " For example --input_layer=input1,input2"
- << " --input_layer_shape=1,224,224,4:1,20";
- return false;
- }
- if (names.size() != types.size()) {
- LOG(ERROR) << "The number of items in"
- << " --input_layer_type (" << types_string << ", with "
- << types.size() << " items)"
- << " must match the number of items in"
- << " --input_layer (" << names_string << ", with "
- << names.size() << " items)."
- << " For example --input_layer=input1,input2"
- << " --input_layer_type=float,int";
- return false;
- }
-
- for (int i = 0; i < names.size(); ++i) {
- info->push_back(InputLayerInfo());
- InputLayerInfo& input = info->back();
+#include "tensorflow/contrib/lite/tools/benchmark_model.h"
- input.name = names[i];
+#include <time.h>
- input.data_type = TfLiteTypeFromString(types[i]);
- CHECK(input.data_type != kTfLiteNoType)
- << types[i] << " was an invalid type";
-
- CHECK(SplitAndParseAsInts(shapes[i], ',', &input.shape))
- << "Incorrect size string specified: " << shapes[i];
- for (int dim : input.shape) {
- if (dim == -1) {
- LOG(ERROR) << "Any unknown sizes in the shapes (-1's) must be replaced"
- << " with the size you want to benchmark with.";
- return false;
- }
- }
-
- if (i < values.size()) {
- CHECK(SplitAndParseAsFloats(values[i], ',', &input.initialization_values))
- << "Incorrect initialization values string specified: " << values[i];
- }
- }
-
- return true;
-}
-
-bool RunBenchmark(Interpreter* interpreter, int64_t* inference_time_us) {
- const int64_t start_time = Env::Default()->NowMicros();
-
- if (interpreter->Invoke() != kTfLiteOk) {
- std::cerr << "Failed to invoke!";
- return false;
- }
-
- const int64_t end_time = Env::Default()->NowMicros();
- *inference_time_us = end_time - start_time;
- return true;
-}
-
-class Latencies {
- public:
- void AddMeasurement(int64_t time_us) {
- max_ = std::max(time_us, max_);
- min_ = std::min(time_us, min_);
- ++count_;
- sum_ += time_us;
- squared_sum_ += static_cast<double>(time_us) * time_us;
- }
-
- double avg() const {
- if (count_ == 0) return std::numeric_limits<int64_t>::quiet_NaN();
- return static_cast<double>(sum_) / count_;
- }
+#include <iostream>
+#include <sstream>
- int64_t std_deviation() const {
- if (count_ == 0 || min_ == max_) return 0;
- return sqrt(squared_sum_ / count_ - avg() * avg());
- }
+#include "tensorflow/contrib/lite/profiling/time.h"
+#include "tensorflow/contrib/lite/tools/logging.h"
- void OutputToStream(std::ostream* stream) const {
- *stream << "count=" << count_;
- if (count_ == 0) return;
- *stream << " min=" << min_ << " max=" << max_;
- *stream << " avg=" << avg() << " std=" << std_deviation();
+namespace {
+void SleepForSeconds(double sleep_seconds) {
+ if (sleep_seconds <= 0.0) {
+ return;
}
-
- private:
- int64_t count_ = 0;
- int64_t min_ = std::numeric_limits<int64_t>::max();
- int64_t max_ = std::numeric_limits<int64_t>::min();
- int64_t sum_ = 0;
- double squared_sum_ = 0;
-};
-
-bool TimeMultipleRuns(Interpreter* interpreter, double sleep_seconds,
- int num_runs, int64* total_time_us) {
// Convert the run_delay string into a timespec.
timespec req;
req.tv_sec = static_cast<time_t>(sleep_seconds);
req.tv_nsec = (sleep_seconds - req.tv_sec) * 1000000000;
-
- *total_time_us = 0;
-
- std::cout << "Running benchmark for " << num_runs
- << " iterations: " << std::endl;
-
- Latencies latencies;
- for (int i = 0; i < num_runs; ++i) {
- int64_t time_us;
- bool run_status = RunBenchmark(interpreter, &time_us);
- latencies.AddMeasurement(time_us);
- *total_time_us += time_us;
- if (!run_status) {
- std::cout << "Failed on run " << i << std::endl;
- return false;
- }
-
- // If requested, sleep between runs for an arbitrary amount of time.
- // This can be helpful to determine the effect of mobile processor
- // scaling and thermal throttling.
- if (sleep_seconds > 0.0) {
+ // If requested, sleep between runs for an arbitrary amount of time.
+ // This can be helpful to determine the effect of mobile processor
+ // scaling and thermal throttling.
#ifdef PLATFORM_WINDOWS
- Sleep(sleep_seconds * 1000);
+ Sleep(sleep_seconds * 1000);
#else
- nanosleep(&req, nullptr);
+ nanosleep(&req, nullptr);
#endif
- }
- }
- latencies.OutputToStream(&std::cout);
- std::cout << std::endl;
-
- return true;
}
-int Main(int argc, char** argv) {
- using tensorflow::Flag;
- using tensorflow::Flags;
+} // namespace
- string graph; // e.g.: /data/local/tmp/tfl_inception-v1_model.fb
- string input_layer_string; // e.g.: input
- string input_layer_shape_string; // e.g.: 1,224,224,3
- string input_layer_type_string; // e.g.: float
- string input_layer_values_string;
- string output_layer_string; // e.g.: output
- int num_runs = 50;
- string run_delay = "-1.0";
- int num_threads = 1;
- string benchmark_name = "";
- string output_prefix = "";
- int warmup_runs = 1;
- bool use_nnapi = false;
+namespace tflite {
+namespace benchmark {
+using tensorflow::Stat;
+
+void BenchmarkLoggingListener::OnBenchmarkEnd(const BenchmarkResults &results) {
+ auto inference_us = results.inference_time_us();
+ auto init_us = results.startup_latency_us();
+ auto warmup_us = results.warmup_time_us();
+ TFLITE_LOG(INFO) << "Average inference timings in us: "
+ << "Warmup: " << warmup_us.avg() << ", "
+ << "Init: " << init_us << ", "
+ << "no stats: " << inference_us.avg();
+}
- std::vector<Flag> flag_list = {
- Flag("graph", &graph, "graph file name"),
- // All the following flags are optional, but can be used in order
- // to benchmark different input shapes.
- Flag("input_layer", &input_layer_string, "input layer names"),
- Flag("input_layer_shape", &input_layer_shape_string, "input layer shape"),
- Flag("input_layer_type", &input_layer_type_string, "input layer type"),
- Flag("input_layer_values", &input_layer_values_string,
- "values to initialize the inputs with"),
- Flag("output_layer", &output_layer_string, "output layer name"),
- Flag("num_runs", &num_runs, "number of runs"),
- Flag("run_delay", &run_delay, "delay between runs in seconds"),
- Flag("num_threads", &num_threads, "number of threads"),
- Flag("benchmark_name", &benchmark_name, "benchmark name"),
- Flag("output_prefix", &output_prefix, "benchmark output prefix"),
- Flag("warmup_runs", &warmup_runs, "how many runs to initialize model"),
- Flag("use_nnapi", &use_nnapi, "use nnapi api"),
+std::vector<Flag> BenchmarkModel::GetFlags() {
+ return {
+ Flag("num_runs", &params_.num_runs, "number of runs"),
+ Flag("run_delay", &params_.run_delay, "delay between runs in seconds"),
+ Flag("num_threads", &params_.num_threads, "number of threads"),
+ Flag("benchmark_name", &params_.benchmark_name, "benchmark name"),
+ Flag("output_prefix", &params_.output_prefix, "benchmark output prefix"),
+ Flag("warmup_runs", &params_.warmup_runs,
+ "how many runs to initialize model"),
};
- string usage = Flags::Usage(argv[0], flag_list);
- const bool parse_result = Flags::Parse(&argc, argv, flag_list);
- tensorflow::port::InitMain(argv[0], &argc, &argv);
+}
- if (!parse_result) {
- std::cerr << usage << std::endl;
- return -1;
- }
+void BenchmarkModel::LogFlags() {
+ TFLITE_LOG(INFO) << "Num runs: [" << params_.num_runs << "]";
+ TFLITE_LOG(INFO) << "Inter-run delay (seconds): [" << params_.run_delay
+ << "]";
+ TFLITE_LOG(INFO) << "Num threads: [" << params_.num_threads << "]";
+ TFLITE_LOG(INFO) << "Benchmark name: [" << params_.benchmark_name << "]";
+ TFLITE_LOG(INFO) << "Output prefix: [" << params_.output_prefix << "]";
+ TFLITE_LOG(INFO) << "Warmup runs: [" << params_.warmup_runs << "]";
+}
- std::cout << "Graph: [" << graph << "]" << std::endl;
- if (!input_layer_string.empty()) {
- std::cout << "Input layers: [" << input_layer_string << "]" << std::endl;
- std::cout << "Input shapes: [" << input_layer_shape_string << "]"
- << std::endl;
- std::cout << "Input types: [" << input_layer_type_string << "]"
- << std::endl;
- }
- if (!output_layer_string.empty()) {
- std::cout << "Output layers: [" << output_layer_string << "]" << std::endl;
- }
- std::cout << "Num runs: [" << num_runs << "]" << std::endl;
- std::cout << "Inter-run delay (seconds): [" << run_delay << "]" << std::endl;
- std::cout << "Num threads: [" << num_threads << "]" << std::endl;
- if (!benchmark_name.empty()) {
- std::cout << "Benchmark name: [" << benchmark_name << "]" << std::endl;
- std::cout << "Output prefix: [" << output_prefix << "]" << std::endl;
- }
- std::cout << "Warmup runs: [" << warmup_runs << "]" << std::endl;
- std::cout << "Use nnapi : [" << use_nnapi << "]" << std::endl;
+Stat<int64_t> BenchmarkModel::Run(int num_times, RunType run_type) {
+ Stat<int64_t> run_stats;
+ TFLITE_LOG(INFO) << "Running benchmark for " << num_times << " iterations ";
+ for (int run = 0; run < num_times; run++) {
+ listeners_.OnSingleRunStart(run_type);
+ int64_t start_us = profiling::time::NowMicros();
+ RunImpl();
+ int64_t end_us = profiling::time::NowMicros();
+ listeners_.OnSingleRunEnd();
- if (graph.empty()) {
- std::cout
- << "Please specify the name of your TF Lite input file with --graph"
- << std::endl;
- return -1;
+ run_stats.UpdateStat(end_us - start_us);
+ SleepForSeconds(params_.run_delay);
}
- std::vector<InputLayerInfo> inputs;
- if (!PopulateInputLayerInfo(input_layer_string, input_layer_shape_string,
- input_layer_type_string,
- input_layer_values_string, &inputs)) {
- return -1;
- }
+ std::stringstream stream;
+ run_stats.OutputToStream(&stream);
+ TFLITE_LOG(INFO) << stream.str() << std::endl;
- int64 initialization_start_us = Env::Default()->NowMicros();
+ return run_stats;
+}
- std::unique_ptr<tflite::FlatBufferModel> model;
- std::unique_ptr<tflite::Interpreter> interpreter;
- if (!CreateInterpreter(graph, &model, &interpreter)) {
- return -1;
+void BenchmarkModel::Run(int argc, char **argv) {
+ if (!ParseFlags(argc, argv)) {
+ return;
}
- if (!PrepareInterpreter(inputs, num_threads, use_nnapi, interpreter.get())) {
- return -1;
- }
-
- int64 initialization_end_us = Env::Default()->NowMicros();
- const double initialization_time_s =
- (initialization_end_us - initialization_start_us) / 1000000.0f;
- std::cout << "Initialized session in " << initialization_time_s << "s"
- << std::endl;
+ LogFlags();
- const double sleep_seconds = std::strtod(run_delay.c_str(), nullptr);
+ listeners_.OnBenchmarkStart(params_);
+ int64_t initialization_start_us = profiling::time::NowMicros();
+ Init();
+ int64_t initialization_end_us = profiling::time::NowMicros();
+ int64_t startup_latency_us = initialization_end_us - initialization_start_us;
+ TFLITE_LOG(INFO) << "Initialized session in " << startup_latency_us / 1e3
+ << "ms";
- // If requested, run through the graph first to preinitialize everything
- // before the benchmarking runs.
- int64 warmup_time_us = 0;
- if (warmup_runs > 0) {
- if (!TimeMultipleRuns(interpreter.get(), sleep_seconds, warmup_runs,
- &warmup_time_us)) {
- std::cerr << "Warmup failed" << std::endl;
- return -1;
- }
- }
+ uint64_t input_bytes = ComputeInputBytes();
+ Stat<int64_t> warmup_time_us = Run(params_.warmup_runs, WARMUP);
+ Stat<int64_t> inference_time_us = Run(params_.num_runs, REGULAR);
+ listeners_.OnBenchmarkEnd(
+ {startup_latency_us, input_bytes, warmup_time_us, inference_time_us});
+}
- // Capture overall inference time without stat logging overhead. This is the
- // timing data that can be compared to other libaries.
- int64 no_stat_time_us = 0;
- if (!TimeMultipleRuns(interpreter.get(), sleep_seconds, num_runs,
- &no_stat_time_us)) {
- std::cerr << "Timing failed." << std::endl;
- return -1;
+bool BenchmarkModel::ParseFlags(int argc, char **argv) {
+ auto flag_list = GetFlags();
+ const bool parse_result =
+ Flags::Parse(&argc, const_cast<const char **>(argv), flag_list);
+ if (!parse_result) {
+ std::string usage = Flags::Usage(argv[0], flag_list);
+ TFLITE_LOG(ERROR) << usage;
+ return false;
}
-
- std::cout << "Average inference timings in us: " << no_stat_time_us / num_runs
- << " , Warmup: "
- << (warmup_runs > 0 ? warmup_time_us / warmup_runs : 0) << ", "
- << std::endl;
-
- return 0;
+ return ValidateFlags();
}
+} // namespace benchmark
} // namespace tflite
-
-int main(int argc, char** argv) { return ::tflite::Main(argc, argv); }
diff --git a/tensorflow/contrib/lite/tools/benchmark_model.h b/tensorflow/contrib/lite/tools/benchmark_model.h
new file mode 100644
index 0000000000..ef8d6a7d1e
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/benchmark_model.h
@@ -0,0 +1,161 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_MODEL_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_MODEL_H_
+
+#include <cmath>
+#include <limits>
+#include <ostream>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/contrib/lite/tools//command_line_flags.h"
+#include "tensorflow/core/util/stats_calculator.h"
+
+namespace tflite {
+namespace benchmark {
+
+enum RunType {
+ WARMUP,
+ REGULAR,
+};
+
+class BenchmarkResults {
+ public:
+ BenchmarkResults(int64_t startup_latency_us, uint64_t input_bytes,
+ tensorflow::Stat<int64_t> warmup_time_us,
+ tensorflow::Stat<int64_t> inference_time_us)
+ : startup_latency_us_(startup_latency_us),
+ input_bytes_(input_bytes),
+ warmup_time_us_(warmup_time_us),
+ inference_time_us_(inference_time_us) {}
+
+ tensorflow::Stat<int64_t> inference_time_us() const {
+ return inference_time_us_;
+ }
+ tensorflow::Stat<int64_t> warmup_time_us() const { return warmup_time_us_; }
+ int64_t startup_latency_us() const { return startup_latency_us_; }
+ uint64_t input_bytes() const { return input_bytes_; }
+ double throughput_MB_per_second() const {
+ double bytes_per_sec = (input_bytes_ * inference_time_us_.count() * 1e6) /
+ inference_time_us_.sum();
+ return bytes_per_sec / (1024.0 * 1024.0);
+ }
+
+ private:
+ int64_t startup_latency_us_;
+ uint64_t input_bytes_;
+ tensorflow::Stat<int64_t> warmup_time_us_;
+ tensorflow::Stat<int64_t> inference_time_us_;
+};
+
+struct BenchmarkParams {
+ BenchmarkParams()
+ : num_runs(50), warmup_runs(1), run_delay(-1.0), num_threads(1) {}
+ int num_runs;
+ int warmup_runs;
+ float run_delay;
+ int num_threads;
+ std::string benchmark_name;
+ std::string output_prefix;
+};
+
+class BenchmarkListener {
+ public:
+ virtual void OnBenchmarkStart(const BenchmarkParams& params) {}
+ virtual void OnSingleRunStart(RunType runType) {}
+ virtual void OnSingleRunEnd() {}
+ virtual void OnBenchmarkEnd(const BenchmarkResults& results) {}
+ virtual ~BenchmarkListener() {}
+};
+
+// A listener that forwards its method calls to a collection of listeners.
+class BenchmarkListeners : public BenchmarkListener {
+ public:
+ // Added a listener to the listener collection.
+ // |listener| is not owned by the instance of |BenchmarkListeners|.
+ // |listener| should not be null and should outlast the instance of
+ // |BenchmarkListeners|.
+ void AddListener(BenchmarkListener* listener) {
+ listeners_.push_back(listener);
+ }
+
+ void OnBenchmarkStart(const BenchmarkParams& params) override {
+ for (auto listener : listeners_) {
+ listener->OnBenchmarkStart(params);
+ }
+ }
+
+ void OnSingleRunStart(RunType runType) override {
+ for (auto listener : listeners_) {
+ listener->OnSingleRunStart(runType);
+ }
+ }
+
+ void OnSingleRunEnd() override {
+ for (auto listener : listeners_) {
+ listener->OnSingleRunEnd();
+ }
+ }
+
+ void OnBenchmarkEnd(const BenchmarkResults& results) override {
+ for (auto listener : listeners_) {
+ listener->OnBenchmarkEnd(results);
+ }
+ }
+
+ ~BenchmarkListeners() {}
+
+ private:
+ // Use vector so listeners are invoked in the order they are added.
+ std::vector<BenchmarkListener*> listeners_;
+};
+
+// Benchmark listener that just logs the results of benchmark run.
+class BenchmarkLoggingListener : public BenchmarkListener {
+ void OnBenchmarkEnd(const BenchmarkResults& results) override;
+};
+
+// Benchmarks a model.
+//
+// Subclasses need to implement initialization and running of the model.
+// The results can be collected by adding BenchmarkListener(s).
+class BenchmarkModel {
+ public:
+ virtual ~BenchmarkModel() {}
+ bool ParseFlags(int argc, char** argv);
+ virtual void Init() = 0;
+ void Run(int argc, char** argv);
+ void AddListener(BenchmarkListener* listener) {
+ listeners_.AddListener(listener);
+ }
+
+ protected:
+ virtual void LogFlags();
+ virtual bool ValidateFlags() { return true; }
+ virtual std::vector<Flag> GetFlags();
+ virtual uint64_t ComputeInputBytes() = 0;
+ virtual tensorflow::Stat<int64_t> Run(int num_times, RunType run_type);
+ virtual void RunImpl() = 0;
+ BenchmarkParams params_;
+ BenchmarkListeners listeners_;
+};
+
+} // namespace benchmark
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_MODEL_H_
diff --git a/tensorflow/contrib/lite/tools/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark_tflite_model.cc
new file mode 100644
index 0000000000..be8f46f599
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/benchmark_tflite_model.cc
@@ -0,0 +1,352 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/benchmark_tflite_model.h"
+
+#include <cstdarg>
+#include <cstdlib>
+#include <iostream>
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/op_resolver.h"
+#include "tensorflow/contrib/lite/string_util.h"
+#include "tensorflow/contrib/lite/tools/logging.h"
+
+#ifdef TFLITE_CUSTOM_OPS_HEADER
+void RegisterSelectedOps(::tflite::MutableOpResolver* resolver);
+#endif
+
+namespace tflite {
+namespace benchmark {
+
+void ProfilingListener::SetInterpreter(tflite::Interpreter* interpreter) {
+ TFLITE_BENCHMARK_CHECK(interpreter);
+ interpreter_ = interpreter;
+ interpreter_->SetProfiler(&profiler_);
+}
+
+void ProfilingListener::OnSingleRunStart(RunType run_type) {
+ if (run_type == REGULAR) {
+ profiler_.Reset();
+ profiler_.StartProfiling();
+ }
+}
+
+void ProfilingListener::OnBenchmarkEnd(const BenchmarkResults& results) {
+ if (has_profiles_) {
+ TFLITE_LOG(INFO) << summarizer_.GetOutputString();
+ }
+}
+
+void ProfilingListener::OnSingleRunEnd() {
+ profiler_.StopProfiling();
+ auto profile_events = profiler_.GetProfileEvents();
+ has_profiles_ = !profile_events.empty();
+ summarizer_.ProcessProfiles(profile_events, *interpreter_);
+}
+
+namespace {
+
+std::vector<std::string> Split(const std::string& str, const char delim) {
+ std::istringstream input(str);
+ std::vector<std::string> results;
+ std::string item;
+ while (std::getline(input, item, delim)) {
+ results.push_back(item);
+ }
+ return results;
+}
+
+template <typename T>
+bool SplitAndParse(const std::string& str, char delim, std::vector<T>* values) {
+ std::istringstream input(str);
+ bool first = true;
+ while (!input.eof()) {
+ if (!first) {
+ char c;
+ input >> c;
+ if (c != delim) {
+ return false;
+ }
+ } else {
+ first = false;
+ }
+ T val;
+ input >> val;
+ if (!input.eof() && !input.good()) {
+ return false;
+ }
+ values->push_back(val);
+ }
+ return true;
+}
+
+template <typename T>
+void FillRandomValue(T* ptr, const std::vector<int>& sizes,
+ const std::function<T()>& random_func) {
+ int num_elements = 1;
+ for (int dim : sizes) {
+ num_elements *= dim;
+ }
+ for (int i = 0; i < num_elements; ++i) {
+ *ptr++ = random_func();
+ }
+}
+
+void FillRandomString(tflite::DynamicBuffer* buffer,
+ const std::vector<int>& sizes,
+ const std::function<string()>& random_func) {
+ int num_elements = 1;
+ for (int dim : sizes) {
+ num_elements *= dim;
+ }
+ for (int i = 0; i < num_elements; ++i) {
+ auto str = random_func();
+ buffer->AddString(str.data(), str.length());
+ }
+}
+
+TfLiteType TfLiteTypeFromString(const string& input_layer_type) {
+ if (input_layer_type == "string")
+ return kTfLiteString;
+ else if (input_layer_type == "float")
+ return kTfLiteFloat32;
+ else if (input_layer_type == "uint8")
+ return kTfLiteUInt8;
+ else if (input_layer_type == "int32")
+ return kTfLiteInt32;
+ else if (input_layer_type == "int64")
+ return kTfLiteInt64;
+ else
+ return kTfLiteNoType;
+}
+
+bool PopulateInputLayerInfo(
+ const string& names_string, const string& shapes_string,
+ const string& types_string, const string& values_string,
+ std::vector<BenchmarkTfLiteModel::InputLayerInfo>* info) {
+ std::vector<std::string> names = Split(names_string, ',');
+ std::vector<std::string> shapes = Split(shapes_string, ':');
+ std::vector<std::string> types = Split(types_string, ',');
+ std::vector<std::string> values = Split(values_string, ':');
+
+ if (names.size() != shapes.size()) {
+ TFLITE_LOG(ERROR) << "The number of items in"
+ << " --input_layer_shape (" << shapes_string << ", with "
+ << shapes.size() << " items)"
+ << " must match the number of items in"
+ << " --input_layer (" << names_string << ", with "
+ << names.size() << " items)."
+ << " For example --input_layer=input1,input2"
+ << " --input_layer_shape=1,224,224,4:1,20";
+ return false;
+ }
+ if (names.size() != types.size()) {
+ TFLITE_LOG(ERROR) << "The number of items in"
+ << " --input_layer_type (" << types_string << ", with "
+ << types.size() << " items)"
+ << " must match the number of items in"
+ << " --input_layer (" << names_string << ", with "
+ << names.size() << " items)."
+ << " For example --input_layer=input1,input2"
+ << " --input_layer_type=float,int";
+ return false;
+ }
+
+ for (int i = 0; i < names.size(); ++i) {
+ info->push_back(BenchmarkTfLiteModel::InputLayerInfo());
+ BenchmarkTfLiteModel::InputLayerInfo& input = info->back();
+
+ input.name = names[i];
+
+ input.data_type = TfLiteTypeFromString(types[i]);
+ TFLITE_BENCHMARK_CHECK(input.data_type != kTfLiteNoType)
+ << types[i] << " was an invalid type";
+
+ TFLITE_BENCHMARK_CHECK(SplitAndParse(shapes[i], ',', &input.shape))
+ << "Incorrect size string specified: " << shapes[i];
+ for (int dim : input.shape) {
+ if (dim == -1) {
+ TFLITE_LOG(ERROR)
+ << "Any unknown sizes in the shapes (-1's) must be replaced"
+ << " with the size you want to benchmark with.";
+ return false;
+ }
+ }
+
+ if (i < values.size()) {
+ TFLITE_BENCHMARK_CHECK(
+ SplitAndParse(values[i], ',', &input.initialization_values))
+ << "Incorrect initialization values string specified: " << values[i];
+ }
+ }
+
+ return true;
+}
+
+} // namespace
+
+std::vector<Flag> BenchmarkTfLiteModel::GetFlags() {
+ std::vector<Flag> flags = BenchmarkTfLiteModel::BenchmarkModel::GetFlags();
+ std::vector<Flag> specific_flags = {
+ Flag("graph", &graph, "graph file name"),
+ Flag("input_layer", &input_layer_string, "input layer names"),
+ Flag("input_layer_shape", &input_layer_shape_string, "input layer shape"),
+ Flag("input_layer_type", &input_layer_type_string, "input layer type"),
+ Flag("input_layer_values", &input_layer_values_string,
+ "values to initialize the inputs with"),
+ Flag("output_layer", &output_layer_string, "output layer name"),
+ Flag("use_nnapi", &use_nnapi, "use nnapi api")};
+
+ flags.insert(flags.end(), specific_flags.begin(), specific_flags.end());
+ return flags;
+}
+
+void BenchmarkTfLiteModel::LogFlags() {
+ BenchmarkModel::LogFlags();
+ TFLITE_LOG(INFO) << "Graph: [" << graph << "]";
+ TFLITE_LOG(INFO) << "Input layers: [" << input_layer_string << "]";
+ TFLITE_LOG(INFO) << "Input shapes: [" << input_layer_shape_string << "]";
+ TFLITE_LOG(INFO) << "Input types: [" << input_layer_type_string << "]";
+ TFLITE_LOG(INFO) << "Output layers: [" << output_layer_string << "]";
+ TFLITE_LOG(INFO) << "Use nnapi : [" << use_nnapi << "]";
+}
+
+bool BenchmarkTfLiteModel::ValidateFlags() {
+ if (graph.empty()) {
+ TFLITE_LOG(ERROR)
+ << "Please specify the name of your TF Lite input file with --graph";
+ return false;
+ }
+ return PopulateInputLayerInfo(input_layer_string, input_layer_shape_string,
+ input_layer_type_string,
+ input_layer_values_string, &inputs);
+}
+
+uint64_t BenchmarkTfLiteModel::ComputeInputBytes() {
+ TFLITE_BENCHMARK_CHECK(interpreter);
+ uint64_t total_input_bytes = 0;
+ for (int input : interpreter->inputs()) {
+ auto* t = interpreter->tensor(input);
+ total_input_bytes += t->bytes;
+ }
+ return total_input_bytes;
+}
+
+void BenchmarkTfLiteModel::Init() {
+ model = tflite::FlatBufferModel::BuildFromFile(graph.c_str());
+ if (!model) {
+ TFLITE_LOG(FATAL) << "Failed to mmap model " << graph;
+ }
+ TFLITE_LOG(INFO) << "Loaded model " << graph;
+ model->error_reporter();
+ TFLITE_LOG(INFO) << "resolved reporter";
+
+#ifdef TFLITE_CUSTOM_OPS_HEADER
+ tflite::MutableOpResolver resolver;
+ RegisterSelectedOps(&resolver);
+#else
+ tflite::ops::builtin::BuiltinOpResolver resolver;
+#endif
+
+ tflite::InterpreterBuilder(*model, resolver)(&interpreter);
+ if (!interpreter) {
+ TFLITE_LOG(FATAL) << "Failed to construct interpreter";
+ }
+ profiling_listener_.SetInterpreter(interpreter.get());
+
+ if (params_.num_threads != -1) {
+ interpreter->SetNumThreads(params_.num_threads);
+ }
+
+ interpreter->UseNNAPI(use_nnapi);
+ auto interpreter_inputs = interpreter->inputs();
+
+ if (!inputs.empty()) {
+ TFLITE_BENCHMARK_CHECK_EQ(inputs.size(), interpreter_inputs.size())
+ << "Inputs mismatch: Model inputs #:" << interpreter_inputs.size()
+ << " expected: " << inputs.size();
+ }
+
+ // TFLITE_BENCHMARK_CHECK that all names and types match
+ for (int j = 0; j < inputs.size(); ++j) {
+ const InputLayerInfo& input = inputs[j];
+ int i = interpreter_inputs[j];
+ TfLiteTensor* t = interpreter->tensor(i);
+ TFLITE_BENCHMARK_CHECK_EQ(t->name, input.name)
+ << "Tensor # " << i << " is named " << t->name << " but flags call it "
+ << input.name;
+ TFLITE_BENCHMARK_CHECK_EQ(t->type, input.data_type)
+ << "Could not match the type of input tensor " << t->name;
+ }
+
+ // Resize all non-string tensors.
+ for (int j = 0; j < inputs.size(); ++j) {
+ const InputLayerInfo& input = inputs[j];
+ int i = interpreter_inputs[j];
+ TfLiteTensor* t = interpreter->tensor(i);
+ if (t->type != kTfLiteString) {
+ interpreter->ResizeInputTensor(i, input.shape);
+ }
+ }
+
+ if (interpreter->AllocateTensors() != kTfLiteOk) {
+ TFLITE_LOG(FATAL) << "Failed to allocate tensors!";
+ }
+
+ // Set the values of the input tensors.
+ for (int j = 0; j < inputs.size(); ++j) {
+ const InputLayerInfo& input = inputs[j];
+ int i = interpreter_inputs[j];
+ TfLiteTensor* t = interpreter->tensor(i);
+ std::vector<int> sizes = input.shape;
+
+ // TODO(ahentz): below we ignore the O-th dimension (number of batches).
+ if (t->type == kTfLiteFloat32) {
+ FillRandomValue<float>(
+ interpreter->typed_tensor<float>(i),
+ std::vector<int>(sizes.begin() + 1, sizes.end()),
+ []() { return static_cast<float>(rand()) / RAND_MAX - 0.5f; });
+ } else if (t->type == kTfLiteUInt8) {
+ FillRandomValue<uint8_t>(
+ interpreter->typed_tensor<uint8_t>(i),
+ std::vector<int>(sizes.begin() + 1, sizes.end()),
+ []() { return static_cast<uint8_t>(rand()) % 255; });
+ } else if (t->type == kTfLiteString) {
+ tflite::DynamicBuffer buffer;
+ FillRandomString(&buffer, sizes, []() {
+ return "we're have some friends over saturday to hang out in the yard";
+ });
+ buffer.WriteToTensor(interpreter->tensor(i));
+ } else {
+ TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name
+ << " of type " << t->type;
+ }
+ }
+}
+
+void BenchmarkTfLiteModel::RunImpl() {
+ if (interpreter->Invoke() != kTfLiteOk) {
+ TFLITE_LOG(FATAL) << "Failed to invoke!";
+ }
+}
+
+} // namespace benchmark
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/tools/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark_tflite_model.h
new file mode 100644
index 0000000000..e6d03d5211
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/benchmark_tflite_model.h
@@ -0,0 +1,90 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_TFLITE_MODEL_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_TFLITE_MODEL_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/profiling/profile_summarizer.h"
+#include "tensorflow/contrib/lite/tools/benchmark_model.h"
+
+namespace tflite {
+namespace benchmark {
+
+// Dumps profiling events if profiling is enabled
+class ProfilingListener : public BenchmarkListener {
+ public:
+ explicit ProfilingListener() : interpreter_(nullptr), has_profiles_(false) {}
+
+ void SetInterpreter(Interpreter* interpreter);
+
+ void OnSingleRunStart(RunType run_type) override;
+
+ void OnSingleRunEnd() override;
+
+ void OnBenchmarkEnd(const BenchmarkResults& results) override;
+
+ private:
+ Interpreter* interpreter_;
+ profiling::Profiler profiler_;
+ profiling::ProfileSummarizer summarizer_;
+ bool has_profiles_;
+};
+
+// Benchmarks a TFLite model by running tflite interpreter.
+class BenchmarkTfLiteModel : public BenchmarkModel {
+ public:
+ BenchmarkTfLiteModel() : use_nnapi(false) {
+ AddListener(&profiling_listener_);
+ }
+
+ std::vector<Flag> GetFlags() override;
+ void LogFlags() override;
+ bool ValidateFlags() override;
+ uint64_t ComputeInputBytes() override;
+ void Init() override;
+ void RunImpl() override;
+ virtual ~BenchmarkTfLiteModel() {}
+
+ struct InputLayerInfo {
+ std::string name;
+ TfLiteType data_type;
+ std::vector<int> shape;
+ // Note that initialization_values is currently unused.
+ std::vector<float> initialization_values;
+ };
+
+ private:
+ std::unique_ptr<tflite::FlatBufferModel> model;
+ std::unique_ptr<tflite::Interpreter> interpreter;
+ std::string graph;
+ std::string input_layer_string;
+ std::string input_layer_type_string;
+ std::string input_layer_shape_string;
+ std::string input_layer_values_string;
+ std::string output_layer_string;
+ std::vector<InputLayerInfo> inputs;
+ bool use_nnapi;
+ ProfilingListener profiling_listener_;
+};
+
+} // namespace benchmark
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_TFLITE_MODEL_H_
diff --git a/tensorflow/contrib/lite/tools/command_line_flags.cc b/tensorflow/contrib/lite/tools/command_line_flags.cc
new file mode 100644
index 0000000000..ba72f40689
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/command_line_flags.cc
@@ -0,0 +1,189 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/command_line_flags.h"
+
+#include <sstream>
+#include <string>
+#include <vector>
+
+namespace tflite {
+namespace {
+
+bool ParseFlag(const std::string& arg, const std::string& flag,
+ const std::function<bool(const std::string&)>& parse_func,
+ bool* value_parsing_ok) {
+ *value_parsing_ok = true;
+ std::string flag_prefix = "--" + flag + "=";
+ if (arg.find(flag_prefix) != 0) {
+ return false;
+ }
+ bool has_value = (arg.size() >= flag_prefix.size() + 1);
+ *value_parsing_ok = has_value;
+ if (has_value) {
+ *value_parsing_ok = parse_func(arg.substr(flag_prefix.size()));
+ }
+ return true;
+}
+
+bool ParseInt32Flag(const std::string& flag_value, int32_t* value) {
+ char extra;
+ return sscanf(flag_value.data(), "%d%c", value, &extra) == 1;
+}
+
+bool ParseInt64Flag(const std::string& flag_value, int64_t* value) {
+ char extra;
+ return sscanf(flag_value.data(), "%ld%c", value, &extra) == 1;
+}
+
+bool ParseBoolFlag(const std::string& flag_value, bool* value) {
+ if (flag_value != "true" && flag_value != "false") {
+ return false;
+ }
+
+ *value = (flag_value == "true");
+ return true;
+}
+
+bool ParseFloatFlag(const std::string& flag_value, float* value) {
+ char extra;
+ return sscanf(flag_value.data(), "%f%c", value, &extra) == 1;
+}
+
+bool ParseStringFlag(const std::string& flag_value, std::string* value) {
+ *value = flag_value;
+ return true;
+}
+
+} // namespace
+
+Flag::Flag(const char* name, int32_t* dst, const std::string& usage_text)
+ : name_(name),
+ type_(TYPE_INT32),
+ value_hook_([dst](const std::string& flag_value) {
+ return ParseInt32Flag(flag_value, dst);
+ }),
+ default_for_display_(std::to_string(*dst)),
+ usage_text_(usage_text) {}
+
+Flag::Flag(const char* name, int64_t* dst, const std::string& usage_text)
+ : name_(name),
+ type_(TYPE_INT64),
+ value_hook_([dst](const std::string& flag_value) {
+ return ParseInt64Flag(flag_value, dst);
+ }),
+ default_for_display_(std::to_string(*dst)),
+ usage_text_(usage_text) {}
+
+Flag::Flag(const char* name, float* dst, const std::string& usage_text)
+ : name_(name),
+ type_(TYPE_FLOAT),
+ value_hook_([dst](const std::string& flag_value) {
+ return ParseFloatFlag(flag_value, dst);
+ }),
+ default_for_display_(std::to_string(*dst)),
+ usage_text_(usage_text) {}
+
+Flag::Flag(const char* name, bool* dst, const std::string& usage_text)
+ : name_(name),
+ type_(TYPE_BOOL),
+ value_hook_([dst](const std::string& flag_value) {
+ return ParseBoolFlag(flag_value, dst);
+ }),
+ default_for_display_((*dst) ? "true" : "false"),
+ usage_text_(usage_text) {}
+
+Flag::Flag(const char* name, std::string* dst, const std::string& usage_text)
+ : name_(name),
+ type_(TYPE_STRING),
+ value_hook_([dst](const std::string& flag_value) {
+ return ParseStringFlag(flag_value, dst);
+ }),
+ default_for_display_(*dst),
+ usage_text_(usage_text) {}
+
+bool Flag::Parse(const std::string& arg, bool* value_parsing_ok) const {
+ return ParseFlag(arg, name_, value_hook_, value_parsing_ok);
+}
+
+std::string Flag::GetTypeName() const {
+ switch (type_) {
+ case TYPE_INT32:
+ return "int32";
+ case TYPE_INT64:
+ return "int64";
+ case TYPE_FLOAT:
+ return "float";
+ case TYPE_BOOL:
+ return "bool";
+ case TYPE_STRING:
+ return "string";
+ }
+
+ return "unknown";
+}
+
+/*static*/ bool Flags::Parse(int* argc, const char** argv,
+ const std::vector<Flag>& flag_list) {
+ bool result = true;
+ std::vector<const char*> unknown_flags;
+ for (int i = 1; i < *argc; ++i) {
+ if (std::string(argv[i]) == "--") {
+ while (i < *argc) {
+ unknown_flags.push_back(argv[i]);
+ ++i;
+ }
+ break;
+ }
+
+ bool was_found = false;
+ for (const Flag& flag : flag_list) {
+ bool value_parsing_ok;
+ was_found = flag.Parse(argv[i], &value_parsing_ok);
+ if (!value_parsing_ok) {
+ result = false;
+ }
+ if (was_found) {
+ break;
+ }
+ }
+ if (!was_found) {
+ unknown_flags.push_back(argv[i]);
+ }
+ }
+ int dst = 1; // Skip argv[0]
+ for (auto f : unknown_flags) {
+ argv[dst++] = f;
+ }
+ argv[dst++] = nullptr;
+ *argc = unknown_flags.size() + 1;
+ return result && (*argc < 2 || strcmp(argv[1], "--help") != 0);
+}
+
+/*static*/ std::string Flags::Usage(const std::string& cmdline,
+ const std::vector<Flag>& flag_list) {
+ std::ostringstream usage_text;
+ usage_text << "usage: " << cmdline << "\n";
+ if (!flag_list.empty()) {
+ usage_text << "Flags:\n";
+ }
+
+ for (const Flag& flag : flag_list) {
+ auto type_name = flag.GetTypeName();
+ usage_text << "\t";
+ usage_text << "--" << flag.name_ << "=" << flag.default_for_display_;
+ usage_text << "\t" << type_name << "\t" << flag.usage_text_ << "\n";
+ }
+ return usage_text.str();
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/tools/command_line_flags.h b/tensorflow/contrib/lite/tools/command_line_flags.h
new file mode 100644
index 0000000000..0605d3c9d4
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/command_line_flags.h
@@ -0,0 +1,112 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_COMMAND_LINE_FLAGS_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_COMMAND_LINE_FLAGS_H_
+
+#include <functional>
+#include <string>
+#include <vector>
+
+namespace tflite {
+// A simple command-line argument parsing module.
+// Dependency free simplified port of core/util/command_line_flags.
+// This class is written for benchmarks and uses inefficient string
+// concatenation. This was written to avoid dependency on tensorflow/core/util
+// which transitively brings in a lot of other dependencies that are not
+// necessary for tflite benchmarking code.
+// The recommended way of using it is with local variables and an initializer
+// list of Flag objects, for example:
+//
+// int some_int = 10;
+// bool some_switch = false;
+// std::string some_name = "something";
+// std::vector<tensorFlow::Flag> flag_list = {
+// Flag("some_int", &some_int, "an integer that affects X"),
+// Flag("some_switch", &some_switch, "a bool that affects Y"),
+// Flag("some_name", &some_name, "a std::string that affects Z")
+// };
+// // Get usage message before ParseFlags() to capture default values.
+// std::string usage = Flag::Usage(argv[0], flag_list);
+// bool parsed_values_ok = Flags::Parse(&argc, argv, flag_list);
+//
+// tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
+// if (argc != 1 || !parsed_values_ok) {
+// ...output usage and error message...
+// }
+//
+// The argc and argv values are adjusted by the Parse function so all that
+// remains is the program name (at argv[0]) and any unknown arguments fill the
+// rest of the array. This means you can check for flags that weren't understood
+// by seeing if argv is greater than 1.
+// The result indicates if there were any errors parsing the values that were
+// passed to the command-line switches. For example, --some_int=foo would return
+// false because the argument is expected to be an integer.
+//
+// NOTE: Unlike gflags-style libraries, this library is intended to be
+// used in the `main()` function of your binary. It does not handle
+// flag definitions that are scattered around the source code.
+
+// A description of a single command line flag, holding its name, type, usage
+// text, and a pointer to the corresponding variable.
+class Flag {
+ public:
+ Flag(const char* name, int32_t* dst, const std::string& usage_text);
+ Flag(const char* name, int64_t* dst, const std::string& usage_text);
+ Flag(const char* name, bool* dst, const std::string& usage_text);
+ Flag(const char* name, std::string* dst, const std::string& usage_text);
+ Flag(const char* name, float* dst, const std::string& usage_text);
+
+ private:
+ friend class Flags;
+
+ bool Parse(const std::string& arg, bool* value_parsing_ok) const;
+
+ std::string name_;
+ enum {
+ TYPE_INT32,
+ TYPE_INT64,
+ TYPE_BOOL,
+ TYPE_STRING,
+ TYPE_FLOAT,
+ } type_;
+
+ std::string GetTypeName() const;
+
+ std::function<bool(const std::string&)> value_hook_;
+ std::string default_for_display_;
+
+ std::string usage_text_;
+};
+
+class Flags {
+ public:
+ // Parse the command line represented by argv[0, ..., (*argc)-1] to find flag
+ // instances matching flags in flaglist[]. Update the variables associated
+ // with matching flags, and remove the matching arguments from (*argc, argv).
+ // Return true iff all recognized flag values were parsed correctly, and the
+ // first remaining argument is not "--help".
+ static bool Parse(int* argc, const char** argv,
+ const std::vector<Flag>& flag_list);
+
+ // Return a usage message with command line cmdline, and the
+ // usage_text strings in flag_list[].
+ static std::string Usage(const std::string& cmdline,
+ const std::vector<Flag>& flag_list);
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_COMMAND_LINE_FLAGS_H_
diff --git a/tensorflow/contrib/lite/tools/command_line_flags_test.cc b/tensorflow/contrib/lite/tools/command_line_flags_test.cc
new file mode 100644
index 0000000000..463647bec9
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/command_line_flags_test.cc
@@ -0,0 +1,153 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/command_line_flags.h"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/testing/util.h"
+
+namespace tflite {
+namespace {
+
+TEST(CommandLineFlagsTest, BasicUsage) {
+ int some_int32 = 10;
+ int64_t some_int64 = 21474836470; // max int32 is 2147483647
+ bool some_switch = false;
+ std::string some_name = "something_a";
+ float some_float = -23.23f;
+ const char* argv_strings[] = {"program_name",
+ "--some_int32=20",
+ "--some_int64=214748364700",
+ "--some_switch=true",
+ "--some_name=somethingelse",
+ "--some_float=42.0"};
+ int argc = 6;
+ bool parsed_ok =
+ Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
+ {
+ Flag("some_int32", &some_int32, "some int32"),
+ Flag("some_int64", &some_int64, "some int64"),
+ Flag("some_switch", &some_switch, "some switch"),
+ Flag("some_name", &some_name, "some name"),
+ Flag("some_float", &some_float, "some float"),
+ });
+
+ EXPECT_EQ(true, parsed_ok);
+ EXPECT_EQ(20, some_int32);
+ EXPECT_EQ(214748364700, some_int64);
+ EXPECT_EQ(true, some_switch);
+ EXPECT_EQ("somethingelse", some_name);
+ EXPECT_NEAR(42.0f, some_float, 1e-5f);
+ EXPECT_EQ(argc, 1);
+}
+
+TEST(CommandLineFlagsTest, BadIntValue) {
+ int some_int = 10;
+ int argc = 2;
+ const char* argv_strings[] = {"program_name", "--some_int=notanumber"};
+ bool parsed_ok =
+ Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
+ {Flag("some_int", &some_int, "some int")});
+
+ EXPECT_EQ(false, parsed_ok);
+ EXPECT_EQ(10, some_int);
+ EXPECT_EQ(argc, 1);
+}
+
+TEST(CommandLineFlagsTest, BadBoolValue) {
+ bool some_switch = false;
+ int argc = 2;
+ const char* argv_strings[] = {"program_name", "--some_switch=notabool"};
+ bool parsed_ok =
+ Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
+ {Flag("some_switch", &some_switch, "some switch")});
+
+ EXPECT_EQ(false, parsed_ok);
+ EXPECT_EQ(false, some_switch);
+ EXPECT_EQ(argc, 1);
+}
+
+TEST(CommandLineFlagsTest, BadFloatValue) {
+ float some_float = -23.23f;
+ int argc = 2;
+ const char* argv_strings[] = {"program_name", "--some_float=notanumber"};
+ bool parsed_ok =
+ Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
+ {Flag("some_float", &some_float, "some float")});
+
+ EXPECT_EQ(false, parsed_ok);
+ EXPECT_NEAR(-23.23f, some_float, 1e-5f);
+ EXPECT_EQ(argc, 1);
+}
+
+// Return whether str==pat, but allowing any whitespace in pat
+// to match zero or more whitespace characters in str.
+static bool MatchWithAnyWhitespace(const std::string& str,
+ const std::string& pat) {
+ bool matching = true;
+ int pat_i = 0;
+ for (int str_i = 0; str_i != str.size() && matching; str_i++) {
+ if (isspace(str[str_i])) {
+ matching = (pat_i != pat.size() && isspace(pat[pat_i]));
+ } else {
+ while (pat_i != pat.size() && isspace(pat[pat_i])) {
+ pat_i++;
+ }
+ matching = (pat_i != pat.size() && str[str_i] == pat[pat_i++]);
+ }
+ }
+ while (pat_i != pat.size() && isspace(pat[pat_i])) {
+ pat_i++;
+ }
+ return (matching && pat_i == pat.size());
+}
+
+TEST(CommandLineFlagsTest, UsageString) {
+ int some_int = 10;
+ int64_t some_int64 = 21474836470; // max int32 is 2147483647
+ bool some_switch = false;
+ std::string some_name = "something";
+ // Don't test float in this case, because precision is hard to predict and
+ // match against, and we don't want a flakey test.
+ const string tool_name = "some_tool_name";
+ string usage = Flags::Usage(tool_name + " <flags>",
+ {Flag("some_int", &some_int, "some int"),
+ Flag("some_int64", &some_int64, "some int64"),
+ Flag("some_switch", &some_switch, "some switch"),
+ Flag("some_name", &some_name, "some name")});
+ // Match the usage message, being sloppy about whitespace.
+ const char* expected_usage =
+ " usage: some_tool_name <flags>\n"
+ "Flags:\n"
+ "--some_int=10\tint32\tsome int\n"
+ "--some_int64=21474836470\tint64\tsome int64\n"
+ "--some_switch=false\tbool\tsome switch\n"
+ "--some_name=something\tstring\tsome name\n";
+ ASSERT_EQ(MatchWithAnyWhitespace(usage, expected_usage), true) << usage;
+
+ // Again but with no flags.
+ usage = Flags::Usage(tool_name, {});
+ ASSERT_EQ(MatchWithAnyWhitespace(usage, " usage: some_tool_name\n"), true)
+ << usage;
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/tools/logging.h b/tensorflow/contrib/lite/tools/logging.h
new file mode 100644
index 0000000000..aa1fa5b827
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/logging.h
@@ -0,0 +1,75 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_LOGGING_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_LOGGING_H_
+
+// LOG and CHECK macros for benchmarks.
+
+#include <iostream>
+#include <sstream>
+
+namespace tflite {
+namespace logging {
+// A wrapper that logs to stderr.
+//
+// Used for TFLITE_LOG and TFLITE_BENCHMARK_CHECK macros.
+class LoggingWrapper {
+ public:
+ enum class LogSeverity : int {
+ INFO = 0,
+ WARN = 1,
+ ERROR = 2,
+ FATAL = 3,
+ };
+ LoggingWrapper(LogSeverity severity)
+ : severity_(severity), should_log_(true) {}
+ LoggingWrapper(LogSeverity severity, bool log)
+ : severity_(severity), should_log_(log) {}
+ std::stringstream& Stream() { return stream_; }
+ ~LoggingWrapper() {
+ if (should_log_) {
+ std::cerr << stream_.str() << std::endl;
+ if (severity_ == LogSeverity::FATAL) {
+ std::flush(std::cerr);
+ std::abort();
+ }
+ }
+ }
+
+ private:
+ std::stringstream stream_;
+ LogSeverity severity_;
+ bool should_log_;
+};
+
+} // namespace logging
+
+} // namespace tflite
+
+#define TFLITE_LOG(severity) \
+ tflite::logging::LoggingWrapper( \
+ tflite::logging::LoggingWrapper::LogSeverity::severity) \
+ .Stream()
+
+#define TFLITE_BENCHMARK_CHECK(condition) \
+ tflite::logging::LoggingWrapper( \
+ tflite::logging::LoggingWrapper::LogSeverity::FATAL, \
+ (condition) ? false : true) \
+ .Stream()
+
+#define TFLITE_BENCHMARK_CHECK_EQ(a, b) TFLITE_BENCHMARK_CHECK(a == b)
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_LOGGING_H_
diff --git a/tensorflow/contrib/lite/util.cc b/tensorflow/contrib/lite/util.cc
index fb4af07d06..8ccb65c24f 100644
--- a/tensorflow/contrib/lite/util.cc
+++ b/tensorflow/contrib/lite/util.cc
@@ -38,4 +38,14 @@ bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size,
return true;
}
+size_t CombineHashes(std::initializer_list<size_t> hashes) {
+ size_t result = 0;
+ // Hash combiner used by TensorFlow core.
+ for (size_t hash : hashes) {
+ result = result ^
+ (hash + 0x9e3779b97f4a7800ULL + (result << 10) + (result >> 4));
+ }
+ return result;
+}
+
} // namespace tflite
diff --git a/tensorflow/contrib/lite/util.h b/tensorflow/contrib/lite/util.h
index a34db35823..89d9b4f5cf 100644
--- a/tensorflow/contrib/lite/util.h
+++ b/tensorflow/contrib/lite/util.h
@@ -35,6 +35,8 @@ TfLiteIntArray* ConvertArrayToTfLiteIntArray(const int rank, const int* dims);
bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size,
const int* b);
+size_t CombineHashes(std::initializer_list<size_t> hashes);
+
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_UTIL_H_
diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc
index 2de3923b06..f5b2d258d7 100644
--- a/tensorflow/contrib/tensorrt/segment/segment_test.cc
+++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc
@@ -275,13 +275,13 @@ TEST_F(SegmentTest, Multiple) {
// Expect two subgraphs
EXPECT_EQ(segments.size(), 2);
- std::vector<string> expected0{"add0", "add1", "add2", "add3"};
+ std::vector<string> expected0{"add6", "add8"};
for (const auto& ex : expected0) {
EXPECT_TRUE(segments[0].first.find(ex) != segments[0].first.end())
<< "Missing expected node " << ex;
}
- std::vector<string> expected1{"add6", "add8"};
+ std::vector<string> expected1{"add0", "add1", "add2", "add3"};
for (const auto& ex : expected1) {
EXPECT_TRUE(segments[1].first.find(ex) != segments[1].first.end())
<< "Missing expected node " << ex;
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index aeb7ba536f..4465833f88 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -46,6 +46,7 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.estimator.export import export_output as export_output_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -2748,7 +2749,8 @@ class _Inputs(object):
"""
iterator = self._dataset.make_initializable_iterator()
# pylint: disable=protected-access
- hook = estimator_lib._DatasetInitializerHook(iterator)
+ hook = estimator_util._DatasetInitializerHook(iterator)
+ # pylint: enable=protected-access
self._iterator = iterator
return hook
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 10109e5ac1..c976079350 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -101,42 +101,43 @@ load("//tensorflow:tensorflow.bzl", "tf_cuda_only_cc_test")
# For platform specific build config
load(
"//tensorflow/core:platform/default/build_config.bzl",
- "tf_platform_hdrs",
- "tf_platform_srcs",
- "tf_proto_library",
- "tf_proto_library_cc",
"tf_additional_all_protos",
+ "tf_additional_cloud_kernel_deps",
+ "tf_additional_cloud_op_deps",
"tf_additional_core_deps",
+ "tf_additional_cupti_wrapper_deps",
+ "tf_additional_device_tracer_cuda_deps",
+ "tf_additional_device_tracer_deps",
+ "tf_additional_device_tracer_srcs",
+ "tf_additional_gdr_lib_defines",
+ "tf_additional_human_readable_json_deps",
"tf_additional_lib_defines",
"tf_additional_lib_deps",
+ "tf_additional_libdevice_data",
+ "tf_additional_libdevice_deps",
+ "tf_additional_libdevice_srcs",
"tf_additional_lib_hdrs",
"tf_additional_lib_srcs",
"tf_additional_minimal_lib_srcs",
+ "tf_additional_mpi_lib_defines",
"tf_additional_proto_hdrs",
"tf_additional_proto_srcs",
- "tf_additional_cupti_wrapper_deps",
- "tf_additional_libdevice_data",
- "tf_additional_libdevice_deps",
- "tf_additional_libdevice_srcs",
"tf_additional_test_deps",
"tf_additional_test_srcs",
- "tf_kernel_tests_linkstatic",
- "tf_additional_cloud_op_deps",
- "tf_additional_cloud_kernel_deps",
- "tf_lib_proto_parsing_deps",
"tf_additional_verbs_lib_defines",
- "tf_additional_mpi_lib_defines",
- "tf_additional_gdr_lib_defines",
- "tf_additional_device_tracer_srcs",
- "tf_additional_device_tracer_deps",
- "tf_additional_device_tracer_cuda_deps",
- "tf_pyclif_proto_library",
"tf_jspb_proto_library",
+ "tf_kernel_tests_linkstatic",
+ "tf_lib_proto_parsing_deps",
"tf_nano_proto_library",
+ "tf_platform_hdrs",
+ "tf_platform_srcs",
+ "tf_proto_library",
+ "tf_proto_library_cc",
"tf_protos_all",
"tf_protos_all_impl",
"tf_protos_grappler",
"tf_protos_grappler_impl",
+ "tf_pyclif_proto_library",
)
load(
"//tensorflow/core:platform/default/build_config_root.bzl",
@@ -401,6 +402,7 @@ cc_library(
"protobuf.cc",
]) + [
"platform/protobuf_util.cc",
+ "lib/core/status.h",
],
hdrs = [
":platform_protobuf_hdrs",
@@ -417,6 +419,18 @@ cc_library(
],
)
+cc_library(
+ name = "human_readable_json",
+ srcs = tf_platform_srcs(["human_readable_json.cc"]),
+ hdrs = ["platform/human_readable_json.h"],
+ copts = tf_copts(),
+ visibility = ["//visibility:public"],
+ deps = [
+ ":lib",
+ ":lib_internal",
+ ] + tf_additional_human_readable_json_deps(),
+)
+
filegroup(
name = "platform_env_hdrs",
srcs = [
@@ -832,7 +846,6 @@ tf_cuda_library(
"util/sparse/sparse_tensor.h",
"util/stat_summarizer.h",
"util/stat_summarizer_options.h",
- "util/stats_calculator.h",
"util/stream_executor_util.h",
"util/strided_slice_op.h",
"util/tensor_format.h",
@@ -859,9 +872,11 @@ tf_cuda_library(
cc_library(
name = "stats_calculator_portable",
- srcs = ["util/stats_calculator.cc"],
- hdrs = [
+ srcs = [
"util/stat_summarizer_options.h",
+ "util/stats_calculator.cc",
+ ],
+ hdrs = [
"util/stats_calculator.h",
],
deps = [":platform_base"],
@@ -2016,6 +2031,7 @@ cc_library(
"platform/**/cuda_libdevice_path.cc",
"platform/**/device_tracer.cc",
"platform/**/logging.cc",
+ "platform/**/human_readable_json.cc",
"platform/abi.cc",
],
) + tf_additional_lib_srcs(
@@ -2028,6 +2044,7 @@ cc_library(
"platform/**/env_time.cc",
"platform/**/device_tracer.cc",
"platform/**/logging.cc",
+ "platform/**/human_readable_json.cc",
"platform/abi.cc",
] +
# Protobuf deps already included through the ":lib_proto_parsing"
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 61b2f0e60f..f4f5198396 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -845,7 +845,7 @@ TEST_F(FunctionLibraryRuntimeTest, ManySwapsNodeDef) {
ASSERT_TRUE(g != nullptr);
OptimizeGraph(flr0_, &g);
const char* e0 = R"P(
-(n3:float, n2:float) -> (n3:float) {
+(n2:float, n3:float) -> (n2:float) {
}
)P";
EXPECT_EQ(e0, DebugString(g.get()));
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 8624af9bf5..23dc903caf 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -351,6 +351,10 @@ class IteratorBase {
// in the outputs of this iterator.
virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
+ // Performs initialization that needs to happen outside of a constructor to
+ // properly propagate errors.
+ virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); }
+
// Saves the state of this iterator.
virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) {
return SaveInternal(writer);
@@ -402,12 +406,13 @@ class DatasetBase : public core::RefCounted {
// iterator will traverse all elements in this dataset from the
// start.
//
- // Ownership of the created iterator will be transferred to the caller.
- //
// The prefix identifies the sequence of iterators leading up to the newly
// created iterator.
- virtual std::unique_ptr<IteratorBase> MakeIterator(
- const string& prefix) const = 0;
+ Status MakeIterator(IteratorContext* ctx, const string& prefix,
+ std::unique_ptr<IteratorBase>* iterator) const {
+ *iterator = MakeIteratorInternal(prefix);
+ return (*iterator)->Initialize(ctx);
+ }
// Returns a vector of DataType values, representing the respective
// element types of each tuple component in the outputs of this
@@ -420,7 +425,7 @@ class DatasetBase : public core::RefCounted {
virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
// A human-readable debug string for this dataset.
- virtual string DebugString() = 0;
+ virtual string DebugString() const = 0;
// Serializes the dataset and writes it to the `writer`.
virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) const {
@@ -451,6 +456,9 @@ class DatasetBase : public core::RefCounted {
Node** node) const {
return errors::Unimplemented("AsGraphDefInternal");
}
+
+ virtual std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const = 0;
};
// Base-class for datasets that are built by ops.
diff --git a/tensorflow/core/framework/variable.proto b/tensorflow/core/framework/variable.proto
index 93ae423bab..66ba4cba7d 100644
--- a/tensorflow/core/framework/variable.proto
+++ b/tensorflow/core/framework/variable.proto
@@ -26,6 +26,9 @@ message VariableDef {
// Whether to represent this as a ResourceVariable.
bool is_resource = 5;
+
+ // Whether this variable should be trained.
+ bool trainable = 7;
}
message SaveSliceInfoDef {
diff --git a/tensorflow/core/graph/algorithm_test.cc b/tensorflow/core/graph/algorithm_test.cc
index 99ced0c0f5..f67d5a2fd2 100644
--- a/tensorflow/core/graph/algorithm_test.cc
+++ b/tensorflow/core/graph/algorithm_test.cc
@@ -144,8 +144,8 @@ TEST(AlgorithmTest, ReversePostOrderStable) {
std::vector<Node*> order;
// Test reverse post order generates expected ordering.
- GetReversePostOrder(g, &order, /*stable_comparator=*/NodeComparatorID());
- EXPECT_TRUE(ExpectBefore({{"t3", "t2"}}, order, &error));
+ GetReversePostOrder(g, &order, /*stable_comparator=*/NodeComparatorName());
+ EXPECT_TRUE(ExpectBefore({{"t2", "t3"}}, order, &error));
}
}
} // namespace
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 2fd32c0bd4..0967492d92 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -278,8 +278,9 @@ class GraphConstructor {
// name, the value is the new unique name.
std::unordered_map<string, string> uniquified_names_;
- // Index of NodeDefs in node_defs_ with all inputs already converted.
- std::vector<int> ready_;
+ // Index of NodeDefs in node_defs_ with all inputs already converted. We use a
+ // (sorted) set so nodes are created in the order defined in the GraphDef.
+ std::set<int> ready_;
// Mapping between index within node_defs_ and the number of inputs that
// still need to be converted.
@@ -520,7 +521,7 @@ Status GraphConstructor::InitFromEdges() {
}
}
if (pending_count == 0) {
- ready_.push_back(n);
+ ready_.insert(n);
}
pending_count_.push_back(pending_count);
}
@@ -884,12 +885,12 @@ namespace {
void UpdatePendingCountAndReady(
const std::vector<gtl::InlinedVector<int, 4>>& outputs, int o,
- std::vector<int>* pending_count, std::vector<int>* ready) {
+ std::vector<int>* pending_count, std::set<int>* ready) {
for (size_t i = 0; i < outputs[o].size(); ++i) {
const int output = outputs[o][i];
(*pending_count)[output]--;
if ((*pending_count)[output] == 0) {
- ready->push_back(output);
+ ready->insert(output);
}
}
}
@@ -913,8 +914,8 @@ Status GraphConstructor::Convert() {
// inputs, pending_counts_ with the number of inputs for each node and
// outputs_ with the outputs of each node).
while (!ready_.empty()) {
- int o = ready_.back();
- ready_.pop_back();
+ int o = *ready_.begin();
+ ready_.erase(ready_.begin());
++processed;
inputs.clear();
bool has_data_back_edge = false;
diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc
index 83b24cafe2..f44ed47a6e 100644
--- a/tensorflow/core/graph/graph_partition_test.cc
+++ b/tensorflow/core/graph/graph_partition_test.cc
@@ -329,11 +329,11 @@ TEST_F(GraphPartitionTest, CrossDeviceControl_MultiUse) {
string b = "/job:a/replica:0/task:0/cpu:1";
a1 = FloatInput(scope_a_.WithOpName("A1"));
auto c = Const(scope_a_.WithOpName("A1/_0").WithControlDependencies(a1), {});
- _Send(scope_a_.WithOpName("A1/_1"), c, "edge_1_A1", a, 82, b);
+ _Send(scope_a_.WithOpName("A1/_1"), c, "edge_3_A1", a, 82, b);
ExpectMatchA();
auto recv =
- _Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_1_A1", a, 82, b);
+ _Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_3_A1", a, 82, b);
auto id = Identity(scope_b_.WithOpName("A1/_3"), recv);
b1 = FloatInput(scope_b_.WithOpName("B1"));
Combine(scope_b_.WithOpName("B2").WithControlDependencies(id), b1, b1);
@@ -353,18 +353,18 @@ TEST_F(GraphPartitionTest, CrossDevice_DataControl) {
string a = "/job:a/replica:0/task:0/cpu:0";
string b = "/job:a/replica:0/task:0/cpu:1";
a1 = FloatInput(scope_a_.WithOpName("A1"));
- auto c = Const(scope_a_.WithOpName("A1/_0").WithControlDependencies(a1), {});
+ _Send(scope_a_.WithOpName("A1/_0"), a1, "edge_1_A1", a, 82, b);
+ auto c = Const(scope_a_.WithOpName("A1/_2").WithControlDependencies(a1), {});
// NOTE: Send 0 A1/_1 -> A1/_2 is not necessarily needed. We could
// use A1/_0 -> A1/_4 as the control as a minor optimization.
- _Send(scope_a_.WithOpName("A1/_1"), c, "edge_1_A1", a, 82, b);
- _Send(scope_a_.WithOpName("A1/_4"), a1, "edge_2_A1", a, 82, b);
+ _Send(scope_a_.WithOpName("A1/_3"), c, "edge_3_A1", a, 82, b);
ExpectMatchA();
auto recv1 =
- _Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_1_A1", a, 82, b);
- auto id1 = Identity(scope_b_.WithOpName("A1/_3"), recv1);
+ _Recv(scope_b_.WithOpName("A1/_4"), DT_FLOAT, "edge_3_A1", a, 82, b);
+ auto id1 = Identity(scope_b_.WithOpName("A1/_5"), recv1);
auto recv2 =
- _Recv(scope_b_.WithOpName("A1/_5"), DT_FLOAT, "edge_2_A1", a, 82, b);
+ _Recv(scope_b_.WithOpName("A1/_1"), DT_FLOAT, "edge_1_A1", a, 82, b);
b1 = FloatInput(scope_b_.WithOpName("B1"));
Combine(scope_b_.WithOpName("B2"), recv2, b1);
FloatInput(scope_b_.WithOpName("B3").WithControlDependencies(id1));
diff --git a/tensorflow/core/graph/optimizer_cse_test.cc b/tensorflow/core/graph/optimizer_cse_test.cc
index 21a63662cf..c1f93ce05a 100644
--- a/tensorflow/core/graph/optimizer_cse_test.cc
+++ b/tensorflow/core/graph/optimizer_cse_test.cc
@@ -115,8 +115,8 @@ TEST_F(OptimizerCSETest, Simple) {
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoCSE(),
- "A(Input);B(Input);D(Mul)|"
- "A->D;B->D:1");
+ "A(Input);B(Input);C(Mul)|"
+ "A->C;B->C:1");
}
TEST_F(OptimizerCSETest, Simple_ThreeEquivalent) {
@@ -130,8 +130,8 @@ TEST_F(OptimizerCSETest, Simple_ThreeEquivalent) {
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoCSE(),
- "A(Input);B(Input);E(Mul)|"
- "A->E;B->E:1");
+ "A(Input);B(Input);C(Mul)|"
+ "A->C;B->C:1");
}
TEST_F(OptimizerCSETest, Simple_WithFixups) {
@@ -145,8 +145,8 @@ TEST_F(OptimizerCSETest, Simple_WithFixups) {
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }");
EXPECT_EQ(DoCSE(),
- "A(Input);B(Input);D(Mul);E(Mul)|"
- "A->D;B->D:1;D->E;D->E:1");
+ "A(Input);B(Input);C(Mul);E(Mul)|"
+ "A->C;B->C:1;C->E;C->E:1");
}
TEST_F(OptimizerCSETest, Simple_Commutative) {
@@ -158,8 +158,8 @@ TEST_F(OptimizerCSETest, Simple_Commutative) {
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'A'] }");
EXPECT_EQ(DoCSE(),
- "A(Input);B(Input);D(Mul)|"
- "A->D:1;B->D");
+ "A(Input);B(Input);C(Mul)|"
+ "A->C;B->C:1");
}
static bool IsNotMultiply(const Node* n) { return n->type_string() != "Mul"; }
@@ -210,8 +210,8 @@ TEST_F(OptimizerCSETest, Simple_SameOps_SameAttrs1) {
" input: ['A', 'B'] attr { key: 'shape'"
" value { shape: { dim: { size: 37 name: 'SAME_NAME' } } } } }");
EXPECT_EQ(DoCSE(),
- "A(Input);B(Input);D(Mul)|"
- "A->D;B->D:1");
+ "A(Input);B(Input);C(Mul)|"
+ "A->C;B->C:1");
}
TEST_F(OptimizerCSETest, Simple_SameOps_SameAttrs2) {
@@ -229,8 +229,8 @@ TEST_F(OptimizerCSETest, Simple_SameOps_SameAttrs2) {
" attr { key: 't' value { type: DT_INT32 } }"
" attr { key: 'a' value { i: 3 } } }");
EXPECT_EQ(DoCSE(),
- "A(Input);B(Input);D(Mul)|"
- "A->D;B->D:1");
+ "A(Input);B(Input);C(Mul)|"
+ "A->C;B->C:1");
}
TEST_F(OptimizerCSETest, SameConstants) {
@@ -249,8 +249,8 @@ TEST_F(OptimizerCSETest, SameConstants) {
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_INT32 } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoCSE(),
- "B(Const);D(Mul)|"
- "B->D;B->D:1");
+ "A(Const);D(Mul)|"
+ "A->D;A->D:1");
}
TEST_F(OptimizerCSETest, DifferentConstants) {
@@ -338,8 +338,8 @@ TEST_F(OptimizerCSETest, Constant_Dedup) {
"n/_0(Const);n/_1(Const);n/_2(Const);n/_3(Const);"
"n/_4(Const);n/_5(Const);n/_6(Const);n/_7(Const)|");
// In theory, there are 2^4 possible correct output of CSE. In this
- // test, it happens to eliminate the first 4 nodes.
- EXPECT_EQ(DoCSE(), "n/_4(Const);n/_5(Const);n/_6(Const);n/_7(Const)|");
+ // test, it happens to eliminate the last 4 nodes.
+ EXPECT_EQ(DoCSE(), "n/_0(Const);n/_1(Const);n/_2(Const);n/_3(Const)|");
}
static void BM_CSE(int iters, int op_nodes) {
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index 69b7594735..d9a08d42db 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <unordered_map>
#include <unordered_set>
#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
@@ -425,6 +426,13 @@ class SymbolicShapeRefiner {
return it->second.inference_context.get();
}
+ // Forward the shapes from the function's fanin to the function body,
+ // then call PropagateShapes.
+ // Returns an error if 'node' is not a function node.
+ Status UpdateFunction(const NodeDef* node, bool* refined) {
+ return UpdateNode(node, refined);
+ }
+
Status UpdateNode(const NodeDef* node, bool* refined) {
NodeContext* node_context = GetNodeContext(node);
if (node_context == nullptr) {
@@ -677,10 +685,16 @@ class SymbolicShapeRefiner {
return true;
}
+ Status AddFunction(const NodeDef* node) { return Status::OK(); }
+
Status AddNode(const NodeDef* node) {
NodeContext& node_ctx = node_to_context_[node];
TF_RETURN_IF_ERROR(function_library_.LookUp(node->op(), &node_ctx.op_data));
+ if (node_ctx.op_data->is_function_op) {
+ TF_RETURN_IF_ERROR(AddFunction(node));
+ }
+
TF_RETURN_IF_ERROR(InOutTypesForNode(*node, node_ctx.op_data->op_def,
&node_ctx.input_types,
&node_ctx.output_types));
@@ -1069,8 +1083,13 @@ Status GraphProperties::UpdateShapes(
TF_RETURN_IF_ERROR(
UpdateEnqueue(n, resource_handles, shape_refiner, new_shapes));
} else {
- // Rely on regular TF shape refinement for all the other nodes.
- TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes));
+ auto c = shape_refiner->GetNodeContext(n);
+ if (c && c->op_data && c->op_data->is_function_op) {
+ TF_RETURN_IF_ERROR(shape_refiner->UpdateFunction(n, new_shapes));
+ } else {
+ // Rely on regular TF shape refinement for all the other nodes.
+ TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes));
+ }
}
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 9c18c45f18..ca3f84a81d 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -194,57 +194,6 @@ void SetSourceDataType(DataType dtype, NodeDef* node) {
SetDataTypeToAttr(dtype, SourceDataTypeAttrName(*node), node);
}
-bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); }
-
-// Returns whether `reshape` is an identity op. The tensor that `reshape`
-// reshapes is the `output_pos`-th output of node `input`.
-bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input,
- const int output_pos,
- const GraphProperties& graph_properties) {
- const std::vector<OpInfo::TensorProperties>& reshape_props =
- graph_properties.GetOutputProperties(reshape.name());
- const std::vector<OpInfo::TensorProperties>& input_props =
- graph_properties.GetOutputProperties(input.name());
- if (reshape_props.empty() || input_props.size() <= output_pos) {
- return false;
- }
-
- const PartialTensorShape& src_shape = input_props[output_pos].shape();
- const PartialTensorShape& dst_shape = reshape_props[0].shape();
-
- if (src_shape.unknown_rank() || dst_shape.unknown_rank()) {
- return false;
- }
-
- if (!dst_shape.IsCompatibleWith(src_shape)) {
- return false;
- }
-
- // Returns false when src_shape or dst_shape has >=2 dimensions with unknown
- // sizes.
- auto num_unknown_dim_sizes = [](const PartialTensorShape& partial_shape) {
- auto dim_sizes = partial_shape.dim_sizes();
- return std::count_if(dim_sizes.begin(), dim_sizes.end(),
- [](int dim) { return dim < 0; });
- };
- int src_num_unknown_dim_sizes = num_unknown_dim_sizes(src_shape);
- int dst_num_unknown_dim_sizes = num_unknown_dim_sizes(dst_shape);
- if (src_num_unknown_dim_sizes > 1 || dst_num_unknown_dim_sizes > 1) {
- return false;
- }
-
- // If dst_num_unknown_dim_sizes != src_num_unknown_dim_sizes we would weaken
- // shape inference in subsequent passes if we removed this reshape.
- if (src_num_unknown_dim_sizes != dst_num_unknown_dim_sizes) {
- return false;
- }
-
- // Remove the reshape if both are fully defined or partially defined and the
- // unknown or symbolic shape appears on the same dimension, i.e., if
- // IsIdenticalTo returns true.
- return dst_shape.IsIdenticalTo(src_shape);
-}
-
NodeDef* GetTailOfValuePreservingChain(
const NodeDef& node, const NodeMap& node_map,
const std::unordered_set<string>& nodes_to_preserve) {
@@ -1856,6 +1805,159 @@ class SqrtDivToRsqrtMulStage : public ArithmeticOptimizerStage {
}
};
+// Bypass redundant reshape nodes:
+//
+// Reshape Reshape <-+
+// ^ |
+// | |
+// Reshape becomes Reshape |
+// ^ |
+// | |
+// input input ---+
+class RemoveRedundantReshape : public ArithmeticOptimizerStage {
+ public:
+ explicit RemoveRedundantReshape(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("RemoveRedundantReshape", ctx, ctx_ext) {}
+ ~RemoveRedundantReshape() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsReshape(*node);
+ }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ NodeDef* input;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
+
+ // 1. Bypass reshape followed by reshape.
+ if (IsReshape(*input) && !HasControlInputs(*input)) {
+ node->set_input(0, input->input(0));
+ ctx().node_map->UpdateInput(node->name(), input->name(), input->input(0));
+ *simplified_node_name = node->name();
+ AddToOptimizationQueue(node);
+ return Status::OK();
+ }
+
+ // 2. If the reshape is a no-op, forward its input to its consumers, unless
+ // it anchors a control dependency since we want to make sure that control
+ // dependency is triggered.
+ if (ReshapeIsIdentity(*node) && !HasControlInputs(*node)) {
+ *simplified_node_name = node->input(0);
+ return Status::OK();
+ }
+
+ return Status::OK();
+ }
+
+ private:
+ // Returns whether `reshape` is an identity op.
+ bool ReshapeIsIdentity(const NodeDef& reshape) {
+ OpInfo::TensorProperties reshape_props;
+ OpInfo::TensorProperties input_props;
+
+ if (!GetTensorProperties(reshape.name(), &reshape_props).ok() ||
+ !GetTensorProperties(reshape.input(0), &input_props).ok()) {
+ return false;
+ }
+
+ return ShapesSymbolicallyEqual(input_props.shape(), reshape_props.shape());
+ }
+};
+
+// Reorder Cast and Transpose if beneficial.
+//
+// A common pattern after the layout optimizer is casting an uint8 NHWC
+// image to float before transposing it to NCHW. It is beneficial to reorder
+// the cast and the transpose to make the transpose process smaller amount
+// of data. This optimization converts
+// Transpose(Cast(image, dst_type), perm)
+// to
+// Cast(Transpose(image, perm), dst_type)
+// when sizeof(image.type) < sizeof(dst_type).
+//
+// TODO(jingyue): This optimization can be generalized to a cast followed by
+// a chain of ops that merely reorder elements (e.g. Reshape and
+// DepthToSpace).
+class ReorderCastAndTranspose : public ArithmeticOptimizerStage {
+ public:
+ explicit ReorderCastAndTranspose(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("ReorderCastAndTranspose", ctx, ctx_ext) {}
+ ~ReorderCastAndTranspose() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsTranspose(*node) && NodeIsOnCpuOrGpu(node);
+ }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ const NodeDef* transpose = node;
+
+ // Verify that input to Transpose is the Cast op.
+ NodeDef* cast;
+ TF_RETURN_IF_ERROR(GetInputNode(transpose->input(0), &cast));
+ if (!IsCast(*cast)) return Status::OK();
+
+ // Input to the Cast-Transpose chain.
+ NodeDef* input;
+ TF_RETURN_IF_ERROR(GetInputNode(cast->input(0), &input));
+
+ const DataType src_type = GetSourceDataType(*cast);
+ const DataType dst_type = GetDestinationDataType(*cast);
+
+ const string src_type_name = DataTypeString(src_type);
+ const string dst_type_name = DataTypeString(dst_type);
+
+ // Check if nodes were not already optimized.
+ const string optimized_cast_name =
+ OptimizedNodeName(ParseNodeScopeAndName(cast->name()), dst_type_name);
+ const string optimized_transpose_name = OptimizedNodeName(
+ ParseNodeScopeAndName(transpose->name()), src_type_name);
+
+ bool is_already_optimized =
+ ctx().node_map->NodeExists(optimized_transpose_name) ||
+ ctx().node_map->NodeExists(optimized_cast_name);
+
+ if (IsNumberType(src_type) && IsNumberType(dst_type) &&
+ DataTypeSize(src_type) < DataTypeSize(dst_type) &&
+ !is_already_optimized) {
+ NodeDef* new_transpose = AddCopyNode(optimized_transpose_name, transpose);
+ (*new_transpose->mutable_attr())["T"].set_type(src_type);
+ new_transpose->set_input(0, cast->input(0));
+
+ ctx().node_map->AddOutput(input->name(), new_transpose->name());
+ ctx().node_map->AddOutput(NodeName(new_transpose->input(1)),
+ new_transpose->name());
+
+ NodeDef* new_cast = AddCopyNode(optimized_cast_name, cast);
+ new_cast->set_input(0, new_transpose->name());
+ ctx().node_map->AddOutput(new_transpose->name(), new_cast->name());
+
+ AddToOptimizationQueue(new_transpose);
+ ForwardControlDependencies(new_transpose, {cast, node});
+
+ *simplified_node_name = new_cast->name();
+ }
+
+ return Status::OK();
+ }
+
+ private:
+ // This optimization can be dangerous on devices other than CPU and
+ // GPU. The transpose might not be implemented for image.type, or
+ // might be slower with image.type than with dst_type.
+ bool NodeIsOnCpuOrGpu(const NodeDef* node) const {
+ using str_util::StrContains;
+
+ string task;
+ string device;
+
+ return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
+ (StrContains(device, DEVICE_CPU) || StrContains(device, DEVICE_GPU));
+ }
+
+ bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); }
+};
+
} // namespace
class UniqueNodes {
@@ -2108,99 +2210,6 @@ void ArithmeticOptimizer::ForwardControlDependencies(
// ArithmeticOptimizerStage
string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
const NodeDef* node, SetVector<NodeDef*>* nodes_to_simplify) {
-
- if (node->op() == "Reshape") {
- // Reshape
- // ^
- // |
- // Reshape
- // ^
- // |
- // input
- //
- // becomes
- //
- // Reshape <-+
- // |
- // Reshape |
- // ^ |
- // | |
- // input ---+
- NodeDef* reshape = const_cast<NodeDef*>(node);
- int output_pos = 0;
- string input_node_name = ParseNodeName(reshape->input(0), &output_pos);
- const NodeDef* input = node_map_->GetNode(input_node_name);
- if (input->op() == "Reshape" && !HasControlInputs(*input)) {
- reshape->set_input(0, input->input(0));
- node_map_->UpdateInput(reshape->name(), input->name(), input->input(0));
- nodes_to_simplify->PushBack(reshape);
- return reshape->name();
- }
-
- // If the reshape is a no-op, forward its input to its consumers, unless it
- // anchors a control dependency since we want to make sure that control
- // dependency is triggered.
- if (ReshapeIsIdentity(*reshape, *input, output_pos, *graph_properties_) &&
- !HasControlInputs(*reshape)) {
- return reshape->input(0);
- }
- }
-
- if (node->op() == "Transpose") {
- // Reorder Cast and Transpose if beneficial.
- //
- // A common pattern after the layout optimizer is casting an uint8 NHWC
- // image to float before transposing it to NCHW. It is beneficial to reorder
- // the cast and the transpose to make the transpose process smaller amount
- // of data. This optimization converts
- // Transpose(Cast(image, dst_type), perm)
- // to
- // Cast(Transpose(image, perm), dst_type)
- // when sizeof(image.type) < sizeof(dst_type).
- //
- // TODO(jingyue): This optimization can be generalized to a cast followed by
- // a chain of ops that merely reorder elements (e.g. Reshape and
- // DepthToSpace).
- const NodeDef* transpose = node;
- string dontcare;
- string device;
- // This optimization can be dangerous on devices other than CPU and GPU. The
- // transpose might not be implemented for image.type, or might be slower
- // with image.type than with dst_type.
- if (DeviceNameUtils::SplitDeviceName(transpose->device(), &dontcare,
- &device) &&
- (str_util::StrContains(device, DEVICE_CPU) ||
- str_util::StrContains(device, DEVICE_GPU))) {
- const NodeDef* cast = node_map_->GetNode(transpose->input(0));
- if (cast->op() == "Cast") {
- const NodeDef* input = node_map_->GetNode(cast->input(0));
- const DataType src_type = GetSourceDataType(*cast);
- const DataType dst_type = GetDestinationDataType(*cast);
- if (IsNumberType(src_type) && IsNumberType(dst_type) &&
- DataTypeSize(src_type) < DataTypeSize(dst_type) &&
- !OptimizedNodeExists(*cast, DataTypeString(dst_type)) &&
- !OptimizedNodeExists(*transpose, DataTypeString(src_type))) {
- NodeDef* new_transpose = AddNode(*transpose, DataTypeString(src_type),
- /*copy_node=*/true);
- (*new_transpose->mutable_attr())["T"].set_type(src_type);
- new_transpose->set_input(0, cast->input(0));
- node_map_->AddOutput(input->name(), new_transpose->name());
- node_map_->AddOutput(NodeName(new_transpose->input(1)),
- new_transpose->name());
-
- NodeDef* new_cast =
- AddNode(*cast, DataTypeString(dst_type), /*copy_node=*/true);
- new_cast->set_input(0, new_transpose->name());
- node_map_->AddOutput(new_transpose->name(), new_cast->name());
-
- nodes_to_simplify->PushBack(new_transpose);
- ForwardControlDependencies(new_transpose, {cast, node});
- return new_cast->name();
- }
- }
- }
- }
-
// Fold a multiply of a scalar into the following convolution. This folding
// can jump across nodes that merely reorders data (such as reshape and
// transpose). For example, we can optimize
@@ -2483,10 +2492,14 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext);
if (options_.remove_redundant_cast)
pipeline.AddStage<RemoveRedundantCastStage>(ctx, ctx_ext);
+ if (options_.remove_redundant_reshape)
+ pipeline.AddStage<RemoveRedundantReshape>(ctx, ctx_ext);
if (options_.remove_negation)
pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
if (options_.remove_logical_not)
pipeline.AddStage<RemoveLogicalNotStage>(ctx, ctx_ext);
+ if (options_.reorder_cast_and_transpose)
+ pipeline.AddStage<ReorderCastAndTranspose>(ctx, ctx_ext);
if (options_.hoist_cwise_unary_chains)
pipeline.AddStage<HoistCWiseUnaryChainsStage>(ctx, ctx_ext);
if (options_.convert_sqrt_div_to_rsqrt_mul)
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 962399119d..0fce23a40a 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -71,6 +71,8 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool remove_negation = true;
bool remove_redundant_bitcast = true;
bool remove_redundant_cast = true;
+ bool remove_redundant_reshape = true;
+ bool reorder_cast_and_transpose = true;
// Choose which arithmetic optimizer stages will be enabled for a given
// optimization level by default.
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index a908416e45..02f76df025 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -97,12 +97,22 @@ class ArithmeticOptimizerTest : public GrapplerTest {
}
// Run ArithmeticOptimizer twice to make sure the rewrite is idempotent.
+ // Optionally run a constant folding pass before pruning.
void OptimizeTwiceAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item,
- GraphDef* output) {
+ GraphDef* output, bool const_folding = false) {
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
+
item->graph.Swap(output);
output->Clear();
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
+
+ if (const_folding) {
+ item->graph.Swap(output);
+ output->Clear();
+ TF_EXPECT_OK(ConstantFolding(/*cpu_device=*/nullptr)
+ .Optimize(nullptr, *item, output));
+ }
+
item->graph.Swap(output);
output->Clear();
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
@@ -124,8 +134,10 @@ class ArithmeticOptimizerTest : public GrapplerTest {
options.remove_idempotent = false;
options.remove_redundant_bitcast = false;
options.remove_redundant_cast = false;
+ options.remove_redundant_reshape = false;
options.remove_negation = false;
options.remove_logical_not = false;
+ options.reorder_cast_and_transpose = false;
optimizer->options_ = options;
}
@@ -168,11 +180,21 @@ class ArithmeticOptimizerTest : public GrapplerTest {
optimizer->options_.remove_redundant_cast = true;
}
+ void EnableOnlyRemoveRedundantReshape(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.remove_redundant_reshape = true;
+ }
+
void EnableOnlyRemoveNegation(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_negation = true;
}
+ void EnableOnlyReorderCastAndTranspose(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.reorder_cast_and_transpose = true;
+ }
+
void EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.hoist_cwise_unary_chains = true;
@@ -955,7 +977,7 @@ TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
test::ExpectTensorNear<complex64>(tensors_expected[0], tensors[0], 1e-6);
}
-TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
+TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_IdentityReshape) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs =
ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28}));
@@ -977,11 +999,11 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
auto tensors_expected =
EvaluateNodes(item.graph, item.fetch, {{"Placeholder", x_t}});
EXPECT_EQ(1, tensors_expected.size());
- GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph.Swap(&output);
- TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyRemoveRedundantReshape(&optimizer);
+ OptimizeTwiceAndPrune(&optimizer, &item, &output);
EXPECT_EQ(0, CountOpNodes(output, "Reshape"));
auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", x_t}});
@@ -989,7 +1011,49 @@ TEST_F(ArithmeticOptimizerTest, IdentityReshape) {
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
-TEST_F(ArithmeticOptimizerTest, NotAssumeValidFeeds) {
+TEST_F(ArithmeticOptimizerTest,
+ RemoveRedundantReshape_IdentityReshapeBetweenSymbolicShapes) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output inputs =
+ ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, -1, -1}));
+ Output inputs_shape = ops::Shape(s, inputs);
+ // The target shape of the reshape is the concatenation of `batch_size`, 3,
+ // `height, and `width`.
+ Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}),
+ ops::Const(s, {1}, {1}));
+ Output height = ops::Slice(s, inputs_shape, ops::Const(s, {2}, {1}),
+ ops::Const(s, {1}, {1}));
+ Output width = ops::Slice(s, inputs_shape, ops::Const(s, {3}, {1}),
+ ops::Const(s, {1}, {1}));
+ Output target_shape =
+ ops::Concat(s.WithOpName("target_shape"),
+ {batch_size, ops::Const(s, {3}, {1}), height, width},
+ ops::Const(s, {0}, {}));
+ Output reshape = ops::Reshape(s, inputs, target_shape);
+ Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
+
+ auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 3, 28, 28}));
+ GrapplerItem item;
+ item.fetch = {"outputs"};
+ item.feed = {{"Placeholder", x_t}};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors_expected.size());
+
+ GraphDef output;
+ // Assume valid feed shape in aggressive mode.
+ ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ EnableOnlyRemoveRedundantReshape(&optimizer);
+ OptimizeTwiceAndPrune(&optimizer, &item, &output);
+
+ EXPECT_EQ(0, CountOpNodes(output, "Reshape"));
+ auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+}
+
+TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_NotAssumeValidFeeds) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs =
ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
@@ -1007,10 +1071,9 @@ TEST_F(ArithmeticOptimizerTest, NotAssumeValidFeeds) {
EXPECT_EQ(1, tensors_expected.size());
GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
-
- item.graph.Swap(&output);
- TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+ ArithmeticOptimizer optimizer;
+ EnableOnlyRemoveRedundantReshape(&optimizer);
+ OptimizeTwiceAndPrune(&optimizer, &item, &output);
// The reshape is preserved because the shape of the placeholder can be
// different from the shape of the actual feed.
@@ -1021,7 +1084,8 @@ TEST_F(ArithmeticOptimizerTest, NotAssumeValidFeeds) {
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
-TEST_F(ArithmeticOptimizerTest, AssumeValidFeedsInAggressiveMode) {
+TEST_F(ArithmeticOptimizerTest,
+ RemoveRedundantReshape_AssumeValidFeedsInAggressiveMode) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs =
ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3, 28, 28}));
@@ -1037,12 +1101,11 @@ TEST_F(ArithmeticOptimizerTest, AssumeValidFeedsInAggressiveMode) {
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
EXPECT_EQ(1, tensors_expected.size());
- GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer(RewriterConfig::AGGRESSIVE)
- .Optimize(nullptr, item, &output));
- item.graph.Swap(&output);
- TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+ GraphDef output;
+ ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+ EnableOnlyRemoveRedundantReshape(&optimizer);
+ OptimizeTwiceAndPrune(&optimizer, &item, &output);
EXPECT_EQ(0, CountOpNodes(output, "Reshape"));
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
@@ -1050,7 +1113,7 @@ TEST_F(ArithmeticOptimizerTest, AssumeValidFeedsInAggressiveMode) {
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
-TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) {
+TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_NotIdentityReshape) {
// Reshape from [-1,3,28,28] to [8,-1,28,28] is not identity, because it can
// be from [4,3,28,28] to [8,6,28,28].
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
@@ -1066,11 +1129,11 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) {
item.feed = {{"Placeholder", x_t}};
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
EXPECT_EQ(1, tensors_expected.size());
- GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph.Swap(&output);
- TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyRemoveRedundantReshape(&optimizer);
+ OptimizeTwiceAndPrune(&optimizer, &item, &output);
EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
@@ -1078,7 +1141,8 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshape) {
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
-TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) {
+TEST_F(ArithmeticOptimizerTest,
+ RemoveRedundantReshape_NotIdentityReshapeTooManyUnknownDimSizes) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs =
ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3}));
@@ -1088,16 +1152,16 @@ TEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) {
GrapplerItem item;
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph.Swap(&output);
- TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyRemoveRedundantReshape(&optimizer);
+ OptimizeTwiceAndPrune(&optimizer, &item, &output);
EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
}
-TEST_F(ArithmeticOptimizerTest, CombineReshapes) {
+TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshape_CombineReshapes) {
// Converts an NCHW_VECT_C tensor to NHWC and then flattens it to 2D. The two
// reshapes should be combined.
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
@@ -1122,11 +1186,11 @@ TEST_F(ArithmeticOptimizerTest, CombineReshapes) {
item.feed = {{"nchw_vect_c", x_t}};
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
EXPECT_EQ(1, tensors_expected.size());
- GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
- item.graph.Swap(&output);
- TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyRemoveRedundantReshape(&optimizer);
+ OptimizeTwiceAndPrune(&optimizer, &item, &output);
EXPECT_EQ(1, CountOpNodes(output, "Reshape"));
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
@@ -1492,6 +1556,7 @@ TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) {
// =>
// Conv2D(Cast(Transpose(I)), W*S)
tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/gpu:0");
+
Output inputs =
ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
Output cast = ops::Cast(s, inputs, DT_FLOAT);
@@ -1509,28 +1574,28 @@ TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
-
- // Run the optimizer twice to make sure the rewrite is idempotent.
- item.graph.Swap(&output);
- TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
+ ArithmeticOptimizer optimizer;
+ OptimizeTwiceAndPrune(&optimizer, &item, &output, /*const_folding=*/true);
- item.graph.Swap(&output);
- TF_EXPECT_OK(
- ConstantFolding(/*cpu_device=*/nullptr).Optimize(nullptr, item, &output));
+ NodeMap node_map(&output);
- item.graph.Swap(&output);
- TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+ // Expected names for the optimized nodes.
+ const string p = "ArithmeticOptimizer/ReorderCastAndTranspose_";
+ const string optimized_cast_name = strings::StrCat(p, "float_Cast");
+ const string optimized_transpose_name = strings::StrCat(p, "uint8_Transpose");
- NodeMap node_map(&output);
- const NodeDef* inputs_node = CHECK_NOTNULL(node_map.GetNode("Placeholder"));
- const NodeDef* transpose_node =
- CHECK_NOTNULL(node_map.GetNode(OptimizedName("Transpose_uint8")));
- const NodeDef* cast_node =
- CHECK_NOTNULL(node_map.GetNode(OptimizedName("Cast_float")));
+ const NodeDef* inputs_node = node_map.GetNode("Placeholder");
+ const NodeDef* transpose_node = node_map.GetNode(optimized_transpose_name);
+ const NodeDef* cast_node = node_map.GetNode(optimized_cast_name);
const NodeDef* weights_node =
- CHECK_NOTNULL(node_map.GetNode(OptimizedName("weights_scaled_Conv2D")));
- const NodeDef* conv_node = CHECK_NOTNULL(node_map.GetNode("Conv2D"));
+ node_map.GetNode(OptimizedName("weights_scaled_Conv2D"));
+ const NodeDef* conv_node = node_map.GetNode("Conv2D");
+
+ ASSERT_TRUE(inputs_node != nullptr);
+ ASSERT_TRUE(transpose_node != nullptr);
+ ASSERT_TRUE(cast_node != nullptr);
+ ASSERT_TRUE(weights_node != nullptr);
+ ASSERT_TRUE(conv_node != nullptr);
EXPECT_EQ(output.node_size(), 7);
EXPECT_EQ(transpose_node->input(0), inputs_node->name());
diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD
index de05c647d6..e292ff200a 100644
--- a/tensorflow/core/kernels/batching_util/BUILD
+++ b/tensorflow/core/kernels/batching_util/BUILD
@@ -127,6 +127,27 @@ tf_cc_test(
)
cc_library(
+ name = "serial_device_batch_scheduler",
+ hdrs = ["serial_device_batch_scheduler.h"],
+ deps = [
+ ":batch_scheduler",
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_cc_test(
+ name = "serial_device_batch_scheduler_test",
+ srcs = ["serial_device_batch_scheduler_test.cc"],
+ deps = [
+ ":fake_clock_env",
+ ":serial_device_batch_scheduler",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
name = "basic_batch_scheduler",
hdrs = ["basic_batch_scheduler.h"],
deps = [
diff --git a/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h b/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h
new file mode 100644
index 0000000000..518f2ff8a9
--- /dev/null
+++ b/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h
@@ -0,0 +1,548 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SERIAL_DEVICE_BATCH_SCHEDULER_H_
+#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SERIAL_DEVICE_BATCH_SCHEDULER_H_
+
+#include <algorithm>
+#include <functional>
+#include <memory>
+#include <random>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace serving {
+namespace internal {
+template <typename TaskType>
+class SDBSBatch;
+
+template <typename TaskType>
+class SDBSQueue;
+} // namespace internal
+
+// EXPERIMENTAL: API MAY BE SUBJECTED TO SUDDEN CHANGES.
+//
+// Shared batch scheduler designed for batches which are processed by a serial
+// device (e.g. GPU, TPU). When batch processing involves a mix of
+// parallelizable cpu work and non-parallelizable on-device work, overall
+// latency can be minimized by producing batches at a (load dependent) rate
+// which keeps the serial device uniformly busy.
+//
+// SerialDeviceBatchScheduler (SDBS) controls the batching rate by limiting the
+// allowed number of concurrently processed batches. Too large a limit causes
+// batches to pile up behind the serial device, adding to the overall batch
+// latency. Too small a limit underutilizes the serial device and harms latency
+// by forcing batches to wait longer to be processed. Feedback from the device
+// (i.e. avg number of batches directly pending on the device) is used to set
+// the correct limit.
+//
+// SDBS groups requests into per model batches which are processed when a batch
+// processing thread becomes available. SDBS prioritizes batches primarily by
+// age (i.e. the batch's oldest request) along with a configurable preference
+// for scheduling larger batches first.
+
+
+template <typename TaskType>
+class SerialDeviceBatchScheduler : public std::enable_shared_from_this<
+ SerialDeviceBatchScheduler<TaskType>> {
+ public:
+ ~SerialDeviceBatchScheduler();
+
+ struct Options {
+ // The name to use for the pool of batch threads.
+ string thread_pool_name = {"batch_threads"};
+ // Maximum number of batch processing threads.
+ int64 num_batch_threads = port::NumSchedulableCPUs();
+ // Although batch selection is primarily based on age, this parameter
+ // specifies a preference for larger batches. A full batch will be
+ // scheduled before an older, nearly empty batch as long as the age gap is
+ // less than full_batch_scheduling_boost_micros. The optimal value for this
+ // parameter should be of order the batch processing latency, but must be
+ // chosen carefully, as too large a value will harm tail latency.
+ int64 full_batch_scheduling_boost_micros = 0;
+ // The environment to use (typically only overridden by test code).
+ Env* env = Env::Default();
+ // Initial limit for number of batches being concurrently processed.
+ int64 initial_in_flight_batches_limit = 3;
+ // Returns the current number of batches directly waiting to be processed
+ // by the serial device (i.e. GPU, TPU).
+ std::function<int64()> get_pending_on_serial_device;
+ // Desired average number of batches directly waiting to be processed by the
+ // serial device. Small numbers of O(1) should deliver the best latency.
+ double target_pending = 2;
+ // Number of batches between potential adjustments of
+ // in_flight_batches_limit. Larger numbers will reduce noise, but will be
+ // less responsive to sudden changes in workload.
+ int64 batches_to_average_over = 1000;
+ };
+
+ // Ownership is shared between the caller of Create() and any queues created
+ // via AddQueue().
+ static Status Create(
+ const Options& options,
+ std::shared_ptr<SerialDeviceBatchScheduler<TaskType>>* scheduler);
+
+ struct QueueOptions {
+ // Maximum size of each batch.
+ int max_batch_size = 1000;
+ // Maximum number of enqueued (i.e. non-scheduled) batches.
+ int max_enqueued_batches = 10;
+ };
+
+ using BatchProcessor = std::function<void(std::unique_ptr<Batch<TaskType>>)>;
+
+ // Adds queue (and its callback) to be managed by this scheduler.
+ Status AddQueue(const QueueOptions& options,
+ BatchProcessor process_batch_callback,
+ std::unique_ptr<BatchScheduler<TaskType>>* queue);
+
+ double in_flight_batches_limit() {
+ mutex_lock l(mu_);
+ return in_flight_batches_limit_;
+ }
+
+ double recent_low_traffic_ratio() {
+ mutex_lock l(mu_);
+ return recent_low_traffic_ratio_;
+ }
+
+ private:
+ // access to AddBatch(), RemoveQueue(), env().
+ friend class internal::SDBSQueue<TaskType>;
+
+ explicit SerialDeviceBatchScheduler(const Options& options);
+
+ // Continuously retrieves and processes batches.
+ void ProcessBatches();
+
+ // Notifies scheduler of non-empty batch which is eligible for processing.
+ void AddBatch(const internal::SDBSBatch<TaskType>* batch);
+
+ // Removes queue from scheduler.
+ void RemoveQueue(const internal::SDBSQueue<TaskType>* queue);
+
+ Env* env() const { return options_.env; }
+
+ const Options options_;
+
+ // Collection of batches added by AddBatch. Owned by scheduler until they are
+ // released for processing.
+ std::vector<const internal::SDBSBatch<TaskType>*> batches_ GUARDED_BY(mu_);
+
+ // Unowned queues and callbacks added by AddQueue.
+ std::unordered_map<const internal::SDBSQueue<TaskType>*, BatchProcessor>
+ queues_and_callbacks_ GUARDED_BY(mu_);
+
+ // Responsible for running the batch processing callbacks.
+ std::unique_ptr<thread::ThreadPool> batch_thread_pool_;
+
+ // Limit on number of batches which can be concurrently processed.
+ int64 in_flight_batches_limit_ GUARDED_BY(mu_);
+
+ // Number of batch processing threads.
+ int64 processing_threads_ GUARDED_BY(mu_) = 0;
+
+ // Number of batches processed since the last in_flight_batches_limit_
+ // adjustment.
+ int64 batch_count_ GUARDED_BY(mu_) = 0;
+
+ // Number of times since the last in_flight_batches_limit_ adjustment when a
+ // processing thread was available but there were no batches to process.
+ int64 no_batch_count_ GUARDED_BY(mu_) = 0;
+
+ // Sum of batches pending on the serial device since the last
+ // in_flight_batches_limit_ adjustment.
+ int64 pending_sum_ = 0;
+
+ // Sum of batch latencies since the last in_flight_batches_limit_ adjustment.
+ int64 batch_latency_sum_ = 0;
+
+ // Average period between which two consecutive batches begin processing.
+ int64 batch_period_micros_ = 0;
+
+ // Moving average tracking the fraction of recent in_flight_batches_limit_
+ // adjustments where the external traffic was not high enough to provide
+ // useful feedback for an adjustment.
+ double recent_low_traffic_ratio_ = 0;
+
+ mutex mu_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(SerialDeviceBatchScheduler);
+};
+
+//////////////////////////////////////////////////////////
+// Implementation details follow. API users need not read.
+
+namespace internal {
+// Consolidates tasks into batches, passing them off to the
+// SerialDeviceBatchScheduler for processing.
+template <typename TaskType>
+class SDBSQueue : public BatchScheduler<TaskType> {
+ public:
+ using QueueOptions =
+ typename SerialDeviceBatchScheduler<TaskType>::QueueOptions;
+
+ SDBSQueue(std::shared_ptr<SerialDeviceBatchScheduler<TaskType>> scheduler,
+ const QueueOptions& options);
+
+ ~SDBSQueue() override;
+
+ // Adds task to current batch. Fails if the task size is larger than the batch
+ // size or if the current batch is full and this queue's number of outstanding
+ // batches is at its maximum.
+ Status Schedule(std::unique_ptr<TaskType>* task) override;
+
+ // Number of tasks waiting to be scheduled.
+ size_t NumEnqueuedTasks() const override;
+
+ // Number of size 1 tasks which could currently be scheduled without failing.
+ size_t SchedulingCapacity() const override;
+
+ // Notifies queue that a batch is about to be scheduled; the queue should not
+ // place any more tasks in this batch.
+ void ReleaseBatch(const SDBSBatch<TaskType>* batch);
+
+ size_t max_task_size() const override { return options_.max_batch_size; }
+
+ private:
+ std::shared_ptr<SerialDeviceBatchScheduler<TaskType>> scheduler_;
+ const QueueOptions options_;
+ // Owned by scheduler_.
+ SDBSBatch<TaskType>* current_batch_ GUARDED_BY(mu_) = nullptr;
+ int64 num_enqueued_batches_ GUARDED_BY(mu_) = 0;
+ int64 num_enqueued_tasks_ GUARDED_BY(mu_) = 0;
+ mutable mutex mu_;
+ TF_DISALLOW_COPY_AND_ASSIGN(SDBSQueue);
+};
+
+// Batch which remembers when and by whom it was created.
+template <typename TaskType>
+class SDBSBatch : public Batch<TaskType> {
+ public:
+ SDBSBatch(SDBSQueue<TaskType>* queue, int64 creation_time_micros)
+ : queue_(queue), creation_time_micros_(creation_time_micros) {}
+
+ ~SDBSBatch() override {}
+
+ SDBSQueue<TaskType>* queue() const { return queue_; }
+
+ int64 creation_time_micros() const { return creation_time_micros_; }
+
+ private:
+ SDBSQueue<TaskType>* queue_;
+ const int64 creation_time_micros_;
+ TF_DISALLOW_COPY_AND_ASSIGN(SDBSBatch);
+};
+} // namespace internal
+
+// ---------------- SerialDeviceBatchScheduler ----------------
+
+template <typename TaskType>
+Status SerialDeviceBatchScheduler<TaskType>::Create(
+ const Options& options,
+ std::shared_ptr<SerialDeviceBatchScheduler<TaskType>>* scheduler) {
+ if (options.num_batch_threads < 1) {
+ return errors::InvalidArgument("num_batch_threads must be positive; was ",
+ options.num_batch_threads);
+ }
+ if (options.initial_in_flight_batches_limit < 1) {
+ return errors::InvalidArgument(
+ "initial_in_flight_batches_limit must be positive; was ",
+ options.initial_in_flight_batches_limit);
+ }
+ if (options.initial_in_flight_batches_limit > options.num_batch_threads) {
+ return errors::InvalidArgument(
+ "initial_in_flight_batches_limit (",
+ options.initial_in_flight_batches_limit,
+ ") should not be larger than num_batch_threads (",
+ options.num_batch_threads, ")");
+ }
+ if (options.full_batch_scheduling_boost_micros < 0) {
+ return errors::InvalidArgument(
+ "full_batch_scheduling_boost_micros can't be negative; was ",
+ options.full_batch_scheduling_boost_micros);
+ }
+ if (options.batches_to_average_over < 1) {
+ return errors::InvalidArgument(
+ "batches_to_average_over should be "
+ "greater than or equal to 1; was ",
+ options.batches_to_average_over);
+ }
+ if (options.target_pending <= 0) {
+ return errors::InvalidArgument(
+ "target_pending should be larger than zero; was ",
+ options.target_pending);
+ }
+ if (!options.get_pending_on_serial_device) {
+ return errors::InvalidArgument(
+ "get_pending_on_serial_device must be "
+ "specified");
+ }
+ scheduler->reset(new SerialDeviceBatchScheduler<TaskType>(options));
+ return Status::OK();
+}
+
+template <typename TaskType>
+SerialDeviceBatchScheduler<TaskType>::SerialDeviceBatchScheduler(
+ const Options& options)
+ : options_(options),
+ in_flight_batches_limit_(options.initial_in_flight_batches_limit),
+ processing_threads_(options.initial_in_flight_batches_limit) {
+ batch_thread_pool_.reset(new thread::ThreadPool(
+ env(), options.thread_pool_name, options.num_batch_threads));
+ for (int i = 0; i < processing_threads_; i++) {
+ batch_thread_pool_->Schedule(
+ std::bind(&SerialDeviceBatchScheduler<TaskType>::ProcessBatches, this));
+ }
+}
+
+template <typename TaskType>
+SerialDeviceBatchScheduler<TaskType>::~SerialDeviceBatchScheduler() {
+ // Signal processing threads to exit.
+ {
+ mutex_lock l(mu_);
+ processing_threads_ = 0;
+ }
+ // Hangs until all threads finish.
+ batch_thread_pool_.reset();
+}
+
+template <typename TaskType>
+Status SerialDeviceBatchScheduler<TaskType>::AddQueue(
+ const QueueOptions& options, BatchProcessor process_batch_callback,
+ std::unique_ptr<BatchScheduler<TaskType>>* queue) {
+ if (options.max_batch_size <= 0) {
+ return errors::InvalidArgument("max_batch_size must be positive; was ",
+ options.max_batch_size);
+ }
+ if (options.max_enqueued_batches <= 0) {
+ return errors::InvalidArgument(
+ "max_enqueued_batches must be positive; was ",
+ options.max_enqueued_batches);
+ }
+ internal::SDBSQueue<TaskType>* SDBS_queue_raw;
+ queue->reset(SDBS_queue_raw = new internal::SDBSQueue<TaskType>(
+ this->shared_from_this(), options));
+ mutex_lock l(mu_);
+ queues_and_callbacks_[SDBS_queue_raw] = process_batch_callback;
+ return Status::OK();
+}
+
+template <typename TaskType>
+void SerialDeviceBatchScheduler<TaskType>::AddBatch(
+ const internal::SDBSBatch<TaskType>* batch) {
+ mutex_lock l(mu_);
+ batches_.push_back(batch);
+}
+
+template <typename TaskType>
+void SerialDeviceBatchScheduler<TaskType>::RemoveQueue(
+ const internal::SDBSQueue<TaskType>* queue) {
+ mutex_lock l(mu_);
+ queues_and_callbacks_.erase(queue);
+}
+
+template <typename TaskType>
+void SerialDeviceBatchScheduler<TaskType>::ProcessBatches() {
+ const int64 kIdleThreadSleepTimeMicros = 1000;
+ const double kMaxNoBatchRatio = .1;
+ const double kLowTrafficMovingAverageFactor = .1;
+ for (;;) {
+ mu_.lock();
+ if (processing_threads_ < 1 ||
+ processing_threads_ > in_flight_batches_limit_) {
+ processing_threads_--;
+ mu_.unlock();
+ break;
+ }
+ if (batches_.empty()) {
+ no_batch_count_++;
+ int64 sleep_time = batch_period_micros_ ? batch_period_micros_
+ : kIdleThreadSleepTimeMicros;
+ mu_.unlock();
+ env()->SleepForMicroseconds(sleep_time);
+ continue;
+ }
+ auto best_it = batches_.begin();
+ double best_score =
+ (*best_it)->creation_time_micros() -
+ options_.full_batch_scheduling_boost_micros * (*best_it)->size() /
+ static_cast<double>((*best_it)->queue()->max_task_size());
+ for (auto it = batches_.begin() + 1; it != batches_.end(); it++) {
+ const double score =
+ (*it)->creation_time_micros() -
+ options_.full_batch_scheduling_boost_micros * (*it)->size() /
+ static_cast<double>((*it)->queue()->max_task_size());
+ if (score < best_score) {
+ best_score = score;
+ best_it = it;
+ }
+ }
+ const internal::SDBSBatch<TaskType>* batch = *best_it;
+ batches_.erase(best_it);
+ // Queue may destroy itself after ReleaseBatch is called.
+ batch->queue()->ReleaseBatch(batch);
+ auto callback = queues_and_callbacks_[batch->queue()];
+ mu_.unlock();
+ int64 start_time = env()->NowMicros();
+ callback(std::unique_ptr<Batch<TaskType>>(
+ const_cast<internal::SDBSBatch<TaskType>*>(batch)));
+ int64 end_time = env()->NowMicros();
+ mu_.lock();
+ batch_count_++;
+ batch_latency_sum_ += end_time - start_time;
+ pending_sum_ += options_.get_pending_on_serial_device();
+ if (batch_count_ == options_.batches_to_average_over) {
+ recent_low_traffic_ratio_ *= (1 - kLowTrafficMovingAverageFactor);
+ // Only adjust in_flight_batches_limit_ if external load is large enough
+ // to consistently provide batches. Otherwise we would (mistakenly) assume
+ // that the device is underutilized because in_flight_batches_limit_ is
+ // too small.
+ if (no_batch_count_ < kMaxNoBatchRatio * batch_count_) {
+ double avg_pending = pending_sum_ / static_cast<double>(batch_count_);
+ // Avg processing time / # of concurrent batches gives the avg period
+ // between which two consecutive batches begin processing. Used to set a
+ // reasonable sleep time for idle batch processing threads.
+ batch_period_micros_ =
+ batch_latency_sum_ / batch_count_ / in_flight_batches_limit_;
+ // When the processing pipeline is consistently busy, the average number
+ // of pending batches differs from in_flight_batches_limit_ by a
+ // load-dependent offset. Adjust in_flight_batches_limit_to maintain
+ // the desired target pending.
+ in_flight_batches_limit_ +=
+ std::round(options_.target_pending - avg_pending);
+ in_flight_batches_limit_ = std::max(in_flight_batches_limit_, 1LL);
+ in_flight_batches_limit_ =
+ std::min(in_flight_batches_limit_, options_.num_batch_threads);
+ // Add extra processing threads if necessary.
+ if (processing_threads_ > 0 &&
+ processing_threads_ < in_flight_batches_limit_) {
+ int extra_threads = in_flight_batches_limit_ - processing_threads_;
+ for (int i = 0; i < extra_threads; i++) {
+ batch_thread_pool_->Schedule(std::bind(
+ &SerialDeviceBatchScheduler<TaskType>::ProcessBatches, this));
+ }
+ processing_threads_ = in_flight_batches_limit_;
+ }
+ } else {
+ recent_low_traffic_ratio_ += kLowTrafficMovingAverageFactor;
+ }
+ batch_count_ = 0;
+ no_batch_count_ = 0;
+ pending_sum_ = 0;
+ batch_latency_sum_ = 0;
+ }
+ mu_.unlock();
+ }
+}
+
+// ---------------- SDBSQueue ----------------
+
+namespace internal {
+template <typename TaskType>
+SDBSQueue<TaskType>::SDBSQueue(
+ std::shared_ptr<SerialDeviceBatchScheduler<TaskType>> scheduler,
+ const QueueOptions& options)
+ : scheduler_(scheduler), options_(options) {}
+
+template <typename TaskType>
+SDBSQueue<TaskType>::~SDBSQueue() {
+ // Wait until last batch has been scheduled.
+ const int kSleepMicros = 1000;
+ for (;;) {
+ {
+ mutex_lock l(mu_);
+ if (num_enqueued_batches_ == 0) {
+ break;
+ }
+ }
+ scheduler_->env()->SleepForMicroseconds(kSleepMicros);
+ }
+ scheduler_->RemoveQueue(this);
+}
+
+template <typename TaskType>
+Status SDBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
+ SDBSBatch<TaskType>* new_batch = nullptr;
+ size_t size = (*task)->size();
+ if (size > options_.max_batch_size) {
+ return errors::InvalidArgument("Task size ", size,
+ " is larger than maximum batch size ",
+ options_.max_batch_size);
+ }
+ {
+ mutex_lock l(mu_);
+ // Current batch is full, create another if allowed.
+ if (current_batch_ &&
+ current_batch_->size() + size > options_.max_batch_size) {
+ if (num_enqueued_batches_ >= options_.max_enqueued_batches) {
+ return errors::Unavailable("The batch scheduling queue is full");
+ }
+ current_batch_->Close();
+ current_batch_ = nullptr;
+ }
+ if (!current_batch_) {
+ num_enqueued_batches_++;
+ current_batch_ = new_batch =
+ new SDBSBatch<TaskType>(this, scheduler_->env()->NowMicros());
+ }
+ current_batch_->AddTask(std::move(*task));
+ num_enqueued_tasks_++;
+ }
+ // AddBatch must be called outside of lock, since it may call ReleaseBatch.
+ if (new_batch != nullptr) scheduler_->AddBatch(new_batch);
+ return Status::OK();
+}
+
+template <typename TaskType>
+void SDBSQueue<TaskType>::ReleaseBatch(const SDBSBatch<TaskType>* batch) {
+ mutex_lock l(mu_);
+ num_enqueued_batches_--;
+ num_enqueued_tasks_ -= batch->num_tasks();
+ if (batch == current_batch_) {
+ current_batch_->Close();
+ current_batch_ = nullptr;
+ }
+}
+
+template <typename TaskType>
+size_t SDBSQueue<TaskType>::NumEnqueuedTasks() const {
+ mutex_lock l(mu_);
+ return num_enqueued_tasks_;
+}
+
+template <typename TaskType>
+size_t SDBSQueue<TaskType>::SchedulingCapacity() const {
+ mutex_lock l(mu_);
+ const int current_batch_capacity =
+ current_batch_ ? options_.max_batch_size - current_batch_->size() : 0;
+ const int spare_batches =
+ options_.max_enqueued_batches - num_enqueued_batches_;
+ return spare_batches * options_.max_batch_size + current_batch_capacity;
+}
+} // namespace internal
+} // namespace serving
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_SERIAL_DEVICE_BATCH_SCHEDULER_H_
diff --git a/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler_test.cc b/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler_test.cc
new file mode 100644
index 0000000000..a2f8f9a03e
--- /dev/null
+++ b/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler_test.cc
@@ -0,0 +1,394 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h"
+
+#include "tensorflow/core/kernels/batching_util/fake_clock_env.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace serving {
+namespace anonymous {
+
+class FakeTask : public BatchTask {
+ public:
+ explicit FakeTask(size_t size) : size_(size) {}
+
+ ~FakeTask() override = default;
+
+ size_t size() const override { return size_; }
+
+ private:
+ const size_t size_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FakeTask);
+};
+
+// Creates a FakeTask of size 'task_size', and calls 'scheduler->Schedule()' on
+// that task. Returns the resulting status.
+Status ScheduleTask(size_t task_size, BatchScheduler<FakeTask>* scheduler) {
+ std::unique_ptr<FakeTask> task(new FakeTask(task_size));
+ Status status = scheduler->Schedule(&task);
+ // Schedule() should have consumed 'task' iff it returned Status::OK.
+ CHECK_EQ(status.ok(), task == nullptr);
+ return status;
+}
+
+// Creates a thread that waits on 'start' and then advances the fake clock in
+// 'env' in a loop until 'stop' is notified. Useful for allowing objects that
+// use the clock to be destroyed.
+std::unique_ptr<Thread> CreateFakeClockAdvancerThread(
+ test_util::FakeClockEnv* env, Notification* start, Notification* stop) {
+ return std::unique_ptr<Thread>(Env::Default()->StartThread(
+ {}, "FakeClockAdvancerThread", [env, start, stop] {
+ start->WaitForNotification();
+ while (!stop->HasBeenNotified()) {
+ env->AdvanceByMicroseconds(10);
+ Env::Default()->SleepForMicroseconds(10);
+ }
+ }));
+}
+
+TEST(SerialDeviceBatchSchedulerTest, BadOptions) {
+ using Scheduler = SerialDeviceBatchScheduler<FakeTask>;
+ std::shared_ptr<Scheduler> scheduler;
+ Scheduler::Options default_options;
+ default_options.get_pending_on_serial_device = []() { return 0; };
+ Scheduler::Options options = default_options;
+ options.num_batch_threads = 0;
+ EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
+ options = default_options;
+ options.initial_in_flight_batches_limit = 0;
+ EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
+ options = default_options;
+ options.num_batch_threads = 5;
+ options.initial_in_flight_batches_limit = 8;
+ EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
+ options = default_options;
+ options.batches_to_average_over = -5;
+ EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
+ options = default_options;
+ options.target_pending = 0;
+ EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
+ options = Scheduler::Options();
+ EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
+}
+
+TEST(SerialDeviceBatchSchedulerTest, InFlightBatchesLimit) {
+ SerialDeviceBatchScheduler<FakeTask>::Options options;
+ options.num_batch_threads = 3;
+ options.initial_in_flight_batches_limit = 2;
+ options.batches_to_average_over = 1000;
+ options.get_pending_on_serial_device = []() { return 0; };
+ mutex mu;
+ int processed_batches = 0;
+ Notification finish_processing;
+ auto queue_callback = [&mu, &processed_batches, &finish_processing](
+ std::unique_ptr<Batch<FakeTask>> batch) {
+ ASSERT_TRUE(batch->IsClosed());
+ EXPECT_GT(batch->num_tasks(), 0);
+ mu.lock();
+ int batch_num = ++processed_batches;
+ mu.unlock();
+ if (batch_num == 2) {
+ // Give third batch a chance to process if it's going to.
+ Env::Default()->SleepForMicroseconds(1000);
+ finish_processing.Notify();
+ }
+ if (batch_num == 3) {
+ ASSERT_TRUE(finish_processing.HasBeenNotified());
+ }
+ finish_processing.WaitForNotification();
+ };
+ std::shared_ptr<SerialDeviceBatchScheduler<FakeTask>> scheduler;
+ TF_ASSERT_OK(
+ SerialDeviceBatchScheduler<FakeTask>::Create(options, &scheduler));
+ std::unique_ptr<BatchScheduler<FakeTask>> queue1;
+ std::unique_ptr<BatchScheduler<FakeTask>> queue2;
+ std::unique_ptr<BatchScheduler<FakeTask>> queue3;
+ TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue1));
+ TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue2));
+ TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue3));
+ // Create 3 batches, only 2 should be processed concurrently.
+ TF_ASSERT_OK(ScheduleTask(100, queue1.get()));
+ TF_ASSERT_OK(ScheduleTask(100, queue2.get()));
+ TF_ASSERT_OK(ScheduleTask(100, queue3.get()));
+}
+
+TEST(SerialDeviceBatchSchedulerTest, PendingOnSerialDevice) {
+ mutex mu;
+ int pending;
+ SerialDeviceBatchScheduler<FakeTask>::Options options;
+ options.num_batch_threads = 3;
+ options.initial_in_flight_batches_limit = 1;
+ options.batches_to_average_over = 1;
+ options.target_pending = 3;
+ options.get_pending_on_serial_device = [&mu, &pending]() {
+ mutex_lock l(mu);
+ return pending;
+ };
+ std::shared_ptr<SerialDeviceBatchScheduler<FakeTask>> scheduler;
+ TF_ASSERT_OK(
+ SerialDeviceBatchScheduler<FakeTask>::Create(options, &scheduler));
+ // Make sure batch processing thread has gone to sleep.
+ Env::Default()->SleepForMicroseconds(1000);
+ int processed_batches = 0;
+ Notification start_processing;
+ auto queue_callback = [&mu, &processed_batches, &start_processing, &pending,
+ &scheduler](std::unique_ptr<Batch<FakeTask>> batch) {
+ // Be careful with mutex mu to avoid potential deadlock with mutex mu_
+ // held in ProcessBatch() and in_flight_batches_limit().
+ int batch_num;
+ {
+ mutex_lock l(mu);
+ batch_num = ++processed_batches;
+ }
+ switch (batch_num) {
+ case 1:
+ start_processing.WaitForNotification();
+ {
+ mutex_lock l(mu);
+ pending = 2;
+ }
+ break;
+ case 2:
+ // No batches initially --> low traffic --> no adjustment.
+ CHECK_EQ(scheduler->in_flight_batches_limit(), 1);
+ {
+ mutex_lock l(mu);
+ pending = 3;
+ }
+ break;
+ case 3:
+ // Pending at target --> no adjustment.
+ CHECK_EQ(scheduler->in_flight_batches_limit(), 1);
+ {
+ mutex_lock l(mu);
+ pending = 1;
+ }
+ break;
+ case 4:
+ // Small pending --> 2 additional threads added.
+ CHECK_EQ(scheduler->in_flight_batches_limit(), 3);
+ {
+ mutex_lock l(mu);
+ pending = 3;
+ }
+ break;
+ default:
+ break;
+ }
+ };
+ std::unique_ptr<BatchScheduler<FakeTask>> queue;
+ TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue));
+ // Create 4 batches.
+ for (int i = 0; i < 4; i++) {
+ TF_ASSERT_OK(ScheduleTask(800, queue.get()));
+ }
+ start_processing.Notify();
+}
+
+TEST(SerialDeviceBatchSchedulerTest, FullBatchSchedulingBoostMicros) {
+ test_util::FakeClockEnv env(Env::Default());
+ Notification start_teardown, stop_teardown;
+ std::unique_ptr<Thread> teardown_thread =
+ CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
+ {
+ SerialDeviceBatchScheduler<FakeTask>::Options options;
+ options.env = &env;
+ options.initial_in_flight_batches_limit = 1;
+ options.batches_to_average_over = 1000;
+ options.full_batch_scheduling_boost_micros = 10;
+ options.get_pending_on_serial_device = []() { return 0; };
+ mutex mu;
+ int processed_batches = 0;
+ auto queue_callback =
+ [&mu, &processed_batches](std::unique_ptr<Batch<FakeTask>> batch) {
+ ASSERT_TRUE(batch->IsClosed());
+ mutex_lock l(mu);
+ processed_batches++;
+ switch (processed_batches) {
+ case 1:
+ EXPECT_EQ(1000, batch->size());
+ break;
+ case 2:
+ EXPECT_EQ(100, batch->size());
+ break;
+ case 3:
+ EXPECT_EQ(80, batch->size());
+ break;
+ default:
+ EXPECT_TRUE(false) << "Should only have 3 batches";
+ }
+ };
+ std::shared_ptr<SerialDeviceBatchScheduler<FakeTask>> scheduler;
+ TF_ASSERT_OK(
+ SerialDeviceBatchScheduler<FakeTask>::Create(options, &scheduler));
+ // Make sure batch processing thread has gone to sleep.
+ Env::Default()->SleepForMicroseconds(1000);
+ SerialDeviceBatchScheduler<FakeTask>::QueueOptions queue_options;
+ std::unique_ptr<BatchScheduler<FakeTask>> queue1;
+ std::unique_ptr<BatchScheduler<FakeTask>> queue2;
+ std::unique_ptr<BatchScheduler<FakeTask>> queue3;
+ queue_options.max_batch_size = 1000;
+ TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_callback, &queue1));
+ queue_options.max_batch_size = 1000;
+ TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_callback, &queue2));
+ queue_options.max_batch_size = 100;
+ TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_callback, &queue3));
+
+ TF_ASSERT_OK(ScheduleTask(100, queue1.get()));
+ // First batch - creation time: 0, fullness: 0.1, sched score: -1
+ env.AdvanceByMicroseconds(3);
+ TF_ASSERT_OK(ScheduleTask(1000, queue2.get()));
+ // Second batch - creation time: 3, fullness: 1, sched score: -7
+ env.AdvanceByMicroseconds(5);
+ TF_ASSERT_OK(ScheduleTask(80, queue3.get()));
+ // Third batch - creation time: 8, fullness: .8, sched score: 0
+ // Release the batch processing thread.
+ env.AdvanceByMicroseconds(1000);
+ start_teardown.Notify();
+ }
+ stop_teardown.Notify();
+}
+
+TEST(SerialDeviceBatchSchedulerTest, DeleteQueue) {
+ SerialDeviceBatchScheduler<FakeTask>::Options options;
+ options.initial_in_flight_batches_limit = 1;
+ options.batches_to_average_over = 1000;
+ options.get_pending_on_serial_device = []() { return 0; };
+ mutex mu;
+ int processed_batches = 0;
+ Notification finish_processing;
+ auto queue_callback = [&mu, &processed_batches, &finish_processing](
+ std::unique_ptr<Batch<FakeTask>> batch) {
+ ASSERT_TRUE(batch->IsClosed());
+ EXPECT_GT(batch->num_tasks(), 0);
+ finish_processing.WaitForNotification();
+ mu.lock();
+ processed_batches++;
+ mu.unlock();
+ };
+ std::shared_ptr<SerialDeviceBatchScheduler<FakeTask>> scheduler;
+ TF_ASSERT_OK(
+ SerialDeviceBatchScheduler<FakeTask>::Create(options, &scheduler));
+ std::unique_ptr<BatchScheduler<FakeTask>> queue;
+ TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue));
+
+ // Enqueue 2 tasks, should result in 2 batches.
+ for (int i = 0; i < 2; i++) {
+ TF_ASSERT_OK(ScheduleTask(800, queue.get()));
+ }
+ std::unique_ptr<Thread> queue_deleter(Env::Default()->StartThread(
+ {}, "QueueDeleterThread", [&queue, &mu, &processed_batches] {
+ // Delete queue, should be kept alive until empty.
+ queue.reset();
+ mutex_lock l(mu);
+ EXPECT_EQ(processed_batches, 2);
+ }));
+ // Give queue_deleter thread time to delete queue.
+ Env::Default()->SleepForMicroseconds(1000);
+ finish_processing.Notify();
+}
+
+TEST(SerialDeviceBatchSchedulerTest, DeleteScheduler) {
+ SerialDeviceBatchScheduler<FakeTask>::Options options;
+ options.initial_in_flight_batches_limit = 1;
+ options.batches_to_average_over = 1000;
+ options.get_pending_on_serial_device = []() { return 0; };
+ mutex mu;
+ int processed_batches = 0;
+ Notification start_processing;
+ Notification finish_processing;
+ auto queue_callback =
+ [&mu, &processed_batches, &start_processing,
+ &finish_processing](std::unique_ptr<Batch<FakeTask>> batch) {
+ ASSERT_TRUE(batch->IsClosed());
+ EXPECT_GT(batch->num_tasks(), 0);
+ start_processing.WaitForNotification();
+ mutex_lock l(mu);
+ processed_batches++;
+ if (processed_batches == 2) {
+ finish_processing.Notify();
+ }
+ };
+
+ std::shared_ptr<SerialDeviceBatchScheduler<FakeTask>> scheduler;
+ TF_ASSERT_OK(
+ SerialDeviceBatchScheduler<FakeTask>::Create(options, &scheduler));
+ std::unique_ptr<BatchScheduler<FakeTask>> queue;
+ TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue));
+
+ // Enqueue 2 tasks, should result in 2 batches.
+ for (int i = 0; i < 2; i++) {
+ TF_ASSERT_OK(ScheduleTask(800, queue.get()));
+ }
+ // Delete scheduler, should be kept alive until queues are empty.
+ scheduler.reset();
+ start_processing.Notify();
+ finish_processing.WaitForNotification();
+}
+
+TEST(SerialDeviceBatchSchedulerTest, QueueCapacityInfo) {
+ SerialDeviceBatchScheduler<FakeTask>::Options options;
+ options.initial_in_flight_batches_limit = 1;
+ options.batches_to_average_over = 1000;
+ options.full_batch_scheduling_boost_micros = 1000;
+ options.get_pending_on_serial_device = []() { return 0; };
+ mutex mu;
+ int processed_batches = 0;
+ Notification finish_processing;
+ auto queue_callback = [&mu, &processed_batches, &finish_processing](
+ std::unique_ptr<Batch<FakeTask>> batch) {
+ ASSERT_TRUE(batch->IsClosed());
+ EXPECT_GT(batch->num_tasks(), 0);
+ mu.lock();
+ int batch_num = ++processed_batches;
+ mu.unlock();
+ if (batch_num == 1) {
+ finish_processing.WaitForNotification();
+ }
+ };
+ std::shared_ptr<SerialDeviceBatchScheduler<FakeTask>> scheduler;
+ TF_ASSERT_OK(
+ SerialDeviceBatchScheduler<FakeTask>::Create(options, &scheduler));
+ std::unique_ptr<BatchScheduler<FakeTask>> queue1;
+ std::unique_ptr<BatchScheduler<FakeTask>> queue2;
+ TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue1));
+ TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue2));
+
+ // Blocker task, should schedule first.
+ TF_ASSERT_OK(ScheduleTask(800, queue1.get()));
+ TF_ASSERT_OK(ScheduleTask(100, queue2.get()));
+
+ EXPECT_EQ(queue2->NumEnqueuedTasks(), 1);
+ EXPECT_EQ(queue2->SchedulingCapacity(), 9 * 1000 + 900);
+ // Enqueue 2 more tasks, should fall in same batch.
+ TF_ASSERT_OK(ScheduleTask(100, queue2.get()));
+ TF_ASSERT_OK(ScheduleTask(200, queue2.get()));
+ EXPECT_EQ(queue2->NumEnqueuedTasks(), 3);
+ EXPECT_EQ(queue2->SchedulingCapacity(), 9 * 1000 + 600);
+ // Enqueue 1 more task, should create new batch.
+ TF_ASSERT_OK(ScheduleTask(700, queue2.get()));
+ EXPECT_EQ(queue2->NumEnqueuedTasks(), 4);
+ EXPECT_EQ(queue2->SchedulingCapacity(), 8 * 1000 + 300);
+ finish_processing.Notify();
+}
+} // namespace anonymous
+} // namespace serving
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc
index 3618c75827..9a83c16f33 100644
--- a/tensorflow/core/kernels/data/batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/batch_dataset_op.cc
@@ -61,7 +61,7 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
Iterator::Params{this, strings::StrCat(prefix, "::Batch")}));
@@ -75,7 +75,7 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
return output_shapes_;
}
- string DebugString() override {
+ string DebugString() const override {
return strings::StrCat("BatchDatasetOp(", batch_size_, ")::Dataset");
}
@@ -95,8 +95,11 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc
index 4b4728dab6..3673df6fa3 100644
--- a/tensorflow/core/kernels/data/cache_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc
@@ -64,7 +64,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
~FileDataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
if (env_->FileExists(strings::StrCat(filename_, ".index")).ok()) {
return std::unique_ptr<IteratorBase>(new FileReaderIterator(
@@ -83,7 +83,9 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
return input_->output_shapes();
}
- string DebugString() override { return "CacheDatasetOp::FileDataset"; }
+ string DebugString() const override {
+ return "CacheDatasetOp::FileDataset";
+ }
private:
static size_t StringPaddingSize(size_t num_tensors) {
@@ -106,12 +108,15 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
explicit FileWriterIterator(const Params& params)
: DatasetIterator<FileDataset>(params),
cur_index_(0),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
writer_(params.dataset->env_, params.dataset->filename_),
lockfile_(strings::StrCat(params.dataset->filename_, ".lockfile")),
lockfile_created_(false),
iteration_completed_(false) {}
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
@@ -268,7 +273,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
~MemoryDataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
mutex_lock l(mu_);
if (cache_) {
@@ -292,7 +297,9 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
return input_->output_shapes();
}
- string DebugString() override { return "CacheDatasetOp::MemoryDataset"; }
+ string DebugString() const override {
+ return "CacheDatasetOp::MemoryDataset";
+ }
private:
// MemoryWriterIterator passes through and appends items from the input
@@ -305,7 +312,6 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
public:
explicit MemoryWriterIterator(const Params& params)
: DatasetIterator<MemoryDataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
cache_(new std::vector<std::vector<Tensor>>) {}
~MemoryWriterIterator() override {
@@ -323,6 +329,10 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
}
}
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
index f11abc62a6..0012a4769d 100644
--- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc
+++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
@@ -61,7 +61,7 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
to_concatenate_->Unref();
}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Concatenate")}));
@@ -75,7 +75,9 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
return output_shapes_;
}
- string DebugString() override { return "ConcatenateDatasetOp::Dataset"; }
+ string DebugString() const override {
+ return "ConcatenateDatasetOp::Dataset";
+ }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
@@ -94,10 +96,12 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- i_(0),
- input_impl_(params.dataset->input_->MakeIterator(
- strings::StrCat(params.prefix, "[0]"))) {}
+ : DatasetIterator<Dataset>(params), i_(0) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(
+ ctx, strings::StrCat(prefix(), "[0]"), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
@@ -114,8 +118,8 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
return Status::OK();
}
if (++i_ < 2) {
- input_impl_ = dataset()->to_concatenate_->MakeIterator(
- strings::StrCat(prefix(), "[1]"));
+ TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator(
+ ctx, strings::StrCat(prefix(), "[1]"), &input_impl_));
}
}
*end_of_sequence = true;
@@ -147,8 +151,8 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
if (!TF_PREDICT_TRUE(i_ >= 0 && i_ <= 2))
return errors::InvalidArgument("i_ must be in range [0, 2].");
if (i_ == 1) {
- input_impl_ = dataset()->to_concatenate_->MakeIterator(
- strings::StrCat(prefix(), "[1]"));
+ TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator(
+ ctx, strings::StrCat(prefix(), "[1]"), &input_impl_));
} else if (i_ == 2) {
input_impl_.reset();
}
diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc
index c608f9e1c6..d85ef1cbab 100644
--- a/tensorflow/core/kernels/data/dataset_utils.cc
+++ b/tensorflow/core/kernels/data/dataset_utils.cc
@@ -41,9 +41,8 @@ Status MakeIteratorFromInputElement(
GetDatasetFromVariantTensor(return_values[0], &returned_dataset));
// Create an iterator for the dataset that was returned by `f`.
- *out_iterator = returned_dataset->MakeIterator(
- strings::StrCat(prefix, "[", thread_index, "]"));
- return Status::OK();
+ return returned_dataset->MakeIterator(
+ ctx, strings::StrCat(prefix, "[", thread_index, "]"), out_iterator);
}
} // namespace dataset
diff --git a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
index 132808a5f1..91b9279427 100644
--- a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
@@ -94,7 +94,7 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
{this, strings::StrCat(prefix, "::DenseToSparseBatch")}));
@@ -109,7 +109,7 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
return output_shapes_;
}
- string DebugString() override {
+ string DebugString() const override {
return strings::StrCat("DenseToSparseBatchDatasetOp(", batch_size_,
")::Dataset");
}
@@ -137,8 +137,12 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset<T>> {
public:
explicit Iterator(const typename Iterator::Params& params)
- : DatasetIterator<Dataset<T>>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset<T>>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return DatasetIterator<Dataset<T>>::dataset()->input_->MakeIterator(
+ ctx, DatasetIterator<Dataset<T>>::prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc
index 186b1e1c6c..6d6c44552d 100644
--- a/tensorflow/core/kernels/data/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_dataset_op.cc
@@ -93,7 +93,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
~FilterDatasetBase() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Filter")}));
@@ -106,7 +106,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
return input_->output_shapes();
}
- string DebugString() override { return "FilterDatasetOp::Dataset"; }
+ string DebugString() const override { return "FilterDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
@@ -145,8 +145,11 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<FilterDatasetBase> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<FilterDatasetBase>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<FilterDatasetBase>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
index 77a48a2aa9..baca022f1e 100644
--- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
@@ -74,7 +74,7 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::FlatMap")}));
@@ -88,7 +88,7 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
return output_shapes_;
}
- string DebugString() override { return "FlatMapDatasetOp::Dataset"; }
+ string DebugString() const override { return "FlatMapDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
@@ -125,8 +125,11 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
@@ -202,7 +205,8 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
current_element_iterator_.reset();
captured_func_inputs_.clear();
if (!reader->Contains(full_name("exhausted"))) {
- input_impl_ = dataset()->input_->MakeIterator(prefix());
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
{
int64 temp;
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc
index 3f1e441b91..aae62ad2fe 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/generator_dataset_op.cc
@@ -99,7 +99,7 @@ class GeneratorDatasetOp : public DatasetOpKernel {
output_types_(output_types),
output_shapes_(output_shapes) {}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Generator")}));
@@ -112,7 +112,9 @@ class GeneratorDatasetOp : public DatasetOpKernel {
return output_shapes_;
}
- string DebugString() override { return "GeneratorDatasetOp::Dataset"; }
+ string DebugString() const override {
+ return "GeneratorDatasetOp::Dataset";
+ }
private:
class Iterator : public DatasetIterator<Dataset> {
diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
index c8aeaab9cb..03abae79d2 100644
--- a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
@@ -88,7 +88,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::GroupByReducer")}));
@@ -101,7 +101,9 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
return output_shapes_;
}
- string DebugString() override { return "GroupByReducerDatasetOp::Dataset"; }
+ string DebugString() const override {
+ return "GroupByReducerDatasetOp::Dataset";
+ }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
@@ -183,8 +185,11 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
index 03f847ce9c..23d769e1ab 100644
--- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
@@ -118,7 +118,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::GroupByWindow")}));
@@ -131,7 +131,9 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
return output_shapes_;
}
- string DebugString() override { return "GroupByWindowDatasetOp::Dataset"; }
+ string DebugString() const override {
+ return "GroupByWindowDatasetOp::Dataset";
+ }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
@@ -198,8 +200,11 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
@@ -484,8 +489,8 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
GetDatasetFromVariantTensor(return_values[0], &returned_dataset));
// Create an iterator for the dataset that was returned by `f`.
- current_group_iterator_ = returned_dataset->MakeIterator(prefix());
- return Status::OK();
+ return returned_dataset->MakeIterator(ctx, prefix(),
+ &current_group_iterator_);
}
mutex mu_;
diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc
index bce3f28d62..0765e63993 100644
--- a/tensorflow/core/kernels/data/interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc
@@ -96,7 +96,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Interleave")}));
@@ -109,7 +109,9 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
return output_shapes_;
}
- string DebugString() override { return "InterleaveDatasetOp::Dataset"; }
+ string DebugString() const override {
+ return "InterleaveDatasetOp::Dataset";
+ }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
@@ -149,10 +151,13 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
current_elements_(params.dataset->cycle_length_),
args_list_(params.dataset->cycle_length_) {}
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
void AdvanceToNextInCycle() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
block_index_ = 0;
cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
@@ -294,7 +299,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
}
mutex mu_;
- const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
std::vector<std::unique_ptr<IteratorBase>> current_elements_
GUARDED_BY(mu_);
std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index 87bc8ebefe..9d9e74adba 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -158,7 +158,10 @@ class IteratorResource : public ResourceBase {
graph_runner.Run(&graph, lib, {}, {output_node}, &outputs));
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
- TF_RETURN_IF_ERROR(set_iterator(dataset->MakeIterator("Iterator")));
+ IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
+ std::unique_ptr<IteratorBase> iterator;
+ TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, "Iterator", &iterator));
+ TF_RETURN_IF_ERROR(set_iterator(std::move(iterator)));
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
if (captured_iterator) {
@@ -657,8 +660,12 @@ class MakeIteratorOp : public OpKernel {
OP_REQUIRES_OK(
ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
core::ScopedUnref unref(iterator_resource);
- OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(
- dataset->MakeIterator("Iterator")));
+
+ IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
+ std::unique_ptr<IteratorBase> iterator;
+ OP_REQUIRES_OK(ctx,
+ dataset->MakeIterator(&iter_ctx, "Iterator", &iterator));
+ OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator)));
}
};
@@ -680,9 +687,12 @@ class ToSingleElementOp : public AsyncOpKernel {
DatasetBase* dataset;
OP_REQUIRES_OK_ASYNC(
ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
- auto iterator = dataset->MakeIterator("SingleElementIterator");
-
IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
+ std::unique_ptr<IteratorBase> iterator;
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ dataset->MakeIterator(&iter_ctx, "SingleElementIterator", &iterator),
+ done);
std::vector<Tensor> components;
components.reserve(dataset->output_dtypes().size());
bool end_of_sequence;
@@ -866,8 +876,10 @@ class OneShotIteratorOp : public AsyncOpKernel {
// factory function.
DatasetBase* dataset;
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
- TF_RETURN_IF_ERROR(
- (*iterator)->set_iterator(dataset->MakeIterator("Iterator")));
+ IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
+ std::unique_ptr<IteratorBase> iter;
+ TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, "Iterator", &iter));
+ TF_RETURN_IF_ERROR((*iterator)->set_iterator(std::move(iter)));
(*iterator)->Ref();
return Status::OK();
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index f41a810b07..703ef194a1 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -125,7 +125,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::MapAndBatch")}));
@@ -139,7 +139,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
return output_shapes_;
}
- string DebugString() override { return "MapAndBatchDatasetOp::Dataset"; }
+ string DebugString() const override {
+ return "MapAndBatchDatasetOp::Dataset";
+ }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
@@ -188,7 +190,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
batch_results_((params.dataset->num_parallel_calls_ +
params.dataset->batch_size_ - 1) /
params.dataset->batch_size_) {
@@ -208,6 +209,10 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
}
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
@@ -647,7 +652,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
int64 num_calls_ GUARDED_BY(mu_) = 0;
// Counts the total number of calls.
int64 call_counter_ GUARDED_BY(mu_) = 0;
- const std::unique_ptr<IteratorBase> input_impl_;
+ std::unique_ptr<IteratorBase> input_impl_;
// Identifies the next batch to be read by the caller.
int64 input_batch_ GUARDED_BY(mu_) = 0;
// Identifies the next batch to create.
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc
index 89360d1cd9..aa530aea19 100644
--- a/tensorflow/core/kernels/data/map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op.cc
@@ -73,7 +73,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Map")}));
@@ -86,7 +86,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
return output_shapes_;
}
- string DebugString() override { return "MapDatasetOp::Dataset"; }
+ string DebugString() const override { return "MapDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
@@ -123,8 +123,11 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
@@ -167,7 +170,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
}
private:
- const std::unique_ptr<IteratorBase> input_impl_;
+ std::unique_ptr<IteratorBase> input_impl_;
};
const DatasetBase* const input_;
diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
index e41800a806..d9e43ace39 100644
--- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
@@ -119,7 +119,7 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::PaddedBatch")}));
@@ -133,7 +133,7 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
return output_shapes_;
}
- string DebugString() override {
+ string DebugString() const override {
return strings::StrCat("PaddedBatchDatasetOp(", batch_size_,
")::Dataset");
}
@@ -186,8 +186,11 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
@@ -325,7 +328,8 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
if (reader->Contains(full_name("exhausted"))) {
input_impl_.reset();
} else {
- input_impl_ = dataset()->input_->MakeIterator(prefix());
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
}
return Status::OK();
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index fa33867ec1..6292b4536e 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -116,7 +116,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
{this, strings::StrCat(prefix, "::ParallelInterleave")}));
@@ -129,7 +129,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
return output_shapes_;
}
- string DebugString() override {
+ string DebugString() const override {
return "ParallelInterleaveDatasetOp::Dataset";
}
@@ -236,7 +236,6 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
workers_(dataset()->num_threads()),
worker_thread_states_(dataset()->num_threads()) {}
@@ -249,6 +248,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
}
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
// It is implemented so that it matches the deterministic interleave
// unless getting the next element would block and we are allowed to be
// sloppy.
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index 7e373f2568..3fa6b0d3a9 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -85,7 +85,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::ParallelMap")}));
@@ -99,7 +99,9 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
return output_shapes_;
}
- string DebugString() override { return "ParallelMapDatasetOp::Dataset"; }
+ string DebugString() const override {
+ return "ParallelMapDatasetOp::Dataset";
+ }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
@@ -150,7 +152,6 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
invocation_results_(params.dataset->num_parallel_calls_) {}
~Iterator() override {
@@ -169,6 +170,10 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
}
}
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
diff --git a/tensorflow/core/kernels/data/prefetch_autotuner_test.cc b/tensorflow/core/kernels/data/prefetch_autotuner_test.cc
index 2f573dfb35..29a8cc50cd 100644
--- a/tensorflow/core/kernels/data/prefetch_autotuner_test.cc
+++ b/tensorflow/core/kernels/data/prefetch_autotuner_test.cc
@@ -33,7 +33,7 @@ TEST(PrefetchAutotuner, Disabled) {
TEST(PrefetchAutotuner, Enabled) {
PrefetchAutotuner t(PrefetchAutotuner::kAutoTune);
EXPECT_EQ(1, t.buffer_limit());
- t.RecordConsumption(0); // Expect buffer limit to increase.
+ t.RecordConsumption(0); // Expect buffer limit to stay the same.
EXPECT_EQ(1, t.buffer_limit());
t.RecordConsumption(1);
EXPECT_EQ(1, t.buffer_limit());
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index 536de81fd8..e2b6aa590e 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -55,7 +55,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Prefetch")}));
@@ -68,7 +68,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
return input_->output_shapes();
}
- string DebugString() override { return "PrefetchDatasetOp::Dataset"; }
+ string DebugString() const override { return "PrefetchDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
@@ -87,7 +87,6 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
auto_tuner_(params.dataset->buffer_size_) {}
~Iterator() override {
@@ -106,6 +105,10 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
}
}
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
@@ -327,7 +330,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
// accessing the parent iterator. We keep this separate from `mu_` to
// allow prefetching to run in parallel with GetNext calls.
mutex parent_mu_ ACQUIRED_BEFORE(mu_);
- const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(parent_mu_);
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(parent_mu_);
condition_variable cond_var_;
PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_);
std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/kernels/data/random_dataset_op.cc b/tensorflow/core/kernels/data/random_dataset_op.cc
index 210b9ad1b8..ff166c3be7 100644
--- a/tensorflow/core/kernels/data/random_dataset_op.cc
+++ b/tensorflow/core/kernels/data/random_dataset_op.cc
@@ -54,7 +54,7 @@ class RandomDatasetOp : public DatasetOpKernel {
Dataset(OpKernelContext* ctx, int64 seed, int64 seed2)
: GraphDatasetBase(ctx), seed_(seed), seed2_(seed2) {}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Random")}));
@@ -71,7 +71,7 @@ class RandomDatasetOp : public DatasetOpKernel {
return *shapes;
}
- string DebugString() override {
+ string DebugString() const override {
return strings::StrCat("RandomDatasetOp(", seed_, ", ", seed2_,
")::Dataset");
}
diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc
index b57518e678..0b5c814767 100644
--- a/tensorflow/core/kernels/data/range_dataset_op.cc
+++ b/tensorflow/core/kernels/data/range_dataset_op.cc
@@ -48,7 +48,7 @@ class RangeDatasetOp : public DatasetOpKernel {
Dataset(OpKernelContext* ctx, int64 start, int64 stop, int64 step)
: GraphDatasetBase(ctx), start_(start), stop_(stop), step_(step) {}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Range")}));
@@ -65,7 +65,7 @@ class RangeDatasetOp : public DatasetOpKernel {
return *shapes;
}
- string DebugString() override {
+ string DebugString() const override {
return strings::StrCat("RangeDatasetOp(", start_, ", ", stop_, ", ",
step_, ")::Dataset");
}
diff --git a/tensorflow/core/kernels/data/reader_dataset_ops.cc b/tensorflow/core/kernels/data/reader_dataset_ops.cc
index 34d7d9f914..29654b9bca 100644
--- a/tensorflow/core/kernels/data/reader_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/reader_dataset_ops.cc
@@ -89,7 +89,7 @@ class TextLineDatasetOp : public DatasetOpKernel {
use_compression_(!compression_type.empty()),
options_(options) {}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::TextLine")}));
@@ -106,7 +106,7 @@ class TextLineDatasetOp : public DatasetOpKernel {
return *shapes;
}
- string DebugString() override { return "TextLineDatasetOp::Dataset"; }
+ string DebugString() const override { return "TextLineDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
@@ -323,7 +323,7 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel {
footer_bytes_(footer_bytes),
buffer_size_(buffer_size) {}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::FixedLengthRecord")}));
@@ -340,7 +340,7 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel {
return *shapes;
}
- string DebugString() override {
+ string DebugString() const override {
return "FixedLengthRecordDatasetOp::Dataset";
}
@@ -543,7 +543,7 @@ class TFRecordDatasetOp : public DatasetOpKernel {
}
}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::TFRecord")}));
@@ -560,7 +560,7 @@ class TFRecordDatasetOp : public DatasetOpKernel {
return *shapes;
}
- string DebugString() override { return "TFRecordDatasetOp::Dataset"; }
+ string DebugString() const override { return "TFRecordDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc
index d37086541d..6b3f4ed27b 100644
--- a/tensorflow/core/kernels/data/repeat_dataset_op.cc
+++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc
@@ -48,7 +48,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
if (count_ < 0) {
return std::unique_ptr<IteratorBase>(new ForeverIterator(
@@ -69,7 +69,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
return input_->output_shapes();
}
- string DebugString() override { return "RepeatDatasetOp::Dataset"; }
+ string DebugString() const override { return "RepeatDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
@@ -108,9 +108,11 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
class FiniteIterator : public DatasetIterator<Dataset> {
public:
explicit FiniteIterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- i_(0),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params), i_(0) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
@@ -127,7 +129,8 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
++i_;
- input_impl_ = dataset()->input_->MakeIterator(prefix());
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
}
*end_of_sequence = true;
input_impl_.reset();
@@ -178,7 +181,8 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
bool first_call = false;
if (!input_impl_) {
first_call = true;
- input_impl_ = dataset()->input_->MakeIterator(prefix());
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
}
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
@@ -214,7 +218,8 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
if (reader->Contains(full_name("uninitialized"))) {
input_impl_.reset();
} else {
- input_impl_ = dataset()->input_->MakeIterator(prefix());
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
}
return Status::OK();
diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc
index 5dd6ff848e..a3b20016a8 100644
--- a/tensorflow/core/kernels/data/scan_dataset_op.cc
+++ b/tensorflow/core/kernels/data/scan_dataset_op.cc
@@ -90,7 +90,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Scan")}));
@@ -103,7 +103,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
return output_shapes_;
}
- string DebugString() override { return "ScanDatasetOp::Dataset"; }
+ string DebugString() const override { return "ScanDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
@@ -149,9 +149,12 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
state_(params.dataset->initial_state_) {}
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
@@ -250,7 +253,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
private:
mutex mu_;
- const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
std::vector<Tensor> state_ GUARDED_BY(mu_);
};
diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
index 2f6bf83da5..3438199ebd 100644
--- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc
+++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
@@ -85,7 +85,8 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
bool first_call = false;
if (!input_impl_ && epoch_ == 0) {
first_call = true;
- input_impl_ = dataset()->input_->MakeIterator(prefix());
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
}
while (input_impl_ && num_elements_ < dataset()->buffer_size_) {
if (ctx->env()->NowMicros() >
@@ -114,7 +115,8 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
epoch_++;
int64 n = slices_.back()->end;
slices_.emplace_back(new Slice{n, n});
- input_impl_ = dataset()->input_->MakeIterator(prefix());
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
}
if (!end_of_input_sequence) {
buffer_[slices_.back()->end % dataset()->buffer_size_] =
@@ -211,7 +213,8 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
// Restore the input iterator if it wasn't already exhausted.
if (!reader->Contains(full_name("end_of_input_sequence"))) {
- input_impl_ = dataset()->input_->MakeIterator(prefix());
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
} else {
input_impl_.reset();
@@ -356,12 +359,12 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase {
parent_generator_(seed, seed2),
generator_(&parent_generator_) {}
- string DebugString() override {
+ string DebugString() const override {
return strings::StrCat("ShuffleDatasetOp(", buffer_size_, ", ", seed_,
", ", seed2_, ")::ReshufflingDataset");
}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
int64 iterator_seed;
int64 iterator_seed2;
@@ -375,6 +378,23 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase {
iterator_seed2));
}
+ protected:
+ Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ return errors::Unimplemented(
+ "Checkpointing ShufflingDataset with reshuffle_each_iteration=true "
+ "is not supported.\n"
+ "If you have a ds.shuffle(buffer_size).repeat(count) in your input "
+ "pipeline, replace it with "
+ "ds.apply(tf.contrib.data.shuffle_and_repeat(buffer_size, count)).\n"
+ "If you iterate over your dataset once, change shuffle(buffer_size) "
+ "to shuffle(buffer_size, reshuffle_each_iteration=False).\n"
+ "If you are using Dataset.list_files(pattern), change it to "
+ "Dataset.list_files(pattern, shuffle=False) and manually shuffle "
+ "the list of files using shuffle_and_repeat as above or using "
+ "ds.shuffle with reshuffle_each_iteration=False.");
+ }
+
private:
const int64 seed_;
const int64 seed2_;
@@ -394,12 +414,12 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase {
seed_(seed),
seed2_(seed) {}
- string DebugString() override {
+ string DebugString() const override {
return strings::StrCat("ShuffleDatasetOp(", buffer_size_, ", ", seed_,
", ", seed2_, ")::FixedSeedDataset");
}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new ShuffleDatasetBase::Iterator(
{this, strings::StrCat(prefix, "::Shuffle")}, seed_, seed2_));
@@ -477,12 +497,12 @@ class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase {
seed_(seed),
seed2_(seed2) {}
- string DebugString() override {
+ string DebugString() const override {
return strings::StrCat("ShuffleAndRepeatDatasetOp(", buffer_size_, ", ",
seed_, ", ", seed2_, ", ", count_, ")::Dataset");
}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new ShuffleDatasetBase::Iterator(
{this, strings::StrCat(prefix, "::ShuffleAndRepeat")}, seed_,
diff --git a/tensorflow/core/kernels/data/skip_dataset_op.cc b/tensorflow/core/kernels/data/skip_dataset_op.cc
index d636c37afe..b84afa3e33 100644
--- a/tensorflow/core/kernels/data/skip_dataset_op.cc
+++ b/tensorflow/core/kernels/data/skip_dataset_op.cc
@@ -47,14 +47,11 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
if (count_ < 0) {
return std::unique_ptr<IteratorBase>(
new EmptyIterator({this, strings::StrCat(prefix, "::EmptySkip")}));
- } else if (count_ == 0) {
- // Pass through.
- return input_->MakeIterator(prefix);
} else {
return std::unique_ptr<IteratorBase>(new FiniteIterator(
{this, strings::StrCat(prefix, "::FiniteSkip")}));
@@ -68,7 +65,7 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
return input_->output_shapes();
}
- string DebugString() override { return "SkipDatasetOp::Dataset"; }
+ string DebugString() const override { return "SkipDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
@@ -108,9 +105,11 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
class FiniteIterator : public DatasetIterator<Dataset> {
public:
explicit FiniteIterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- i_(0),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params), i_(0) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
diff --git a/tensorflow/core/kernels/data/slide_dataset_op.cc b/tensorflow/core/kernels/data/slide_dataset_op.cc
index 78c8363f91..48776cbf61 100644
--- a/tensorflow/core/kernels/data/slide_dataset_op.cc
+++ b/tensorflow/core/kernels/data/slide_dataset_op.cc
@@ -33,10 +33,9 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
DatasetBase** output) override {
int64 window_size = 0;
int64 stride = 1;
- OP_REQUIRES_OK(ctx,
- ParseScalarArgument<int64>(ctx, "window_size", &window_size));
- OP_REQUIRES_OK(ctx,
- ParseScalarArgument<int64>(ctx, "stride", &stride));
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<int64>(ctx, "window_size", &window_size));
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "stride", &stride));
OP_REQUIRES(
ctx, window_size > 0,
errors::InvalidArgument("Window size must be greater than zero."));
@@ -50,8 +49,12 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
private:
class Dataset : public GraphDatasetBase {
public:
- Dataset(OpKernelContext* ctx, int64 window_size, int64 stride, const DatasetBase* input)
- : GraphDatasetBase(ctx), window_size_(window_size), stride_(stride), input_(input) {
+ Dataset(OpKernelContext* ctx, int64 window_size, int64 stride,
+ const DatasetBase* input)
+ : GraphDatasetBase(ctx),
+ window_size_(window_size),
+ stride_(stride),
+ input_(input) {
input_->Ref();
const auto& input_shapes = input_->output_shapes();
@@ -64,7 +67,7 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
Iterator::Params{this, strings::StrCat(prefix, "::Slide")}));
@@ -78,8 +81,9 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
return output_shapes_;
}
- string DebugString() override {
- return strings::StrCat("SlideDatasetOp(", window_size_, ", ", stride_, ")::Dataset");
+ string DebugString() const override {
+ return strings::StrCat("SlideDatasetOp(", window_size_, ", ", stride_,
+ ")::Dataset");
}
protected:
@@ -101,8 +105,11 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
index fcf17ad68b..2604822cc9 100644
--- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
+++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
@@ -39,7 +39,7 @@ class Dataset : public GraphDatasetBase {
{-1},
{sparse_tensor.dims() - 1}}) {}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::SparseTensorSlice")}));
@@ -50,7 +50,7 @@ class Dataset : public GraphDatasetBase {
return shapes_;
}
- string DebugString() override {
+ string DebugString() const override {
return "SparseTensorSliceDatasetOp::Dataset";
}
diff --git a/tensorflow/core/kernels/data/sql_dataset_ops.cc b/tensorflow/core/kernels/data/sql_dataset_ops.cc
index 634b3c280f..16652e792c 100644
--- a/tensorflow/core/kernels/data/sql_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/sql_dataset_ops.cc
@@ -88,7 +88,7 @@ class SqlDatasetOp : public DatasetOpKernel {
output_types_(output_types),
output_shapes_(output_shapes) {}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Sql")}));
@@ -102,7 +102,7 @@ class SqlDatasetOp : public DatasetOpKernel {
return output_shapes_;
}
- string DebugString() override { return "SqlDatasetOp::Dataset"; }
+ string DebugString() const override { return "SqlDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
index eb96b8a872..2ff90d7b10 100644
--- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
@@ -53,7 +53,7 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
stats_aggregator_resource_->Unref();
}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
{this, strings::StrCat(prefix, "::SetStatsAggregator")}));
@@ -66,7 +66,7 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
return input_->output_shapes();
}
- string DebugString() override {
+ string DebugString() const override {
return "SetStatsAggregatorDatasetOp::Dataset";
}
@@ -82,8 +82,11 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
diff --git a/tensorflow/core/kernels/data/stats_dataset_ops.cc b/tensorflow/core/kernels/data/stats_dataset_ops.cc
index 633cd85451..7370a24b38 100644
--- a/tensorflow/core/kernels/data/stats_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/stats_dataset_ops.cc
@@ -56,7 +56,7 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel {
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::LatencyStats")}));
@@ -69,7 +69,9 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel {
return input_->output_shapes();
}
- string DebugString() override { return "LatencyStatsDatasetOp::Dataset"; }
+ string DebugString() const override {
+ return "LatencyStatsDatasetOp::Dataset";
+ }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
@@ -86,8 +88,11 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
@@ -150,7 +155,7 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel {
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
{this, strings::StrCat(prefix, "::BytesProducedStats")}));
@@ -163,7 +168,7 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel {
return input_->output_shapes();
}
- string DebugString() override {
+ string DebugString() const override {
return "BytesProducedStatsDatasetOp::Dataset";
}
@@ -182,8 +187,11 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
diff --git a/tensorflow/core/kernels/data/take_dataset_op.cc b/tensorflow/core/kernels/data/take_dataset_op.cc
index 3bea46a747..3d29221f3e 100644
--- a/tensorflow/core/kernels/data/take_dataset_op.cc
+++ b/tensorflow/core/kernels/data/take_dataset_op.cc
@@ -47,12 +47,9 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- if (count_ < 0) {
- // Pass through
- return input_->MakeIterator(prefix);
- } else if (count_ == 0) {
+ if (count_ == 0) {
return std::unique_ptr<IteratorBase>(
new EmptyIterator({this, strings::StrCat(prefix, "::EmptyTake")}));
} else {
@@ -69,7 +66,7 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
return input_->output_shapes();
}
- string DebugString() override { return "TakeDatasetOp::Dataset"; }
+ string DebugString() const override { return "TakeDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
@@ -109,9 +106,11 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
class FiniteIterator : public DatasetIterator<Dataset> {
public:
explicit FiniteIterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- i_(0),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params), i_(0) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
@@ -121,7 +120,7 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
*end_of_sequence = true;
return Status::OK();
}
- while (i_ < dataset()->count_) {
+ while (dataset()->count_ < 0 || i_ < dataset()->count_) {
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
if (!*end_of_sequence) {
diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc
index 8c8994b1c3..36fc434d8f 100644
--- a/tensorflow/core/kernels/data/tensor_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc
@@ -53,7 +53,7 @@ class TensorDatasetOp : public DatasetOpKernel {
}
}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::FromTensor")}));
@@ -64,7 +64,7 @@ class TensorDatasetOp : public DatasetOpKernel {
return shapes_;
}
- string DebugString() override { return "TensorDatasetOp::Dataset"; }
+ string DebugString() const override { return "TensorDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
diff --git a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc
index e271a42b2a..29b4c9053e 100644
--- a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc
@@ -81,7 +81,7 @@ class PrependFromQueueAndPaddedBatchDataset : public GraphDatasetBase {
~PrependFromQueueAndPaddedBatchDataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
{this, strings::StrCat(prefix, "::PrependFromQueueAndPaddedBatch")}));
@@ -94,7 +94,7 @@ class PrependFromQueueAndPaddedBatchDataset : public GraphDatasetBase {
return batched_shapes_with_queue_;
}
- string DebugString() override {
+ string DebugString() const override {
return "PrependFromQueueAndPaddedBatchDatasetOp::Dataset";
}
@@ -152,15 +152,19 @@ class PrependFromQueueAndPaddedBatchDataset : public GraphDatasetBase {
: public DatasetIterator<PrependFromQueueAndPaddedBatchDataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<PrependFromQueueAndPaddedBatchDataset>(params),
- queue_(new TensorQueue(/*input_impl*/
- params.dataset->input_->MakeIterator(
- params.prefix),
- params.dataset->dtypes_,
- params.dataset->shapes_)) {}
+ : DatasetIterator<PrependFromQueueAndPaddedBatchDataset>(params) {}
~Iterator() override { queue_->Unref(); }
+ Status Initialize(IteratorContext* ctx) override {
+ std::unique_ptr<IteratorBase> iterator;
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &iterator));
+ queue_ = new TensorQueue(std::move(iterator), dataset()->dtypes_,
+ dataset()->shapes_);
+ return Status::OK();
+ }
+
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
@@ -372,7 +376,8 @@ class PrependFromQueueAndPaddedBatchDataset : public GraphDatasetBase {
if (reader->Contains(iter->full_name("input_exhausted"))) {
input_impl_.reset();
} else {
- input_impl_ = iter->dataset_input()->MakeIterator(iter->prefix());
+ TF_RETURN_IF_ERROR(iter->dataset_input()->MakeIterator(
+ ctx, iter->prefix(), &input_impl_));
TF_RETURN_IF_ERROR(iter->RestoreParent(ctx, reader, input_impl_));
}
entries_.clear();
@@ -469,7 +474,7 @@ class PrependFromQueueAndPaddedBatchDataset : public GraphDatasetBase {
};
private:
- TensorQueue* const queue_;
+ TensorQueue* queue_;
};
private:
diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
index 95708cc01c..68ce324081 100644
--- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
@@ -70,7 +70,7 @@ class TensorSliceDatasetOp : public DatasetOpKernel {
}
}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::TensorSlice")}));
@@ -81,7 +81,9 @@ class TensorSliceDatasetOp : public DatasetOpKernel {
return shapes_;
}
- string DebugString() override { return "TensorSliceDatasetOp::Dataset"; }
+ string DebugString() const override {
+ return "TensorSliceDatasetOp::Dataset";
+ }
protected:
Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
diff --git a/tensorflow/core/kernels/data/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/unbatch_dataset_op.cc
index 2b383e5097..2aec9fb090 100644
--- a/tensorflow/core/kernels/data/unbatch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/unbatch_dataset_op.cc
@@ -49,7 +49,7 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel {
}
}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Unbatch")}));
@@ -62,7 +62,7 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel {
return shapes_;
}
- string DebugString() override { return "UnbatchDatasetOp::Dataset"; }
+ string DebugString() const override { return "UnbatchDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
@@ -80,9 +80,12 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params),
current_index_(0),
current_batch_size_(0),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
shapes_(params.dataset->output_shapes().size()) {}
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
diff --git a/tensorflow/core/kernels/data/window_dataset.cc b/tensorflow/core/kernels/data/window_dataset.cc
index e24bdea4ac..668b461374 100644
--- a/tensorflow/core/kernels/data/window_dataset.cc
+++ b/tensorflow/core/kernels/data/window_dataset.cc
@@ -26,7 +26,7 @@ class WindowDataset : public DatasetBase {
output_types_(std::move(output_types)),
output_shapes_(std::move(output_shapes)) {}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Window")}));
@@ -38,7 +38,7 @@ class WindowDataset : public DatasetBase {
return output_shapes_;
}
- string DebugString() override { return "WindowDataset"; }
+ string DebugString() const override { return "WindowDataset"; }
private:
class Iterator : public DatasetIterator<WindowDataset> {
diff --git a/tensorflow/core/kernels/data/writer_ops.cc b/tensorflow/core/kernels/data/writer_ops.cc
index 656fee1e85..80d9a5b867 100644
--- a/tensorflow/core/kernels/data/writer_ops.cc
+++ b/tensorflow/core/kernels/data/writer_ops.cc
@@ -70,9 +70,13 @@ class ToTFRecordOp : public AsyncOpKernel {
DatasetBase* dataset;
OP_REQUIRES_OK_ASYNC(
ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
- auto iterator = dataset->MakeIterator("ToTFRecordOpIterator");
-
IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
+ std::unique_ptr<IteratorBase> iterator;
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ dataset->MakeIterator(&iter_ctx, "ToTFRecordOpIterator", &iterator),
+ done);
+
std::vector<Tensor> components;
components.reserve(dataset->output_dtypes().size());
bool end_of_sequence;
diff --git a/tensorflow/core/kernels/data/zip_dataset_op.cc b/tensorflow/core/kernels/data/zip_dataset_op.cc
index 0f79eac947..00705236f9 100644
--- a/tensorflow/core/kernels/data/zip_dataset_op.cc
+++ b/tensorflow/core/kernels/data/zip_dataset_op.cc
@@ -60,7 +60,7 @@ class ZipDatasetOp : public DatasetOpKernel {
}
}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Zip")}));
@@ -74,7 +74,7 @@ class ZipDatasetOp : public DatasetOpKernel {
return output_shapes_;
}
- string DebugString() override { return "ZipDatasetOp::Dataset"; }
+ string DebugString() const override { return "ZipDatasetOp::Dataset"; }
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
@@ -95,13 +95,16 @@ class ZipDatasetOp : public DatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {
- input_impls_.reserve(params.dataset->inputs_.size());
- size_t idx = 0;
- for (const auto& input : params.dataset->inputs_) {
- input_impls_.emplace_back(input->MakeIterator(
- strings::StrCat(params.prefix, "[", idx++, "]")));
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ mutex_lock l(mu_);
+ input_impls_.resize(dataset()->inputs_.size());
+ for (size_t i = 0; i < input_impls_.size(); ++i) {
+ TF_RETURN_IF_ERROR(dataset()->inputs_[i]->MakeIterator(
+ ctx, strings::StrCat(prefix(), "[", i, "]"), &input_impls_[i]));
}
+ return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
diff --git a/tensorflow/core/kernels/inplace_ops.cc b/tensorflow/core/kernels/inplace_ops.cc
index ef6ce0546b..8f51cc3819 100644
--- a/tensorflow/core/kernels/inplace_ops.cc
+++ b/tensorflow/core/kernels/inplace_ops.cc
@@ -476,6 +476,7 @@ REGISTER_EMPTY(string, CPU)
REGISTER_EMPTY(int32, CPU)
REGISTER_EMPTY(int64, CPU)
REGISTER_EMPTY(bool, CPU)
+REGISTER_EMPTY(uint8, CPU)
#if GOOGLE_CUDA
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 632bb32063..22ae6121e0 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -64,6 +64,10 @@ constexpr uint64 HTTP_CODE_RESUME_INCOMPLETE = 308;
// The environment variable that overrides the size of the readahead buffer.
// DEPRECATED. Use GCS_BLOCK_SIZE_MB instead.
constexpr char kReadaheadBufferSize[] = "GCS_READAHEAD_BUFFER_SIZE_BYTES";
+// The environment variable that disables the GCS block cache for reads.
+// This is the explicit alternative to setting BLOCK_SIZE or MAX_SIZE to 0, and
+// takes precedence over either of those environment variables.
+constexpr char kReadCacheDisabled[] = "GCS_READ_CACHE_DISABLED";
// The environment variable that overrides the block size for aligned reads from
// GCS. Specified in MB (e.g. "16" = 16 x 1024 x 1024 = 16777216 bytes).
constexpr char kBlockSize[] = "GCS_READ_CACHE_BLOCK_SIZE_MB";
@@ -129,9 +133,6 @@ constexpr char kInitialTokens[] = "GCS_INITIAL_TOKENS";
// TODO: DO NOT use a hardcoded path
Status GetTmpFilename(string* filename) {
- if (!filename) {
- return errors::Internal("'filename' cannot be nullptr.");
- }
#ifndef _WIN32
char buffer[] = "/tmp/gcs_filesystem_XXXXXX";
int fd = mkstemp(buffer);
@@ -158,9 +159,6 @@ Status GetTmpFilename(string* filename) {
/// object is empty.
Status ParseGcsPath(StringPiece fname, bool empty_object_ok, string* bucket,
string* object) {
- if (!bucket || !object) {
- return errors::Internal("bucket and object cannot be null.");
- }
StringPiece scheme, bucketp, objectp;
io::ParseURI(fname, &scheme, &bucketp, &objectp);
if (scheme != "gs") {
@@ -448,9 +446,6 @@ class GcsWritableFile : public WritableFile {
}
Status GetCurrentFileSize(uint64* size) {
- if (size == nullptr) {
- return errors::Internal("'size' cannot be nullptr");
- }
const auto tellp = outfile_.tellp();
if (tellp == static_cast<std::streampos>(-1)) {
return errors::Internal(
@@ -462,9 +457,6 @@ class GcsWritableFile : public WritableFile {
/// Initiates a new resumable upload session.
Status CreateNewUploadSession(string* session_uri) {
- if (session_uri == nullptr) {
- return errors::Internal("'session_uri' cannot be nullptr.");
- }
uint64 file_size;
TF_RETURN_IF_ERROR(GetCurrentFileSize(&file_size));
@@ -498,9 +490,6 @@ class GcsWritableFile : public WritableFile {
/// uploaded size in bytes.
Status RequestUploadSessionStatus(const string& session_uri, bool* completed,
uint64* uploaded) {
- if (completed == nullptr || uploaded == nullptr) {
- return errors::Internal("'completed' and 'uploaded' cannot be nullptr.");
- }
uint64 file_size;
TF_RETURN_IF_ERROR(GetCurrentFileSize(&file_size));
@@ -638,6 +627,10 @@ GcsFileSystem::GcsFileSystem()
if (GetEnvVar(kMaxStaleness, strings::safe_strtou64, &value)) {
max_staleness = value;
}
+ if (std::getenv(kReadCacheDisabled)) {
+ // Setting either to 0 disables the cache; set both for good measure.
+ block_size = max_bytes = 0;
+ }
file_block_cache_ = MakeFileBlockCache(block_size, max_bytes, max_staleness);
// Apply overrides for the stat cache max age and max entries, if provided.
uint64 stat_cache_max_age = kStatCacheDefaultMaxAge;
@@ -965,11 +958,16 @@ Status GcsFileSystem::FileExists(const string& fname) {
return Status::OK();
}
}
- bool result;
- TF_RETURN_IF_ERROR(ObjectExists(fname, bucket, object, &result));
- if (result) {
- return Status::OK();
+
+ // Check if the object exists.
+ GcsFileStat stat;
+ const Status status = StatForObject(fname, bucket, object, &stat);
+ if (status.code() != errors::Code::NOT_FOUND) {
+ return status;
}
+
+ // Check if the folder exists.
+ bool result;
TF_RETURN_IF_ERROR(FolderExists(fname, &result));
if (result) {
return Status::OK();
@@ -979,14 +977,11 @@ Status GcsFileSystem::FileExists(const string& fname) {
Status GcsFileSystem::ObjectExists(const string& fname, const string& bucket,
const string& object, bool* result) {
- if (!result) {
- return errors::Internal("'result' cannot be nullptr.");
- }
- GcsFileStat not_used_stat;
- const Status status = StatForObject(fname, bucket, object, &not_used_stat);
+ GcsFileStat stat;
+ const Status status = StatForObject(fname, bucket, object, &stat);
switch (status.code()) {
case errors::Code::OK:
- *result = true;
+ *result = !stat.base.is_directory;
return Status::OK();
case errors::Code::NOT_FOUND:
*result = false;
@@ -1040,15 +1035,19 @@ Status GcsFileSystem::UncachedStatForObject(const string& fname,
<< "; mtime_nsec: " << stat->base.mtime_nsec
<< "; updated: " << updated;
- stat->base.is_directory = false;
+ if (str_util::EndsWith(fname, "/")) {
+ // In GCS a path can be both a directory and a file, both it is uncommon for
+ // other file systems. To avoid the ambiguity, if a path ends with "/" in
+ // GCS, we always regard it as a directory mark or a virtual directory.
+ stat->base.is_directory = true;
+ } else {
+ stat->base.is_directory = false;
+ }
return Status::OK();
}
Status GcsFileSystem::StatForObject(const string& fname, const string& bucket,
const string& object, GcsFileStat* stat) {
- if (!stat) {
- return errors::Internal("'stat' cannot be nullptr.");
- }
if (object.empty()) {
return errors::InvalidArgument(strings::Printf(
"'object' must be a non-empty string. (File: %s)", fname.c_str()));
@@ -1059,18 +1058,10 @@ Status GcsFileSystem::StatForObject(const string& fname, const string& bucket,
[this, &bucket, &object](const string& fname, GcsFileStat* stat) {
return UncachedStatForObject(fname, bucket, object, stat);
}));
- if (stat->base.is_directory) {
- return errors::NotFound(fname, " is a directory.");
- } else {
- return Status::OK();
- }
+ return Status::OK();
}
Status GcsFileSystem::BucketExists(const string& bucket, bool* result) {
- if (!result) {
- return errors::Internal("'result' cannot be nullptr.");
- }
-
std::unique_ptr<HttpRequest> request;
TF_RETURN_IF_ERROR(CreateHttpRequest(&request));
request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket));
@@ -1089,9 +1080,6 @@ Status GcsFileSystem::BucketExists(const string& bucket, bool* result) {
}
Status GcsFileSystem::FolderExists(const string& dirname, bool* result) {
- if (!result) {
- return errors::Internal("'result' cannot be nullptr.");
- }
StatCache::ComputeFunc compute_func = [this](const string& dirname,
GcsFileStat* stat) {
std::vector<string> children;
diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
index 6a28d9162f..e791ae5a19 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
@@ -1137,6 +1137,28 @@ TEST(GcsFileSystemTest, FileExists_StatCache) {
}
}
+TEST(GcsFileSystemTest, FileExists_DirectoryMark) {
+ std::vector<HttpRequest*> requests({new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/"
+ "dir%2F?fields=size%2Cgeneration%2Cupdated\n"
+ "Auth Token: fake_token\n"
+ "Timeouts: 5 1 10\n",
+ strings::StrCat("{\"size\": \"5\",\"generation\": \"1\","
+ "\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 3600 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ kTestTimeoutConfig, nullptr /* gcs additional header */);
+
+ TF_EXPECT_OK(fs.FileExists("gs://bucket/dir/"));
+ TF_EXPECT_OK(fs.IsDirectory("gs://bucket/dir/"));
+}
+
TEST(GcsFileSystemTest, GetChildren_NoItems) {
std::vector<HttpRequest*> requests({new FakeHttpRequest(
"Uri: https://www.googleapis.com/storage/v1/b/bucket/o?"
@@ -2407,6 +2429,30 @@ TEST(GcsFileSystemTest, Stat_Cache_Flush) {
}
}
+TEST(GcsFileSystemTest, Stat_FilenameEndingWithSlash) {
+ std::vector<HttpRequest*> requests({new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/"
+ "dir%2F?fields=size%2Cgeneration%2Cupdated\n"
+ "Auth Token: fake_token\n"
+ "Timeouts: 5 1 10\n",
+ strings::StrCat("{\"size\": \"5\",\"generation\": \"1\","
+ "\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */,
+ 0 /* initial retry delay*/, kTestTimeoutConfig,
+ nullptr /* gcs additional header */);
+
+ FileStatistics stat;
+ TF_EXPECT_OK(fs.Stat("gs://bucket/dir/", &stat));
+ EXPECT_EQ(5, stat.length);
+ EXPECT_TRUE(stat.is_directory);
+}
+
TEST(GcsFileSystemTest, IsDirectory_NotFound) {
std::vector<HttpRequest*> requests(
{new FakeHttpRequest(
diff --git a/tensorflow/core/platform/cloud/ram_file_block_cache.h b/tensorflow/core/platform/cloud/ram_file_block_cache.h
index 2303f9caaa..46fb9a35b8 100644
--- a/tensorflow/core/platform/cloud/ram_file_block_cache.h
+++ b/tensorflow/core/platform/cloud/ram_file_block_cache.h
@@ -60,6 +60,8 @@ class RamFileBlockCache : public FileBlockCache {
pruning_thread_.reset(env_->StartThread(ThreadOptions(), "TF_prune_FBC",
[this] { Prune(); }));
}
+ VLOG(1) << "GCS file block cache is "
+ << (IsCacheEnabled() ? "enabled" : "disabled");
}
~RamFileBlockCache() override {
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index b9eb3d02c5..9e52ba344a 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -517,6 +517,9 @@ def tf_additional_proto_srcs():
"platform/default/protobuf.cc",
]
+def tf_additional_human_readable_json_deps():
+ return []
+
def tf_additional_all_protos():
return ["//tensorflow/core:protos_all"]
diff --git a/tensorflow/core/platform/default/human_readable_json.cc b/tensorflow/core/platform/default/human_readable_json.cc
new file mode 100644
index 0000000000..6bf2106f6e
--- /dev/null
+++ b/tensorflow/core/platform/default/human_readable_json.cc
@@ -0,0 +1,54 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/human_readable_json.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+
+Status ProtoToHumanReadableJson(const ::google::protobuf::Message& proto,
+ string* result) {
+ result->clear();
+
+ auto status = google::protobuf::util::MessageToJsonString(proto, result);
+ if (!status.ok()) {
+ // Convert error_msg google::protobuf::StringPiece to
+ // tensorflow::StringPiece.
+ auto error_msg = status.error_message();
+ return errors::Internal(
+ strings::StrCat("Could not convert proto to JSON string: ",
+ StringPiece(error_msg.data(), error_msg.length())));
+ }
+ return Status::OK();
+}
+
+Status HumanReadableJsonToProto(const string& str,
+ ::google::protobuf::Message* proto) {
+ proto->Clear();
+ auto status = google::protobuf::util::JsonStringToMessage(str, proto);
+ if (!status.ok()) {
+ // Convert error_msg google::protobuf::StringPiece to
+ // tensorflow::StringPiece.
+ auto error_msg = status.error_message();
+ return errors::Internal(
+ strings::StrCat("Could not convert JSON string to proto: ",
+ StringPiece(error_msg.data(), error_msg.length())));
+ }
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/human_readable_json.h b/tensorflow/core/platform/human_readable_json.h
new file mode 100644
index 0000000000..c759e801e9
--- /dev/null
+++ b/tensorflow/core/platform/human_readable_json.h
@@ -0,0 +1,37 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_HUMAN_READABLE_JSON_H_
+#define TENSORFLOW_CORE_PLATFORM_HUMAN_READABLE_JSON_H_
+
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+// Converts a proto to a JSON-like string that's meant to be human-readable
+// but still machine-parseable.
+//
+// This string may not be strictly JSON-compliant, but it must be parseable by
+// HumanReadableJSONToProto.
+Status ProtoToHumanReadableJson(const protobuf::Message& proto, string* result);
+
+// Converts a string produced by ProtoToHumanReadableJSON to a protobuf. Not
+// guaranteed to work for general JSON.
+Status HumanReadableJsonToProto(const string& str, protobuf::Message* proto);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_PLATFORM_HUMAN_READABLE_JSON_H_
diff --git a/tensorflow/docs_src/extend/new_data_formats.md b/tensorflow/docs_src/extend/new_data_formats.md
index 2c33a6b6f7..d1d1f69766 100644
--- a/tensorflow/docs_src/extend/new_data_formats.md
+++ b/tensorflow/docs_src/extend/new_data_formats.md
@@ -45,7 +45,7 @@ Each of these implementations comprises three related classes:
* A `tensorflow::GraphDatasetBase` subclass (e.g. `TextLineDatasetOp::Dataset`),
which represents the *immutable* definition of the dataset itself, and tells
TensorFlow how to construct an iterator object over that dataset, in its
- `MakeIterator()` method.
+ `MakeIteratorInternal()` method.
* A `tensorflow::DatasetIterator<Dataset>` subclass (e.g.
`TextLineDatasetOp::Dataset::Iterator`), which represents the *mutable* state
@@ -103,7 +103,7 @@ class MyReaderDatasetOp : public DatasetOpKernel {
public:
Dataset(OpKernelContext* ctx) : GraphDatasetBase(ctx) {}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::MyReader")}));
@@ -124,7 +124,7 @@ class MyReaderDatasetOp : public DatasetOpKernel {
return *shapes;
}
- string DebugString() override { return "MyReaderDatasetOp::Dataset"; }
+ string DebugString() const override { return "MyReaderDatasetOp::Dataset"; }
protected:
// Optional: Implementation of `GraphDef` serialization for this dataset.
diff --git a/tensorflow/docs_src/performance/benchmarks.md b/tensorflow/docs_src/performance/benchmarks.md
index 20165a090e..a5fa551dd4 100644
--- a/tensorflow/docs_src/performance/benchmarks.md
+++ b/tensorflow/docs_src/performance/benchmarks.md
@@ -403,8 +403,6 @@ GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
This
[script](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks)
was run on the various platforms to generate the above results.
-@{$performance_models$High-Performance Models} details techniques in the script
-along with examples of how to execute the script.
In order to create results that are as repeatable as possible, each test was run
5 times and then the times were averaged together. GPUs are run in their default
diff --git a/tensorflow/docs_src/performance/index.md b/tensorflow/docs_src/performance/index.md
index 49343eaac7..131d28fa3e 100644
--- a/tensorflow/docs_src/performance/index.md
+++ b/tensorflow/docs_src/performance/index.md
@@ -1,19 +1,31 @@
# Performance
-Performance is often a significant issue when training a machine learning
-model. This section explains various ways to optimize performance. Start
-your investigation with the @{$performance_guide$Performance Guide} and then go
-deeper with techniques detailed in @{$performance_models$High-Performance Models}:
-
- * @{$performance_guide$Performance Guide}, which contains a collection of best
+Performance is an important consideration when training machine learning
+models. Performance speeds up and scales research while
+also providing end users with near instant predictions. This section provides
+details on the high level APIs to use along with best practices to build
+and train high performance models, and quantize models for the least latency
+and highest throughput for inference.
+
+ * @{$performance_guide$Performance Guide} contains a collection of best
practices for optimizing your TensorFlow code.
- * @{$performance_models$High-Performance Models}, which contains a collection
- of advanced techniques to build highly scalable models targeting different
- system types and network topologies.
+ * @{$datasets_performance$Data input pipeline guide} describes the tf.data
+ API for building efficient data input pipelines for TensorFlow.
+
+ * @{$performance/benchmarks$Benchmarks} contains a collection of
+ benchmark results for a variety of hardware configurations.
+
+ * For improving inference efficiency on mobile and
+ embedded hardware, see
+ @{$quantization$How to Quantize Neural Networks with TensorFlow}, which
+ explains how to use quantization to reduce model size, both in storage
+ and at runtime.
+
+ * For optimizing inference on GPUs, refer to [NVIDIA TensorRTâ„¢
+ integration with TensorFlow.](
+ https://medium.com/tensorflow/speed-up-tensorflow-inference-on-gpus-with-tensorrt-13b49f3db3fa)
- * @{$performance/benchmarks$Benchmarks}, which contains a collection of
- benchmark results.
XLA (Accelerated Linear Algebra) is an experimental compiler for linear
algebra that optimizes TensorFlow computations. The following guides explore
@@ -36,10 +48,5 @@ XLA:
standalone tool that compiles TensorFlow graphs into executable code in
order to optimize performance.
-And finally, we offer the following guide:
- * @{$quantization$How to Quantize Neural Networks with TensorFlow}, which
- can explains how to use quantization to reduce model size, both in storage
- and at runtime. Quantization can improve performance, especially on
- mobile hardware.
diff --git a/tensorflow/docs_src/performance/leftnav_files b/tensorflow/docs_src/performance/leftnav_files
index 1f894c39fe..12e0dbd48a 100644
--- a/tensorflow/docs_src/performance/leftnav_files
+++ b/tensorflow/docs_src/performance/leftnav_files
@@ -1,7 +1,6 @@
index.md
performance_guide.md
datasets_performance.md
-performance_models.md
benchmarks.md
quantization.md
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 679ef93229..569403fa9a 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -71,6 +71,7 @@ py_library(
visibility = [
"//tensorflow:__pkg__",
"//tensorflow/python/tools:__pkg__",
+ "//tensorflow/tools/api/generator:__pkg__",
],
deps = [
":array_ops",
@@ -717,6 +718,38 @@ py_library(
)
py_library(
+ name = "function_def_to_graph",
+ srcs = ["framework/function_def_to_graph.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":framework",
+ ":function",
+ ":op_def_registry",
+ ":tensor_shape",
+ ":versions",
+ "//tensorflow/core:protos_all_py",
+ ],
+)
+
+py_test(
+ name = "function_def_to_graph_test",
+ size = "small",
+ srcs = ["framework/function_def_to_graph_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":array_ops",
+ ":client_testlib",
+ ":dtypes",
+ ":framework_ops",
+ ":function_def_to_graph",
+ ":graph_to_function_def",
+ ":math_ops",
+ ":test_ops",
+ ],
+)
+
+py_library(
name = "graph_util",
srcs = [
"framework/graph_util.py",
@@ -2699,7 +2732,6 @@ py_library(
":util",
":variables",
"//tensorflow/python/eager:context",
- "//tensorflow/python/estimator:util",
"@six_archive//:six",
],
)
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 23d87fb394..559063d6ae 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -494,7 +494,7 @@ class GraphModeFunction(object):
def __call__(self, *args):
"""Executes the passed function in eager mode."""
for v in self._variables:
- if v._trainable: # pylint: disable=protected-access
+ if v.trainable:
tape.watch_variable(v)
tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)]
diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py
index d9ffcbd203..760a148552 100644
--- a/tensorflow/python/eager/graph_callable.py
+++ b/tensorflow/python/eager/graph_callable.py
@@ -202,7 +202,7 @@ class _InitializingFunctionObject(object):
v.handle).numpy() for v in self._call_fn.variables]
if all(x for x in initialized):
for v in self._call_fn.variables:
- if v._trainable: # pylint: disable=protected-access
+ if v.trainable:
tape.watch_variable(v)
return self._call_fn(*args)
elif all(not x for x in initialized):
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 52b90504f3..e3ce0ef9d0 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -1874,10 +1874,10 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
void MaybeWatchVariable(PyObject* input) {
DCHECK(CheckResourceVariable(input));
- DCHECK(PyObject_HasAttrString(input, "_trainable"));
+ DCHECK(PyObject_HasAttrString(input, "trainable"));
tensorflow::Safe_PyObjectPtr trainable(
- PyObject_GetAttrString(input, "_trainable"));
+ PyObject_GetAttrString(input, "trainable"));
if (trainable.get() == Py_False) return;
TFE_Py_TapeSetWatchVariable(input);
}
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 0754041f9e..9c4d58b177 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -446,7 +446,26 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
+ "//tensorflow/python:platform",
+ "//tensorflow/python:training",
"//tensorflow/python:util",
+ "//tensorflow/python/data",
+ ],
+)
+
+py_test(
+ name = "util_test",
+ srcs = ["util_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["notsan"], # b/67510291
+ deps = [
+ ":util",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:training",
+ "//tensorflow/python/data",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
],
)
@@ -598,6 +617,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
+ ":util",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 331ee7490e..4f57a4ef79 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -32,15 +32,17 @@ from tensorflow.core.framework import summary_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session as tf_session
-from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config
+from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.estimator.export import export as export_helpers
from tensorflow.python.estimator.export import export_output
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import metrics as metrics_lib
@@ -964,17 +966,9 @@ class Estimator(object):
def _get_features_from_input_fn(self, input_fn, mode):
"""Extracts the `features` from return values of `input_fn`."""
result = self._call_input_fn(input_fn, mode)
- input_hooks = []
- if isinstance(result, dataset_ops.Dataset):
- iterator = result.make_initializable_iterator()
- input_hooks.append(_DatasetInitializerHook(iterator))
- result = iterator.get_next()
- if isinstance(result, (list, tuple)):
- # Unconditionally drop the label (the second element of result).
- result = result[0]
-
+ result, _, hooks = estimator_util.parse_input_fn_result(result)
self._validate_features_in_predict_input(result)
- return result, input_hooks
+ return result, hooks
def _validate_features_in_predict_input(self, result):
if not _has_dataset_or_queue_runner(result):
@@ -984,25 +978,13 @@ class Estimator(object):
def _get_features_and_labels_from_input_fn(self, input_fn, mode):
"""Extracts the `features` and labels from return values of `input_fn`."""
- input_hooks = []
if self._distribution is not None and mode == model_fn_lib.ModeKeys.TRAIN:
result = self._distribution.distribute_dataset(
lambda: self._call_input_fn(input_fn, mode))
- iterator = result.make_initializable_iterator()
- input_hooks.append(_DatasetInitializerHook(iterator))
- result = iterator.get_next()
else:
result = self._call_input_fn(input_fn, mode)
- if isinstance(result, dataset_ops.Dataset):
- iterator = result.make_initializable_iterator()
- input_hooks.append(_DatasetInitializerHook(iterator))
- result = iterator.get_next()
- if isinstance(result, (list, tuple)):
- if len(result) != 2:
- raise ValueError(
- 'input_fn should return (features, labels) as a len 2 tuple.')
- return result[0], result[1], input_hooks
- return result, None, input_hooks
+
+ return estimator_util.parse_input_fn_result(result)
def _extract_batch_length(self, preds_evaluated):
"""Extracts batch length of predictions."""
@@ -1067,9 +1049,15 @@ class Estimator(object):
mode: ModeKeys
Returns:
- Either features or (features, labels) where features and labels are:
- features - `Tensor` or dictionary of string feature name to `Tensor`.
- labels - `Tensor` or dictionary of `Tensor` with labels.
+ The return value of the passed input_fn, which should be one of:
+
+ * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
+ tuple (features, labels) with same constraints as below.
+ * A tuple (features, labels): Where `features` is a `Tensor` or a
+ dictionary of string feature name to `Tensor` and `labels` is a
+ `Tensor` or a dictionary of string label name to `Tensor`. Both
+ `features` and `labels` are consumed by `model_fn`. They should
+ satisfy the expectation of `model_fn` from inputs.
Raises:
ValueError: if input_fn takes invalid arguments.
@@ -1397,10 +1385,18 @@ class Estimator(object):
hooks=all_hooks,
config=self._session_config)
+ current_global_step = eval_results[ops.GraphKeys.GLOBAL_STEP]
+
_write_dict_to_summary(
output_dir=output_dir,
dictionary=eval_results,
- current_global_step=eval_results[ops.GraphKeys.GLOBAL_STEP])
+ current_global_step=current_global_step)
+
+ if checkpoint_path:
+ _write_checkpoint_path_to_summary(
+ output_dir=output_dir,
+ checkpoint_path=checkpoint_path,
+ current_global_step=current_global_step)
return eval_results
@@ -1599,6 +1595,30 @@ def _write_dict_to_summary(output_dir,
summary_writer.flush()
+def _write_checkpoint_path_to_summary(output_dir, checkpoint_path,
+ current_global_step):
+ """Writes `checkpoint_path` into summary file in the given output directory.
+
+ Args:
+ output_dir: `str`, directory to write the summary file in.
+ checkpoint_path: `str`, checkpoint file path to be written to summary file.
+ current_global_step: `int`, the current global step.
+ """
+
+ checkpoint_path_tag = 'checkpoint_path'
+
+ logging.info('Saving \'%s\' summary for global step %d: %s',
+ checkpoint_path_tag, current_global_step, checkpoint_path)
+ summary_proto = summary_pb2.Summary()
+ summary_proto.value.add(
+ tag=checkpoint_path_tag,
+ tensor=tensor_util.make_tensor_proto(
+ checkpoint_path, dtype=dtypes.string))
+ summary_writer = writer_cache.FileWriterCache.get(output_dir)
+ summary_writer.add_summary(summary_proto, current_global_step)
+ summary_writer.flush()
+
+
def _has_dataset_or_queue_runner(maybe_tensor):
"""Returns True if TF dataset or QueueRunner has been used."""
# Check TF dataset first. Here, we use a simple algorithm to check the top
@@ -1610,19 +1630,6 @@ def _has_dataset_or_queue_runner(maybe_tensor):
# Now, check queue.
return ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS)
-
-class _DatasetInitializerHook(training.SessionRunHook):
-
- def __init__(self, iterator):
- self._iterator = iterator
-
- def begin(self):
- self._initializer = self._iterator.initializer
-
- def after_create_session(self, session, coord):
- del coord
- session.run(self._initializer)
-
VocabInfo = warm_starting_util.VocabInfo # pylint: disable=invalid-name
tf_export('estimator.VocabInfo', allow_multiple_exports=True)(VocabInfo)
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index a9f20f7fa4..9c0d0f7390 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -39,6 +39,7 @@ from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
from tensorflow.python.layers import layers
from tensorflow.python.lib.io import file_io
@@ -81,21 +82,22 @@ def dummy_model_fn(features, labels, params):
_, _, _ = features, labels, params
-def check_eventfile_for_keyword(keyword, dir_):
- """Checks event files for the keyword."""
+def summaries_with_matching_keyword(keyword, dir_):
+ """Yields summary protos matching given keyword from event file."""
writer_cache.FileWriterCache.clear()
- # Get last Event written.
event_paths = glob.glob(os.path.join(dir_, 'events*'))
- last_event = None
- for last_event in summary_iterator.summary_iterator(event_paths[-1]):
- if last_event.summary is not None:
- for value in last_event.summary.value:
+ for event in summary_iterator.summary_iterator(event_paths[-1]):
+ if event.summary is not None:
+ for value in event.summary.value:
if keyword in value.tag:
- return True
+ yield event.summary
+
- return False
+def check_eventfile_for_keyword(keyword, dir_):
+ """Checks event files for the keyword."""
+ return any(summaries_with_matching_keyword(keyword, dir_))
class EstimatorInheritanceConstraintTest(test.TestCase):
@@ -1398,6 +1400,19 @@ class EstimatorEvaluateTest(test.TestCase):
check_eventfile_for_keyword(key, est.eval_dir()),
'{} should be part of reported summaries.'.format(key))
+ # Verify that evaluated checkpoint path is written to event file.
+ checkpoint_path_tag = 'checkpoint_path'
+ self.assertTrue(
+ check_eventfile_for_keyword(checkpoint_path_tag, est.eval_dir()),
+ '{} should be part of reported summaries.'.format(checkpoint_path_tag))
+
+ expected_tensor_proto = tensor_util.make_tensor_proto(
+ est.latest_checkpoint(), dtype=dtypes.string)
+ summaries = summaries_with_matching_keyword(checkpoint_path_tag,
+ est.eval_dir())
+ self.assertProtoEquals(expected_tensor_proto,
+ next(summaries).value[0].tensor)
+
class EstimatorPredictTest(test.TestCase):
diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py
index 48ae8cd497..ff19a0a7f4 100644
--- a/tensorflow/python/estimator/export/export.py
+++ b/tensorflow/python/estimator/export/export.py
@@ -404,6 +404,42 @@ def build_raw_supervised_input_receiver_fn(features,
return supervised_input_receiver_fn
+def build_supervised_input_receiver_fn_from_input_fn(input_fn, **input_fn_args):
+ """Get a function that returns a SupervisedInputReceiver matching an input_fn.
+
+ Note that this function calls the input_fn in a local graph in order to
+ extract features and labels. Placeholders are then created from those
+ features and labels in the default graph.
+
+ Args:
+ input_fn: An Estimator input_fn, which is a function that returns one of:
+
+ * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
+ tuple (features, labels) with same constraints as below.
+ * A tuple (features, labels): Where `features` is a `Tensor` or a
+ dictionary of string feature name to `Tensor` and `labels` is a
+ `Tensor` or a dictionary of string label name to `Tensor`. Both
+ `features` and `labels` are consumed by `model_fn`. They should
+ satisfy the expectation of `model_fn` from inputs.
+
+ **input_fn_args: set of kwargs to be passed to the input_fn. Note that
+ these will not be checked or validated here, and any errors raised by
+ the input_fn will be thrown to the top.
+
+ Returns:
+ A function taking no arguments that, when called, returns a
+ SupervisedInputReceiver. This function can be passed in as part of the
+ input_receiver_map when exporting SavedModels from Estimator with multiple
+ modes.
+ """
+ # Wrap the input_fn call in a graph to prevent sullying the default namespace
+ with ops.Graph().as_default():
+ result = input_fn(**input_fn_args)
+ features, labels, _ = util.parse_input_fn_result(result)
+ # Placeholders are created back in the default graph.
+ return build_raw_supervised_input_receiver_fn(features, labels)
+
+
### Below utilities are specific to SavedModel exports.
diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py
index 0af587f2a8..a7074712c2 100644
--- a/tensorflow/python/estimator/export/export_test.py
+++ b/tensorflow/python/estimator/export/export_test.py
@@ -459,6 +459,41 @@ class ExportTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError):
export.build_raw_supervised_input_receiver_fn(features, labels)
+ def test_build_supervised_input_receiver_fn_from_input_fn(self):
+ def dummy_input_fn():
+ return ({"x": constant_op.constant([[1], [1]]),
+ "y": constant_op.constant(["hello", "goodbye"])},
+ constant_op.constant([[1], [1]]))
+
+ input_receiver_fn = export.build_supervised_input_receiver_fn_from_input_fn(
+ dummy_input_fn)
+
+ with ops.Graph().as_default():
+ input_receiver = input_receiver_fn()
+ self.assertEqual(set(["x", "y"]),
+ set(input_receiver.features.keys()))
+ self.assertIsInstance(input_receiver.labels, ops.Tensor)
+ self.assertEqual(set(["x", "y", "label"]),
+ set(input_receiver.receiver_tensors.keys()))
+
+ def test_build_supervised_input_receiver_fn_from_input_fn_args(self):
+ def dummy_input_fn(feature_key="x"):
+ return ({feature_key: constant_op.constant([[1], [1]]),
+ "y": constant_op.constant(["hello", "goodbye"])},
+ {"my_label": constant_op.constant([[1], [1]])})
+
+ input_receiver_fn = export.build_supervised_input_receiver_fn_from_input_fn(
+ dummy_input_fn, feature_key="z")
+
+ with ops.Graph().as_default():
+ input_receiver = input_receiver_fn()
+ self.assertEqual(set(["z", "y"]),
+ set(input_receiver.features.keys()))
+ self.assertEqual(set(["my_label"]),
+ set(input_receiver.labels.keys()))
+ self.assertEqual(set(["z", "y", "my_label"]),
+ set(input_receiver.receiver_tensors.keys()))
+
def test_build_all_signature_defs_without_receiver_alternatives(self):
receiver_tensor = array_ops.placeholder(dtypes.string)
output_1 = constant_op.constant([1.])
diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py
index e4e1d37f74..924ca309ff 100644
--- a/tensorflow/python/estimator/util.py
+++ b/tensorflow/python/estimator/util.py
@@ -24,6 +24,7 @@ import time
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import training
from tensorflow.python.util import compat
from tensorflow.python.util import function_utils
@@ -72,3 +73,59 @@ def get_timestamped_dir(dir_base):
result_dir, attempts, MAX_DIRECTORY_CREATION_ATTEMPTS))
raise RuntimeError('Failed to obtain a unique export directory name after '
'{} attempts.'.format(MAX_DIRECTORY_CREATION_ATTEMPTS))
+
+
+def parse_input_fn_result(result):
+ """Gets features, labels, and hooks from the result of an Estimator input_fn.
+
+ Args:
+ result: output of an input_fn to an estimator, which should be one of:
+
+ * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
+ tuple (features, labels) with same constraints as below.
+ * A tuple (features, labels): Where `features` is a `Tensor` or a
+ dictionary of string feature name to `Tensor` and `labels` is a
+ `Tensor` or a dictionary of string label name to `Tensor`. Both
+ `features` and `labels` are consumed by `model_fn`. They should
+ satisfy the expectation of `model_fn` from inputs.
+
+ Returns:
+ Tuple of features, labels, and input_hooks, where features are as described
+ above, labels are as described above or None, and input_hooks are a list
+ of SessionRunHooks to be included when running.
+
+ Raises:
+ ValueError: if the result is a list or tuple of length != 2.
+ """
+ input_hooks = []
+ try:
+ # We can't just check whether this is a tf.data.Dataset instance here,
+ # as this is plausibly a PerDeviceDataset. Try treating as a dataset first.
+ iterator = result.make_initializable_iterator()
+ except AttributeError:
+ # Not a dataset or dataset-like-object. Move along.
+ pass
+ else:
+ input_hooks.append(_DatasetInitializerHook(iterator))
+ result = iterator.get_next()
+
+ if isinstance(result, (list, tuple)):
+ if len(result) != 2:
+ raise ValueError(
+ 'input_fn should return (features, labels) as a len 2 tuple.')
+ return result[0], result[1], input_hooks
+ return result, None, input_hooks
+
+
+class _DatasetInitializerHook(training.SessionRunHook):
+ """Creates a SessionRunHook that initializes the passed iterator."""
+
+ def __init__(self, iterator):
+ self._iterator = iterator
+
+ def begin(self):
+ self._initializer = self._iterator.initializer
+
+ def after_create_session(self, session, coord):
+ del coord
+ session.run(self._initializer)
diff --git a/tensorflow/python/estimator/util_test.py b/tensorflow/python/estimator/util_test.py
new file mode 100644
index 0000000000..d7e0610779
--- /dev/null
+++ b/tensorflow/python/estimator/util_test.py
@@ -0,0 +1,102 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for util.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.estimator import util
+from tensorflow.python.framework import constant_op
+from tensorflow.python.platform import test
+from tensorflow.python.training import training
+
+
+class UtilTest(test.TestCase):
+ """Tests for miscellaneous Estimator utils."""
+
+ def test_parse_input_fn_result_tuple(self):
+ def _input_fn():
+ features = constant_op.constant(np.arange(100))
+ labels = constant_op.constant(np.arange(100, 200))
+ return features, labels
+
+ features, labels, hooks = util.parse_input_fn_result(_input_fn())
+
+ with self.test_session() as sess:
+ vals = sess.run([features, labels])
+
+ self.assertAllEqual(vals[0], np.arange(100))
+ self.assertAllEqual(vals[1], np.arange(100, 200))
+ self.assertEqual(hooks, [])
+
+ def test_parse_input_fn_result_dataset(self):
+ def _input_fn():
+ features = np.expand_dims(np.arange(100), 0)
+ labels = np.expand_dims(np.arange(100, 200), 0)
+ return dataset_ops.Dataset.from_tensor_slices((features, labels))
+
+ features, labels, hooks = util.parse_input_fn_result(_input_fn())
+
+ with training.MonitoredSession(hooks=hooks) as sess:
+ vals = sess.run([features, labels])
+
+ self.assertAllEqual(vals[0], np.arange(100))
+ self.assertAllEqual(vals[1], np.arange(100, 200))
+ self.assertIsInstance(hooks[0], util._DatasetInitializerHook)
+
+ def test_parse_input_fn_result_features_only(self):
+ def _input_fn():
+ return constant_op.constant(np.arange(100))
+
+ features, labels, hooks = util.parse_input_fn_result(_input_fn())
+
+ with self.test_session() as sess:
+ vals = sess.run([features])
+
+ self.assertAllEqual(vals[0], np.arange(100))
+ self.assertEqual(labels, None)
+ self.assertEqual(hooks, [])
+
+ def test_parse_input_fn_result_features_only_dataset(self):
+ def _input_fn():
+ features = np.expand_dims(np.arange(100), 0)
+ return dataset_ops.Dataset.from_tensor_slices(features)
+
+ features, labels, hooks = util.parse_input_fn_result(_input_fn())
+
+ with training.MonitoredSession(hooks=hooks) as sess:
+ vals = sess.run([features])
+
+ self.assertAllEqual(vals[0], np.arange(100))
+ self.assertEqual(labels, None)
+ self.assertIsInstance(hooks[0], util._DatasetInitializerHook)
+
+ def test_parse_input_fn_result_invalid(self):
+ def _input_fn():
+ features = np.expand_dims(np.arange(100), 0)
+ labels = np.expand_dims(np.arange(100, 200), 0)
+ return dataset_ops.Dataset.from_tensor_slices((features, labels, labels))
+
+ with self.assertRaisesRegexp(ValueError, 'input_fn should return'):
+ util.parse_input_fn_result(_input_fn())
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 0675222016..259cab6699 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -718,8 +718,12 @@ class _FuncGraph(ops.Graph):
tensor.dtype, shape=tensor.get_shape(), name=name)
# pylint: disable=protected-access
if ops._USE_C_SHAPES:
- handle_data = c_api.GetResourceHandleShapeAndType(tensor.graph._c_graph,
- tensor._as_tf_output())
+ if isinstance(tensor, ops.EagerTensor):
+ handle_data = tensor._handle_data
+ else:
+ handle_data = c_api.GetResourceHandleShapeAndType(
+ tensor.graph._c_graph, tensor._as_tf_output())
+
if handle_data:
c_api.SetResourceHandleShapeAndType(ph.graph._c_graph,
ph._as_tf_output(),
diff --git a/tensorflow/python/framework/function_def_to_graph.py b/tensorflow/python/framework/function_def_to_graph.py
new file mode 100644
index 0000000000..4fecc41343
--- /dev/null
+++ b/tensorflow/python/framework/function_def_to_graph.py
@@ -0,0 +1,189 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Utlity to convert FunctionDef to GraphDef and Graph."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.framework import graph_pb2
+from tensorflow.core.framework import types_pb2
+from tensorflow.core.framework import versions_pb2
+from tensorflow.python.framework import function
+from tensorflow.python.framework import importer
+from tensorflow.python.framework import op_def_registry
+from tensorflow.python.framework import versions
+
+
+def function_def_to_graph(fdef, input_shapes=None):
+ """Converts a FunctionDef to a function._FuncGraph (sub-class Graph).
+
+ The returned _FuncGraph's `name`, `inputs` and `outputs` fields will be set.
+ The input tensors are represented as placeholders.
+
+ Note: `_FuncGraph.inputs` and `_FuncGraph._captured` are not set and may be
+ set by the caller.
+
+ Args:
+ fdef: FunctionDef.
+ input_shapes: Optional. A list of TensorShape objects of the shapes of
+ function inputs. If specified, its length must match length of
+ `fdef.signature.input_arg`. If a shape is None, the corresponding input
+ placeholder will have unknown shape.
+
+ Returns:
+ A _FuncGraph.
+ """
+ func_graph = function._FuncGraph(fdef.signature.name, capture_by_value=False) # pylint: disable=protected-access
+ graph_def, nested_to_flat_tensor_name = function_def_to_graph_def(
+ fdef, input_shapes)
+
+ with func_graph.as_default():
+ # Add all function nodes to the graph.
+ importer.import_graph_def(graph_def, name="")
+
+ # Initialize fields specific to _FuncGraph.
+
+ # inputs
+ input_tensor_names = [
+ nested_to_flat_tensor_name[arg.name] for arg in fdef.signature.input_arg
+ ]
+ func_graph.inputs = [
+ func_graph.get_tensor_by_name(name) for name in input_tensor_names
+ ]
+
+ # outputs
+ output_tensor_names = [
+ nested_to_flat_tensor_name[fdef.ret[arg.name]]
+ for arg in fdef.signature.output_arg
+ ]
+ func_graph.outputs = [
+ func_graph.get_tensor_by_name(name) for name in output_tensor_names
+ ]
+
+ return func_graph
+
+
+def function_def_to_graph_def(fdef, input_shapes=None):
+ """Convert a FunctionDef to a GraphDef.
+
+ Steps:
+ 1. Creates placeholder nodes corresponding to inputs in
+ `FunctionDef.signature.input_arg`.
+ 2. Adds NodeDefs in `FunctionDef.node_def` to `GraphDef.node`.
+ 3. Renames inputs of all nodes to use the convention of GraphDef instead of
+ FunctionDef. See comment on `FunctionDef.node_def` on how the tensor naming
+ in FunctionDefs is different from GraphDefs.
+
+ Args:
+ fdef: FunctionDef.
+ input_shapes: Optional. A list of TensorShape objects of the shapes of
+ function inputs. If specified, its length must match length of
+ `fdef.signature.input_arg`. If a shape is None, the corresponding input
+ placeholder will have unknown shape.
+
+ Returns:
+ A tuple of (GraphDef, dict<string, string>). The dict contains a mapping
+ from nested tensor names (in FunctionDef) to flattened names (in GraphDef).
+
+ Raises:
+ ValueError: If the length of input_shapes does not match the number of
+ input_args or if the FunctionDef is invalid.
+ """
+ graph_def = graph_pb2.GraphDef()
+ graph_def.versions.CopyFrom(
+ versions_pb2.VersionDef(
+ producer=versions.GRAPH_DEF_VERSION,
+ min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER))
+
+ if input_shapes and len(input_shapes) != len(fdef.signature.input_arg):
+ raise ValueError("Length of input_shapes must match the number of " +
+ "input_args. len(input_shapes): {} len(input_arg): {}".
+ format(len(input_shapes), len(fdef.signature.input_arg)))
+
+ # 1. Create placeholders for input nodes.
+ for i, arg_def in enumerate(fdef.signature.input_arg):
+ node_def = graph_def.node.add()
+ node_def.name = arg_def.name
+ node_def.op = "Placeholder"
+ node_def.attr["dtype"].type = arg_def.type
+ if input_shapes and input_shapes[i] is not None:
+ node_def.attr["shape"].shape.CopyFrom(input_shapes[i].as_proto())
+
+ # 2. Copy all body NodeDefs to the GraphDef.
+ graph_def.node.extend(fdef.node_def)
+
+ # 3. Perform the renaming.
+
+ # Build the tensor name mapping then flatten the tensor names.
+ # See comment on `FunctionDef.node_def` on how the tensor naming in
+ # FunctionDefs is different from GraphDefs.
+ nested_to_flat_tensor_name = {}
+
+ for arg_def in fdef.signature.input_arg:
+ nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name)
+
+ for node_def in fdef.node_def:
+ op_def = op_def_registry.get_registered_ops().get(node_def.op)
+ if not op_def:
+ # TODO(b/80470245): Support functions which refer other functions.
+ raise NotImplementedError(
+ "No op registered for {},".format(node_def.op) +
+ " it may be a function. function_def_to_graph_def " +
+ "currently does not support converting functions with " +
+ "references to other graph functions.")
+
+ for attr in op_def.attr:
+ if attr.type in ("func", "list(func)"):
+ # TODO(b/80470245): Support functions which refer other functions.
+ raise NotImplementedError("Unsupported attr {} ".format(attr.name) +
+ " with type {}".format(attr.type) +
+ " in op {}. ".format(op_def.name) +
+ "function_def_to_graph_def currently does " +
+ "not support converting functions with " +
+ "references to other graph functions.")
+
+ # Iterate over output_args in op_def to build the map.
+ # Index of the output tensor in the flattened list of *all* output
+ # tensors of the op.
+ flattened_index = 0
+ for arg_def in op_def.output_arg:
+ num_args = _get_num_args(arg_def, node_def)
+ for i in range(num_args):
+ # Map tensor names from "node_name:output_arg_name:index" to
+ # "node_name:flattened_index".
+ nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i)
+ flat_name = "{}:{}".format(node_def.name, flattened_index)
+ nested_to_flat_tensor_name[nested_name] = flat_name
+ flattened_index += 1
+
+ # Update inputs of all nodes in graph.
+ for node_def in graph_def.node:
+ for i in range(len(node_def.input)):
+ node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]]
+
+ return graph_def, nested_to_flat_tensor_name
+
+
+# Based on implementation in core/framework/node_def_util.cc::ComputeArgRange.
+def _get_num_args(arg_def, node_def):
+ if arg_def.number_attr:
+ return node_def.attr[arg_def.number_attr].i
+ elif arg_def.type_list_attr:
+ return len(node_def.attr[arg_def.type_list_attr].list.type)
+ elif arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID:
+ return 1
+ else:
+ raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def)))
diff --git a/tensorflow/python/framework/function_def_to_graph_test.py b/tensorflow/python/framework/function_def_to_graph_test.py
new file mode 100644
index 0000000000..0f4e6ef54f
--- /dev/null
+++ b/tensorflow/python/framework/function_def_to_graph_test.py
@@ -0,0 +1,184 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.python.framework.function_def_to_graph."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function_def_to_graph
+from tensorflow.python.framework import graph_to_function_def
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class FunctionDefToGraphTest(test.TestCase):
+
+ def _build_function_def(self):
+ with ops.Graph().as_default() as g:
+ # Inputs
+ x = array_ops.placeholder(dtypes.float32, name="x")
+ y = array_ops.placeholder(dtypes.float32, name="y")
+
+ # Outputs
+ sum_squares = math_ops.add_n(
+ [math_ops.pow(x, 2), math_ops.pow(y, 2)], name="sum_squares")
+ sum_cubes = math_ops.add_n(
+ [math_ops.pow(x, 3), math_ops.pow(y, 3)], name="sum_cubes")
+ fdef = graph_to_function_def.graph_to_function_def(
+ g,
+ g.get_operations(),
+ [x, y], # Inputs
+ [sum_squares, sum_cubes]) # Outputs.
+ fdef.signature.name = "_whats_in_a_name"
+ return fdef
+
+ def testInputsAndOutputs(self):
+ fdef = self._build_function_def()
+ g = function_def_to_graph.function_def_to_graph(fdef)
+ self.assertEqual(g.name, "_whats_in_a_name")
+ with self.test_session(graph=g) as sess:
+ inputs = sess.run(g.inputs, feed_dict={"x:0": 2, "y:0": 3})
+ self.assertSequenceEqual(inputs, [2.0, 3.0])
+ outputs = sess.run(g.outputs, feed_dict={"x:0": 2, "y:0": 3})
+ self.assertSequenceEqual(outputs, [13.0, 35.0])
+
+ def testShapes(self):
+ fdef = self._build_function_def()
+
+ g = function_def_to_graph.function_def_to_graph(fdef)
+ self.assertIsNone(g.inputs[0].shape.dims) # Unknown dims.
+ self.assertIsNone(g.inputs[1].shape.dims) # Unknown dims.
+ self.assertIsNone(g.outputs[0].shape.dims) # Unknown dims.
+ self.assertIsNone(g.outputs[1].shape.dims) # Unknown dims.
+
+ g = function_def_to_graph.function_def_to_graph(
+ fdef, input_shapes=[tensor_shape.vector(5),
+ tensor_shape.vector(5)])
+ self.assertSequenceEqual(g.inputs[0].shape.dims, [5])
+ self.assertSequenceEqual(g.inputs[1].shape.dims, [5])
+ self.assertSequenceEqual(g.outputs[0].shape.dims, [5])
+ self.assertSequenceEqual(g.outputs[1].shape.dims, [5])
+
+ g = function_def_to_graph.function_def_to_graph(
+ fdef, input_shapes=[None, tensor_shape.matrix(5, 7)])
+ print(g.as_graph_def())
+ self.assertIsNone(g.inputs[0].shape.dims)
+ self.assertSequenceEqual(g.inputs[1].shape.dims, [5, 7])
+ self.assertSequenceEqual(g.outputs[0].shape.dims, [5, 7])
+ self.assertSequenceEqual(g.outputs[1].shape.dims, [5, 7])
+
+ # Should raise a ValueError if the length of input_shapes does not match
+ # the number of input args in FunctionDef.signature.input_arg.
+ with self.assertRaises(ValueError):
+ g = function_def_to_graph.function_def_to_graph(
+ fdef, input_shapes=[tensor_shape.matrix(5, 7)])
+
+
+class FunctionDefToGraphDefTest(test.TestCase):
+
+ def _build_function_def(self):
+ with ops.Graph().as_default() as g:
+ # Inputs: x y z
+ # |\ | /
+ # | \ | /
+ # | foo_1 list_output
+ # | / \ / \
+ # | d_1 e_1 a:1 a:0
+ # | \ | / |
+ # | \ | / |
+ # | foo_2 |
+ # | / \ |
+ # Outputs: x d_2 e_2 a:0
+
+ x = array_ops.placeholder(dtypes.float32, name="x")
+ y = array_ops.placeholder(dtypes.int32, name="y")
+ z = array_ops.placeholder(dtypes.int32, name="z")
+
+ d_1, e_1 = test_ops._op_def_lib.apply_op(
+ "Foo1", name="foo_1", a=x, b=y, c=z)
+
+ list_output0, list_output1 = test_ops.list_output(
+ T=[dtypes.int32, dtypes.int32], name="list_output")
+
+ d_2, e_2 = test_ops.foo1(a=d_1, b=e_1, c=list_output1, name="foo_2")
+
+ fdef = graph_to_function_def.graph_to_function_def(
+ g,
+ g.get_operations(),
+ [x, y, z], # Inputs
+ [x, d_2, e_2, list_output0]) # Outputs.
+
+ # Assert that the FunctionDef was correctly built.
+ assert len(fdef.node_def) == 3 # 2 Foo1 nodes and 1 ListOutput node.
+ assert fdef.node_def[0].op == "Foo1"
+ assert fdef.node_def[0].input == ["x", "y", "z"]
+ assert fdef.node_def[1].op == "ListOutput"
+ assert not fdef.node_def[1].input
+ assert fdef.node_def[2].op == "Foo1"
+ assert fdef.node_def[2].input == [
+ "foo_1:d:0", "foo_1:e:0", "list_output:a:1"
+ ]
+ return fdef
+
+ def testTensorNames(self):
+ fdef = self._build_function_def()
+ g, tensor_name_map = function_def_to_graph.function_def_to_graph_def(fdef)
+
+ # Verify that inputs of body nodes are correctly renamed.
+ # foo_1
+ self.assertSequenceEqual(g.node[3].input, ["x:0", "y:0", "z:0"])
+ # foo_2
+ self.assertSequenceEqual(g.node[5].input,
+ ["foo_1:0", "foo_1:1", "list_output:1"])
+
+ # Verify that the `tensor_name_map` has the correct mapping.
+ self.assertDictEqual(
+ tensor_name_map, {
+ "x": "x:0",
+ "y": "y:0",
+ "z": "z:0",
+ "foo_1:d:0": "foo_1:0",
+ "foo_1:e:0": "foo_1:1",
+ "list_output:a:0": "list_output:0",
+ "list_output:a:1": "list_output:1",
+ "foo_2:d:0": "foo_2:0",
+ "foo_2:e:0": "foo_2:1",
+ })
+
+ def testShapes(self):
+ fdef = self._build_function_def()
+ g, _ = function_def_to_graph.function_def_to_graph_def(
+ fdef,
+ input_shapes=[tensor_shape.scalar(),
+ tensor_shape.vector(5), None])
+ self.assertEqual("shape" in g.node[0].attr, True)
+ self.assertSequenceEqual(
+ tensor_shape.TensorShape(g.node[0].attr["shape"].shape).as_list(), [])
+ self.assertEqual(g.node[0].attr["shape"].shape.unknown_rank, False)
+ self.assertEqual("shape" in g.node[1].attr, True)
+ self.assertSequenceEqual(
+ tensor_shape.TensorShape(g.node[1].attr["shape"].shape).as_list(), [5])
+ self.assertEqual(g.node[0].attr["shape"].shape.unknown_rank, False)
+ self.assertFalse("shape" in g.node[2].attr)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 6b031fe99b..a19a72c881 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -59,11 +59,9 @@ from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-# Temporary global switch determining if we should enable the work-in-progress
-# calls to the C API. Currently disabled by default but can be manually enabled
-# in code or via the environment variable. This will be removed once all
-# functionality is supported and there's no performance penalty with it enabled.
-_USE_C_API = os.getenv("TF_C_API_GRAPH_CONSTRUCTION", "1") is not "0"
+# Temporary global switches determining if we should enable the work-in-progress
+# calls to the C API. These will be removed once all functionality is supported.
+_USE_C_API = True
_USE_C_SHAPES = os.getenv("TF_C_API_GRAPH_CONSTRUCTION_SHAPES", "0") is not "0"
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index 2d6925d1a8..af5d709f7e 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -1389,7 +1389,7 @@ class LayoutOptimizerTest(test.TestCase):
expected_num_transposes = 3
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_trans_nhwc_to_nchw('map/while/Conv2D-0', nodes)
- self._assert_trans_nchw_to_nhwc('map/while/Add-0-2', nodes)
+ self._assert_trans_nchw_to_nhwc('map/while/Add_1-0-2', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
def testLoopWithVecAnd4D(self):
@@ -1413,7 +1413,7 @@ class LayoutOptimizerTest(test.TestCase):
expected_num_transposes = 2
self.assertEqual(expected_num_transposes, num_transposes)
self._assert_trans_nhwc_to_nchw('map/while/Conv2D-0', nodes)
- self._assert_trans_nchw_to_nhwc('map/while/Add-0-2', nodes)
+ self._assert_trans_nchw_to_nhwc('map/while/Add_1-0-2', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
def testBinaryOpSecondPort(self):
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index a6b5940e2f..9dbf94a276 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -36,9 +36,10 @@ from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import saving
from tensorflow.python.keras.utils import generic_utils
+from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
-from tensorflow.python.keras.utils.layer_utils import print_summary as print_layer_summary
+from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.checkpointable import data_structures_base
@@ -94,6 +95,11 @@ class Network(base_layer.Layer):
self.trainable = True
self._is_compiled = False
self._expects_training_arg = False
+ # A list of "extra" variables assigned to attributes of this class, included
+ # in self.weights and self.variables. Always empty for graph networks (but
+ # included in base_init to avoid excessive special casing when retrieving
+ # the value).
+ self._extra_variables = []
self.supports_masking = False
if not hasattr(self, 'optimizer'):
@@ -347,11 +353,22 @@ class Network(base_layer.Layer):
# layers). Therefore Model tracks Checkpointable objects itself.
self._track_checkpointable(
checkpointable=value, name=name, overwrite=True)
+ if ( # For subclassed models only, users may add extra weights/variables
+ # simply by assigning them to attributes.
+ not self._is_graph_network
+ and isinstance(value, variables.Variable)):
+ self._extra_variables.append(value)
super(Network, self).__setattr__(name, value)
def add_variable(self, name, shape, dtype=None, initializer=None,
regularizer=None, trainable=True, constraint=None):
- raise NotImplementedError('`add_variable` is not supported on Networks.')
+ if self._is_graph_network:
+ raise NotImplementedError('`add_variable` is not supported on Networks.')
+ else:
+ raise NotImplementedError(
+ '`add_variable` is not supported on Networks. However, you may '
+ 'assign variables to attributes and they will show up in the weights '
+ 'and variables properties.')
def add_loss(self, *args, **kwargs):
if context.executing_eagerly():
@@ -589,24 +606,17 @@ class Network(base_layer.Layer):
@property
def trainable_weights(self):
- if not self.trainable:
- return []
- weights = []
- for layer in self.layers:
- weights += layer.trainable_weights
- return weights
+ return layer_utils.gather_trainable_weights(
+ trainable=self.trainable,
+ sub_layers=self.layers,
+ extra_variables=self._extra_variables)
@property
def non_trainable_weights(self):
- weights = []
- for layer in self.layers:
- weights += layer.non_trainable_weights
- if not self.trainable:
- trainable_weights = []
- for layer in self.layers:
- trainable_weights += layer.trainable_weights
- return trainable_weights + weights
- return weights
+ return layer_utils.gather_non_trainable_weights(
+ trainable=self.trainable,
+ sub_layers=self.layers,
+ extra_variables=self._extra_variables)
@property
def input_spec(self):
@@ -1437,10 +1447,10 @@ class Network(base_layer.Layer):
'have not yet been created, so no summary can be '
'displayed. Build the model first '
'(e.g. by calling it on some data).')
- print_layer_summary(self,
- line_length=line_length,
- positions=positions,
- print_fn=print_fn)
+ layer_utils.print_summary(self,
+ line_length=line_length,
+ positions=positions,
+ print_fn=print_fn)
def get_source_inputs(tensor, layer=None, node_index=None):
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index c0dc5220f1..7743d00c0f 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -574,28 +574,26 @@ class BatchNormalization(Layer):
lambda: variance,
lambda: moving_variance)
+ if self.virtual_batch_size is not None:
+ # This isn't strictly correct since in ghost batch norm, you are
+ # supposed to sequentially update the moving_mean and moving_variance
+ # with each sub-batch. However, since the moving statistics are only
+ # used during evaluation, it is more efficient to just update in one
+ # step and should not make a significant difference in the result.
+ new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True)
+ new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True)
+ else:
+ new_mean, new_variance = mean, variance
+
if self.renorm:
r, d, new_mean, new_variance = self._renorm_correction_and_moments(
- mean, variance, training)
+ new_mean, new_variance, training)
# When training, the normalized values (say, x) will be transformed as
# x * gamma + beta without renorm, and (x * r + d) * gamma + beta
# = x * (r * gamma) + (d * gamma + beta) with renorm.
r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
scale, offset = _compose_transforms(r, d, scale, offset)
- else:
- new_mean, new_variance = mean, variance
-
- if self.virtual_batch_size is not None:
- # This isn't strictly correct since in ghost batch norm, you are
- # supposed to sequentially update the moving_mean and moving_variance
- # with each sub-batch. However, since the moving statistics are only
- # used during evaluation, it is more efficient to just update in one
- # step and should not make a significant difference in the result.
- new_mean = math_ops.reduce_mean(new_mean,
- axis=1, keepdims=True)
- new_variance = math_ops.reduce_mean(new_variance,
- axis=1, keepdims=True)
def _do_update(var, value):
if in_eager_mode and not self.trainable:
diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py
index 558854ab97..86f7e20bec 100644
--- a/tensorflow/python/keras/model_subclassing_test.py
+++ b/tensorflow/python/keras/model_subclassing_test.py
@@ -622,6 +622,51 @@ class ModelSubclassingTest(test.TestCase):
self.assertIs(m.isdep, m._checkpoint_dependencies[0].ref)
self.assertEqual('notdep_var:0', m.notdep_var.name)
+ def test_extra_variable(self):
+
+ class ExtraVar(keras.Model):
+
+ def __init__(self):
+ super(ExtraVar, self).__init__()
+ self.dense = keras.layers.Dense(1)
+ self.var = resource_variable_ops.ResourceVariable(1.)
+ self.not_trainable_var = resource_variable_ops.ResourceVariable(
+ 2., trainable=False)
+
+ def call(self, inputs):
+ return self.dense(inputs + self.var)
+
+ m = ExtraVar()
+ self.assertTrue(m.trainable)
+ self.assertEqual([m.dense], m.layers)
+ self.assertEqual([m.var, m.not_trainable_var], m.variables)
+ self.assertEqual([m.var], m.trainable_variables)
+ self.assertEqual([m.not_trainable_var], m.non_trainable_variables)
+ m.trainable = False
+ self.assertEqual([m.var, m.not_trainable_var], m.variables)
+ self.assertEqual([], m.trainable_variables)
+ self.assertEqual([m.var, m.not_trainable_var], m.non_trainable_variables)
+ m.trainable = True
+
+ m(array_ops.ones([1, 1]))
+
+ self.assertEqual([m.dense.kernel, m.dense.bias], m.dense.variables)
+ self.assertEqual([m.dense.kernel, m.dense.bias], m.dense.weights)
+
+ self.assertEqual([m.dense.kernel, m.dense.bias, m.var, m.not_trainable_var],
+ m.variables)
+ self.assertEqual([m.dense.kernel, m.dense.bias, m.var],
+ m.trainable_variables)
+ self.assertEqual([m.not_trainable_var], m.non_trainable_variables)
+
+ m.dense.trainable = False
+ self.assertEqual(
+ [m.var, m.dense.kernel, m.dense.bias, m.not_trainable_var],
+ m.variables)
+ self.assertEqual([m.var], m.trainable_variables)
+ self.assertEqual([m.dense.kernel, m.dense.bias, m.not_trainable_var],
+ m.non_trainable_variables)
+
class CustomCallModel(keras.Model):
diff --git a/tensorflow/python/keras/models_test.py b/tensorflow/python/keras/models_test.py
index 01fb41b8ee..c616d8f24f 100644
--- a/tensorflow/python/keras/models_test.py
+++ b/tensorflow/python/keras/models_test.py
@@ -18,10 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+
import numpy as np
from tensorflow.python import keras
+from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
+from tensorflow.python.training import adam
class TestModelCloning(test.TestCase):
@@ -123,5 +127,22 @@ class TestModelCloning(test.TestCase):
keras.models._clone_sequential_model(seq_model, input_tensors=y)
+class CheckpointingTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_optimizer_dependency(self):
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(1, input_shape=(4,)))
+ opt = adam.AdamOptimizer(0.01)
+ model.compile(optimizer=opt, loss='mse')
+ model.fit(x=np.array([[1., 2., 3., 4.]]), y=[1.], epochs=2)
+ save_prefix = os.path.join(self.get_temp_dir(), 'ckpt')
+ beta1_power, _ = opt._get_beta_accumulators()
+ self.evaluate(beta1_power.assign(12.))
+ model.save_weights(save_prefix)
+ self.evaluate(beta1_power.assign(13.))
+ model.load_weights(save_prefix)
+ self.assertEqual(12., self.evaluate(beta1_power))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/optimizers.py b/tensorflow/python/keras/optimizers.py
index febbda4df6..f58aeaea1a 100644
--- a/tensorflow/python/keras/optimizers.py
+++ b/tensorflow/python/keras/optimizers.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training import training_util
+from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util.tf_export import tf_export
@@ -718,7 +719,7 @@ class Nadam(Optimizer):
return dict(list(base_config.items()) + list(config.items()))
-class TFOptimizer(Optimizer):
+class TFOptimizer(Optimizer, checkpointable.Checkpointable):
"""Wrapper class for native TensorFlow optimizers.
"""
diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py
index bd61f8e9cc..88daff0461 100644
--- a/tensorflow/python/keras/utils/layer_utils.py
+++ b/tensorflow/python/keras/utils/layer_utils.py
@@ -201,6 +201,61 @@ def print_summary(model, line_length=None, positions=None, print_fn=None):
print_fn('_' * line_length)
+def gather_trainable_weights(trainable, sub_layers, extra_variables):
+ """Lists the trainable weights for an object with sub-layers.
+
+ Args:
+ trainable: Whether the object collecting the variables is trainable.
+ sub_layers: A flat list of Layer objects owned by this object, to collect
+ variables from.
+ extra_variables: Any extra variables to include. Their `.trainable` property
+ is used to categorize them.
+
+ Returns:
+ A list of collected trainable weights/variables.
+ """
+ if not trainable:
+ return []
+ weights = []
+ for layer in sub_layers:
+ weights += layer.trainable_weights
+ trainable_extra_variables = [
+ v for v in extra_variables if v.trainable]
+ return weights + trainable_extra_variables
+
+
+def gather_non_trainable_weights(trainable, sub_layers, extra_variables):
+ """Lists the non-trainable weights for an object with sub-layers.
+
+ Args:
+ trainable: Whether the object collecting the variables is trainable.
+ sub_layers: A flat list of Layer objects owned by this object, to collect
+ variables from.
+ extra_variables: Any extra variables to include. Their `.trainable` property
+ is used to categorize them.
+
+ Returns:
+ A list of collected non-trainable weights/variables.
+ """
+ trainable_extra_variables = []
+ non_trainable_extra_variables = []
+ for v in extra_variables:
+ if v.trainable:
+ trainable_extra_variables.append(v)
+ else:
+ non_trainable_extra_variables.append(v)
+ weights = []
+ for layer in sub_layers:
+ weights += layer.non_trainable_weights
+ if not trainable:
+ trainable_weights = []
+ for layer in sub_layers:
+ trainable_weights += layer.trainable_weights
+ return (trainable_weights + trainable_extra_variables
+ + weights + non_trainable_extra_variables)
+ return weights + non_trainable_extra_variables
+
+
@tf_export('keras.utils.convert_all_kernels_in_model')
def convert_all_kernels_in_model(model):
"""Converts all convolution kernels in a model from Theano to TensorFlow.
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 3dfad9c130..5d29c2e5f8 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -9,6 +9,7 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "sycl_py_test")
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
# CPU only tests should use tf_py_test, GPU tests use cuda_py_test
# Please avoid the py_tests and cuda_py_tests (plural) while we
@@ -3029,3 +3030,60 @@ tf_py_test(
"//tensorflow/python/eager:tape",
],
)
+
+# Custom op tests
+tf_custom_op_library(
+ name = "ackermann_op.so",
+ srcs = ["ackermann_op.cc"],
+)
+
+tf_py_test(
+ name = "ackermann_test",
+ size = "small",
+ srcs = ["ackermann_test.py"],
+ additional_deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:platform",
+ ],
+ data = [":ackermann_op.so"],
+ tags = ["no_pip"],
+)
+
+tf_custom_op_library(
+ name = "duplicate_op.so",
+ srcs = ["duplicate_op.cc"],
+)
+
+tf_py_test(
+ name = "duplicate_op_test",
+ size = "small",
+ srcs = ["duplicate_op_test.py"],
+ additional_deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ ],
+ data = [":duplicate_op.so"],
+ tags = ["no_pip"],
+)
+
+tf_custom_op_library(
+ name = "invalid_op.so",
+ srcs = ["invalid_op.cc"],
+)
+
+tf_py_test(
+ name = "invalid_op_test",
+ size = "small",
+ srcs = ["invalid_op_test.py"],
+ additional_deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:platform",
+ ],
+ data = [":invalid_op.so"],
+ tags = ["no_pip"],
+)
diff --git a/tensorflow/user_ops/ackermann_op.cc b/tensorflow/python/kernel_tests/ackermann_op.cc
index d42ca6f662..d42ca6f662 100644
--- a/tensorflow/user_ops/ackermann_op.cc
+++ b/tensorflow/python/kernel_tests/ackermann_op.cc
diff --git a/tensorflow/user_ops/ackermann_test.py b/tensorflow/python/kernel_tests/ackermann_test.py
index 257de49808..5e0d87c783 100644
--- a/tensorflow/user_ops/ackermann_test.py
+++ b/tensorflow/python/kernel_tests/ackermann_test.py
@@ -17,17 +17,19 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import os.path
+import os
-import tensorflow as tf
+from tensorflow.python.framework import load_library
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.platform import test
-class AckermannTest(tf.test.TestCase):
+class AckermannTest(test.TestCase):
def testBasic(self):
- library_filename = os.path.join(tf.resource_loader.get_data_files_path(),
+ library_filename = os.path.join(resource_loader.get_data_files_path(),
'ackermann_op.so')
- ackermann = tf.load_op_library(library_filename)
+ ackermann = load_library.load_op_library(library_filename)
self.assertEqual(len(ackermann.OP_LIST.op), 1)
self.assertEqual(ackermann.OP_LIST.op[0].name, 'Ackermann')
@@ -37,4 +39,4 @@ class AckermannTest(tf.test.TestCase):
if __name__ == '__main__':
- tf.test.main()
+ test.main()
diff --git a/tensorflow/user_ops/duplicate_op.cc b/tensorflow/python/kernel_tests/duplicate_op.cc
index 9f622e4db5..9f622e4db5 100644
--- a/tensorflow/user_ops/duplicate_op.cc
+++ b/tensorflow/python/kernel_tests/duplicate_op.cc
diff --git a/tensorflow/user_ops/duplicate_op_test.py b/tensorflow/python/kernel_tests/duplicate_op_test.py
index b61e68d75e..529d3dd0b3 100644
--- a/tensorflow/user_ops/duplicate_op_test.py
+++ b/tensorflow/python/kernel_tests/duplicate_op_test.py
@@ -17,23 +17,26 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import os.path
+import os
-import tensorflow as tf
+from tensorflow.python.framework import load_library
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.platform import test
-class DuplicateOpTest(tf.test.TestCase):
+class DuplicateOpTest(test.TestCase):
def testBasic(self):
- library_filename = os.path.join(tf.resource_loader.get_data_files_path(),
+ library_filename = os.path.join(resource_loader.get_data_files_path(),
'duplicate_op.so')
- duplicate = tf.load_op_library(library_filename)
+ duplicate = load_library.load_op_library(library_filename)
self.assertEqual(len(duplicate.OP_LIST.op), 0)
with self.test_session():
- self.assertEqual(tf.add(1, 41).eval(), 42)
+ self.assertEqual(math_ops.add(1, 41).eval(), 42)
if __name__ == '__main__':
- tf.test.main()
+ test.main()
diff --git a/tensorflow/python/kernel_tests/inplace_ops_test.py b/tensorflow/python/kernel_tests/inplace_ops_test.py
index 0f95e13187..6e894365af 100644
--- a/tensorflow/python/kernel_tests/inplace_ops_test.py
+++ b/tensorflow/python/kernel_tests/inplace_ops_test.py
@@ -166,7 +166,8 @@ class InplaceOpsTest(test_util.TensorFlowTestCase):
def testEmpty(self):
for dtype in [
- dtypes.float32, dtypes.float64, dtypes.int32, dtypes.int64, dtypes.bool
+ dtypes.float32, dtypes.float64, dtypes.int32, dtypes.int64, dtypes.bool,
+ dtypes.uint8
]:
with self.test_session(use_gpu=True):
test_shapes = [(), (1,), (2, 3), (0, 2), (2, 3, 5), (2, 0, 5)]
@@ -187,11 +188,12 @@ class InplaceOpsTest(test_util.TensorFlowTestCase):
self.assertEqual(val.dtype, dtype.as_numpy_dtype)
self.assertAllEqual(val, np.zeros(shape, dtype.as_numpy_dtype))
- val = inplace_ops.empty((1, 2), dtypes.string, init=True).eval()
- self.assertEqual(val.tolist(), [[b"", b""]])
+ with self.test_session(use_gpu=True):
+ val = inplace_ops.empty((1, 2), dtypes.string, init=True).eval()
+ self.assertEqual(val.tolist(), [[b"", b""]])
- val = inplace_ops.empty((1, 2), dtypes.string, init=False).eval()
- self.assertEqual(val.tolist(), [[b"", b""]])
+ val = inplace_ops.empty((1, 2), dtypes.string, init=False).eval()
+ self.assertEqual(val.tolist(), [[b"", b""]])
if __name__ == "__main__":
diff --git a/tensorflow/user_ops/invalid_op.cc b/tensorflow/python/kernel_tests/invalid_op.cc
index 51431660f2..51431660f2 100644
--- a/tensorflow/user_ops/invalid_op.cc
+++ b/tensorflow/python/kernel_tests/invalid_op.cc
diff --git a/tensorflow/user_ops/invalid_op_test.py b/tensorflow/python/kernel_tests/invalid_op_test.py
index c90a00ce58..238299a895 100644
--- a/tensorflow/user_ops/invalid_op_test.py
+++ b/tensorflow/python/kernel_tests/invalid_op_test.py
@@ -17,19 +17,22 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import os.path
+import os
-import tensorflow as tf
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import load_library
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.platform import test
-class InvalidOpTest(tf.test.TestCase):
+class InvalidOpTest(test.TestCase):
def testBasic(self):
- library_filename = os.path.join(tf.resource_loader.get_data_files_path(),
+ library_filename = os.path.join(resource_loader.get_data_files_path(),
'invalid_op.so')
- with self.assertRaises(tf.errors.InvalidArgumentError):
- tf.load_op_library(library_filename)
+ with self.assertRaises(errors.InvalidArgumentError):
+ load_library.load_op_library(library_filename)
if __name__ == '__main__':
- tf.test.main()
+ test.main()
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index 846231fe81..00d517e64e 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -119,6 +119,13 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
dtype=dtypes.int32, shape=[1], name="foo")
self.assertGreater(len(handle.eval()), 0)
+ def testCachedValueReadBeforeWrite(self):
+ with self.test_session() as sess:
+ v = resource_variable_ops.ResourceVariable(0.0, caching_device="cpu:0")
+ sess.run(v.initializer)
+ value, _ = sess.run([v, v.assign_add(1.0)])
+ self.assertAllEqual(value, 0.0)
+
def testAssignVariableDtypeMismatchEager(self):
with context.eager_mode():
handle = resource_variable_ops.var_handle_op(
@@ -531,6 +538,25 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError):
sess.run(v.initialized_value())
+ def testTrainableInProto(self):
+ with ops.Graph().as_default():
+ non_trainable_variable = resource_variable_ops.ResourceVariable(
+ trainable=False,
+ initial_value=constant_op.constant(10.0))
+ self.assertEqual(
+ False,
+ resource_variable_ops.ResourceVariable(
+ variable_def=non_trainable_variable.to_proto())
+ .trainable)
+ trainable_variable = resource_variable_ops.ResourceVariable(
+ trainable=True,
+ initial_value=constant_op.constant(10.0))
+ self.assertEqual(
+ True,
+ resource_variable_ops.ResourceVariable(
+ variable_def=trainable_variable.to_proto())
+ .trainable)
+
@test_util.run_in_graph_and_eager_modes()
def testSparseRead(self):
with self.test_session():
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 27599868b7..62d596da91 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -496,6 +496,23 @@ class VariablesTestCase(test.TestCase):
with self.assertRaises(ValueError):
sess.run(v.initialized_value())
+ def testTrainableInProto(self):
+ with ops.Graph().as_default():
+ non_trainable_variable = variables.Variable(
+ trainable=False,
+ initial_value=constant_op.constant(10.0))
+ self.assertEqual(
+ False,
+ variables.Variable(variable_def=non_trainable_variable.to_proto())
+ .trainable)
+ trainable_variable = variables.Variable(
+ trainable=True,
+ initial_value=constant_op.constant(10.0))
+ self.assertEqual(
+ True,
+ variables.Variable(variable_def=trainable_variable.to_proto())
+ .trainable)
+
def testLoad(self):
with self.test_session():
var = variables.Variable(np.zeros((5, 5), np.float32))
diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py
index 394ad0b1a2..30413f289a 100644
--- a/tensorflow/python/ops/functional_ops.py
+++ b/tensorflow/python/ops/functional_ops.py
@@ -455,7 +455,8 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
lambda i, _: i < n, compute, (i, accs_ta),
parallel_iterations=parallel_iterations,
back_prop=back_prop,
- swap_memory=swap_memory)
+ swap_memory=swap_memory,
+ maximum_iterations=n)
results_flat = [r.stack() for r in r_a]
n_static = elems_flat[0].get_shape().with_rank_at_least(1)[0]
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index e5b80200c0..7061b32808 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -551,6 +551,7 @@ class ResourceVariable(variables.Variable):
import_scope=import_scope))
else:
self._initial_value = None
+ self._trainable = getattr(variable_def, "trainable", True)
if variable_def.snapshot_name:
snapshot = g.as_graph_element(
ops.prepend_name_scope(
@@ -576,6 +577,21 @@ class ResourceVariable(variables.Variable):
self._constraint = None
self._cached_shape_as_list = None
+ @contextlib.contextmanager
+ def _assign_dependencies(self):
+ """Makes assignments depend on the cached value, if any.
+
+ This prevents undefined behavior with reads not ordered wrt writes.
+
+ Yields:
+ None.
+ """
+ if self._cached_value is not None:
+ with ops.control_dependencies([self._cached_value]):
+ yield
+ else:
+ yield
+
def __nonzero__(self):
return self.__bool__()
@@ -720,7 +736,7 @@ class ResourceVariable(variables.Variable):
return self._save_slice_info
def _read_variable_op(self):
- if hasattr(self, "_trainable") and self._trainable:
+ if self.trainable:
tape.watch_variable(self)
return gen_resource_variable_ops.read_variable_op(self._handle,
self._dtype)
@@ -745,7 +761,7 @@ class ResourceVariable(variables.Variable):
def sparse_read(self, indices, name=None):
"""Reads the value of this variable sparsely, using `gather`."""
with ops.name_scope("Gather" if name is None else name) as name:
- if self._trainable:
+ if self.trainable:
tape.watch_variable(self)
value = gen_resource_variable_ops.resource_gather(
self._handle, indices, dtype=self._dtype, name=name)
@@ -786,6 +802,7 @@ class ResourceVariable(variables.Variable):
var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name,
export_scope)
var_def.is_resource = True
+ var_def.trainable = self.trainable
if self._save_slice_info:
var_def.save_slice_info_def.MergeFrom(
self._save_slice_info.to_proto(export_scope=export_scope))
@@ -865,7 +882,7 @@ class ResourceVariable(variables.Variable):
# TODO(apassos): this here and below is not atomic. Consider making it
# atomic if there's a way to do so without a performance cost for those who
# don't need it.
- with _handle_graph(self.handle):
+ with _handle_graph(self.handle), self._assign_dependencies():
assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(
self.handle, ops.convert_to_tensor(delta, dtype=self.dtype),
name=name)
@@ -889,7 +906,7 @@ class ResourceVariable(variables.Variable):
it will return the `Operation` that does the assignment, and when in eager
mode it will return `None`.
"""
- with _handle_graph(self.handle):
+ with _handle_graph(self.handle), self._assign_dependencies():
assign_add_op = gen_resource_variable_ops.assign_add_variable_op(
self.handle, ops.convert_to_tensor(delta, dtype=self.dtype),
name=name)
@@ -898,7 +915,7 @@ class ResourceVariable(variables.Variable):
return assign_add_op
def _lazy_read(self, op):
- if hasattr(self, "_trainable") and self._trainable:
+ if self.trainable:
tape.watch_variable(self)
return _UnreadVariable(
self._handle, self.dtype, self._shape, self._in_graph_mode,
@@ -921,6 +938,8 @@ class ResourceVariable(variables.Variable):
it will return the `Operation` that does the assignment, and when in eager
mode it will return `None`.
"""
+ # Note: not depending on the cached value here since this can used to
+ # initialize the variable.
with _handle_graph(self.handle):
value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
self._shape.assert_is_compatible_with(value_tensor.shape)
@@ -933,7 +952,7 @@ class ResourceVariable(variables.Variable):
def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask,
end_mask, ellipsis_mask, new_axis_mask,
shrink_axis_mask):
- with _handle_graph(self.handle):
+ with _handle_graph(self.handle), self._assign_dependencies():
return self._lazy_read(
gen_array_ops.resource_strided_slice_assign(
ref=self.handle,
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 8d93d24b14..fa34774622 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -1261,13 +1261,13 @@ class EagerVariableStore(object):
def trainable_variables(self):
# pylint: disable=protected-access
- return sorted([x for x in self._store._vars.values() if x._trainable],
+ return sorted([x for x in self._store._vars.values() if x.trainable],
key=lambda x: x.name)
# pylint: enable=protected-access
def non_trainable_variables(self):
# pylint: disable=protected-access
- return sorted([x for x in self._store._vars.values() if not x._trainable],
+ return sorted([x for x in self._store._vars.values() if not x.trainable],
key=lambda x: x.name)
# pylint: enable=protected-access
@@ -1296,7 +1296,7 @@ class EagerVariableStore(object):
new_var = resource_variable_ops.ResourceVariable(
var.read_value(),
name=stripped_var_name,
- trainable=var._trainable)
+ trainable=var.trainable)
new_store._store._vars[key] = new_var
return new_store
# pylint: enable=protected-access
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index d88fd836f5..4be9f5eb68 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -341,6 +341,7 @@ class Variable(checkpointable.CheckpointableBase):
self._update_uid = initial_value.checkpoint_position.restore_uid
initial_value = initial_value.wrapped_value
+ self._trainable = trainable
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
with ops.init_scope():
@@ -450,6 +451,7 @@ class Variable(checkpointable.CheckpointableBase):
import_scope=import_scope))
else:
self._initial_value = None
+ self._trainable = getattr(variable_def, "trainable", True)
self._snapshot = g.as_graph_element(
ops.prepend_name_scope(variable_def.snapshot_name,
import_scope=import_scope))
@@ -543,6 +545,10 @@ class Variable(checkpointable.CheckpointableBase):
self._ref().set_shape(shape)
self.value().set_shape(shape)
+ @property
+ def trainable(self):
+ return self._trainable
+
def eval(self, session=None):
"""In a session, computes and returns the value of this variable.
@@ -1050,6 +1056,7 @@ class Variable(checkpointable.CheckpointableBase):
# For backwards compatibility.
var_def.initial_value_name = ops.strip_name_scope(
self._initial_value.name, export_scope)
+ var_def.trainable = self.trainable
var_def.initializer_name = ops.strip_name_scope(
self.initializer.name, export_scope)
var_def.snapshot_name = ops.strip_name_scope(
diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py
index 62cefa4f20..69ed253fb2 100644
--- a/tensorflow/python/training/checkpointable/data_structures.py
+++ b/tensorflow/python/training/checkpointable/data_structures.py
@@ -22,6 +22,8 @@ import collections
import six
from tensorflow.python.keras.engine import base_layer
+from tensorflow.python.keras.utils import layer_utils
+from tensorflow.python.ops import variables
from tensorflow.python.training.checkpointable import base as checkpointable_lib
from tensorflow.python.training.checkpointable import data_structures_base
@@ -41,11 +43,14 @@ class CheckpointableDataStructure(
def __init__(self):
self._layers = []
self.trainable = True
+ self._extra_variables = []
def _track_value(self, value, name):
"""Add a dependency on `value`."""
if isinstance(value, checkpointable_lib.CheckpointableBase):
self._track_checkpointable(value, name=name)
+ if isinstance(value, variables.Variable):
+ self._extra_variables.append(value)
else:
raise ValueError(
("Only checkpointable objects (such as Layers or Optimizers) may be "
@@ -67,30 +72,31 @@ class CheckpointableDataStructure(
@property
def trainable_weights(self):
- if not self.trainable:
- return []
- weights = []
- for layer in self.layers:
- weights += layer.trainable_weights
- return weights
+ return layer_utils.gather_trainable_weights(
+ trainable=self.trainable,
+ sub_layers=self.layers,
+ extra_variables=self._extra_variables)
@property
def non_trainable_weights(self):
- weights = []
- for layer in self.layers:
- weights += layer.non_trainable_weights
- if not self.trainable:
- trainable_weights = []
- for layer in self.layers:
- trainable_weights += layer.trainable_weights
- return trainable_weights + weights
- return weights
+ return layer_utils.gather_non_trainable_weights(
+ trainable=self.trainable,
+ sub_layers=self.layers,
+ extra_variables=self._extra_variables)
@property
def weights(self):
return self.trainable_weights + self.non_trainable_weights
@property
+ def trainable_variables(self):
+ return self.trainable_weights
+
+ @property
+ def non_trainable_variables(self):
+ return self.non_trainable_weights
+
+ @property
def variables(self):
return self.weights
diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py
index 31a0e8b622..b05b3a8800 100644
--- a/tensorflow/python/training/checkpointable/data_structures_test.py
+++ b/tensorflow/python/training/checkpointable/data_structures_test.py
@@ -139,6 +139,25 @@ class ListTests(test.TestCase):
outer.variables[0],
resource_variable_ops.ResourceVariable)
+ def testNonLayerVariables(self):
+ v = resource_variable_ops.ResourceVariable([1.])
+ l = data_structures.List([v])
+ self.assertTrue(l.trainable)
+ self.assertEqual([], l.layers)
+ self.assertEqual([v], l.variables)
+ self.assertEqual([v], l.trainable_weights)
+ self.assertEqual([], l.non_trainable_variables)
+ l.trainable = False
+ self.assertEqual([v], l.variables)
+ self.assertEqual([], l.trainable_variables)
+ self.assertEqual([v], l.non_trainable_variables)
+ l.trainable = True
+ v2 = resource_variable_ops.ResourceVariable(1., trainable=False)
+ l.append(v2)
+ self.assertEqual([v, v2], l.weights)
+ self.assertEqual([v], l.trainable_weights)
+ self.assertEqual([v2], l.non_trainable_weights)
+
def testHashing(self):
has_sequences = set([data_structures.List(),
data_structures.List()])
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index 61fc828a84..60cc54c264 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -344,6 +344,11 @@ class ExponentialMovingAverage(object):
self._name = name
self._averages = {}
+ @property
+ def name(self):
+ """The name of this ExponentialMovingAverage object."""
+ return self._name
+
def apply(self, var_list=None):
"""Maintains moving averages of variables.
@@ -394,7 +399,7 @@ class ExponentialMovingAverage(object):
if isinstance(var, variables.Variable):
avg = slot_creator.create_slot(var,
var.initialized_value(),
- self._name,
+ self.name,
colocate_with_primary=True)
# NOTE(mrry): We only add `tf.Variable` objects to the
# `MOVING_AVERAGE_VARIABLES` collection.
@@ -402,7 +407,7 @@ class ExponentialMovingAverage(object):
else:
avg = slot_creator.create_zeros_slot(
var,
- self._name,
+ self.name,
colocate_with_primary=(var.op.type in ["Variable",
"VariableV2",
"VarHandleOp"]))
@@ -410,7 +415,7 @@ class ExponentialMovingAverage(object):
zero_debias_true.add(avg)
self._averages[var] = avg
- with ops.name_scope(self._name) as scope:
+ with ops.name_scope(self.name) as scope:
decay = ops.convert_to_tensor(self._decay, name="decay")
if self._num_updates is not None:
num_updates = math_ops.cast(self._num_updates,
@@ -462,7 +467,7 @@ class ExponentialMovingAverage(object):
if var in self._averages:
return self._averages[var].op.name
return ops.get_default_graph().unique_name(
- var.op.name + "/" + self._name, mark_as_used=False)
+ var.op.name + "/" + self.name, mark_as_used=False)
def variables_to_restore(self, moving_avg_variables=None):
"""Returns a map of names to `Variables` to restore.
diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py
index 6717811bbb..3e85e6bfa7 100644
--- a/tensorflow/python/training/moving_averages_test.py
+++ b/tensorflow/python/training/moving_averages_test.py
@@ -263,6 +263,7 @@ class ExponentialMovingAverageTest(test.TestCase):
tensor2 = v0 + v1
ema = moving_averages.ExponentialMovingAverage(
0.25, zero_debias=zero_debias, name="foo")
+ self.assertEqual("foo", ema.name)
self.assertEqual("v0/foo", ema.average_name(v0))
self.assertEqual("v1/foo", ema.average_name(v1))
self.assertEqual("add/foo", ema.average_name(tensor2))
diff --git a/tensorflow/python/util/stat_summarizer.i b/tensorflow/python/util/stat_summarizer.i
index f423553faa..73fa85494b 100644
--- a/tensorflow/python/util/stat_summarizer.i
+++ b/tensorflow/python/util/stat_summarizer.i
@@ -88,9 +88,4 @@ def NewStatSummarizer(unused):
def DeleteStatSummarizer(stat_summarizer):
_DeleteStatSummarizer(stat_summarizer)
-
-NewStatSummarizer._tf_api_names = ["contrib.stat_summarizer.NewStatSummarizer"]
-DeleteStatSummarizer._tf_api_names = [
- "contrib.stat_summarizer.DeleteStatSummarizer"]
-StatSummarizer._tf_api_names = ["contrib.stat_summarizer.StatSummarizer"]
%}
diff --git a/tensorflow/security/advisory/tfsa-2018-001.md b/tensorflow/security/advisory/tfsa-2018-001.md
index e62757fb5f..bb97543a21 100644
--- a/tensorflow/security/advisory/tfsa-2018-001.md
+++ b/tensorflow/security/advisory/tfsa-2018-001.md
@@ -21,8 +21,8 @@ TensorFlow 1.3.0, 1.3.1, 1.4.0, 1.4.1, 1.5.0, 1.5.1, 1.6.0
### Mitigation
-We have patched the vulnerability in GitHub commits
-[https://github.com/tensorflow/tensorflow/commit/49f73c55d56edffebde4bca4a407ad69c1cae4333c55](49f73c55).
+We have patched the vulnerability in GitHub commit
+[49f73c55](https://github.com/tensorflow/tensorflow/commit/49f73c55d56edffebde4bca4a407ad69c1cae4333c55).
If users are running TensorFlow in production or on untrusted data, they are
encouraged to apply this patch.
diff --git a/tensorflow/security/advisory/tfsa-2018-002.md b/tensorflow/security/advisory/tfsa-2018-002.md
index baf3fb418e..fad7fdd40f 100644
--- a/tensorflow/security/advisory/tfsa-2018-002.md
+++ b/tensorflow/security/advisory/tfsa-2018-002.md
@@ -21,7 +21,7 @@ TensorFlow 1.0.0, 1.0.1, 1.1.0, 1.2.0, 1.2.1, 1.3.0, 1.3.1, 1 1.4.1, 1.5.0, 1.5.
### Mitigation
We have patched the vulnerability in GitHub commit
-[https://github.com/tensorflow/tensorflow/commit/c48431588e7cf8aff61d4c299231e3e925144df8](c4843158).
+[c4843158](https://github.com/tensorflow/tensorflow/commit/c48431588e7cf8aff61d4c299231e3e925144df8).
If users are running TensorFlow in production or on untrusted data, they are
encouraged to apply this patch.
diff --git a/tensorflow/security/advisory/tfsa-2018-003.md b/tensorflow/security/advisory/tfsa-2018-003.md
index e20e358f29..747d37064c 100644
--- a/tensorflow/security/advisory/tfsa-2018-003.md
+++ b/tensorflow/security/advisory/tfsa-2018-003.md
@@ -35,8 +35,8 @@ TensorFlow 1.5.0, 1.5.1, 1.6.0, 1.7.0
### Mitigation
-We have patched the vulnerability in GitHub commits [https://github.com/tensorflow/tensorflow/commit/41335abb46f80ca644b5738550daef6136ba5476](41335abb) and
-[https://github.com/tensorflow/tensorflow/commit/41335abb46f80ca644b5738550daef6136ba5476](41335abb) and
+We have patched the vulnerability in GitHub commits [41335abb](https://github.com/tensorflow/tensorflow/commit/41335abb46f80ca644b5738550daef6136ba5476) and
+[8badd11d](https://github.com/tensorflow/tensorflow/commit/8badd11d875a826bd318ed439909d5c47a7fb811).
If users are running the TensorFlow TFLite TOCO compiler in production or on
untrusted data, they are encouraged to apply this patch.
diff --git a/tensorflow/security/advisory/tfsa-2018-004.md b/tensorflow/security/advisory/tfsa-2018-004.md
index d172247288..3af28defa1 100644
--- a/tensorflow/security/advisory/tfsa-2018-004.md
+++ b/tensorflow/security/advisory/tfsa-2018-004.md
@@ -22,7 +22,7 @@ TensorFlow 1.0.0, 1.0.1, 1.1.0, 1.2.0, 1.2.1, 1.3.0, 1.3.1, 1.4.0, 1.4.1, 1.5.0,
### Mitigation
We have patched the vulnerability in GitHub commit
-[https://github.com/tensorflow/tensorflow/commit/d107fee1e4a9a4462f01564798d345802acc2aef](d107fee1).
+[d107fee1](https://github.com/tensorflow/tensorflow/commit/d107fee1e4a9a4462f01564798d345802acc2aef).
If users are running TensorFlow on untrusted meta checkpoints, such as those
downloaded from the Internet, in production or on untrusted data, they are
encouraged to apply this patch.
diff --git a/tensorflow/security/advisory/tfsa-2018-005.md b/tensorflow/security/advisory/tfsa-2018-005.md
index 1c91567db5..c0f339fd97 100644
--- a/tensorflow/security/advisory/tfsa-2018-005.md
+++ b/tensorflow/security/advisory/tfsa-2018-005.md
@@ -22,7 +22,7 @@ TensorFlow 1.1.0, 1.2.0, 1.2.1, 1.3.0, 1.3.1, 1.4.0, 1.4.1, 1.5.0, 1.5.1, 1.6.0,
### Mitigation
We have patched the vulnerability in GitHub commit
-[https://github.com/tensorflow/tensorflow/commit/dfa9921e6343727b05f42f8d4a918b19528ff994](dfa9921e)
+[dfa9921e](https://github.com/tensorflow/tensorflow/commit/dfa9921e6343727b05f42f8d4a918b19528ff994)
by upgrading the version of the snappy library used by TensorFlow to v1.1.7.
If users are loading untrusted checkpoints in TensorFlow, we encourage users to
diff --git a/tensorflow/security/advisory/tfsa-2018-006.md b/tensorflow/security/advisory/tfsa-2018-006.md
index a1d1a9f3d1..17f514d8d2 100644
--- a/tensorflow/security/advisory/tfsa-2018-006.md
+++ b/tensorflow/security/advisory/tfsa-2018-006.md
@@ -21,7 +21,7 @@ TensorFlow 1.1.0, 1.2.0, 1.2.1, 1.3.0, 1.3.1, 1.4.0, 1.4.1, 1.5.0, 1.5.1, 1.6.0,
### Mitigation
We have patched the vulnerability in GitHub commit
-[https://github.com/tensorflow/tensorflow/commit/c89ab82a82585cdaa90bf4911980e9e845909e78](c89ab82a).
+[c89ab82a](https://github.com/tensorflow/tensorflow/commit/c89ab82a82585cdaa90bf4911980e9e845909e78).
If users are loading untrusted configurations in TensorFlow, we encourage users
to apply the patch to upgrade snappy or upgrade the version of TensorFlow they
diff --git a/tensorflow/security/index.md b/tensorflow/security/index.md
index c1f9f1da74..44f51ad07b 100644
--- a/tensorflow/security/index.md
+++ b/tensorflow/security/index.md
@@ -8,11 +8,11 @@ in [https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md](SECURITY.m
| Advisory Number | Type | Versions affected | Reported by | Additional Information |
|-----------------|--------------------|:-----------------:|-----------------------|-----------------------------|
-| TFSA-2018-006 | Crafted Configuration File results in Invalid Memory Access | <= 1.7 | Blade Team of Tencent | |
-| TFSA-2018-005 | Old Snappy Library Usage Resulting in Memcpy Parameter Overlap | <= 1.7 | Blade Team of Tencent | |
-| TFSA-2018-004 | Checkpoint Meta File Out-of-Bounds Read | <= 1.7 | Blade Team of Tencent | |
-| TFSA-2018-003 | TensorFlow Lite TOCO FlatBuffer Parsing Vulnerability | <= 1.7 | Blade Team of Tencent | |
-| TFSA-2018-002 | GIF File Parsing Null Pointer Dereference Error | <= 1.5 | Blade Team of Tencent | |
-| TFSA-2018-001 | BMP File Parser Out-of-bounds Read | <= 1.6 | Blade Team of Tencent | |
+| [TFSA-2018-006](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2018-006.md) | Crafted Configuration File results in Invalid Memory Access | <= 1.7 | Blade Team of Tencent | |
+| [TFSA-2018-005](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2018-005.md) | Old Snappy Library Usage Resulting in Memcpy Parameter Overlap | <= 1.7 | Blade Team of Tencent | |
+| [TFSA-2018-004](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2018-004.md) | Checkpoint Meta File Out-of-Bounds Read | <= 1.7 | Blade Team of Tencent | |
+| [TFSA-2018-003](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2018-003.md) | TensorFlow Lite TOCO FlatBuffer Parsing Vulnerability | <= 1.7 | Blade Team of Tencent | |
+| [TFSA-2018-002](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2018-002.md) | GIF File Parsing Null Pointer Dereference Error | <= 1.5 | Blade Team of Tencent | |
+| [TFSA-2018-001](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2018-001.md) | BMP File Parser Out-of-bounds Read | <= 1.6 | Blade Team of Tencent | |
| - | Out Of Bounds Read | <=1.4 | Blade Team of Tencent | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) |
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc
index eed93efc8d..5315d1f3da 100644
--- a/tensorflow/stream_executor/dnn.cc
+++ b/tensorflow/stream_executor/dnn.cc
@@ -407,6 +407,8 @@ string FilterDescriptor::ToShortString() const {
switch (layout_) {
case FilterLayout::kOutputInputYX:
return port::StrCat(od, id, spatial);
+ case FilterLayout::kOutputYXInput:
+ return port::StrCat(od, spatial, id);
case FilterLayout::kOutputInputYX4:
return port::StrCat(od, id, spatial, "(VECT_C)");
case FilterLayout::kInputYXOutput:
diff --git a/tensorflow/tools/api/generator/BUILD b/tensorflow/tools/api/generator/BUILD
index f46bb4b5fc..f0c5877a90 100644
--- a/tensorflow/tools/api/generator/BUILD
+++ b/tensorflow/tools/api/generator/BUILD
@@ -9,8 +9,9 @@ py_binary(
name = "create_python_api",
srcs = ["create_python_api.py"],
srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
deps = [
- "//tensorflow/python",
+ "//tensorflow/python:no_contrib",
],
)
@@ -23,116 +24,3 @@ py_test(
"//tensorflow/python:client_testlib",
],
)
-
-genrule(
- name = "python_api_gen",
- # List of API files. This list should include file name for
- # every module exported using tf_export. For e.g. if an op is decorated with
- # @tf_export('module1.module2', 'module3'). Then, outs should include
- # api/module1/module2/__init__.py and api/module3/__init__.py.
- # keep sorted
- outs = [
- # BEGIN GENERATED FILES
- "api/__init__.py",
- "api/app/__init__.py",
- "api/bitwise/__init__.py",
- "api/compat/__init__.py",
- "api/contrib/__init__.py",
- "api/contrib/stat_summarizer/__init__.py",
- "api/data/__init__.py",
- "api/distributions/__init__.py",
- "api/distributions/bijectors/__init__.py",
- "api/errors/__init__.py",
- "api/estimator/__init__.py",
- "api/estimator/export/__init__.py",
- "api/estimator/inputs/__init__.py",
- "api/feature_column/__init__.py",
- "api/gfile/__init__.py",
- "api/graph_util/__init__.py",
- "api/image/__init__.py",
- "api/initializers/__init__.py",
- "api/keras/__init__.py",
- "api/keras/activations/__init__.py",
- "api/keras/applications/__init__.py",
- "api/keras/applications/densenet/__init__.py",
- "api/keras/applications/inception_resnet_v2/__init__.py",
- "api/keras/applications/inception_v3/__init__.py",
- "api/keras/applications/mobilenet/__init__.py",
- "api/keras/applications/nasnet/__init__.py",
- "api/keras/applications/resnet50/__init__.py",
- "api/keras/applications/vgg16/__init__.py",
- "api/keras/applications/vgg19/__init__.py",
- "api/keras/applications/xception/__init__.py",
- "api/keras/backend/__init__.py",
- "api/keras/callbacks/__init__.py",
- "api/keras/constraints/__init__.py",
- "api/keras/datasets/__init__.py",
- "api/keras/datasets/boston_housing/__init__.py",
- "api/keras/datasets/cifar10/__init__.py",
- "api/keras/datasets/cifar100/__init__.py",
- "api/keras/datasets/fashion_mnist/__init__.py",
- "api/keras/datasets/imdb/__init__.py",
- "api/keras/datasets/mnist/__init__.py",
- "api/keras/datasets/reuters/__init__.py",
- "api/keras/estimator/__init__.py",
- "api/keras/initializers/__init__.py",
- "api/keras/layers/__init__.py",
- "api/keras/losses/__init__.py",
- "api/keras/metrics/__init__.py",
- "api/keras/models/__init__.py",
- "api/keras/optimizers/__init__.py",
- "api/keras/preprocessing/__init__.py",
- "api/keras/preprocessing/image/__init__.py",
- "api/keras/preprocessing/sequence/__init__.py",
- "api/keras/preprocessing/text/__init__.py",
- "api/keras/regularizers/__init__.py",
- "api/keras/utils/__init__.py",
- "api/keras/wrappers/__init__.py",
- "api/keras/wrappers/scikit_learn/__init__.py",
- "api/layers/__init__.py",
- "api/linalg/__init__.py",
- "api/logging/__init__.py",
- "api/losses/__init__.py",
- "api/manip/__init__.py",
- "api/math/__init__.py",
- "api/metrics/__init__.py",
- "api/nn/__init__.py",
- "api/nn/rnn_cell/__init__.py",
- "api/profiler/__init__.py",
- "api/python_io/__init__.py",
- "api/resource_loader/__init__.py",
- "api/strings/__init__.py",
- "api/saved_model/__init__.py",
- "api/saved_model/builder/__init__.py",
- "api/saved_model/constants/__init__.py",
- "api/saved_model/loader/__init__.py",
- "api/saved_model/main_op/__init__.py",
- "api/saved_model/signature_constants/__init__.py",
- "api/saved_model/signature_def_utils/__init__.py",
- "api/saved_model/tag_constants/__init__.py",
- "api/saved_model/utils/__init__.py",
- "api/sets/__init__.py",
- "api/sparse/__init__.py",
- "api/spectral/__init__.py",
- "api/summary/__init__.py",
- "api/sysconfig/__init__.py",
- "api/test/__init__.py",
- "api/train/__init__.py",
- "api/train/queue_runner/__init__.py",
- "api/user_ops/__init__.py",
- # END GENERATED FILES
- ],
- cmd = "$(location create_python_api) $(OUTS)",
- tools = ["create_python_api"],
-)
-
-py_library(
- name = "python_api",
- srcs = [":python_api_gen"],
- srcs_version = "PY2AND3",
- visibility = ["//tensorflow:__subpackages__"],
- deps = [
- "//tensorflow/contrib:contrib_py", # keep
- "//tensorflow/python", # keep
- ],
-)
diff --git a/tensorflow/tools/api/generator/api_gen.bzl b/tensorflow/tools/api/generator/api_gen.bzl
new file mode 100644
index 0000000000..fe3e4d1434
--- /dev/null
+++ b/tensorflow/tools/api/generator/api_gen.bzl
@@ -0,0 +1,125 @@
+"""Targets for generating TensorFlow Python API __init__.py files."""
+
+# keep sorted
+TENSORFLOW_API_INIT_FILES = [
+ # BEGIN GENERATED FILES
+ "__init__.py",
+ "app/__init__.py",
+ "bitwise/__init__.py",
+ "compat/__init__.py",
+ "data/__init__.py",
+ "distributions/__init__.py",
+ "distributions/bijectors/__init__.py",
+ "errors/__init__.py",
+ "estimator/__init__.py",
+ "estimator/export/__init__.py",
+ "estimator/inputs/__init__.py",
+ "feature_column/__init__.py",
+ "gfile/__init__.py",
+ "graph_util/__init__.py",
+ "image/__init__.py",
+ "initializers/__init__.py",
+ "keras/__init__.py",
+ "keras/activations/__init__.py",
+ "keras/applications/__init__.py",
+ "keras/applications/densenet/__init__.py",
+ "keras/applications/inception_resnet_v2/__init__.py",
+ "keras/applications/inception_v3/__init__.py",
+ "keras/applications/mobilenet/__init__.py",
+ "keras/applications/nasnet/__init__.py",
+ "keras/applications/resnet50/__init__.py",
+ "keras/applications/vgg16/__init__.py",
+ "keras/applications/vgg19/__init__.py",
+ "keras/applications/xception/__init__.py",
+ "keras/backend/__init__.py",
+ "keras/callbacks/__init__.py",
+ "keras/constraints/__init__.py",
+ "keras/datasets/__init__.py",
+ "keras/datasets/boston_housing/__init__.py",
+ "keras/datasets/cifar10/__init__.py",
+ "keras/datasets/cifar100/__init__.py",
+ "keras/datasets/fashion_mnist/__init__.py",
+ "keras/datasets/imdb/__init__.py",
+ "keras/datasets/mnist/__init__.py",
+ "keras/datasets/reuters/__init__.py",
+ "keras/estimator/__init__.py",
+ "keras/initializers/__init__.py",
+ "keras/layers/__init__.py",
+ "keras/losses/__init__.py",
+ "keras/metrics/__init__.py",
+ "keras/models/__init__.py",
+ "keras/optimizers/__init__.py",
+ "keras/preprocessing/__init__.py",
+ "keras/preprocessing/image/__init__.py",
+ "keras/preprocessing/sequence/__init__.py",
+ "keras/preprocessing/text/__init__.py",
+ "keras/regularizers/__init__.py",
+ "keras/utils/__init__.py",
+ "keras/wrappers/__init__.py",
+ "keras/wrappers/scikit_learn/__init__.py",
+ "layers/__init__.py",
+ "linalg/__init__.py",
+ "logging/__init__.py",
+ "losses/__init__.py",
+ "manip/__init__.py",
+ "math/__init__.py",
+ "metrics/__init__.py",
+ "nn/__init__.py",
+ "nn/rnn_cell/__init__.py",
+ "profiler/__init__.py",
+ "python_io/__init__.py",
+ "resource_loader/__init__.py",
+ "strings/__init__.py",
+ "saved_model/__init__.py",
+ "saved_model/builder/__init__.py",
+ "saved_model/constants/__init__.py",
+ "saved_model/loader/__init__.py",
+ "saved_model/main_op/__init__.py",
+ "saved_model/signature_constants/__init__.py",
+ "saved_model/signature_def_utils/__init__.py",
+ "saved_model/tag_constants/__init__.py",
+ "saved_model/utils/__init__.py",
+ "sets/__init__.py",
+ "sparse/__init__.py",
+ "spectral/__init__.py",
+ "summary/__init__.py",
+ "sysconfig/__init__.py",
+ "test/__init__.py",
+ "train/__init__.py",
+ "train/queue_runner/__init__.py",
+ "user_ops/__init__.py",
+ # END GENERATED FILES
+]
+
+# Creates a genrule that generates a directory structure with __init__.py
+# files that import all exported modules (i.e. modules with tf_export
+# decorators).
+#
+# Args:
+# name: name of genrule to create.
+# output_files: List of __init__.py files that should be generated.
+# This list should include file name for every module exported using
+# tf_export. For e.g. if an op is decorated with
+# @tf_export('module1.module2', 'module3'). Then, output_files should
+# include module1/module2/__init__.py and module3/__init__.py.
+# root_init_template: Python init file that should be used as template for
+# root __init__.py file. "# API IMPORTS PLACEHOLDER" comment inside this
+# template will be replaced with root imports collected by this genrule.
+# srcs: genrule sources. If passing root_init_template, the template file
+# must be included in sources.
+def gen_api_init_files(name,
+ output_files=TENSORFLOW_API_INIT_FILES,
+ root_init_template=None,
+ srcs=[]):
+ root_init_template_flag = ""
+ if root_init_template:
+ root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
+ native.genrule(
+ name = name,
+ outs = output_files,
+ cmd = (
+ "$(location //tensorflow/tools/api/generator:create_python_api) " +
+ root_init_template_flag + " --apidir=$(@D) $(OUTS)"),
+ srcs = srcs,
+ tools = ["//tensorflow/tools/api/generator:create_python_api"],
+ )
diff --git a/tensorflow/tools/api/generator/create_python_api.py b/tensorflow/tools/api/generator/create_python_api.py
index 18182090da..9f210ad42b 100644
--- a/tensorflow/tools/api/generator/create_python_api.py
+++ b/tensorflow/tools/api/generator/create_python_api.py
@@ -29,9 +29,13 @@ from tensorflow.python.util import tf_decorator
_API_CONSTANTS_ATTR = '_tf_api_constants'
_API_NAMES_ATTR = '_tf_api_names'
-_API_DIR = '/api/'
_DEFAULT_PACKAGE = 'tensorflow.python'
-_OUTPUT_MODULE = 'tensorflow.tools.api.generator.api'
+_GENFILES_DIR_SUFFIX = 'genfiles/'
+_SYMBOLS_TO_SKIP_EXPLICITLY = {
+ # Overrides __getattr__, so that unwrapping tf_decorator
+ # would have side effects.
+ 'tensorflow.python.platform.flags.FLAGS'
+}
_GENERATED_FILE_HEADER = """\"\"\"Imports for Python API.
This file is MACHINE GENERATED! Do not edit.
@@ -147,8 +151,8 @@ class _ModuleInitCodeBuilder(object):
# the script outputs.
module_text_map[''] = module_text_map.get('', '') + '''
_names_with_underscore = [%s]
-__all__ = [s for s in dir() if not s.startswith('_')]
-__all__.extend([s for s in _names_with_underscore])
+__all__ = [_s for _s in dir() if not _s.startswith('_')]
+__all__.extend([_s for _s in _names_with_underscore])
__all__.remove('print_function')
''' % underscore_names_str
@@ -182,6 +186,9 @@ def get_api_init_text(package):
continue
for module_contents_name in dir(module):
+ if (module.__name__ + '.' + module_contents_name
+ in _SYMBOLS_TO_SKIP_EXPLICITLY):
+ continue
attr = getattr(module, module_contents_name)
# If attr is _tf_api_constants attribute, then add the constants.
@@ -194,7 +201,11 @@ def get_api_init_text(package):
-1, dest_module, module.__name__, value, names[-1])
continue
- _, attr = tf_decorator.unwrap(attr)
+ try:
+ _, attr = tf_decorator.unwrap(attr)
+ except Exception as e:
+ print('5555: %s %s' % (module, module_contents_name), file=sys.stderr)
+ raise e
# If attr is a symbol with _tf_api_names attribute, then
# add import for it.
if hasattr(attr, '__dict__') and _API_NAMES_ATTR in attr.__dict__:
@@ -209,6 +220,7 @@ def get_api_init_text(package):
# For e.g. if we import 'foo.bar.Value'. Then, we also
# import 'bar' in 'foo'.
imported_modules = set(module_code_builder.module_imports.keys())
+ import_from = '.'
for module in imported_modules:
if not module:
continue
@@ -216,11 +228,9 @@ def get_api_init_text(package):
parent_module = '' # we import submodules in their parent_module
for submodule_index in range(len(module_split)):
- import_from = _OUTPUT_MODULE
if submodule_index > 0:
parent_module += ('.' + module_split[submodule_index-1] if parent_module
else module_split[submodule_index-1])
- import_from += '.' + parent_module
module_code_builder.add_import(
-1, parent_module, import_from,
module_split[submodule_index], module_split[submodule_index])
@@ -228,7 +238,24 @@ def get_api_init_text(package):
return module_code_builder.build()
-def create_api_files(output_files, package):
+def get_module(dir_path, relative_to_dir):
+ """Get module that corresponds to path relative to relative_to_dir.
+
+ Args:
+ dir_path: Path to directory.
+ relative_to_dir: Get module relative to this directory.
+
+ Returns:
+ module that corresponds to the given directory.
+ """
+ dir_path = dir_path[len(relative_to_dir):]
+ # Convert path separators to '/' for easier parsing below.
+ dir_path = dir_path.replace(os.sep, '/')
+ return dir_path.replace('/', '.').strip('.')
+
+
+def create_api_files(
+ output_files, package, root_init_template, output_dir):
"""Creates __init__.py files for the Python API.
Args:
@@ -236,6 +263,10 @@ def create_api_files(output_files, package):
Each file must be under api/ directory.
package: Base python package containing python with target tf_export
decorators.
+ root_init_template: Template for top-level __init__.py file.
+ "#API IMPORTS PLACEHOLDER" comment in the template file will be replaced
+ with imports.
+ output_dir: output API root directory.
Raises:
ValueError: if an output file is not under api/ directory,
@@ -243,18 +274,7 @@ def create_api_files(output_files, package):
"""
module_name_to_file_path = {}
for output_file in output_files:
- # Convert path separators to '/' for easier parsing below.
- normalized_output_file = output_file.replace(os.sep, '/')
- if _API_DIR not in output_file:
- raise ValueError(
- 'Output files must be in api/ directory, found %s.' % output_file)
- # Get the module name that corresponds to output_file.
- # First get module directory under _API_DIR.
- module_dir = os.path.dirname(
- normalized_output_file[
- normalized_output_file.rfind(_API_DIR)+len(_API_DIR):])
- # Convert / to .
- module_name = module_dir.replace('/', '.').strip('.')
+ module_name = get_module(os.path.dirname(output_file), output_dir)
module_name_to_file_path[module_name] = os.path.normpath(output_file)
# Create file for each expected output in genrule.
@@ -270,12 +290,20 @@ def create_api_files(output_files, package):
for module, text in module_text_map.items():
# Make sure genrule output file list is in sync with API exports.
if module not in module_name_to_file_path:
- module_file_path = '"api/%s/__init__.py"' % (
+ module_file_path = '"%s/__init__.py"' % (
module.replace('.', '/'))
missing_output_files.append(module_file_path)
continue
+ contents = ''
+ if module or not root_init_template:
+ contents = _GENERATED_FILE_HEADER + text + _GENERATED_FILE_FOOTER
+ else:
+ # Read base init file
+ with open(root_init_template, 'r') as root_init_template_file:
+ contents = root_init_template_file.read()
+ contents = contents.replace('# API IMPORTS PLACEHOLDER', text)
with open(module_name_to_file_path[module], 'w') as fp:
- fp.write(_GENERATED_FILE_HEADER + text + _GENERATED_FILE_FOOTER)
+ fp.write(contents)
if missing_output_files:
raise ValueError(
@@ -297,6 +325,16 @@ def main():
'--package', default=_DEFAULT_PACKAGE, type=str,
help='Base package that imports modules containing the target tf_export '
'decorators.')
+ parser.add_argument(
+ '--root_init_template', default='', type=str,
+ help='Template for top level __init__.py file. '
+ '"#API IMPORTS PLACEHOLDER" comment will be replaced with imports.')
+ parser.add_argument(
+ '--apidir', type=str, required=True,
+ help='Directory where generated output files are placed. '
+ 'gendir should be a prefix of apidir. Also, apidir '
+ 'should be a prefix of every directory in outputs.')
+
args = parser.parse_args()
if len(args.outputs) == 1:
@@ -309,7 +347,8 @@ def main():
# Populate `sys.modules` with modules containing tf_export().
importlib.import_module(args.package)
- create_api_files(outputs, args.package)
+ create_api_files(
+ outputs, args.package, args.root_init_template, args.apidir)
if __name__ == '__main__':
diff --git a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt
index 8c8912dfab..23b552cc38 100644
--- a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt
@@ -43,6 +43,10 @@ tf_class {
name: "shape"
mtype: "<type \'property\'>"
}
+ member {
+ name: "trainable"
+ mtype: "<type \'property\'>"
+ }
member_method {
name: "__init__"
argspec: "args=[\'self\', \'initial_value\', \'trainable\', \'collections\', \'validate_shape\', \'caching_device\', \'name\', \'variable_def\', \'dtype\', \'expected_shape\', \'import_scope\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-exponential-moving-average.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-exponential-moving-average.pbtxt
index 737acbe07c..c9fe136e68 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-exponential-moving-average.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-exponential-moving-average.pbtxt
@@ -2,6 +2,10 @@ path: "tensorflow.train.ExponentialMovingAverage"
tf_class {
is_instance: "<class \'tensorflow.python.training.moving_averages.ExponentialMovingAverage\'>"
is_instance: "<type \'object\'>"
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
member_method {
name: "__init__"
argspec: "args=[\'self\', \'decay\', \'num_updates\', \'zero_debias\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'ExponentialMovingAverage\'], "
diff --git a/tensorflow/tools/api/lib/api_objects.proto b/tensorflow/tools/api/lib/api_objects.proto
index 7dcde0bbc3..7207b9c5a9 100644
--- a/tensorflow/tools/api/lib/api_objects.proto
+++ b/tensorflow/tools/api/lib/api_objects.proto
@@ -27,6 +27,10 @@ message TFAPIClass {
};
message TFAPIProto {
+ // Suppress generation of the proto API's descriptor() method lest it
+ // conflict with the standard accessor for the field having the same name.
+ option no_standard_descriptor_accessor = true;
+
optional google.protobuf.DescriptorProto descriptor = 1;
};
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 677ea65edd..e113565f45 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -173,9 +173,7 @@ sh_binary(
"//conditions:default": COMMON_PIP_DEPS + [
":simple_console",
"//tensorflow/contrib/lite/python:interpreter_test_data",
- "//tensorflow/contrib/lite/python:tf_lite_py_pip",
- "//tensorflow/contrib/lite/toco:toco",
- "//tensorflow/contrib/lite/toco/python:toco_wrapper",
+ "//tensorflow/contrib/lite/python:tflite_convert",
"//tensorflow/contrib/lite/toco/python:toco_from_protos",
],
}) + if_mkl(["//third_party/mkl:intel_binary_blob"]) + if_tensorrt([
diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh
index 41e714b1c1..f7e42ce536 100755
--- a/tensorflow/tools/pip_package/build_pip_package.sh
+++ b/tensorflow/tools/pip_package/build_pip_package.sh
@@ -112,9 +112,7 @@ function prepare_src() {
fi
mkdir "${TMPDIR}/tensorflow/aux-bin"
# Install toco as a binary in aux-bin.
- # TODO(aselle): Re-enable this when we find a way to do it without doubling
- # the whl size (over the limit).
- # cp bazel-bin/tensorflow/contrib/lite/toco/toco ${TMPDIR}/tensorflow/aux-bin/
+ cp bazel-bin/tensorflow/contrib/lite/python/tflite_convert ${TMPDIR}/tensorflow/aux-bin/
fi
# protobuf pip package doesn't ship with header files. Copy the headers
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index 70e6662763..d25a9e77b1 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -95,7 +95,8 @@ if sys.version_info < (3, 4):
CONSOLE_SCRIPTS = [
'freeze_graph = tensorflow.python.tools.freeze_graph:run_main',
'toco_from_protos = tensorflow.contrib.lite.toco.python.toco_from_protos:main',
- 'toco = tensorflow.contrib.lite.toco.python.toco_wrapper:main',
+ 'tflite_convert = tensorflow.contrib.lite.python.tflite_convert:main',
+ 'toco = tensorflow.contrib.lite.python.tflite_convert:main',
'saved_model_cli = tensorflow.python.tools.saved_model_cli:main',
# We need to keep the TensorBoard command, even though the console script
# is now declared by the tensorboard pip package. If we remove the
diff --git a/tensorflow/user_ops/BUILD b/tensorflow/user_ops/BUILD
deleted file mode 100644
index 71443cc41e..0000000000
--- a/tensorflow/user_ops/BUILD
+++ /dev/null
@@ -1,52 +0,0 @@
-# Description:
-# An example for custom op and kernel defined as a TensorFlow plugin.
-
-package(
- default_visibility = ["//tensorflow:internal"],
-)
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-load("//tensorflow:tensorflow.bzl", "tf_py_test")
-load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
-
-tf_custom_op_library(
- name = "ackermann_op.so",
- srcs = ["ackermann_op.cc"],
-)
-
-tf_py_test(
- name = "ackermann_test",
- size = "small",
- srcs = ["ackermann_test.py"],
- additional_deps = ["//tensorflow:tensorflow_py"],
- data = [":ackermann_op.so"],
-)
-
-tf_custom_op_library(
- name = "duplicate_op.so",
- srcs = ["duplicate_op.cc"],
-)
-
-tf_py_test(
- name = "duplicate_op_test",
- size = "small",
- srcs = ["duplicate_op_test.py"],
- additional_deps = ["//tensorflow:tensorflow_py"],
- data = [":duplicate_op.so"],
-)
-
-tf_custom_op_library(
- name = "invalid_op.so",
- srcs = ["invalid_op.cc"],
-)
-
-tf_py_test(
- name = "invalid_op_test",
- size = "small",
- srcs = ["invalid_op_test.py"],
- additional_deps = ["//tensorflow:tensorflow_py"],
- data = [":invalid_op.so"],
-)
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 86c2b50827..50a69598a1 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -452,11 +452,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/d3b4e8171138b4d39106fb3bea1b9b8d2bbd4001.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/d3b4e8171138b4d39106fb3bea1b9b8d2bbd4001.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/bf13d093f13a295d71080614c3036ada591201d5.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/bf13d093f13a295d71080614c3036ada591201d5.tar.gz",
],
- sha256 = "03db53e502dd4fbdbbf1c470776315eeff665180ade32859cfb6c1e996bbf2a5",
- strip_prefix = "llvm-d3b4e8171138b4d39106fb3bea1b9b8d2bbd4001",
+ sha256 = "3c5b4538a4df95090693bf6b758e861afc5b8c599592368f9dc57901f7560bd0",
+ strip_prefix = "llvm-bf13d093f13a295d71080614c3036ada591201d5",
build_file = clean_dep("//third_party/llvm:llvm.BUILD"),
)