aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-06-29 13:50:37 -0700
committerGravatar Michael Case <mikecase@google.com>2018-06-29 13:50:37 -0700
commit7be4245d629510ed3d1c2edd7a2598167017f33b (patch)
tree6585d143160249ae282044790d4589145a79efa2
parent01c36c3d7b3e230c865e71d67e138a8dc765e7a6 (diff)
parent79dab9ced650d69bdf3f312bd902bd52de5bdad8 (diff)
Merge commit for internal changes
-rw-r--r--configure.py4
-rw-r--r--tensorflow/BUILD10
-rw-r--r--tensorflow/c/c_api.cc6
-rw-r--r--tensorflow/compiler/jit/BUILD2
-rw-r--r--tensorflow/compiler/jit/xla_device_ops.h29
-rw-r--r--tensorflow/compiler/tests/BUILD14
-rw-r--r--tensorflow/compiler/tests/adagrad_test.py4
-rw-r--r--tensorflow/compiler/tests/adam_test.py4
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py4
-rw-r--r--tensorflow/compiler/tests/bucketize_op_test.py4
-rw-r--r--tensorflow/compiler/tests/categorical_op_test.py4
-rw-r--r--tensorflow/compiler/tests/cholesky_op_test.py4
-rw-r--r--tensorflow/compiler/tests/clustering_test.py4
-rw-r--r--tensorflow/compiler/tests/concat_ops_test.py8
-rw-r--r--tensorflow/compiler/tests/conv2d_test.py8
-rw-r--r--tensorflow/compiler/tests/conv3d_test.py6
-rw-r--r--tensorflow/compiler/tests/depthwise_conv_op_test.py4
-rw-r--r--tensorflow/compiler/tests/dynamic_slice_ops_test.py4
-rw-r--r--tensorflow/compiler/tests/dynamic_stitch_test.py4
-rw-r--r--tensorflow/compiler/tests/eager_test.py8
-rw-r--r--tensorflow/compiler/tests/extract_image_patches_op_test.py4
-rw-r--r--tensorflow/compiler/tests/fake_quant_ops_test.py10
-rw-r--r--tensorflow/compiler/tests/fft_test.py4
-rw-r--r--tensorflow/compiler/tests/fifo_queue_test.py201
-rw-r--r--tensorflow/compiler/tests/ftrl_test.py4
-rw-r--r--tensorflow/compiler/tests/function_test.py4
-rw-r--r--tensorflow/compiler/tests/fused_batchnorm_test.py4
-rw-r--r--tensorflow/compiler/tests/gather_nd_op_test.py4
-rw-r--r--tensorflow/compiler/tests/gather_test.py14
-rw-r--r--tensorflow/compiler/tests/image_ops_test.py12
-rw-r--r--tensorflow/compiler/tests/lrn_ops_test.py4
-rw-r--r--tensorflow/compiler/tests/matrix_band_part_test.py4
-rw-r--r--tensorflow/compiler/tests/matrix_triangular_solve_op_test.py4
-rw-r--r--tensorflow/compiler/tests/momentum_test.py4
-rw-r--r--tensorflow/compiler/tests/nary_ops_test.py4
-rw-r--r--tensorflow/compiler/tests/nullary_ops_test.py4
-rw-r--r--tensorflow/compiler/tests/placeholder_test.py4
-rw-r--r--tensorflow/compiler/tests/pooling_ops_3d_test.py4
-rw-r--r--tensorflow/compiler/tests/pooling_ops_test.py6
-rw-r--r--tensorflow/compiler/tests/random_ops_test.py4
-rw-r--r--tensorflow/compiler/tests/reduce_ops_test.py6
-rw-r--r--tensorflow/compiler/tests/reduce_window_test.py4
-rw-r--r--tensorflow/compiler/tests/reverse_ops_test.py4
-rw-r--r--tensorflow/compiler/tests/reverse_sequence_op_test.py4
-rw-r--r--tensorflow/compiler/tests/rmsprop_test.py4
-rw-r--r--tensorflow/compiler/tests/scan_ops_test.py6
-rw-r--r--tensorflow/compiler/tests/scatter_nd_op_test.py4
-rw-r--r--tensorflow/compiler/tests/slice_ops_test.py6
-rw-r--r--tensorflow/compiler/tests/spacetobatch_op_test.py6
-rw-r--r--tensorflow/compiler/tests/stack_ops_test.py4
-rw-r--r--tensorflow/compiler/tests/stateless_random_ops_test.py4
-rw-r--r--tensorflow/compiler/tests/ternary_ops_test.py4
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py4
-rw-r--r--tensorflow/compiler/tests/variable_ops_test.py6
-rw-r--r--tensorflow/compiler/tests/while_test.py4
-rw-r--r--tensorflow/compiler/tests/xla_device_test.py4
-rw-r--r--tensorflow/compiler/tf2xla/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc5
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cholesky_op.cc7
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_ops.cc12
-rw-r--r--tensorflow/compiler/tf2xla/kernels/diag_op.cc5
-rw-r--r--tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc5
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc11
-rw-r--r--tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc9
-rw-r--r--tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc7
-rw-r--r--tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/random_ops.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc6
-rw-r--r--tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/topk_op.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/unary_ops.cc4
-rw-r--r--tensorflow/compiler/tf2xla/lib/batch_dot.cc166
-rw-r--r--tensorflow/compiler/tf2xla/lib/batch_dot.h7
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.cc308
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.h3
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.cc1101
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.h22
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc104
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.cc219
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.h51
-rw-r--r--tensorflow/compiler/tf2xla/lib/util_test.cc22
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.cc29
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.h4
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc15
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h8
-rw-r--r--tensorflow/compiler/xla/client/lib/BUILD32
-rw-r--r--tensorflow/compiler/xla/client/lib/numeric.cc71
-rw-r--r--tensorflow/compiler/xla/client/lib/numeric.h30
-rw-r--r--tensorflow/compiler/xla/client/lib/numeric_test.cc37
-rw-r--r--tensorflow/compiler/xla/client/xla_client/BUILD3
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc71
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.h477
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc3
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc26
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.h8
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion_test.cc33
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc29
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc13
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h168
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc25
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc25
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc11
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.cc25
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.h8
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc10
-rw-r--r--tensorflow/compiler/xla/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h14
-rw-r--r--tensorflow/compiler/xla/tests/compute_constant_test.cc38
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc53
-rw-r--r--tensorflow/compiler/xla/tests/half_test.cc89
-rw-r--r--tensorflow/compiler/xla/tests/pred_test.cc22
-rw-r--r--tensorflow/compiler/xla/tests/scalar_computations_test.cc112
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc3
-rw-r--r--tensorflow/compiler/xla/tests/test_utils_test.cc19
-rw-r--r--tensorflow/contrib/BUILD1
-rw-r--r--tensorflow/contrib/autograph/pyct/BUILD12
-rw-r--r--tensorflow/contrib/autograph/pyct/cfg.py733
-rw-r--r--tensorflow/contrib/autograph/pyct/cfg_test.py790
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/cfg.py2
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py2
-rw-r--r--tensorflow/contrib/bigtable/BUILD196
-rw-r--r--tensorflow/contrib/bigtable/README.md10
-rw-r--r--tensorflow/contrib/bigtable/__init__.py39
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc313
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_lib.cc45
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_lib.h138
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc220
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc103
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc111
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc214
-rw-r--r--tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc367
-rw-r--r--tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h87
-rw-r--r--tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_op.cc77
-rw-r--r--tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc279
-rw-r--r--tensorflow/contrib/bigtable/ops/bigtable_ops.cc88
-rw-r--r--tensorflow/contrib/bigtable/ops/bigtable_test_ops.cc27
-rw-r--r--tensorflow/contrib/bigtable/python/kernel_tests/__init__.py20
-rw-r--r--tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py132
-rw-r--r--tensorflow/contrib/bigtable/python/ops/__init__.py20
-rw-r--r--tensorflow/contrib/bigtable/python/ops/bigtable_api.py480
-rw-r--r--tensorflow/contrib/cloud/BUILD1
-rw-r--r--tensorflow/contrib/cloud/README.md18
-rw-r--r--tensorflow/contrib/cloud/__init__.py13
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt2
-rw-r--r--tensorflow/contrib/cmake/tf_core_framework.cmake94
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/README.md45
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java30
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h1
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h13
-rw-r--r--tensorflow/contrib/lite/model.cc2
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h2
-rw-r--r--tensorflow/contrib/lite/python/lite.py4
-rw-r--r--tensorflow/contrib/lite/python/tflite_convert.py2
-rw-r--r--tensorflow/contrib/lite/toco/README.md13
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md26
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md4
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/python_api.md12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc16
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt1
-rw-r--r--tensorflow/contrib/opt/__init__.py7
-rw-r--r--tensorflow/contrib/opt/python/training/weight_decay_optimizers.py44
-rw-r--r--tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py8
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/decoder.py29
-rw-r--r--tensorflow/contrib/slim/python/slim/evaluation_test.py3
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc7
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.h4
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc20
-rw-r--r--tensorflow/contrib/tensorrt/trt_conversion.i10
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py10
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/api_def/api_test.cc2
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc9
-rw-r--r--tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc14
-rw-r--r--tensorflow/core/distributed_runtime/BUILD2
-rw-r--r--tensorflow/core/distributed_runtime/master_test.cc2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/BUILD14
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc164
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h224
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc4
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h3
-rw-r--r--tensorflow/core/framework/resource_op_kernel.h25
-rw-r--r--tensorflow/core/framework/stats_aggregator.h4
-rw-r--r--tensorflow/core/grappler/op_types.cc2
-rw-r--r--tensorflow/core/grappler/op_types.h1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc171
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc42
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils_test.cc26
-rw-r--r--tensorflow/core/kernels/BUILD15
-rw-r--r--tensorflow/core/kernels/constant_op.cc3
-rw-r--r--tensorflow/core/kernels/data/slide_dataset_op.cc5
-rw-r--r--tensorflow/core/kernels/data/stats_aggregator_ops.cc29
-rw-r--r--tensorflow/core/kernels/data/stats_dataset_ops.cc13
-rw-r--r--tensorflow/core/kernels/deserialize_sparse_variant_op.cc372
-rw-r--r--tensorflow/core/kernels/fifo_queue.cc15
-rw-r--r--tensorflow/core/kernels/fifo_queue.h23
-rw-r--r--tensorflow/core/kernels/fifo_queue_op.cc39
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc186
-rw-r--r--tensorflow/core/kernels/queue_op.cc367
-rw-r--r--tensorflow/core/kernels/queue_op.h233
-rw-r--r--tensorflow/core/kernels/queue_ops.cc395
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.h6
-rw-r--r--tensorflow/core/kernels/serialize_sparse_op.cc12
-rw-r--r--tensorflow/core/kernels/tensor_array_ops.cc1
-rw-r--r--tensorflow/core/kernels/variable_ops.cc3
-rw-r--r--tensorflow/core/util/mkl_util.h2
-rw-r--r--tensorflow/docs_src/api_guides/python/spectral_ops.md1
-rw-r--r--tensorflow/docs_src/get_started/index.md29
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md26
-rw-r--r--tensorflow/go/attrs_test.go4
-rw-r--r--tensorflow/java/src/gen/cc/op_generator.cc29
-rw-r--r--tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java296
-rw-r--r--tensorflow/python/compat/BUILD10
-rw-r--r--tensorflow/python/compat/compat.py81
-rw-r--r--tensorflow/python/debug/BUILD13
-rw-r--r--tensorflow/python/debug/examples/debug_keras.py89
-rwxr-xr-xtensorflow/python/debug/examples/examples_test.sh7
-rw-r--r--tensorflow/python/debug/wrappers/framework.py87
-rw-r--r--tensorflow/python/debug/wrappers/grpc_wrapper.py6
-rw-r--r--tensorflow/python/debug/wrappers/local_cli_wrapper.py2
-rw-r--r--tensorflow/python/debug/wrappers/local_cli_wrapper_test.py118
-rw-r--r--tensorflow/python/eager/function.py3
-rw-r--r--tensorflow/python/eager/function_test.py12
-rw-r--r--tensorflow/python/estimator/estimator_test.py26
-rw-r--r--tensorflow/python/keras/datasets/mnist.py2
-rw-r--r--tensorflow/python/keras/engine/saving.py11
-rw-r--r--tensorflow/python/keras/engine/saving_test.py78
-rw-r--r--tensorflow/python/keras/estimator/__init__.py2
-rw-r--r--tensorflow/python/kernel_tests/dct_ops_test.py96
-rw-r--r--tensorflow/python/lib/core/numpy.h2
-rw-r--r--tensorflow/python/lib/core/py_util.cc2
-rw-r--r--tensorflow/python/ops/image_ops_impl.py24
-rw-r--r--tensorflow/python/ops/image_ops_test.py2
-rw-r--r--tensorflow/python/ops/math_ops_test.py4
-rw-r--r--tensorflow/python/ops/rnn.py19
-rw-r--r--tensorflow/python/ops/special_math_ops.py4
-rw-r--r--tensorflow/python/ops/spectral_ops.py125
-rw-r--r--tensorflow/python/training/saver.py50
-rw-r--r--tensorflow/python/training/saver_test.py54
-rw-r--r--tensorflow/tools/api/golden/tensorflow.spectral.pbtxt4
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.cpu.ppc64le1
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le1
-rwxr-xr-xtensorflow/tools/ci_build/ci_sanity.sh4
-rw-r--r--tensorflow/tools/lib_package/BUILD2
-rw-r--r--tensorflow/tools/pip_package/BUILD2
-rw-r--r--tensorflow/tools/pip_package/setup.py2
-rw-r--r--tensorflow/workspace.bzl71
-rw-r--r--third_party/googleapis.BUILD45
256 files changed, 9727 insertions, 3255 deletions
diff --git a/configure.py b/configure.py
index ad585fa52e..5243e09b24 100644
--- a/configure.py
+++ b/configure.py
@@ -1134,7 +1134,9 @@ def set_tf_nccl_install_path(environ_cp):
nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path)
nccl_hdr_path = os.path.join(nccl_install_path, 'include/nccl.h')
- if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path):
+ nccl_license_path = os.path.join(nccl_install_path, 'NCCL-SLA.txt')
+ if os.path.exists(nccl_lib_path) and os.path.exists(
+ nccl_hdr_path) and os.path.exists(nccl_license_path):
# Set NCCL_INSTALL_PATH
environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path
write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path)
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index f362900387..67749ec04e 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -603,3 +603,13 @@ py_library(
visibility = ["//visibility:public"],
deps = ["//tensorflow/python:no_contrib"],
)
+
+cc_library(
+ name = "grpc",
+ deps = ["@grpc"],
+)
+
+cc_library(
+ name = "grpc++",
+ deps = ["@grpc//:grpc++"],
+)
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index a8ad8e4b94..5c218d3f25 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -2068,7 +2068,8 @@ TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
TF_Graph* graph, const TF_Buffer* graph_def,
const TF_ImportGraphDefOptions* options, TF_Status* status) {
GraphDef def;
- if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, graph_def->length)) {
+ if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
+ graph_def->length)) {
status->status = InvalidArgument("Invalid GraphDef");
return nullptr;
}
@@ -2098,7 +2099,8 @@ void TF_GraphImportGraphDefWithReturnOutputs(
return;
}
GraphDef def;
- if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data, graph_def->length)) {
+ if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
+ graph_def->length)) {
status->status = InvalidArgument("Invalid GraphDef");
return;
}
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index d976f8296c..c2245b8eae 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -176,9 +176,11 @@ cc_library(
"//tensorflow/core/kernels:cast_op",
"//tensorflow/core/kernels:constant_op",
"//tensorflow/core/kernels:control_flow_ops",
+ "//tensorflow/core/kernels:fifo_queue",
"//tensorflow/core/kernels:identity_n_op",
"//tensorflow/core/kernels:identity_op",
"//tensorflow/core/kernels:no_op",
+ "//tensorflow/core/kernels:queue_op",
"//tensorflow/core/kernels:resource_variable_ops",
"//tensorflow/core/kernels:sendrecv_ops",
"//tensorflow/core/kernels:shape_ops",
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index 11e45d2823..a605335a94 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -23,9 +23,11 @@ limitations under the License.
#include "tensorflow/core/kernels/cast_op.h"
#include "tensorflow/core/kernels/constant_op.h"
#include "tensorflow/core/kernels/control_flow_ops.h"
+#include "tensorflow/core/kernels/fifo_queue.h"
#include "tensorflow/core/kernels/identity_n_op.h"
#include "tensorflow/core/kernels/identity_op.h"
#include "tensorflow/core/kernels/no_op.h"
+#include "tensorflow/core/kernels/queue_op.h"
#include "tensorflow/core/kernels/resource_variable_ops.h"
#include "tensorflow/core/kernels/sendrecv_ops.h"
#include "tensorflow/core/kernels/shape_ops.h"
@@ -145,7 +147,32 @@ class XlaAssignVariableOp : public AsyncOpKernel {
.Device(DEVICE) \
.HostMemory("input") \
.HostMemory("output"), \
- LoopCondOp);
+ LoopCondOp); \
+ \
+ REGISTER_KERNEL_BUILDER( \
+ Name("QueueEnqueueV2").Device(DEVICE).HostMemory("handle"), EnqueueOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("QueueDequeueV2").Device(DEVICE).HostMemory("handle"), DequeueOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("QueueCloseV2").Device(DEVICE).HostMemory("handle"), QueueCloseOp); \
+ REGISTER_KERNEL_BUILDER(Name("QueueSizeV2") \
+ .Device(DEVICE) \
+ .HostMemory("size") \
+ .HostMemory("handle"), \
+ QueueSizeOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("QueueIsClosedV2").Device(DEVICE).HostMemory("handle"), \
+ QueueIsClosedOp); \
+ \
+ REGISTER_KERNEL_BUILDER( \
+ Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp);
+
+// TODO(phawkins): currently we do not register the QueueEnqueueMany,
+// QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read
+// and write the tensors they access in order to concatenate them into a batch.
+// We would need either to call out to an XLA computation to perform the
+// concatenation, or we would need to refactor those kernels so the splitting
+// or merging is done in a separate operator that can be compiled.
} // namespace tensorflow
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index c1f65416b4..366822f0b7 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -372,6 +372,20 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "fifo_queue_test",
+ size = "medium",
+ srcs = ["fifo_queue_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:data_flow_ops",
+ "//tensorflow/python:extra_py_tests_deps",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
name = "fft_test",
size = "medium",
srcs = ["fft_test.py"],
diff --git a/tensorflow/compiler/tests/adagrad_test.py b/tensorflow/compiler/tests/adagrad_test.py
index 9a93b32164..d775850a80 100644
--- a/tensorflow/compiler/tests/adagrad_test.py
+++ b/tensorflow/compiler/tests/adagrad_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
@@ -28,7 +28,7 @@ from tensorflow.python.platform import test
from tensorflow.python.training import adagrad
-class AdagradOptimizerTest(XLATestCase):
+class AdagradOptimizerTest(xla_test.XLATestCase):
def testBasic(self):
for dtype in self.float_types:
diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py
index 3215dc36e5..03554d6933 100644
--- a/tensorflow/compiler/tests/adam_test.py
+++ b/tensorflow/compiler/tests/adam_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
@@ -48,7 +48,7 @@ def adam_update_numpy(param,
return param_t, m_t, v_t
-class AdamOptimizerTest(XLATestCase):
+class AdamOptimizerTest(xla_test.XLATestCase):
def testBasic(self):
for dtype in self.float_types:
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index afef36d9d2..9cb3d04546 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
@@ -32,7 +32,7 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import googletest
-class BinaryOpsTest(XLATestCase):
+class BinaryOpsTest(xla_test.XLATestCase):
"""Test cases for binary operators."""
def _testBinary(self, op, a, b, expected, equality_test=None):
diff --git a/tensorflow/compiler/tests/bucketize_op_test.py b/tensorflow/compiler/tests/bucketize_op_test.py
index fde9759a1c..ef4d5f6322 100644
--- a/tensorflow/compiler/tests/bucketize_op_test.py
+++ b/tensorflow/compiler/tests/bucketize_op_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.ops import array_ops
@@ -26,7 +26,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class BucketizationOpTest(XLATestCase):
+class BucketizationOpTest(xla_test.XLATestCase):
def testInt(self):
with self.test_session() as sess:
diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py
index 035cdea178..a4e7f75081 100644
--- a/tensorflow/compiler/tests/categorical_op_test.py
+++ b/tensorflow/compiler/tests/categorical_op_test.py
@@ -22,7 +22,7 @@ import collections
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
@@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest
# TODO(srvasude): Merge this with
# third_party/tensorflow/python/kernel_tests/random/multinomial_op_test.py.
-class CategoricalTest(XLATestCase):
+class CategoricalTest(xla_test.XLATestCase):
"""Test cases for random-number generating operators."""
def output_dtypes(self):
diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py
index 1a8989d7c2..d2867278af 100644
--- a/tensorflow/compiler/tests/cholesky_op_test.py
+++ b/tensorflow/compiler/tests/cholesky_op_test.py
@@ -23,7 +23,7 @@ import unittest
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
@@ -32,7 +32,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class CholeskyOpTest(XLATestCase):
+class CholeskyOpTest(xla_test.XLATestCase):
# Cholesky defined for float64, float32, complex64, complex128
# (https://www.tensorflow.org/api_docs/python/tf/cholesky)
diff --git a/tensorflow/compiler/tests/clustering_test.py b/tensorflow/compiler/tests/clustering_test.py
index 574f82fc71..e42ebf8f9e 100644
--- a/tensorflow/compiler/tests/clustering_test.py
+++ b/tensorflow/compiler/tests/clustering_test.py
@@ -21,7 +21,7 @@ from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest
CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
-class ClusteringTest(XLATestCase):
+class ClusteringTest(xla_test.XLATestCase):
def testAdd(self):
val1 = np.array([4, 3, 2, 1], dtype=np.float32)
diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py
index f10973e19f..d9ad428147 100644
--- a/tensorflow/compiler/tests/concat_ops_test.py
+++ b/tensorflow/compiler/tests/concat_ops_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
@@ -30,7 +30,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
-class ConcatTest(XLATestCase):
+class ConcatTest(xla_test.XLATestCase):
def testHStack(self):
with self.test_session():
@@ -292,7 +292,7 @@ class ConcatTest(XLATestCase):
array_ops.concat([scalar, scalar, scalar], dim)
-class ConcatOffsetTest(XLATestCase):
+class ConcatOffsetTest(xla_test.XLATestCase):
def testBasic(self):
with self.test_session() as sess:
@@ -306,7 +306,7 @@ class ConcatOffsetTest(XLATestCase):
self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]])
-class PackTest(XLATestCase):
+class PackTest(xla_test.XLATestCase):
def testBasic(self):
with self.test_session() as sess:
diff --git a/tensorflow/compiler/tests/conv2d_test.py b/tensorflow/compiler/tests/conv2d_test.py
index d12e1ff1e8..98d41ba7ed 100644
--- a/tensorflow/compiler/tests/conv2d_test.py
+++ b/tensorflow/compiler/tests/conv2d_test.py
@@ -26,7 +26,7 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.compiler.tests import test_utils
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
@@ -42,7 +42,7 @@ DATA_FORMATS = (
)
-class Conv2DTest(XLATestCase, parameterized.TestCase):
+class Conv2DTest(xla_test.XLATestCase, parameterized.TestCase):
def _VerifyValues(self,
input_sizes=None,
@@ -236,7 +236,7 @@ class Conv2DTest(XLATestCase, parameterized.TestCase):
expected=np.reshape([108, 128], [1, 1, 1, 2]))
-class Conv2DBackpropInputTest(XLATestCase, parameterized.TestCase):
+class Conv2DBackpropInputTest(xla_test.XLATestCase, parameterized.TestCase):
def _VerifyValues(self,
input_sizes=None,
@@ -534,7 +534,7 @@ class Conv2DBackpropInputTest(XLATestCase, parameterized.TestCase):
expected=[5, 0, 11, 0, 0, 0, 17, 0, 23])
-class Conv2DBackpropFilterTest(XLATestCase, parameterized.TestCase):
+class Conv2DBackpropFilterTest(xla_test.XLATestCase, parameterized.TestCase):
def _VerifyValues(self,
input_sizes=None,
diff --git a/tensorflow/compiler/tests/conv3d_test.py b/tensorflow/compiler/tests/conv3d_test.py
index 3bebf46511..31ee41f04f 100644
--- a/tensorflow/compiler/tests/conv3d_test.py
+++ b/tensorflow/compiler/tests/conv3d_test.py
@@ -21,7 +21,7 @@ from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
@@ -33,7 +33,7 @@ from tensorflow.python.platform import googletest
# Test cloned from
# tensorflow/python/kernel_tests/conv3d_backprop_filter_v2_grad_test.py
-class Conv3DBackpropFilterV2GradTest(XLATestCase):
+class Conv3DBackpropFilterV2GradTest(xla_test.XLATestCase):
def testGradient(self):
with self.test_session(), self.test_scope():
@@ -66,7 +66,7 @@ class Conv3DBackpropFilterV2GradTest(XLATestCase):
# Test cloned from tensorflow/python/kernel_tests/conv3d_transpose_test.py
-class Conv3DTransposeTest(XLATestCase):
+class Conv3DTransposeTest(xla_test.XLATestCase):
def testConv3DTransposeSingleStride(self):
with self.test_session(), self.test_scope():
diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py
index 03d96a2cd8..98dc73e189 100644
--- a/tensorflow/compiler/tests/depthwise_conv_op_test.py
+++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py
@@ -21,7 +21,7 @@ from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -114,7 +114,7 @@ def CheckGradConfigsToTest():
yield i, f, o, s, p
-class DepthwiseConv2DTest(XLATestCase):
+class DepthwiseConv2DTest(xla_test.XLATestCase):
# This is testing that depthwise_conv2d and depthwise_conv2d_native
# produce the same results. It also tests that NCHW and NWHC
diff --git a/tensorflow/compiler/tests/dynamic_slice_ops_test.py b/tensorflow/compiler/tests/dynamic_slice_ops_test.py
index 6a46d2ec3e..154e36b10e 100644
--- a/tensorflow/compiler/tests/dynamic_slice_ops_test.py
+++ b/tensorflow/compiler/tests/dynamic_slice_ops_test.py
@@ -20,14 +20,14 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.compiler.tf2xla.python import xla
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class DynamicUpdateSliceOpsTest(XLATestCase):
+class DynamicUpdateSliceOpsTest(xla_test.XLATestCase):
def _assertOpOutputMatchesExpected(self, op, args, expected):
with self.test_session() as session:
diff --git a/tensorflow/compiler/tests/dynamic_stitch_test.py b/tensorflow/compiler/tests/dynamic_stitch_test.py
index c109c27abe..edd78153b5 100644
--- a/tensorflow/compiler/tests/dynamic_stitch_test.py
+++ b/tensorflow/compiler/tests/dynamic_stitch_test.py
@@ -20,14 +20,14 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.platform import googletest
-class DynamicStitchTest(XLATestCase):
+class DynamicStitchTest(xla_test.XLATestCase):
def _AssertDynamicStitchResultIs(self, indices, data, expected):
with self.test_session() as session:
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index e438832a23..3524666499 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
@@ -40,7 +40,7 @@ from tensorflow.python.platform import googletest
from tensorflow.python.training import adam
-class EagerTest(XLATestCase):
+class EagerTest(xla_test.XLATestCase):
def testBasic(self):
with self.test_scope():
@@ -286,7 +286,7 @@ class EagerTest(XLATestCase):
[2.0, 2.0]], embedding_matrix.numpy())
-class EagerFunctionTest(XLATestCase):
+class EagerFunctionTest(xla_test.XLATestCase):
def testBasic(self):
with self.test_scope():
@@ -419,7 +419,7 @@ class EagerFunctionTest(XLATestCase):
self.assertAllEqual((2, 3, 4), dz.shape.as_list())
-class ExcessivePaddingTest(XLATestCase):
+class ExcessivePaddingTest(xla_test.XLATestCase):
"""Test that eager execution works with TPU flattened tensors.
Tensors that would normally be excessively padded when written
diff --git a/tensorflow/compiler/tests/extract_image_patches_op_test.py b/tensorflow/compiler/tests/extract_image_patches_op_test.py
index 0361702e7a..5529fdbb09 100644
--- a/tensorflow/compiler/tests/extract_image_patches_op_test.py
+++ b/tensorflow/compiler/tests/extract_image_patches_op_test.py
@@ -20,13 +20,13 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class ExtractImagePatches(XLATestCase):
+class ExtractImagePatches(xla_test.XLATestCase):
"""Functional tests for ExtractImagePatches op."""
def _VerifyValues(self, image, ksizes, strides, rates, padding, patches):
diff --git a/tensorflow/compiler/tests/fake_quant_ops_test.py b/tensorflow/compiler/tests/fake_quant_ops_test.py
index dfe9400ef0..c48ab178bf 100644
--- a/tensorflow/compiler/tests/fake_quant_ops_test.py
+++ b/tensorflow/compiler/tests/fake_quant_ops_test.py
@@ -17,14 +17,14 @@ from __future__ import division
from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.platform import googletest
-class FakeQuantWithMinMaxArgsTest(XLATestCase):
+class FakeQuantWithMinMaxArgsTest(xla_test.XLATestCase):
"""Test cases for FakeQuantWithMinMaxArgs operation."""
# 8 bits, wide range.
@@ -122,7 +122,7 @@ class FakeQuantWithMinMaxArgsTest(XLATestCase):
result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03)
-class FakeQuantWithMinMaxArgsGradientTest(XLATestCase):
+class FakeQuantWithMinMaxArgsGradientTest(xla_test.XLATestCase):
"""Test cases for FakeQuantWithMinMaxArgsGradient operation."""
# 8 bits, wide range.
@@ -223,7 +223,7 @@ class FakeQuantWithMinMaxArgsGradientTest(XLATestCase):
bfloat16_rtol=0.03)
-class FakeQuantWithMinMaxVarsTest(XLATestCase):
+class FakeQuantWithMinMaxVarsTest(xla_test.XLATestCase):
"""Test cases for FakeQuantWithMinMaxVars operation."""
# 8 bits, wide range.
@@ -328,7 +328,7 @@ class FakeQuantWithMinMaxVarsTest(XLATestCase):
result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03)
-class FakeQuantWithMinMaxVarsGradientTest(XLATestCase):
+class FakeQuantWithMinMaxVarsGradientTest(xla_test.XLATestCase):
"""Test cases for FakeQuantWithMinMaxVarsGradient operation."""
# 8 bits, wide range.
diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py
index b2360dd009..c64ea249ec 100644
--- a/tensorflow/compiler/tests/fft_test.py
+++ b/tensorflow/compiler/tests/fft_test.py
@@ -23,7 +23,7 @@ import itertools
import numpy as np
import scipy.signal as sps
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.contrib.signal.python.ops import spectral_ops as signal
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
@@ -58,7 +58,7 @@ INNER_DIMS_2D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2))
INNER_DIMS_3D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2, POWS_OF_2))
-class FFTTest(XLATestCase):
+class FFTTest(xla_test.XLATestCase):
def _VerifyFftMethod(self, inner_dims, complex_to_input, input_to_expected,
tf_method):
diff --git a/tensorflow/compiler/tests/fifo_queue_test.py b/tensorflow/compiler/tests/fifo_queue_test.py
new file mode 100644
index 0000000000..0f64cc87cd
--- /dev/null
+++ b/tensorflow/compiler/tests/fifo_queue_test.py
@@ -0,0 +1,201 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.data_flow_ops.FIFOQueue."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import dtypes as dtypes_lib
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.platform import test
+
+
+class FIFOQueueTest(xla_test.XLATestCase):
+
+ def testEnqueue(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ enqueue_op = q.enqueue((10.0,))
+ enqueue_op.run()
+
+ def testEnqueueWithShape(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2))
+ enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],))
+ enqueue_correct_op.run()
+ with self.assertRaises(ValueError):
+ q.enqueue(([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],))
+ self.assertEqual(1, q.size().eval())
+
+ def testMultipleDequeues(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+ self.evaluate(q.enqueue([1]))
+ self.evaluate(q.enqueue([2]))
+ self.evaluate(q.enqueue([3]))
+ a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()])
+ self.assertAllEqual(set([1, 2, 3]), set([a, b, c]))
+
+ def testQueuesDontShare(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+ self.evaluate(q.enqueue(1))
+ q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+ self.evaluate(q2.enqueue(2))
+ self.assertAllEqual(self.evaluate(q2.dequeue()), 2)
+ self.assertAllEqual(self.evaluate(q.dequeue()), 1)
+
+ def testEnqueueDictWithoutNames(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ with self.assertRaisesRegexp(ValueError, "must have names"):
+ q.enqueue({"a": 12.0})
+
+ def testParallelEnqueue(self):
+ with self.test_session() as sess, self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
+ enqueue_ops = [q.enqueue((x,)) for x in elems]
+ dequeued_t = q.dequeue()
+
+ # Run one producer thread for each element in elems.
+ def enqueue(enqueue_op):
+ sess.run(enqueue_op)
+
+ threads = [
+ self.checkedThread(target=enqueue, args=(e,)) for e in enqueue_ops
+ ]
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+
+ # Dequeue every element using a single thread.
+ results = []
+ for _ in xrange(len(elems)):
+ results.append(dequeued_t.eval())
+ self.assertItemsEqual(elems, results)
+
+ def testParallelDequeue(self):
+ with self.test_session() as sess, self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
+ enqueue_ops = [q.enqueue((x,)) for x in elems]
+ dequeued_t = q.dequeue()
+
+ # Enqueue every element using a single thread.
+ for enqueue_op in enqueue_ops:
+ enqueue_op.run()
+
+ # Run one consumer thread for each element in elems.
+ results = []
+
+ def dequeue():
+ results.append(sess.run(dequeued_t))
+
+ threads = [self.checkedThread(target=dequeue) for _ in enqueue_ops]
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+ self.assertItemsEqual(elems, results)
+
+ def testDequeue(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ elems = [10.0, 20.0, 30.0]
+ enqueue_ops = [q.enqueue((x,)) for x in elems]
+ dequeued_t = q.dequeue()
+
+ for enqueue_op in enqueue_ops:
+ enqueue_op.run()
+
+ for i in xrange(len(elems)):
+ vals = dequeued_t.eval()
+ self.assertEqual([elems[i]], vals)
+
+ def testEnqueueAndBlockingDequeue(self):
+ with self.test_session() as sess, self.test_scope():
+ q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32)
+ elems = [10.0, 20.0, 30.0]
+ enqueue_ops = [q.enqueue((x,)) for x in elems]
+ dequeued_t = q.dequeue()
+
+ def enqueue():
+ # The enqueue_ops should run after the dequeue op has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ for enqueue_op in enqueue_ops:
+ sess.run(enqueue_op)
+
+ results = []
+
+ def dequeue():
+ for _ in xrange(len(elems)):
+ results.append(sess.run(dequeued_t))
+
+ enqueue_thread = self.checkedThread(target=enqueue)
+ dequeue_thread = self.checkedThread(target=dequeue)
+ enqueue_thread.start()
+ dequeue_thread.start()
+ enqueue_thread.join()
+ dequeue_thread.join()
+
+ for elem, result in zip(elems, results):
+ self.assertEqual([elem], result)
+
+ def testMultiEnqueueAndDequeue(self):
+ with self.test_session() as sess, self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32))
+ elems = [(5, 10.0), (10, 20.0), (15, 30.0)]
+ enqueue_ops = [q.enqueue((x, y)) for x, y in elems]
+ dequeued_t = q.dequeue()
+
+ for enqueue_op in enqueue_ops:
+ enqueue_op.run()
+
+ for i in xrange(len(elems)):
+ x_val, y_val = sess.run(dequeued_t)
+ x, y = elems[i]
+ self.assertEqual([x], x_val)
+ self.assertEqual([y], y_val)
+
+ def testQueueSizeEmpty(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ self.assertEqual([0], q.size().eval())
+
+ def testQueueSizeAfterEnqueueAndDequeue(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ enqueue_op = q.enqueue((10.0,))
+ dequeued_t = q.dequeue()
+ size = q.size()
+ self.assertEqual([], size.get_shape())
+
+ enqueue_op.run()
+ self.assertEqual(1, size.eval())
+ dequeued_t.op.run()
+ self.assertEqual(0, size.eval())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py
index 8e6407dffd..1da97fd512 100644
--- a/tensorflow/compiler/tests/ftrl_test.py
+++ b/tensorflow/compiler/tests/ftrl_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
@@ -30,7 +30,7 @@ from tensorflow.python.training import ftrl
from tensorflow.python.training import gradient_descent
-class FtrlOptimizerTest(XLATestCase):
+class FtrlOptimizerTest(xla_test.XLATestCase):
def initVariableAndGradient(self, dtype):
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py
index 8a3f4b0bdc..04fba44446 100644
--- a/tensorflow/compiler/tests/function_test.py
+++ b/tensorflow/compiler/tests/function_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
@@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import googletest
-class FunctionTest(XLATestCase):
+class FunctionTest(xla_test.XLATestCase):
def testFunction(self):
"""Executes a simple TensorFlow function."""
diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py
index 5782e76734..132e42ac7a 100644
--- a/tensorflow/compiler/tests/fused_batchnorm_test.py
+++ b/tensorflow/compiler/tests/fused_batchnorm_test.py
@@ -22,7 +22,7 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.compiler.tests import test_utils
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradient_checker
@@ -30,7 +30,7 @@ from tensorflow.python.ops import nn
from tensorflow.python.platform import test
-class FusedBatchNormTest(XLATestCase, parameterized.TestCase):
+class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
def _reference_training(self, x, scale, offset, epsilon, data_format):
if data_format != "NHWC":
diff --git a/tensorflow/compiler/tests/gather_nd_op_test.py b/tensorflow/compiler/tests/gather_nd_op_test.py
index 9378b1db72..23b0aed34f 100644
--- a/tensorflow/compiler/tests/gather_nd_op_test.py
+++ b/tensorflow/compiler/tests/gather_nd_op_test.py
@@ -20,13 +20,13 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class GatherNdTest(XLATestCase):
+class GatherNdTest(xla_test.XLATestCase):
def _runGather(self, params, indices):
with self.test_session():
diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py
index 1a8c451911..e9c8ef7c91 100644
--- a/tensorflow/compiler/tests/gather_test.py
+++ b/tensorflow/compiler/tests/gather_test.py
@@ -136,6 +136,20 @@ class GatherTest(xla_test.XLATestCase):
self.assertAllEqual(
[[7]], gather.eval(feed_dict={params: [4, 7, 2], indices: [[1]]}))
+ def testGatherPrecision(self):
+ with self.test_session() as session, self.test_scope():
+ data = np.array([[0, 0, 0, 0], [0, 2 * (1 + np.exp2(-8)), 0, 0],
+ [0, 0, 0, 0], [0.015789, 0.0985, 0.55789, 0.3842]])
+ indices = np.array([1, 2, 3, 1])
+ dtype = dtypes.float32
+ params_np = self._buildParams(data, dtype)
+ params = array_ops.placeholder(dtype=dtype)
+ indices_tf = constant_op.constant(indices)
+ gather_t = array_ops.gather(params, indices_tf)
+ gather_val = session.run(gather_t, feed_dict={params: params_np})
+ np_val = params_np[indices]
+ self.assertAllEqual(np_val, gather_val)
+
class GatherBenchmark(test.Benchmark):
"""Microbenchmarks for the gather op."""
diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py
index 7cf953ef25..8b01ef96db 100644
--- a/tensorflow/compiler/tests/image_ops_test.py
+++ b/tensorflow/compiler/tests/image_ops_test.py
@@ -25,7 +25,7 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -41,7 +41,7 @@ def GenerateNumpyRandomRGB(shape):
return np.random.randint(0, 256, shape) / 256.
-class RGBToHSVTest(XLATestCase):
+class RGBToHSVTest(xla_test.XLATestCase):
def testBatch(self):
# Build an arbitrary RGB image
@@ -104,7 +104,7 @@ class RGBToHSVTest(XLATestCase):
self.assertAllCloseAccordingToType(hsv_tf, hsv_np)
-class AdjustContrastTest(XLATestCase):
+class AdjustContrastTest(xla_test.XLATestCase):
def _testContrast(self, x_np, y_np, contrast_factor):
with self.test_session():
@@ -168,7 +168,7 @@ class AdjustContrastTest(XLATestCase):
self.assertAllClose(y_tf, y_np, rtol=1e-5, atol=1e-5)
-class AdjustHueTest(XLATestCase):
+class AdjustHueTest(xla_test.XLATestCase):
def testAdjustNegativeHue(self):
x_shape = [2, 2, 3]
@@ -303,7 +303,7 @@ class AdjustHueTest(XLATestCase):
self._adjustHueTf(x_np, delta_h)
-class AdjustSaturationTest(XLATestCase):
+class AdjustSaturationTest(xla_test.XLATestCase):
def _adjust_saturation(self, image, saturation_factor):
image = ops.convert_to_tensor(image, name="image")
@@ -403,7 +403,7 @@ class AdjustSaturationTest(XLATestCase):
self.assertAllClose(y_fused, y_baseline, rtol=2e-5, atol=1e-5)
-class ResizeBilinearTest(XLATestCase):
+class ResizeBilinearTest(xla_test.XLATestCase):
def _assertForwardOpMatchesExpected(self,
image_np,
diff --git a/tensorflow/compiler/tests/lrn_ops_test.py b/tensorflow/compiler/tests/lrn_ops_test.py
index 69bd8f7230..253b45902f 100644
--- a/tensorflow/compiler/tests/lrn_ops_test.py
+++ b/tensorflow/compiler/tests/lrn_ops_test.py
@@ -22,7 +22,7 @@ import copy
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -36,7 +36,7 @@ CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
# Local response normalization tests. The forward tests are copied from
# tensorflow/python/kernel_tests/lrn_op_test.py
-class LRNTest(XLATestCase):
+class LRNTest(xla_test.XLATestCase):
def _LRN(self, input_image, lrn_depth_radius=5, bias=1.0, alpha=1.0,
beta=0.5):
diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py
index 29394f9ea5..0d9f99f8a6 100644
--- a/tensorflow/compiler/tests/matrix_band_part_test.py
+++ b/tensorflow/compiler/tests/matrix_band_part_test.py
@@ -19,14 +19,14 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class MatrixBandPartTest(XLATestCase):
+class MatrixBandPartTest(xla_test.XLATestCase):
def _testMatrixBandPart(self, dtype, shape):
with self.test_session():
diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
index 5819b2bf2b..2bb8a97bda 100644
--- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
+++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
@@ -22,7 +22,7 @@ import itertools
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
@@ -35,7 +35,7 @@ def MakePlaceholder(x):
return array_ops.placeholder(dtypes.as_dtype(x.dtype), shape=x.shape)
-class MatrixTriangularSolveOpTest(XLATestCase):
+class MatrixTriangularSolveOpTest(xla_test.XLATestCase):
# MatrixTriangularSolve defined for float64, float32, complex64, complex128
# (https://www.tensorflow.org/api_docs/python/tf/matrix_triangular_solve)
diff --git a/tensorflow/compiler/tests/momentum_test.py b/tensorflow/compiler/tests/momentum_test.py
index af9394e7d7..c2592c54cf 100644
--- a/tensorflow/compiler/tests/momentum_test.py
+++ b/tensorflow/compiler/tests/momentum_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
@@ -30,7 +30,7 @@ from tensorflow.python.platform import test
from tensorflow.python.training import momentum as momentum_lib
-class MomentumOptimizerTest(XLATestCase):
+class MomentumOptimizerTest(xla_test.XLATestCase):
def _update_nesterov_momentum_numpy(self, var, accum, g, lr, momentum):
var += accum * lr * momentum
diff --git a/tensorflow/compiler/tests/nary_ops_test.py b/tensorflow/compiler/tests/nary_ops_test.py
index e4843b169b..da08225e9f 100644
--- a/tensorflow/compiler/tests/nary_ops_test.py
+++ b/tensorflow/compiler/tests/nary_ops_test.py
@@ -22,14 +22,14 @@ import unittest
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
-class NAryOpsTest(XLATestCase):
+class NAryOpsTest(xla_test.XLATestCase):
def _testNAry(self, op, args, expected, equality_fn=None):
with self.test_session() as session:
diff --git a/tensorflow/compiler/tests/nullary_ops_test.py b/tensorflow/compiler/tests/nullary_ops_test.py
index 6f588d8ab5..2f9122645d 100644
--- a/tensorflow/compiler/tests/nullary_ops_test.py
+++ b/tensorflow/compiler/tests/nullary_ops_test.py
@@ -20,13 +20,13 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import googletest
-class NullaryOpsTest(XLATestCase):
+class NullaryOpsTest(xla_test.XLATestCase):
def _testNullary(self, op, expected):
with self.test_session() as session:
diff --git a/tensorflow/compiler/tests/placeholder_test.py b/tensorflow/compiler/tests/placeholder_test.py
index 5e6d1313bd..a75d99189b 100644
--- a/tensorflow/compiler/tests/placeholder_test.py
+++ b/tensorflow/compiler/tests/placeholder_test.py
@@ -18,14 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
-class PlaceholderTest(XLATestCase):
+class PlaceholderTest(xla_test.XLATestCase):
def test_placeholder_with_default_default(self):
with self.test_session() as sess, self.test_scope():
diff --git a/tensorflow/compiler/tests/pooling_ops_3d_test.py b/tensorflow/compiler/tests/pooling_ops_3d_test.py
index d9285186ba..17f860db61 100644
--- a/tensorflow/compiler/tests/pooling_ops_3d_test.py
+++ b/tensorflow/compiler/tests/pooling_ops_3d_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -41,7 +41,7 @@ def _AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding):
padding=padding)
-class Pooling3DTest(XLATestCase):
+class Pooling3DTest(xla_test.XLATestCase):
def _VerifyValues(self, pool_func, input_sizes, window, strides, padding,
expected):
diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py
index fe270af3d6..9fc94752ea 100644
--- a/tensorflow/compiler/tests/pooling_ops_test.py
+++ b/tensorflow/compiler/tests/pooling_ops_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -69,7 +69,7 @@ def GetTestConfigs():
return test_configs
-class PoolingTest(XLATestCase):
+class PoolingTest(xla_test.XLATestCase):
def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding,
data_format, expected):
@@ -288,7 +288,7 @@ class PoolingTest(XLATestCase):
expected=expected_output)
-class PoolGradTest(XLATestCase):
+class PoolGradTest(xla_test.XLATestCase):
CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py
index 2e71b00ba6..b880b2a3fe 100644
--- a/tensorflow/compiler/tests/random_ops_test.py
+++ b/tensorflow/compiler/tests/random_ops_test.py
@@ -22,7 +22,7 @@ import math
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@@ -31,7 +31,7 @@ from tensorflow.python.ops.distributions import special_math
from tensorflow.python.platform import googletest
-class RandomOpsTest(XLATestCase):
+class RandomOpsTest(xla_test.XLATestCase):
"""Test cases for random-number generating operators."""
def _random_types(self):
diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py
index 7420724bdb..cea2ec816f 100644
--- a/tensorflow/compiler/tests/reduce_ops_test.py
+++ b/tensorflow/compiler/tests/reduce_ops_test.py
@@ -22,7 +22,7 @@ import functools
import itertools
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.ops import array_ops
@@ -30,7 +30,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
-class ReduceOpsTest(XLATestCase):
+class ReduceOpsTest(xla_test.XLATestCase):
def _testReduction(self,
tf_reduce_fn,
@@ -156,7 +156,7 @@ class ReduceOpsTest(XLATestCase):
self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA)
-class ReduceOpPrecisionTest(XLATestCase):
+class ReduceOpPrecisionTest(xla_test.XLATestCase):
def _testReduceSum(self,
expected_result,
diff --git a/tensorflow/compiler/tests/reduce_window_test.py b/tensorflow/compiler/tests/reduce_window_test.py
index e78a63465b..c69b6837b0 100644
--- a/tensorflow/compiler/tests/reduce_window_test.py
+++ b/tensorflow/compiler/tests/reduce_window_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.compiler.tf2xla.python import xla
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
@@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import googletest
-class ReduceWindowTest(XLATestCase):
+class ReduceWindowTest(xla_test.XLATestCase):
"""Test cases for xla.reduce_window."""
def _reduce_window(self, operand, init, reducer, **kwargs):
diff --git a/tensorflow/compiler/tests/reverse_ops_test.py b/tensorflow/compiler/tests/reverse_ops_test.py
index 18fabca28c..d01c676e7c 100644
--- a/tensorflow/compiler/tests/reverse_ops_test.py
+++ b/tensorflow/compiler/tests/reverse_ops_test.py
@@ -21,14 +21,14 @@ from __future__ import print_function
import itertools
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import googletest
-class ReverseOpsTest(XLATestCase):
+class ReverseOpsTest(xla_test.XLATestCase):
def testReverseOneDim(self):
shape = (7, 5, 9, 11)
diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py
index 1a5d05094e..ccfa630016 100644
--- a/tensorflow/compiler/tests/reverse_sequence_op_test.py
+++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py
@@ -20,13 +20,13 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class ReverseSequenceTest(XLATestCase):
+class ReverseSequenceTest(xla_test.XLATestCase):
def _testReverseSequence(self,
x,
diff --git a/tensorflow/compiler/tests/rmsprop_test.py b/tensorflow/compiler/tests/rmsprop_test.py
index ecdce4f052..9489fded32 100644
--- a/tensorflow/compiler/tests/rmsprop_test.py
+++ b/tensorflow/compiler/tests/rmsprop_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
@@ -28,7 +28,7 @@ from tensorflow.python.platform import test
from tensorflow.python.training import rmsprop
-class RmspropTest(XLATestCase):
+class RmspropTest(xla_test.XLATestCase):
def testBasic(self):
for dtype in self.float_types:
diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py
index 3260e63b23..4292352e76 100644
--- a/tensorflow/compiler/tests/scan_ops_test.py
+++ b/tensorflow/compiler/tests/scan_ops_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
@@ -69,7 +69,7 @@ def handle_options(func, x, axis, exclusive, reverse):
return x
-class CumsumTest(XLATestCase):
+class CumsumTest(xla_test.XLATestCase):
valid_dtypes = [np.float32]
@@ -147,7 +147,7 @@ class CumsumTest(XLATestCase):
math_ops.cumsum(input_tensor, [0]).eval()
-class CumprodTest(XLATestCase):
+class CumprodTest(xla_test.XLATestCase):
valid_dtypes = [np.float32]
diff --git a/tensorflow/compiler/tests/scatter_nd_op_test.py b/tensorflow/compiler/tests/scatter_nd_op_test.py
index 638946e234..f606f88545 100644
--- a/tensorflow/compiler/tests/scatter_nd_op_test.py
+++ b/tensorflow/compiler/tests/scatter_nd_op_test.py
@@ -22,7 +22,7 @@ import functools
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -68,7 +68,7 @@ def _NumpyUpdate(indices, updates, shape):
return _NumpyScatterNd(ref, indices, updates, lambda p, u: u)
-class ScatterNdTest(XLATestCase):
+class ScatterNdTest(xla_test.XLATestCase):
def _VariableRankTest(self,
np_scatter,
diff --git a/tensorflow/compiler/tests/slice_ops_test.py b/tensorflow/compiler/tests/slice_ops_test.py
index 305ca0c6b7..6c4890565d 100644
--- a/tensorflow/compiler/tests/slice_ops_test.py
+++ b/tensorflow/compiler/tests/slice_ops_test.py
@@ -18,14 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import googletest
-class SliceTest(XLATestCase):
+class SliceTest(xla_test.XLATestCase):
def test1D(self):
for dtype in self.numeric_types:
@@ -110,7 +110,7 @@ class SliceTest(XLATestCase):
self.assertAllEqual([[[1, 1, 1, 1], [6, 5, 4, 3]]], result)
-class StridedSliceTest(XLATestCase):
+class StridedSliceTest(xla_test.XLATestCase):
def test1D(self):
for dtype in self.numeric_types:
diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py
index f37c34156f..c685bc548f 100644
--- a/tensorflow/compiler/tests/spacetobatch_op_test.py
+++ b/tensorflow/compiler/tests/spacetobatch_op_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
@@ -68,7 +68,7 @@ def space_to_batch_direct(input_array, block_shape, paddings):
return permuted_reshaped_padded.reshape(output_shape)
-class SpaceToBatchTest(XLATestCase):
+class SpaceToBatchTest(xla_test.XLATestCase):
"""Tests input-output pairs for the SpaceToBatch and BatchToSpace ops."""
def _testPad(self, inputs, paddings, block_size, outputs):
@@ -149,7 +149,7 @@ class SpaceToBatchTest(XLATestCase):
self._testOne(x_np, block_size, x_out)
-class SpaceToBatchNDTest(XLATestCase):
+class SpaceToBatchNDTest(xla_test.XLATestCase):
"""Tests input-output pairs for the SpaceToBatchND and BatchToSpaceND ops."""
def _testPad(self, inputs, block_shape, paddings, outputs):
diff --git a/tensorflow/compiler/tests/stack_ops_test.py b/tensorflow/compiler/tests/stack_ops_test.py
index 94342f9567..b7dd787fef 100644
--- a/tensorflow/compiler/tests/stack_ops_test.py
+++ b/tensorflow/compiler/tests/stack_ops_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -28,7 +28,7 @@ from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.platform import test
-class StackOpTest(XLATestCase):
+class StackOpTest(xla_test.XLATestCase):
def testStackPushPop(self):
with self.test_session(), self.test_scope():
diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py
index abce190d83..d162675ef8 100644
--- a/tensorflow/compiler/tests/stateless_random_ops_test.py
+++ b/tensorflow/compiler/tests/stateless_random_ops_test.py
@@ -22,7 +22,7 @@ import math
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.contrib import stateless
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
@@ -30,7 +30,7 @@ from tensorflow.python.ops.distributions import special_math
from tensorflow.python.platform import test
-class StatelessRandomOpsTest(XLATestCase):
+class StatelessRandomOpsTest(xla_test.XLATestCase):
"""Test cases for stateless random-number generator operators."""
def _random_types(self):
diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py
index ef047005b6..effa5a59fe 100644
--- a/tensorflow/compiler/tests/ternary_ops_test.py
+++ b/tensorflow/compiler/tests/ternary_ops_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_math_ops
@@ -28,7 +28,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
-class TernaryOpsTest(XLATestCase):
+class TernaryOpsTest(xla_test.XLATestCase):
def _testTernary(self, op, a, b, c, expected):
with self.test_session() as session:
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index a24abd7547..6a7011aea6 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -23,7 +23,7 @@ import unittest
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import bitwise_ops
@@ -44,7 +44,7 @@ def nhwc_to_format(x, data_format):
raise ValueError("Unknown format {}".format(data_format))
-class UnaryOpsTest(XLATestCase):
+class UnaryOpsTest(xla_test.XLATestCase):
"""Test cases for unary operators."""
def _assertOpOutputMatchesExpected(self,
diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py
index bd616f2a20..dd2c252d38 100644
--- a/tensorflow/compiler/tests/variable_ops_test.py
+++ b/tensorflow/compiler/tests/variable_ops_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -37,7 +37,7 @@ from tensorflow.python.platform import googletest
from tensorflow.python.training.gradient_descent import GradientDescentOptimizer
-class VariableOpsTest(XLATestCase):
+class VariableOpsTest(xla_test.XLATestCase):
"""Test cases for resource variable operators."""
def testOneWriteOneOutput(self):
@@ -435,7 +435,7 @@ class StridedSliceAssignChecker(object):
self.test.assertAllEqual(val, valnp)
-class SliceAssignTest(XLATestCase):
+class SliceAssignTest(xla_test.XLATestCase):
def testSliceAssign(self):
for dtype in self.numeric_types:
diff --git a/tensorflow/compiler/tests/while_test.py b/tensorflow/compiler/tests/while_test.py
index f79eb27435..b637cf31cf 100644
--- a/tensorflow/compiler/tests/while_test.py
+++ b/tensorflow/compiler/tests/while_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.compiler.tf2xla.python import xla
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -29,7 +29,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class WhileTest(XLATestCase):
+class WhileTest(xla_test.XLATestCase):
def testSingletonLoopHandrolled(self):
# Define a function for the loop body
diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py
index f0b010fa67..06d977b93c 100644
--- a/tensorflow/compiler/tests/xla_device_test.py
+++ b/tensorflow/compiler/tests/xla_device_test.py
@@ -20,14 +20,14 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.platform import test
-class XlaDeviceTest(XLATestCase):
+class XlaDeviceTest(xla_test.XLATestCase):
def testCopies(self):
"""Tests that copies onto and off XLA devices work."""
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index a7b9cc6c81..aa9c0596d1 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -169,6 +169,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:numeric",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:core_cpu",
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 45657bb150..e6cbf2349d 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -121,6 +121,7 @@ tf_kernel_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/lib:numeric",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:image_ops_op_lib",
diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc
index b0ba25b998..4cfe946b2e 100644
--- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc
@@ -28,11 +28,10 @@ class BatchMatMulOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- auto result = BatchDot(ctx->builder(), ctx->Input(0), ctx->Input(1),
+ auto result = BatchDot(ctx->Input(0), ctx->Input(1),
/*transpose_x=*/adj_x_, /*transpose_y=*/adj_y_,
/*conjugate_x=*/adj_x_, /*conjugate_y=*/adj_y_);
- OP_REQUIRES_OK(ctx, result.status());
- ctx->SetOutput(0, result.ValueOrDie());
+ ctx->SetOutput(0, result);
}
private:
diff --git a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc
index fe6651793d..9fcbc86adc 100644
--- a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc
@@ -24,12 +24,7 @@ class CholeskyOp : public XlaOpKernel {
public:
explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- auto result = Cholesky(ctx->builder(), ctx->Input(0));
- if (!result.ok()) {
- ctx->SetStatus(result.status());
- return;
- }
- ctx->SetOutput(0, result.ValueOrDie());
+ ctx->SetOutput(0, Cholesky(ctx->Input(0)));
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
index 5d41fc708a..48ac4867ed 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/numeric_op.h"
@@ -96,14 +97,9 @@ xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape,
// Create a M sized linspace and an M*N sized linspace that will be
// broadcasted into perpendicular dimensions and compared.
- xla::XlaOp input_feature_iota;
- // DT_INT32 Iota will always return status::OK().
- TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature,
- &input_feature_iota));
- xla::XlaOp expanded_feature_iota;
- TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32,
- input_feature * depthwise_multiplier,
- &expanded_feature_iota));
+ xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature);
+ xla::XlaOp expanded_feature_iota =
+ xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier);
// Divide the M*N sized linspace by the depthwise_multiplier to create
// [0 0 1 1 2 2] in the example in the function comment.
diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
index 17bf0c069c..378b62c0d6 100644
--- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -39,9 +40,7 @@ xla::StatusOr<xla::XlaOp> CreateDiagonal(
//
// This produces a predicate matrix of the right size, with "true" on the
// diagonal.
- xla::XlaOp iota;
- TF_RETURN_IF_ERROR(
- XlaHelpers::Iota(builder, DataType::DT_INT32, last_dim_size, &iota));
+ xla::XlaOp iota = xla::Iota(builder, xla::S32, last_dim_size);
xla::XlaOp iota_broadcast = xla::Broadcast(iota, {last_dim_size});
xla::XlaOp mask = xla::Eq(iota_broadcast, iota, {0});
diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
index b2451236de..65d42a302f 100644
--- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/util/tensor_format.h"
@@ -111,9 +112,7 @@ class ExtractImagePatchesOp : public XlaOpKernel {
// Builds an identity matrix as a broadcast equality of iotas.
// iota = np.arange(np.prod(ksize), depth)
// filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32)
- xla::XlaOp iota;
- TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32,
- kernel_size * depth, &iota));
+ xla::XlaOp iota = xla::Iota(builder, xla::S32, kernel_size * depth);
auto lhs = xla::Reshape(iota, lhs_shape);
auto filter = xla::ConvertElementType(
diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
index de971ce4ac..d6bf92fb3d 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/register_types.h"
@@ -128,10 +129,7 @@ const int64 kMax2DKernelSize = 16;
xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder,
gtl::ArraySlice<int64> kernel_size,
int64 channels) {
- xla::XlaOp channels_iota;
- // DT_INT32 Iota will always return status::OK().
- TF_CHECK_OK(
- XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota));
+ xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels);
auto diag = xla::ConvertElementType(
xla::Eq(xla::Broadcast(channels_iota, {2 * kernel_size[0] - 1,
@@ -149,10 +147,7 @@ xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder,
xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder,
gtl::ArraySlice<int64> kernel_size,
int64 channels, int64 dim) {
- xla::XlaOp channels_iota;
- // DT_INT32 Iota will always return status::OK().
- TF_CHECK_OK(
- XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota));
+ xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels);
auto diag = xla::ConvertElementType(
xla::Eq(
diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
index 9d3575e331..e06c87db7a 100644
--- a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -51,6 +52,7 @@ class MatrixBandPartOp : public XlaOpKernel {
xla::XlaOp num_upper = context->Input(2);
DataType input_type = context->input_type(0);
DataType index_type = context->input_type(1);
+ xla::PrimitiveType index_xla_type = context->input_xla_type(1);
TensorShape batch_shape = input_shape;
batch_shape.RemoveLastDims(2);
@@ -59,11 +61,8 @@ class MatrixBandPartOp : public XlaOpKernel {
// Compute 'offset', which is how many diagonals we are above/below the
// diagonal.
- xla::XlaOp iota_m;
- OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, m, &iota_m));
-
- xla::XlaOp iota_n;
- OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, n, &iota_n));
+ xla::XlaOp iota_m = xla::Iota(builder, index_xla_type, m);
+ xla::XlaOp iota_n = xla::Iota(builder, index_xla_type, n);
auto offset = xla::Sub(xla::Broadcast(iota_n, {m}), iota_m,
/*broadcast_dimensions=*/{0});
diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
index 7bf1894ea0..e2ab4b83cf 100644
--- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
namespace tensorflow {
@@ -62,10 +63,8 @@ class MatrixSetDiagOp : public XlaOpKernel {
auto zero = XlaHelpers::Zero(builder, context->input_type(0));
// Create an indicator tensor that is true only on the diagonal.
- xla::XlaOp iota_m;
- OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, m, &iota_m));
- xla::XlaOp iota_n;
- OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, n, &iota_n));
+ xla::XlaOp iota_m = xla::Iota(builder, xla::S32, m);
+ xla::XlaOp iota_n = xla::Iota(builder, xla::S32, n);
auto indicator = xla::Eq(iota_m, xla::Broadcast(iota_n, {m}),
/*broadcast_dimensions=*/{0});
indicator = xla::Broadcast(indicator, batch_shape.dim_sizes());
diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc
index eaed931464..f4def11d08 100644
--- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc
@@ -30,13 +30,9 @@ class MatrixTriangularSolveOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
auto result = TriangularSolve(
- ctx->builder(), ctx->Input(0), ctx->Input(1), /*left_side=*/true,
+ ctx->Input(0), ctx->Input(1), /*left_side=*/true,
/*lower=*/lower_, /*transpose_a=*/adjoint_, /*conjugate_a=*/adjoint_);
- if (!result.ok()) {
- ctx->SetStatus(result.status());
- return;
- }
- ctx->SetOutput(0, result.ValueOrDie());
+ ctx->SetOutput(0, result);
}
private:
diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
index 51f2cdc9f4..d5b645d70a 100644
--- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
@@ -84,8 +85,7 @@ class RandomShuffleOp : public XlaOpKernel {
xla::ConstantR0<int32>(builder, n), swaps_shape);
// Generate range(n) as the initial value for the indices to be swapped.
- xla::XlaOp indices;
- TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, n, &indices));
+ xla::XlaOp indices = xla::Iota(builder, xla::S32, n);
// Swap the indices at i and swaps[i].
auto swap_body_fn = [&](xla::XlaOp i,
diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
index 16491002b4..c810456f94 100644
--- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -165,9 +166,8 @@ class ReverseSequenceOp : public XlaOpKernel {
auto output = xla::GetTupleElement(loop_output, 2);
// Mask out elements after the sequence length.
- xla::XlaOp iota;
- OP_REQUIRES_OK(
- context, XlaHelpers::Iota(builder, seq_lens_type, max_seq_len, &iota));
+ xla::XlaOp iota =
+ xla::Iota(builder, seq_lens_xla_shape.element_type(), max_seq_len);
std::vector<int64> dims(input_shape.dims(), 1);
dims[batch_dim_] = batch_size;
auto mask = xla::Lt(iota, xla::Reshape(seq_lens, dims), {seq_dim_});
diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
index 3b19f8d872..50a455b520 100644
--- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
@@ -127,7 +128,7 @@ xla::XlaOp RandomUniform(xla::XlaBuilder* builder, const xla::XlaOp& seed,
// Fill the generator inputs with unique counter values.
ThreeFry2x32State inputs;
- TF_CHECK_OK(XlaHelpers::Iota(builder, DT_INT32, half_size, &inputs[0]));
+ inputs[0] = xla::Iota(builder, xla::S32, half_size);
inputs[1] = xla::Add(inputs[0], xla::ConstantR0<int32>(builder, half_size));
ThreeFry2x32State outputs = ThreeFry2x32(builder, inputs, key);
diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
index beb7cf263d..8a1377fc38 100644
--- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
@@ -62,8 +63,7 @@ class TopKOp : public XlaOpKernel {
k = input_shape.dim_size(0);
}
const xla::XlaOp input_bf16 = context->Input(0);
- xla::XlaOp iota_s32;
- OP_REQUIRES_OK(context, XlaHelpers::Iota(b, DT_INT32, n, &iota_s32));
+ xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, n);
// TODO(b/73891930): add a key-value sort to HLO, rather than using
// bit-packing tricks here.
diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
index 3823f5c087..e996916461 100644
--- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
@@ -118,7 +118,7 @@ XLAJIT_MAKE_UNARY(Inv, xla::Div(XlaHelpers::One(b, input_type(0)), x));
XLAJIT_MAKE_UNARY(Reciprocal, xla::Div(XlaHelpers::One(b, input_type(0)), x));
XLAJIT_MAKE_UNARY(Log, xla::Log(x));
-XLAJIT_MAKE_UNARY(Log1p, b->Log1p(x));
+XLAJIT_MAKE_UNARY(Log1p, xla::Log1p(x));
XLAJIT_MAKE_UNARY(Invert, xla::Not(x));
XLAJIT_MAKE_UNARY(LogicalNot, xla::Not(x));
@@ -172,7 +172,7 @@ XLAJIT_MAKE_UNARY(Sinh,
// max(x, 0) + log1p(exp(-abs(x)))
XLAJIT_MAKE_UNARY(Softplus,
xla::Add(xla::Max(x, XlaHelpers::Zero(b, input_type(0))),
- b->Log1p(xla::Exp(xla::Neg(xla::Abs(x))))));
+ xla::Log1p(xla::Exp(xla::Neg(xla::Abs(x))))));
// softsign(x) = x / (abs(x) + 1)
XLAJIT_MAKE_UNARY(Softsign,
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
index dd29bafcd9..f9f3a8c8cf 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
@@ -26,92 +26,94 @@ limitations under the License.
namespace tensorflow {
-xla::StatusOr<xla::XlaOp> BatchDot(xla::XlaBuilder* builder, xla::XlaOp x,
- xla::XlaOp y, bool transpose_x,
- bool transpose_y, bool conjugate_x,
- bool conjugate_y) {
- TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x));
- TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y));
-
- // Check that both tensors have the same number of dimensions. There must be
- // at least two (the batch dimensions can be empty).
- if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) {
- return errors::InvalidArgument(
- "Arguments to BatchedDot have different ranks: ",
- xla::ShapeUtil::HumanString(x_shape), " vs. ",
- xla::ShapeUtil::HumanString(y_shape));
- }
- const int ndims = xla::ShapeUtil::Rank(x_shape);
- if (ndims < 2) {
- return errors::InvalidArgument(
- "Arguments to BatchedDot must have rank >= 2: ", ndims);
- }
-
- // The batch dimensions must be equal and the matrix dimensions must be
- // valid.
- std::vector<int64> batch_dimension_numbers;
- for (int i = 0; i < ndims - 2; ++i) {
- if (x_shape.dimensions(i) != y_shape.dimensions(i)) {
+xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x,
+ bool transpose_y, bool conjugate_x, bool conjugate_y) {
+ xla::XlaBuilder* builder = x.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x));
+ TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y));
+
+ // Check that both tensors have the same number of dimensions. There must be
+ // at least two (the batch dimensions can be empty).
+ if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) {
return errors::InvalidArgument(
- "Dimension ", i, " of inputs to BatchedDot must be equal: ",
- xla::ShapeUtil::HumanString(x_shape), " vs ",
+ "Arguments to BatchedDot have different ranks: ",
+ xla::ShapeUtil::HumanString(x_shape), " vs. ",
xla::ShapeUtil::HumanString(y_shape));
}
- batch_dimension_numbers.push_back(i);
- }
-
- int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1);
- int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2);
- if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) {
- return errors::InvalidArgument(
- "Dimensions ", x_inner_dim, " and ", y_inner_dim,
- " of arguments to BatchedDot must be equal: ",
- xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x,
- " vs. ", xla::ShapeUtil::HumanString(y_shape),
- " transpose: ", transpose_y);
- }
-
- // Check for zero lhs/rhs dim size.
- if (xla::ShapeUtil::IsZeroElementArray(x_shape) ||
- xla::ShapeUtil::IsZeroElementArray(y_shape)) {
- std::vector<int64> dimensions(batch_dimension_numbers.size());
- for (int i = 0; i < batch_dimension_numbers.size(); ++i) {
- dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]);
+ const int ndims = xla::ShapeUtil::Rank(x_shape);
+ if (ndims < 2) {
+ return errors::InvalidArgument(
+ "Arguments to BatchedDot must have rank >= 2: ", ndims);
+ }
+
+ // The batch dimensions must be equal and the matrix dimensions must be
+ // valid.
+ std::vector<int64> batch_dimension_numbers;
+ for (int i = 0; i < ndims - 2; ++i) {
+ if (x_shape.dimensions(i) != y_shape.dimensions(i)) {
+ return errors::InvalidArgument(
+ "Dimension ", i, " of inputs to BatchedDot must be equal: ",
+ xla::ShapeUtil::HumanString(x_shape), " vs ",
+ xla::ShapeUtil::HumanString(y_shape));
+ }
+ batch_dimension_numbers.push_back(i);
+ }
+
+ int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1);
+ int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2);
+ if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) {
+ return errors::InvalidArgument(
+ "Dimensions ", x_inner_dim, " and ", y_inner_dim,
+ " of arguments to BatchedDot must be equal: ",
+ xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x,
+ " vs. ", xla::ShapeUtil::HumanString(y_shape),
+ " transpose: ", transpose_y);
+ }
+
+ // Check for zero lhs/rhs dim size.
+ if (xla::ShapeUtil::IsZeroElementArray(x_shape) ||
+ xla::ShapeUtil::IsZeroElementArray(y_shape)) {
+ std::vector<int64> dimensions(batch_dimension_numbers.size());
+ for (int i = 0; i < batch_dimension_numbers.size(); ++i) {
+ dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]);
+ }
+ int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2);
+ int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1);
+ dimensions.push_back(x_shape.dimensions(x_outer_dim));
+ dimensions.push_back(y_shape.dimensions(y_outer_dim));
+ return xla::Broadcast(
+ xla::ConstantLiteral(builder,
+ xla::Literal::Zero(x_shape.element_type())),
+ dimensions);
+ }
+
+ if (x_shape.element_type() == xla::C64 && conjugate_x) {
+ x = xla::Conj(x);
+ }
+ if (y_shape.element_type() == xla::C64 && conjugate_y) {
+ y = xla::Conj(y);
+ }
+
+ // If there are no batch dimensions, use a regular Dot.
+ // TODO(b/69062148) Remove this code when Dot emitters can be passed
+ // dimensions to transpose directly (i.e. without requiring a Transpose
+ // HLO).
+ if (batch_dimension_numbers.empty()) {
+ auto lhs = transpose_x ? xla::Transpose(x, {1, 0}) : x;
+ auto rhs = transpose_y ? xla::Transpose(y, {1, 0}) : y;
+ return xla::Dot(lhs, rhs);
+ }
+
+ xla::DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(x_inner_dim);
+ dot_dnums.add_rhs_contracting_dimensions(y_inner_dim);
+ for (auto batch_dimension_number : batch_dimension_numbers) {
+ dot_dnums.add_lhs_batch_dimensions(batch_dimension_number);
+ dot_dnums.add_rhs_batch_dimensions(batch_dimension_number);
}
- int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2);
- int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1);
- dimensions.push_back(x_shape.dimensions(x_outer_dim));
- dimensions.push_back(y_shape.dimensions(y_outer_dim));
- return xla::Broadcast(
- xla::ConstantLiteral(builder,
- xla::Literal::Zero(x_shape.element_type())),
- dimensions);
- }
-
- if (x_shape.element_type() == xla::C64 && conjugate_x) {
- x = xla::Conj(x);
- }
- if (y_shape.element_type() == xla::C64 && conjugate_y) {
- y = xla::Conj(y);
- }
-
- // If there are no batch dimensions, use a regular Dot.
- // TODO(b/69062148) Remove this code when Dot emitters can be passed
- // dimensions to transpose directly (i.e. without requiring a Transpose HLO).
- if (batch_dimension_numbers.empty()) {
- auto lhs = transpose_x ? xla::Transpose(x, {1, 0}) : x;
- auto rhs = transpose_y ? xla::Transpose(y, {1, 0}) : y;
- return xla::Dot(lhs, rhs);
- }
-
- xla::DotDimensionNumbers dot_dnums;
- dot_dnums.add_lhs_contracting_dimensions(x_inner_dim);
- dot_dnums.add_rhs_contracting_dimensions(y_inner_dim);
- for (auto batch_dimension_number : batch_dimension_numbers) {
- dot_dnums.add_lhs_batch_dimensions(batch_dimension_number);
- dot_dnums.add_rhs_batch_dimensions(batch_dimension_number);
- }
- return xla::DotGeneral(x, y, dot_dnums);
+ return xla::DotGeneral(x, y, dot_dnums);
+ });
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h
index 1acc72033b..d07a9486f1 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.h
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h
@@ -43,10 +43,9 @@ namespace tensorflow {
// It is computed as:
//
// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
-xla::StatusOr<xla::XlaOp> BatchDot(xla::XlaBuilder* builder, xla::XlaOp x,
- xla::XlaOp y, bool transpose_x,
- bool transpose_y, bool conjugate_x = false,
- bool conjugate_y = false);
+xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x = false,
+ bool transpose_y = false, bool conjugate_x = false,
+ bool conjugate_y = false);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc
index 397f0e3a72..a90178c7d9 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.cc
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc
@@ -48,173 +48,163 @@ namespace {
// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) /
// l[..., j, j]
// return l
-xla::StatusOr<xla::XlaOp> CholeskyUnblocked(xla::XlaBuilder* builder,
- const xla::XlaOp& a) {
- TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
- const int n_dims = xla::ShapeUtil::Rank(a_shape);
- const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
- gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(a_shape.dimensions()),
- /*pos=*/0,
- /*len=*/n_dims - 2);
-
- xla::XlaOp l = Zeros(builder, a_shape);
-
- // Construct the for loop body to iterate over rows.
- auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
- xla::XlaBuilder* body_builder)
- -> xla::StatusOr<std::vector<xla::XlaOp>> {
- xla::Shape col_shape;
- xla::Shape row_shape;
- for (int64 d : major_dims) {
- row_shape.add_dimensions(d);
- col_shape.add_dimensions(d);
- }
- row_shape.add_dimensions(1);
- row_shape.add_dimensions(n);
- row_shape.set_element_type(a_shape.element_type());
- auto mask_zeros_row = Zeros(body_builder, row_shape);
-
- col_shape.add_dimensions(n);
- col_shape.add_dimensions(1);
- col_shape.set_element_type(a_shape.element_type());
- auto mask_zeros_col = Zeros(body_builder, col_shape);
-
- std::vector<int32> mask_vector(n);
- std::iota(mask_vector.begin(), mask_vector.end(), 0);
- auto mask_range = xla::ConstantR1<int32>(body_builder, mask_vector);
- auto mask_range_row =
- xla::Broadcast(xla::Reshape(mask_range, {0}, {1, n}), major_dims);
- auto mask_range_col =
- xla::Broadcast(xla::Reshape(mask_range, {0}, {n, 1}), major_dims);
- auto body_a = loop_vars[0];
- auto body_l = loop_vars[1];
-
- // row = l[..., i, :i]
- // select the whole i-th row, then mask out all columns past i-1
- auto zero = xla::ConstantR0<int32>(body_builder, 0);
- TF_ASSIGN_OR_RETURN(auto l_i, DynamicSliceInMinorDims(body_builder, body_l,
- {i, zero}, {1, n}));
- auto row = xla::Select(xla::Ge(mask_range_row, i), mask_zeros_row, l_i);
- // a[..., i, i]
- TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(body_builder, body_a,
- {i, i}, {1, 1}));
- // np.dot(row, np.swapaxes(row, -1, -2))
- xla::XlaOp diag_dot;
- TF_ASSIGN_OR_RETURN(diag_dot, BatchDot(body_builder, row, row,
- /*transpose_x=*/false,
- /*transpose_y=*/true));
- // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row,
- // np.swapaxes(row, -1, -2)))
- auto l_ii =
- xla::Pow(xla::Sub(a_ii, diag_dot),
- FloatLiteral(body_builder, a_shape.element_type(), 0.5));
-
- // a[..., i+1:, i]
- // select the whole i-th column, then mask out all rows above i+1
+xla::XlaOp CholeskyUnblocked(xla::XlaOp a) {
+ xla::XlaBuilder* builder = a.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ const int n_dims = xla::ShapeUtil::Rank(a_shape);
+ const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
+ gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(a_shape.dimensions()),
+ /*pos=*/0,
+ /*len=*/n_dims - 2);
+
+ xla::XlaOp l = Zeros(builder, a_shape);
+
+ // Construct the for loop body to iterate over rows.
+ auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
+ xla::XlaBuilder* body_builder)
+ -> xla::StatusOr<std::vector<xla::XlaOp>> {
+ xla::Shape col_shape;
+ xla::Shape row_shape;
+ for (int64 d : major_dims) {
+ row_shape.add_dimensions(d);
+ col_shape.add_dimensions(d);
+ }
+ row_shape.add_dimensions(1);
+ row_shape.add_dimensions(n);
+ row_shape.set_element_type(a_shape.element_type());
+ auto mask_zeros_row = Zeros(body_builder, row_shape);
+
+ col_shape.add_dimensions(n);
+ col_shape.add_dimensions(1);
+ col_shape.set_element_type(a_shape.element_type());
+ auto mask_zeros_col = Zeros(body_builder, col_shape);
+
+ std::vector<int32> mask_vector(n);
+ std::iota(mask_vector.begin(), mask_vector.end(), 0);
+ auto mask_range = xla::ConstantR1<int32>(body_builder, mask_vector);
+ auto mask_range_row =
+ xla::Broadcast(xla::Reshape(mask_range, {0}, {1, n}), major_dims);
+ auto mask_range_col =
+ xla::Broadcast(xla::Reshape(mask_range, {0}, {n, 1}), major_dims);
+ auto body_a = loop_vars[0];
+ auto body_l = loop_vars[1];
+
+ // row = l[..., i, :i]
+ // select the whole i-th row, then mask out all columns past i-1
+ auto zero = xla::ConstantR0<int32>(body_builder, 0);
+ auto l_i = DynamicSliceInMinorDims(body_l, {i, zero}, {1, n});
+ auto row = xla::Select(xla::Ge(mask_range_row, i), mask_zeros_row, l_i);
+ // a[..., i, i]
+ auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1});
+ // np.dot(row, np.swapaxes(row, -1, -2))
+ auto diag_dot = BatchDot(row, row,
+ /*transpose_x=*/false,
+ /*transpose_y=*/true);
+ // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row,
+ // np.swapaxes(row, -1, -2)))
+ auto l_ii =
+ xla::Pow(a_ii - diag_dot,
+ FloatLiteral(body_builder, a_shape.element_type(), 0.5));
+
+ // a[..., i+1:, i]
+ // select the whole i-th column, then mask out all rows above i+1
+ auto a_0i = DynamicSliceInMinorDims(body_a, {i}, {1});
+ auto a_ip1i =
+ xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, a_0i);
+
+ // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) /
+ // l[..., i, i]
+ // The columns in [i, n] are zeroed out in `row`, so we just have to
+ // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i],
+ // r.T)
+ auto dot = BatchDot(body_l, row,
+ /*transpose_x=*/false,
+ /*transpose_y=*/true);
+ // np.dot(l[..., i+1:, :i], r.T)
+ auto dot_ip1 =
+ xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot);
+
+ body_l =
+ DynamicUpdateSliceInMinorDims(body_l, (a_ip1i - dot_ip1) / l_ii, {i});
+ // Assign the diagonal after the rest of the column because otherwise the
+ // column assign will wrap around and overwrite the diagonal assign.
+ body_l = DynamicUpdateSliceInMinorDims(body_l, l_ii, {i, i});
+
+ return std::vector<xla::XlaOp>{body_a, body_l};
+ };
+
TF_ASSIGN_OR_RETURN(
- auto a_0i, DynamicSliceInMinorDims(body_builder, body_a, {i}, {1}));
- auto a_ip1i = xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, a_0i);
-
- // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) /
- // l[..., i, i]
- // The columns in [i, n] are zeroed out in `row`, so we just have to
- // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i],
- // r.T)
- TF_ASSIGN_OR_RETURN(auto dot, BatchDot(body_builder, body_l, row,
- /*transpose_x=*/false,
- /*transpose_y=*/true));
- // np.dot(l[..., i+1:, :i], r.T)
- auto dot_ip1 = xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot);
-
- auto col_update = xla::Div(xla::Sub(a_ip1i, dot_ip1), l_ii);
- TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims(
- body_builder, body_l, col_update, {i}));
- // Assign the diagonal after the rest of the column because otherwise the
- // column assign will wrap around and overwrite the diagonal assign.
- TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims(
- body_builder, body_l, l_ii, {i, i}));
-
- return std::vector<xla::XlaOp>{body_a, body_l};
- };
-
- TF_ASSIGN_OR_RETURN(
- auto cholesky_while,
- XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder));
-
- return cholesky_while[1];
+ auto cholesky_while,
+ XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder));
+
+ return cholesky_while[1];
+ });
}
} // namespace
-xla::StatusOr<xla::XlaOp> Cholesky(xla::XlaBuilder* builder, xla::XlaOp a,
- int64 block_size) {
- TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
- const int ndims = xla::ShapeUtil::Rank(a_shape);
- if (ndims < 2) {
- return errors::InvalidArgument(
- "Arguments to Cholesky must have rank >= 2: ", ndims);
- }
-
- const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
- if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) {
- return errors::InvalidArgument(
- "Arguments to Cholesky must be square matrices: ",
- xla::ShapeUtil::HumanString(a_shape));
- }
-
- if (block_size < 1) {
- return errors::InvalidArgument(
- "block_size argument to Cholesky must be >= 1; got ", block_size);
- }
-
- // Blocked left-looking Cholesky factorization.
- // Algorithm 1 from
- // Haidar, Azzam, et al. "High-performance Cholesky factorization for GPU-only
- // execution." Proceedings of General Purpose GPUs. ACM, 2017.
- xla::XlaOp l = Zeros(builder, a_shape);
- for (int64 i = 0; i < n; i += block_size) {
- int64 k = std::min(block_size, n - i);
- if (i > 0) {
- // TODO(phawkins): consider implementing SYRK for the diagonal part of
- // the panel.
- // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i]))
- TF_ASSIGN_OR_RETURN(auto lhs,
- SliceInMinorDims(builder, l, {i, 0}, {n, i}));
- TF_ASSIGN_OR_RETURN(auto rhs,
- SliceInMinorDims(builder, l, {i, 0}, {i + k, i}));
- TF_ASSIGN_OR_RETURN(auto delta,
- BatchDot(builder, lhs, rhs, /*transpose_x=*/false,
- /*transpose_y=*/true, /*conjugate_x=*/false,
- /*conjugate_y=*/false));
- TF_ASSIGN_OR_RETURN(auto before,
- SliceInMinorDims(builder, a, {i, i}, {n, i + k}));
- TF_ASSIGN_OR_RETURN(a, UpdateSliceInMinorDims(
- builder, a, xla::Sub(before, delta), {i, i}));
+xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size) {
+ xla::XlaBuilder* builder = a.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ const int ndims = xla::ShapeUtil::Rank(a_shape);
+ if (ndims < 2) {
+ return errors::InvalidArgument(
+ "Arguments to Cholesky must have rank >= 2: ", ndims);
+ }
+
+ const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
+ if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) {
+ return errors::InvalidArgument(
+ "Arguments to Cholesky must be square matrices: ",
+ xla::ShapeUtil::HumanString(a_shape));
+ }
+
+ if (block_size < 1) {
+ return errors::InvalidArgument(
+ "block_size argument to Cholesky must be >= 1; got ", block_size);
}
- // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k])
- TF_ASSIGN_OR_RETURN(auto x,
- SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
- TF_ASSIGN_OR_RETURN(auto factorized, CholeskyUnblocked(builder, x));
- TF_ASSIGN_OR_RETURN(l,
- UpdateSliceInMinorDims(builder, l, factorized, {i, i}));
-
- if (i + k < n) {
- // l[i+k:, i:i+k] = trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k])
- TF_ASSIGN_OR_RETURN(auto panel,
- SliceInMinorDims(builder, a, {i + k, i}, {n, i + k}));
- TF_ASSIGN_OR_RETURN(auto update,
- TriangularSolve(builder, factorized, panel,
- /*left_side=*/false,
- /*lower=*/true,
- /*transpose_a=*/true,
- /*conjugate_a=*/false,
- /*block_size=*/block_size));
- TF_ASSIGN_OR_RETURN(
- l, UpdateSliceInMinorDims(builder, l, update, {i + k, i}));
+ // Blocked left-looking Cholesky factorization.
+ // Algorithm 1 from
+ // Haidar, Azzam, et al. "High-performance Cholesky factorization for
+ // GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017.
+ xla::XlaOp l = Zeros(builder, a_shape);
+ for (int64 i = 0; i < n; i += block_size) {
+ int64 k = std::min(block_size, n - i);
+ if (i > 0) {
+ // TODO(phawkins): consider implementing SYRK for the diagonal part of
+ // the panel.
+ // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i]))
+ auto lhs = SliceInMinorDims(l, {i, 0}, {n, i});
+ auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i});
+ auto delta = BatchDot(lhs, rhs, /*transpose_x=*/false,
+ /*transpose_y=*/true);
+ auto before = SliceInMinorDims(a, {i, i}, {n, i + k});
+ a = UpdateSliceInMinorDims(a, before - delta, {i, i});
+ }
+
+ // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k])
+ auto x = SliceInMinorDims(a, {i, i}, {i + k, i + k});
+ auto factorized = CholeskyUnblocked(x);
+ l = UpdateSliceInMinorDims(l, factorized, {i, i});
+
+ if (i + k < n) {
+ // l[i+k:, i:i+k] =
+ // trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k])
+ auto panel = SliceInMinorDims(a, {i + k, i}, {n, i + k});
+ auto update = TriangularSolve(factorized, panel,
+ /*left_side=*/false,
+ /*lower=*/true,
+ /*transpose_a=*/true,
+ /*conjugate_a=*/false,
+ /*block_size=*/block_size);
+ l = UpdateSliceInMinorDims(l, update, {i + k, i});
+ }
}
- }
- return l;
+ return l;
+ });
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h
index 20fca7969e..0f6e0e9d15 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.h
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.h
@@ -30,8 +30,7 @@ namespace tensorflow {
// TODO(phawkins): check for negative values on the diagonal and return an
// error, instead of silently yielding NaNs.
// TODO(znado): handle the complex Hermitian case
-xla::StatusOr<xla::XlaOp> Cholesky(xla::XlaBuilder* builder, xla::XlaOp a,
- int64 block_size = 256);
+xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size = 256);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
index b9f695ac4b..0d3ce129c7 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
@@ -30,621 +30,564 @@ limitations under the License.
namespace tensorflow {
-xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder,
- const xla::XlaOp& a, xla::XlaOp b,
- bool left_side, bool lower,
- bool transpose_a, bool conjugate_a,
- int64 block_size) {
- TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
- TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b));
- if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) {
- return errors::InvalidArgument(
- "Arguments to TriangularSolve have different ranks: ",
- xla::ShapeUtil::HumanString(a_shape), " vs. ",
- xla::ShapeUtil::HumanString(b_shape));
- }
- const int ndims = xla::ShapeUtil::Rank(a_shape);
- if (ndims < 2) {
- return errors::InvalidArgument(
- "Arguments to TriangularSolve must have rank >= 2: ", ndims);
- }
- // The batch dimensions must be equal.
- std::vector<int64> batch_dimensions;
- for (int i = 0; i < ndims - 2; ++i) {
- int64 a_size = a_shape.dimensions(i);
- int64 b_size = b_shape.dimensions(i);
- if (a_size != b_size) {
+xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side,
+ bool lower, bool transpose_a, bool conjugate_a,
+ int64 block_size) {
+ xla::XlaBuilder* builder = a.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b));
+ if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) {
return errors::InvalidArgument(
- "Batch dimensions of arguments to TriangularSolve must be equal: ",
- xla::ShapeUtil::HumanString(a_shape), " vs ",
+ "Arguments to TriangularSolve have different ranks: ",
+ xla::ShapeUtil::HumanString(a_shape), " vs. ",
xla::ShapeUtil::HumanString(b_shape));
}
- batch_dimensions.push_back(a_size);
- }
-
- if (xla::ShapeUtil::GetDimension(a_shape, -1) !=
- xla::ShapeUtil::GetDimension(a_shape, -2)) {
- return errors::InvalidArgument(
- "The 'a' arguments to TriangularSolve must be square matrices: ",
- xla::ShapeUtil::HumanString(a_shape));
- }
- const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2);
- const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1);
- if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) {
- return errors::InvalidArgument(
- "Arguments to TriangularSolve have incompatible matrix shapes: ",
- xla::ShapeUtil::HumanString(a_shape), " vs ",
- xla::ShapeUtil::HumanString(b_shape));
- }
-
- if (block_size < 1) {
- return errors::InvalidArgument(
- "block_size argument to TriangularSolve must be >= 1; got ",
- block_size);
- }
-
- std::map<int, xla::XlaComputation> base_computations;
- auto get_base_triangular_solve =
- [&](int k) -> xla::StatusOr<xla::XlaComputation*> {
- xla::XlaComputation& computation = base_computations[k];
- if (computation.IsNull()) {
- std::unique_ptr<xla::XlaBuilder> sub = builder->CreateSubBuilder(
- tensorflow::strings::StrCat("trsm_base_", k));
-
- auto a_param = xla::Parameter(
- sub.get(), 0,
- xla::ShapeUtil::MakeShape(
- b_shape.element_type(),
- PrependMajorDims(sub.get(), batch_dimensions, {k, k})),
- "a");
-
- std::array<int64, 2> b_lastd;
- if (left_side) {
- b_lastd = {k, n};
- } else {
- b_lastd = {m, k};
- }
- auto b_param = xla::Parameter(
- sub.get(), 1,
- xla::ShapeUtil::MakeShape(
- b_shape.element_type(),
- PrependMajorDims(sub.get(), batch_dimensions, b_lastd)),
- "b");
-
- // We use a left-looking or right-looking subroutine on the block diagonal
- // in the lower=true cases, while falling back to a recursive call in
- // others. The left-looking and right-looking subroutines are written with
- // a While loop and so yields much faster compile times. Moreover, they
- // can give higher performance on smaller (sub)problems.
- if (left_side && lower) {
- TF_RETURN_IF_ERROR(TriangularSolveLeftLooking(sub.get(), a_param,
- b_param, transpose_a,
- conjugate_a)
- .status());
- } else if (!left_side && lower) {
- TF_RETURN_IF_ERROR(TriangularSolveRightLooking(sub.get(), a_param,
- b_param, transpose_a,
- conjugate_a)
- .status());
- } else {
- TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param,
- left_side, lower, transpose_a,
- conjugate_a,
- /*block_size=*/1)
- .status());
+ const int ndims = xla::ShapeUtil::Rank(a_shape);
+ if (ndims < 2) {
+ return errors::InvalidArgument(
+ "Arguments to TriangularSolve must have rank >= 2: ", ndims);
+ }
+ // The batch dimensions must be equal.
+ std::vector<int64> batch_dimensions;
+ for (int i = 0; i < ndims - 2; ++i) {
+ int64 a_size = a_shape.dimensions(i);
+ int64 b_size = b_shape.dimensions(i);
+ if (a_size != b_size) {
+ return errors::InvalidArgument(
+ "Batch dimensions of arguments to TriangularSolve must be equal: ",
+ xla::ShapeUtil::HumanString(a_shape), " vs ",
+ xla::ShapeUtil::HumanString(b_shape));
}
+ batch_dimensions.push_back(a_size);
+ }
- TF_ASSIGN_OR_RETURN(computation, sub->Build());
+ if (xla::ShapeUtil::GetDimension(a_shape, -1) !=
+ xla::ShapeUtil::GetDimension(a_shape, -2)) {
+ return errors::InvalidArgument(
+ "The 'a' arguments to TriangularSolve must be square matrices: ",
+ xla::ShapeUtil::HumanString(a_shape));
+ }
+ const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2);
+ const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1);
+ if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) {
+ return errors::InvalidArgument(
+ "Arguments to TriangularSolve have incompatible matrix shapes: ",
+ xla::ShapeUtil::HumanString(a_shape), " vs ",
+ xla::ShapeUtil::HumanString(b_shape));
}
- return &computation;
- };
-
- xla::XlaOp output = Zeros(builder, b_shape);
-
- // Right-looking blocked triangular solve.
- // For an explanation of the algorithm, see the TRSM discussion in:
- // Goto, Kazushige, and Robert Van De Geijn. "High-performance implementation
- // of the level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1
- // (2008): 4.
-
- // In the code comments below, T = lambda x: np.swapaxes(x, -1, -2) if
- // conjugate_a is False, or T = lambda x: np.conj(np.swapaxes(x, -1, -2)) if
- // conjugate_a is True.
-
- if (!left_side && lower == transpose_a) {
- // for i in range(0, a.shape[-1], block_size):
- for (int64 i = 0; i < n; i += block_size) {
- int64 k = std::min(block_size, n - i);
-
- // output[..., :, i:i+k] = triangular_solve(
- // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1)
- TF_ASSIGN_OR_RETURN(auto a_slice,
- SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
- TF_ASSIGN_OR_RETURN(auto b_slice,
- SliceInMinorDims(builder, b, {0, i}, {m, i + k}));
- xla::XlaOp update;
- if (k > 1) {
- TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
- get_base_triangular_solve(k));
- update = xla::Call(builder, *solve, {a_slice, b_slice});
- } else {
- TF_ASSIGN_OR_RETURN(auto a_slice_conj,
- MaybeConjugate(builder, a_slice, conjugate_a));
- update = xla::Div(b_slice, a_slice_conj);
- }
- TF_ASSIGN_OR_RETURN(
- output, UpdateSliceInMinorDims(builder, output, update, {0, i}));
-
- // if i + k < a.shape[-1]:
- // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:]
- // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
- // b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2)
- if (i + k < n) {
- xla::XlaOp a_slice_2;
- if (lower) {
- TF_ASSIGN_OR_RETURN(
- a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {n, i + k}));
- } else {
- TF_ASSIGN_OR_RETURN(
- a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, n}));
- }
- TF_ASSIGN_OR_RETURN(auto b_update,
- BatchDot(builder, update, a_slice_2,
- /*transpose_x=*/false,
- /*transpose_y=*/transpose_a,
- /*conjugate_x=*/false,
- /*conjugate_y=*/conjugate_a));
- TF_ASSIGN_OR_RETURN(auto b_slice_2,
- SliceInMinorDims(builder, b, {0, i + k}, {m, n}));
- b_update = xla::Sub(b_slice_2, b_update);
- TF_ASSIGN_OR_RETURN(
- b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k}));
- }
+ if (block_size < 1) {
+ return errors::InvalidArgument(
+ "block_size argument to TriangularSolve must be >= 1; got ",
+ block_size);
}
- } else if (left_side && lower != transpose_a) {
- // for i in range(0, a.shape[-1], block_size):
- for (int64 i = 0; i < m; i += block_size) {
- int64 k = std::min(block_size, m - i);
-
- // output[..., i:i+k, :] = triangular_solve(
- // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1)
- TF_ASSIGN_OR_RETURN(auto a_slice,
- SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
- TF_ASSIGN_OR_RETURN(auto b_slice,
- SliceInMinorDims(builder, b, {i, 0}, {i + k, n}));
- xla::XlaOp update;
- if (k > 1) {
- TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
- get_base_triangular_solve(k));
- update = xla::Call(builder, *solve, {a_slice, b_slice});
- } else {
- TF_ASSIGN_OR_RETURN(auto a_slice_conj,
- MaybeConjugate(builder, a_slice, conjugate_a));
- update = xla::Div(b_slice, a_slice_conj);
- }
- TF_ASSIGN_OR_RETURN(
- output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
-
- // if i + k < a.shape[-1]:
- // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:]
- // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
- // b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :])
- if (i + k < m) {
- xla::XlaOp a_slice_2;
- if (lower) {
- TF_ASSIGN_OR_RETURN(
- a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {m, i + k}));
+ std::map<int, xla::XlaComputation> base_computations;
+ auto get_base_triangular_solve =
+ [&](int k) -> xla::StatusOr<xla::XlaComputation*> {
+ xla::XlaComputation& computation = base_computations[k];
+ if (computation.IsNull()) {
+ std::unique_ptr<xla::XlaBuilder> sub = builder->CreateSubBuilder(
+ tensorflow::strings::StrCat("trsm_base_", k));
+
+ auto a_param = xla::Parameter(
+ sub.get(), 0,
+ xla::ShapeUtil::MakeShape(b_shape.element_type(),
+ ConcatVectors(batch_dimensions, {k, k})),
+ "a");
+
+ std::array<int64, 2> b_lastd;
+ if (left_side) {
+ b_lastd = {k, n};
+ } else {
+ b_lastd = {m, k};
+ }
+ auto b_param = xla::Parameter(
+ sub.get(), 1,
+ xla::ShapeUtil::MakeShape(b_shape.element_type(),
+ ConcatVectors(batch_dimensions, b_lastd)),
+ "b");
+
+ // We use a left-looking or right-looking subroutine on the block
+ // diagonal in the lower=true cases, while falling back to a recursive
+ // call in others. The left-looking and right-looking subroutines are
+ // written with a While loop and so yields much faster compile times.
+ // Moreover, they can give higher performance on smaller (sub)problems.
+ if (left_side && lower) {
+ TriangularSolveLeftLooking(a_param, b_param, transpose_a,
+ conjugate_a);
+ } else if (!left_side && lower) {
+ TriangularSolveRightLooking(a_param, b_param, transpose_a,
+ conjugate_a);
} else {
- TF_ASSIGN_OR_RETURN(
- a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, m}));
+ TriangularSolve(a_param, b_param, left_side, lower, transpose_a,
+ conjugate_a,
+ /*block_size=*/1);
}
- TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update,
- /*transpose_x=*/transpose_a,
- /*transpose_y=*/false,
- /*conjugate_x=*/conjugate_a,
- /*conjugate_y=*/false));
- TF_ASSIGN_OR_RETURN(auto b_slice_2,
- SliceInMinorDims(builder, b, {i + k, 0}, {m, n}));
- b_update = xla::Sub(b_slice_2, b_update);
- TF_ASSIGN_OR_RETURN(
- b, UpdateSliceInMinorDims(builder, b, b_update, {i + k, 0}));
- }
- }
- } else if (!left_side && lower != transpose_a) {
- // for i in reversed(range(0, a.shape[-1], block_size)):
- const int64 last_blk_ix = xla::RoundUpToNearest(n, block_size) - block_size;
- for (int64 i = last_blk_ix; i >= 0; i -= block_size) {
- int64 k = std::min(block_size, n - i);
-
- // output[..., :, i:i+k] triangular_solve(
- // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1)
- TF_ASSIGN_OR_RETURN(auto a_slice,
- SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
- TF_ASSIGN_OR_RETURN(auto b_slice,
- SliceInMinorDims(builder, b, {0, i}, {m, i + k}));
- xla::XlaOp update;
- if (k > 1) {
- TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
- get_base_triangular_solve(k));
- update = xla::Call(builder, *solve, {a_slice, b_slice});
- } else {
- TF_ASSIGN_OR_RETURN(auto a_slice_conj,
- MaybeConjugate(builder, a_slice, conjugate_a));
- update = xla::Div(b_slice, a_slice_conj);
+ TF_ASSIGN_OR_RETURN(computation, sub->Build());
}
- TF_ASSIGN_OR_RETURN(
- output, UpdateSliceInMinorDims(builder, output, update, {0, i}));
-
- // if i - k >= 0:
- // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k]
- // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
- // b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2)
- if (i - k >= 0) {
- xla::XlaOp a_slice_2;
- if (lower) {
- TF_ASSIGN_OR_RETURN(a_slice_2,
- SliceInMinorDims(builder, a, {i, 0}, {i + k, i}));
+ return &computation;
+ };
+
+ xla::XlaOp output = Zeros(builder, b_shape);
+
+ // Right-looking blocked triangular solve.
+ // For an explanation of the algorithm, see the TRSM discussion in:
+ // Goto, Kazushige, and Robert Van De Geijn. "High-performance
+ // implementation of the level-3 BLAS." ACM Transactions on Mathematical
+ // Software (TOMS) 35.1 (2008): 4.
+
+ // In the code comments below, T = lambda x: np.swapaxes(x, -1, -2) if
+ // conjugate_a is False, or T = lambda x: np.conj(np.swapaxes(x, -1, -2)) if
+ // conjugate_a is True.
+
+ if (!left_side && lower == transpose_a) {
+ // for i in range(0, a.shape[-1], block_size):
+ for (int64 i = 0; i < n; i += block_size) {
+ int64 k = std::min(block_size, n - i);
+
+ // output[..., :, i:i+k] = triangular_solve(
+ // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1)
+ auto a_slice = SliceInMinorDims(a, {i, i}, {i + k, i + k});
+ auto b_slice = SliceInMinorDims(b, {0, i}, {m, i + k});
+ xla::XlaOp update;
+ if (k > 1) {
+ TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
+ get_base_triangular_solve(k));
+ update = xla::Call(builder, *solve, {a_slice, b_slice});
} else {
- TF_ASSIGN_OR_RETURN(a_slice_2,
- SliceInMinorDims(builder, a, {0, i}, {i, i + k}));
+ auto a_slice_conj = MaybeConjugate(a_slice, conjugate_a);
+ update = b_slice / a_slice_conj;
}
+ output = UpdateSliceInMinorDims(output, update, {0, i});
+
+ // if i + k < a.shape[-1]:
+ // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:]
+ // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
+ // b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2)
+ if (i + k < n) {
+ xla::XlaOp a_slice_2;
+ if (lower) {
+ a_slice_2 = SliceInMinorDims(a, {i + k, i}, {n, i + k});
+ } else {
+ a_slice_2 = SliceInMinorDims(a, {i, i + k}, {i + k, n});
+ }
+
+ auto b_update = BatchDot(update, a_slice_2,
+ /*transpose_x=*/false,
+ /*transpose_y=*/transpose_a,
+ /*conjugate_x=*/false,
+ /*conjugate_y=*/conjugate_a);
+ auto b_slice_2 = SliceInMinorDims(b, {0, i + k}, {m, n});
+ b = UpdateSliceInMinorDims(b, b_slice_2 - b_update, {0, i + k});
+ }
+ }
- TF_ASSIGN_OR_RETURN(auto b_update,
- BatchDot(builder, update, a_slice_2,
- /*transpose_x=*/false,
- /*transpose_y=*/transpose_a,
- /*conjugate_x=*/false,
- /*conjugate_y=*/conjugate_a));
- TF_ASSIGN_OR_RETURN(auto b_slice_2,
- SliceInMinorDims(builder, b, {0, 0}, {m, i}));
- b_update = xla::Sub(b_slice_2, b_update);
- TF_ASSIGN_OR_RETURN(
- b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0}));
+ } else if (left_side && lower != transpose_a) {
+ // for i in range(0, a.shape[-1], block_size):
+ for (int64 i = 0; i < m; i += block_size) {
+ int64 k = std::min(block_size, m - i);
+
+ // output[..., i:i+k, :] = triangular_solve(
+ // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1)
+ auto a_slice = SliceInMinorDims(a, {i, i}, {i + k, i + k});
+ auto b_slice = SliceInMinorDims(b, {i, 0}, {i + k, n});
+ xla::XlaOp update;
+ if (k > 1) {
+ TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
+ get_base_triangular_solve(k));
+ update = xla::Call(builder, *solve, {a_slice, b_slice});
+ } else {
+ auto a_slice_conj = MaybeConjugate(a_slice, conjugate_a);
+ update = b_slice / a_slice_conj;
+ }
+ output = UpdateSliceInMinorDims(output, update, {i, 0});
+
+ // if i + k < a.shape[-1]:
+ // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:]
+ // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
+ // b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :])
+ if (i + k < m) {
+ xla::XlaOp a_slice_2;
+ if (lower) {
+ a_slice_2 = SliceInMinorDims(a, {i + k, i}, {m, i + k});
+ } else {
+ a_slice_2 = SliceInMinorDims(a, {i, i + k}, {i + k, m});
+ }
+
+ auto b_update = BatchDot(a_slice_2, update,
+ /*transpose_x=*/transpose_a,
+ /*transpose_y=*/false,
+ /*conjugate_x=*/conjugate_a,
+ /*conjugate_y=*/false);
+ auto b_slice_2 = SliceInMinorDims(b, {i + k, 0}, {m, n});
+ b = UpdateSliceInMinorDims(b, b_slice_2 - b_update, {i + k, 0});
+ }
}
- }
- } else { // left_side && lower == transpose_a
- // for i in reversed(range(0, a.shape[-1], block_size)):
- const int64 last_blk_ix = xla::RoundUpToNearest(m, block_size) - block_size;
- for (int64 i = last_blk_ix; i >= 0; i -= block_size) {
- int64 k = std::min(block_size, m - i);
-
- // output[..., i:i+k, :] triangular_solve(
- // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1)
- TF_ASSIGN_OR_RETURN(auto a_slice,
- SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
- TF_ASSIGN_OR_RETURN(auto b_slice,
- SliceInMinorDims(builder, b, {i, 0}, {i + k, n}));
- xla::XlaOp update;
- if (k > 1) {
- TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
- get_base_triangular_solve(k));
- update = xla::Call(builder, *solve, {a_slice, b_slice});
- } else {
- TF_ASSIGN_OR_RETURN(auto a_slice_conj,
- MaybeConjugate(builder, a_slice, conjugate_a));
- update = xla::Div(b_slice, a_slice_conj);
+ } else if (!left_side && lower != transpose_a) {
+ // for i in reversed(range(0, a.shape[-1], block_size)):
+ const int64 last_blk_ix =
+ xla::RoundUpToNearest(n, block_size) - block_size;
+ for (int64 i = last_blk_ix; i >= 0; i -= block_size) {
+ int64 k = std::min(block_size, n - i);
+
+ // output[..., :, i:i+k] triangular_solve(
+ // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1)
+ auto a_slice = SliceInMinorDims(a, {i, i}, {i + k, i + k});
+ auto b_slice = SliceInMinorDims(b, {0, i}, {m, i + k});
+ xla::XlaOp update;
+ if (k > 1) {
+ TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
+ get_base_triangular_solve(k));
+ update = xla::Call(builder, *solve, {a_slice, b_slice});
+ } else {
+ auto a_slice_conj = MaybeConjugate(a_slice, conjugate_a);
+ update = b_slice / a_slice_conj;
+ }
+ output = UpdateSliceInMinorDims(output, update, {0, i});
+
+ // if i - k >= 0:
+ // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k]
+ // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
+ // b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2)
+ if (i - k >= 0) {
+ xla::XlaOp a_slice_2;
+ if (lower) {
+ a_slice_2 = SliceInMinorDims(a, {i, 0}, {i + k, i});
+ } else {
+ a_slice_2 = SliceInMinorDims(a, {0, i}, {i, i + k});
+ }
+
+ auto b_update = BatchDot(update, a_slice_2,
+ /*transpose_x=*/false,
+ /*transpose_y=*/transpose_a,
+ /*conjugate_x=*/false,
+ /*conjugate_y=*/conjugate_a);
+ auto b_slice_2 = SliceInMinorDims(b, {0, 0}, {m, i});
+ b = UpdateSliceInMinorDims(b, b_slice_2 - b_update, {0, 0});
+ }
}
- TF_ASSIGN_OR_RETURN(
- output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
-
- // if i - k >= 0:
- // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k]
- // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
- // b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :])
- if (i - k >= 0) {
- xla::XlaOp a_slice_2;
- if (lower) {
- TF_ASSIGN_OR_RETURN(a_slice_2,
- SliceInMinorDims(builder, a, {i, 0}, {i + k, i}));
+ } else { // left_side && lower == transpose_a
+ // for i in reversed(range(0, a.shape[-1], block_size)):
+ const int64 last_blk_ix =
+ xla::RoundUpToNearest(m, block_size) - block_size;
+ for (int64 i = last_blk_ix; i >= 0; i -= block_size) {
+ int64 k = std::min(block_size, m - i);
+
+ // output[..., i:i+k, :] triangular_solve(
+ // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1)
+ auto a_slice = SliceInMinorDims(a, {i, i}, {i + k, i + k});
+ auto b_slice = SliceInMinorDims(b, {i, 0}, {i + k, n});
+ xla::XlaOp update;
+ if (k > 1) {
+ TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
+ get_base_triangular_solve(k));
+ update = xla::Call(builder, *solve, {a_slice, b_slice});
} else {
- TF_ASSIGN_OR_RETURN(a_slice_2,
- SliceInMinorDims(builder, a, {0, i}, {i, i + k}));
+ auto a_slice_conj = MaybeConjugate(a_slice, conjugate_a);
+ update = b_slice / a_slice_conj;
+ }
+ output = UpdateSliceInMinorDims(output, update, {i, 0});
+
+ // if i - k >= 0:
+ // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k]
+ // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
+ // b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :])
+ if (i - k >= 0) {
+ xla::XlaOp a_slice_2;
+ if (lower) {
+ a_slice_2 = SliceInMinorDims(a, {i, 0}, {i + k, i});
+ } else {
+ a_slice_2 = SliceInMinorDims(a, {0, i}, {i, i + k});
+ }
+
+ auto b_update = BatchDot(a_slice_2, update,
+ /*transpose_x=*/transpose_a,
+ /*transpose_y=*/false,
+ /*conjugate_x=*/conjugate_a,
+ /*conjugate_y=*/false);
+ auto b_slice_2 = SliceInMinorDims(b, {0, 0}, {i, n});
+ b = UpdateSliceInMinorDims(b, b_slice_2 - b_update, {0, 0});
}
-
- TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update,
- /*transpose_x=*/transpose_a,
- /*transpose_y=*/false,
- /*conjugate_x=*/conjugate_a,
- /*conjugate_y=*/false));
- TF_ASSIGN_OR_RETURN(auto b_slice_2,
- SliceInMinorDims(builder, b, {0, 0}, {i, n}));
- b_update = xla::Sub(b_slice_2, b_update);
- TF_ASSIGN_OR_RETURN(
- b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0}));
}
}
- }
- return output;
+ return output;
+ });
}
-xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder,
- const xla::XlaOp& a,
- const xla::XlaOp& b,
- bool transpose_a,
- bool conjugate_a) {
- TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
- TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b));
- const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2);
- const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1);
- const int64 ndims = xla::ShapeUtil::Rank(a_shape);
-
- std::vector<int64> batch_dimensions;
- for (int i = 0; i < ndims - 2; ++i) {
- int64 a_size = a_shape.dimensions(i);
- batch_dimensions.push_back(a_size);
- }
-
- // The main computation is performed in a While loop.
-
- // Allocate the output and set its first or last row,
- // output = np.zeros_like(b)
- // if transpose_a:
- // output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:]
- // else:
- // output[..., :1, :] = b[..., :1, :] / a[..., :1, :1]
- xla::XlaOp output = Zeros(builder, b_shape);
- {
- auto i = transpose_a ? m - 1 : 0;
- TF_ASSIGN_OR_RETURN(auto a_slice,
- SliceInMinorDims(builder, a, {i, i}, {i + 1, i + 1}));
- TF_ASSIGN_OR_RETURN(auto b_slice,
- SliceInMinorDims(builder, b, {i, 0}, {i + 1, n}));
- TF_ASSIGN_OR_RETURN(auto a_slice_conj,
- MaybeConjugate(builder, a_slice, conjugate_a));
- auto update = xla::Div(b_slice, a_slice_conj);
- TF_ASSIGN_OR_RETURN(
- output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
- }
-
- // Construct the initial loop carry tuple,
- // if transpose_a:
- // init = (m-2, output, a, b)
- // else:
- // init = (1, output, a, b)
- std::vector<xla::Shape> tuple_shapes = {
- // The loop iteration counter is a scalar, incremented each iteration.
- xla::ShapeUtil::MakeShape(xla::S32, {}),
- // The output has the shape of b, with one row updated each iteration.
- b_shape,
- // The coefficient matrix a is a loop invariant.
- a_shape,
- // The right-hand-side matrix b is a loop invariant.
- b_shape};
- xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes);
- auto init_i = xla::ConstantR0<int32>(builder, transpose_a ? m - 2 : 1);
- auto init = xla::Tuple(builder, {init_i, output, a, b});
-
- // Construct the loop condition function,
- // def cond_fun(loop_carry):
- // i, output, a, b = loop_carry
- // return i >= 0 if transpose_a else i < m
- std::unique_ptr<xla::XlaBuilder> condb =
- builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond");
- {
- auto i = xla::GetTupleElement(
- xla::Parameter(condb.get(), 0, tuple_shape,
- "TriangularSolveLeftLookingWhileTuple"),
- 0);
- if (transpose_a) {
- xla::Ge(i, xla::ConstantR0<int32>(condb.get(), 0));
- } else {
- xla::Lt(i, xla::ConstantR0<int32>(condb.get(), m));
+xla::XlaOp TriangularSolveLeftLooking(xla::XlaOp a, xla::XlaOp b,
+ bool transpose_a, bool conjugate_a) {
+ xla::XlaBuilder* builder = a.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b));
+ const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2);
+ const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1);
+ const int64 ndims = xla::ShapeUtil::Rank(a_shape);
+
+ std::vector<int64> batch_dimensions;
+ for (int i = 0; i < ndims - 2; ++i) {
+ int64 a_size = a_shape.dimensions(i);
+ batch_dimensions.push_back(a_size);
}
- }
- TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
-
- // Construct the loop body function,
- // def body_fun(loop_carry):
- // i, output, a, b = loop_carry
- // if transpose_a:
- // a_row = np.swapaxes(a[..., i+1:, i:i+1], -1 -2)
- // else:
- // a_row = a[..., i:i+1, :i]
- // result_row = b[..., i:i+1, :] - np.matmul(a_row, output[..., :, :])
- // output[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1]
- // if transpose_a:
- // return (i - 1, output, a, b)
- // else:
- // return (i + 1, output, a, b)
- // We have to do some extra FLOPs propagating zeros in the matrix multiply
- // because we can't have the size of its arguments depend on the loop counter.
- std::unique_ptr<xla::XlaBuilder> bodyb =
- builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody");
- {
- auto input_tuple = xla::Parameter(bodyb.get(), 0, tuple_shape,
- "TriangularSolveLeftLookingWhileTuple");
-
- // i, output, a, b = loop_carry
- auto i = xla::GetTupleElement(input_tuple, 0);
- auto body_out = xla::GetTupleElement(input_tuple, 1);
- auto body_a = xla::GetTupleElement(input_tuple, 2);
- auto body_b = xla::GetTupleElement(input_tuple, 3);
- auto zero = xla::ConstantR0<int32>(bodyb.get(), 0);
-
- // We'd like to implement this:
- // if transpose_a:
- // a_row = T(a[..., i+1:, i:i+1])
- // result_row = (b[..., i:i+1, :]
- // - np.matmul(a_row, body_out[..., i+1:, :]))
- // else:
- // result_row = (b[..., i:i+1, :]
- // - np.matmul(a[..., i:i+1, :i], body_out[..., :i, :]))
- // But since we can't have intermediate array sizes depend on the loop
- // counter, we instead exploit the fact that we initialized the output to
- // all zeros and use that as zero-padding (doing unnecessary FLOPs).
- xla::XlaOp a_row;
- if (transpose_a) {
- TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a,
- {zero, i}, {m, 1}));
- } else {
- TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a,
- {i, zero}, {1, m}));
+
+ // The main computation is performed in a While loop.
+
+ // Allocate the output and set its first or last row,
+ // output = np.zeros_like(b)
+ // if transpose_a:
+ // output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:]
+ // else:
+ // output[..., :1, :] = b[..., :1, :] / a[..., :1, :1]
+ xla::XlaOp output = Zeros(builder, b_shape);
+ {
+ auto i = transpose_a ? m - 1 : 0;
+ auto a_slice = SliceInMinorDims(a, {i, i}, {i + 1, i + 1});
+ auto b_slice = SliceInMinorDims(b, {i, 0}, {i + 1, n});
+ auto a_slice_conj = MaybeConjugate(a_slice, conjugate_a);
+ auto update = b_slice / a_slice_conj;
+ output = UpdateSliceInMinorDims(output, update, {i, 0});
}
- TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), a_row, body_out,
- /*transpose_x=*/transpose_a,
- /*transpose_y=*/false,
- /*conjugate_x=*/conjugate_a,
- /*conjugate_y=*/false));
- TF_ASSIGN_OR_RETURN(
- auto result_row_slice,
- DynamicSliceInMinorDims(bodyb.get(), body_b, {i, zero}, {1, n}));
- auto result_row = xla::Sub(result_row_slice, b_update);
-
- // body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1]
- TF_ASSIGN_OR_RETURN(auto a_elt, DynamicSliceInMinorDims(bodyb.get(), body_a,
- {i, i}, {1, 1}));
- TF_ASSIGN_OR_RETURN(auto a_elt_conj,
- MaybeConjugate(bodyb.get(), a_elt, conjugate_a));
- auto div_result = xla::Div(result_row, a_elt_conj);
- TF_ASSIGN_OR_RETURN(body_out,
- DynamicUpdateSliceInMinorDims(bodyb.get(), body_out,
- div_result, {i, zero}));
+ // Construct the initial loop carry tuple,
// if transpose_a:
- // return (i - 1, body_out, a, b)
+ // init = (m-2, output, a, b)
// else:
- // return (i + 1, body_out, a, b)
- auto next_i =
- xla::Add(i, xla::ConstantR0<int32>(bodyb.get(), transpose_a ? -1 : 1));
- xla::Tuple(bodyb.get(), {next_i, body_out, body_a, body_b});
- }
- TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
-
- // Construct the While loop and return the result,
- // return while_loop(cond_fun, body_fun, init)[1]
- auto triangular_solve_left_looking_while = xla::While(cond, body, init);
- return xla::GetTupleElement(triangular_solve_left_looking_while, 1);
+ // init = (1, output, a, b)
+ std::vector<xla::Shape> tuple_shapes = {
+ // The loop iteration counter is a scalar, incremented each iteration.
+ xla::ShapeUtil::MakeShape(xla::S32, {}),
+ // The output has the shape of b, with one row updated each iteration.
+ b_shape,
+ // The coefficient matrix a is a loop invariant.
+ a_shape,
+ // The right-hand-side matrix b is a loop invariant.
+ b_shape};
+ xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes);
+ auto init_i = xla::ConstantR0<int32>(builder, transpose_a ? m - 2 : 1);
+ auto init = xla::Tuple(builder, {init_i, output, a, b});
+
+ // Construct the loop condition function,
+ // def cond_fun(loop_carry):
+ // i, output, a, b = loop_carry
+ // return i >= 0 if transpose_a else i < m
+ std::unique_ptr<xla::XlaBuilder> condb =
+ builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond");
+ {
+ auto i = xla::GetTupleElement(
+ xla::Parameter(condb.get(), 0, tuple_shape,
+ "TriangularSolveLeftLookingWhileTuple"),
+ 0);
+ if (transpose_a) {
+ xla::Ge(i, xla::ConstantR0<int32>(condb.get(), 0));
+ } else {
+ xla::Lt(i, xla::ConstantR0<int32>(condb.get(), m));
+ }
+ }
+ TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
+
+ // Construct the loop body function,
+ // def body_fun(loop_carry):
+ // i, output, a, b = loop_carry
+ // if transpose_a:
+ // a_row = np.swapaxes(a[..., i+1:, i:i+1], -1 -2)
+ // else:
+ // a_row = a[..., i:i+1, :i]
+ // result_row = b[..., i:i+1, :] - np.matmul(a_row, output[..., :, :])
+ // output[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1]
+ // if transpose_a:
+ // return (i - 1, output, a, b)
+ // else:
+ // return (i + 1, output, a, b)
+ // We have to do some extra FLOPs propagating zeros in the matrix multiply
+ // because we can't have the size of its arguments depend on the loop
+ // counter.
+ std::unique_ptr<xla::XlaBuilder> bodyb =
+ builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody");
+ {
+ auto input_tuple = xla::Parameter(bodyb.get(), 0, tuple_shape,
+ "TriangularSolveLeftLookingWhileTuple");
+
+ // i, output, a, b = loop_carry
+ auto i = xla::GetTupleElement(input_tuple, 0);
+ auto body_out = xla::GetTupleElement(input_tuple, 1);
+ auto body_a = xla::GetTupleElement(input_tuple, 2);
+ auto body_b = xla::GetTupleElement(input_tuple, 3);
+ auto zero = xla::ConstantR0<int32>(bodyb.get(), 0);
+
+ // We'd like to implement this:
+ // if transpose_a:
+ // a_row = T(a[..., i+1:, i:i+1])
+ // result_row = (b[..., i:i+1, :]
+ // - np.matmul(a_row, body_out[..., i+1:, :]))
+ // else:
+ // result_row = (b[..., i:i+1, :]
+ // - np.matmul(a[..., i:i+1, :i], body_out[..., :i, :]))
+ // But since we can't have intermediate array sizes depend on the loop
+ // counter, we instead exploit the fact that we initialized the output to
+ // all zeros and use that as zero-padding (doing unnecessary FLOPs).
+ xla::XlaOp a_row;
+ if (transpose_a) {
+ a_row = DynamicSliceInMinorDims(body_a, {zero, i}, {m, 1});
+ } else {
+ a_row = DynamicSliceInMinorDims(body_a, {i, zero}, {1, m});
+ }
+ auto b_update = BatchDot(a_row, body_out,
+ /*transpose_x=*/transpose_a,
+ /*transpose_y=*/false,
+ /*conjugate_x=*/conjugate_a,
+ /*conjugate_y=*/false);
+ auto result_row_slice =
+ DynamicSliceInMinorDims(body_b, {i, zero}, {1, n});
+ auto result_row = result_row_slice - b_update;
+
+ // body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1]
+ auto a_elt = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1});
+ auto a_elt_conj = MaybeConjugate(a_elt, conjugate_a);
+ auto div_result = xla::Div(result_row, a_elt_conj);
+ body_out = DynamicUpdateSliceInMinorDims(body_out, div_result, {i, zero});
+
+ // if transpose_a:
+ // return (i - 1, body_out, a, b)
+ // else:
+ // return (i + 1, body_out, a, b)
+ auto next_i = xla::Add(
+ i, xla::ConstantR0<int32>(bodyb.get(), transpose_a ? -1 : 1));
+ xla::Tuple(bodyb.get(), {next_i, body_out, body_a, body_b});
+ }
+ TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
+
+ // Construct the While loop and return the result,
+ // return while_loop(cond_fun, body_fun, init)[1]
+ auto triangular_solve_left_looking_while = xla::While(cond, body, init);
+ return xla::GetTupleElement(triangular_solve_left_looking_while, 1);
+ });
}
-xla::StatusOr<xla::XlaOp> TriangularSolveRightLooking(xla::XlaBuilder* builder,
- const xla::XlaOp& a,
- const xla::XlaOp& b,
- bool transpose_a,
- bool conjugate_a) {
- TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
- TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b));
- const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2);
- const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1);
- const int64 ndims = xla::ShapeUtil::Rank(a_shape);
-
- std::vector<int64> batch_dimensions;
- for (int i = 0; i < ndims - 2; ++i) {
- int64 a_size = a_shape.dimensions(i);
- batch_dimensions.push_back(a_size);
- }
-
- // The main computation is performed in a While loop.
- xla::XlaOp output = Zeros(builder, b_shape);
-
- // Construct the initial loop carry tuple,
- // if transpose_a:
- // init = (0, output, a, b)
- // else:
- // init = (n-1, output, a, b)
- std::vector<xla::Shape> tuple_shapes = {
- // The loop iteration counter is a scalar, incremented each iteration.
- xla::ShapeUtil::MakeShape(xla::S32, {}),
- // The output has the shape of b, with one row updated each iteration.
- b_shape,
- // The coefficient matrix a is a loop invariant.
- a_shape,
- // The right-hand-side matrix b is a loop invariant.
- b_shape};
- xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes);
- auto init_i = xla::ConstantR0<int32>(builder, transpose_a ? 0 : n - 1);
- auto init = xla::Tuple(builder, {init_i, output, a, b});
-
- // Construct the loop condition function,
- // def cond_fun(loop_carry):
- // i, output, a, b = loop_carry
- // return i < n if transpose_a else i >= 0
- std::unique_ptr<xla::XlaBuilder> condb =
- builder->CreateSubBuilder("TriangularSolveRightLookingWhileCond");
- {
- auto i = xla::GetTupleElement(
- xla::Parameter(condb.get(), 0, tuple_shape,
- "TriangularSolveRightLookingWhileTuple"),
- 0);
- if (transpose_a) {
- xla::Lt(i, xla::ConstantR0<int32>(condb.get(), n));
- } else {
- xla::Ge(i, xla::ConstantR0<int32>(condb.get(), 0));
+xla::XlaOp TriangularSolveRightLooking(xla::XlaOp a, xla::XlaOp b,
+ bool transpose_a, bool conjugate_a) {
+ xla::XlaBuilder* builder = a.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b));
+ const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2);
+ const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1);
+ const int64 ndims = xla::ShapeUtil::Rank(a_shape);
+
+ std::vector<int64> batch_dimensions;
+ for (int i = 0; i < ndims - 2; ++i) {
+ int64 a_size = a_shape.dimensions(i);
+ batch_dimensions.push_back(a_size);
}
- }
- TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
-
- // Construct the loop body function,
- // def body_fun(loop_carry):
- // i, output, a, b = loop_carry
- // if transpose_a:
- // a_row = np.swapaxes(a[..., :, i:i+1], -1 -2)
- // else:
- // a_row = a[..., :, i:i+1]
- // result_row = b[..., :, i:i+1] - np.matmul(output, a_row)
- // output[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1]
- // if transpose_a:
- // return (i - 1, output, a, b)
- // else:
- // return (i + 1, output, a, b)
- // We have to do some extra FLOPs propagating zeros in the matrix multiply
- // because we can't have the size of its arguments depend on the loop counter.
- std::unique_ptr<xla::XlaBuilder> bodyb =
- builder->CreateSubBuilder("TriangularSolveRightLookingWhileBody");
- {
- auto input_tuple = xla::Parameter(bodyb.get(), 0, tuple_shape,
- "TriangularSolveRightLookingWhileTuple");
-
- // i, output, a, b = loop_carry
- auto i = xla::GetTupleElement(input_tuple, 0);
- auto body_out = xla::GetTupleElement(input_tuple, 1);
- auto body_a = xla::GetTupleElement(input_tuple, 2);
- auto body_b = xla::GetTupleElement(input_tuple, 3);
- auto zero = xla::ConstantR0<int32>(bodyb.get(), 0);
-
- // We'd like to implement b[..., :, i:i+1] - np.matmul(output, a[..., :,
- // i:i+1]) But since we can't have intermediate array sizes depend on the
- // loop counter, we instead exploit the fact that we initialized the output
- // to all zeros and use that as zero-padding (doing unnecessary FLOPs).
- TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), body_out, body_a,
- /*transpose_x=*/false,
- /*transpose_y=*/transpose_a,
- /*conjugate_x=*/false,
- /*conjugate_y=*/conjugate_a));
- // result = b - np.matmul(output, a)
- auto result = xla::Sub(body_b, b_update);
- // result_row = result[..., :, i:i+1]
- TF_ASSIGN_OR_RETURN(
- auto result_row,
- DynamicSliceInMinorDims(bodyb.get(), result, {zero, i}, {m, 1}));
-
- // body_out[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1]
- TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(bodyb.get(), body_a,
- {i, i}, {1, 1}));
- TF_ASSIGN_OR_RETURN(auto a_ii_conj,
- MaybeConjugate(bodyb.get(), a_ii, conjugate_a));
- auto div_result = xla::Div(result_row, a_ii_conj);
- TF_ASSIGN_OR_RETURN(body_out,
- DynamicUpdateSliceInMinorDims(bodyb.get(), body_out,
- div_result, {zero, i}));
+ // The main computation is performed in a While loop.
+ xla::XlaOp output = Zeros(builder, b_shape);
+
+ // Construct the initial loop carry tuple,
// if transpose_a:
- // return (i + 1, body_out, a, b)
+ // init = (0, output, a, b)
// else:
- // return (i - 1, body_out, a, b)
- auto next_i =
- xla::Add(i, xla::ConstantR0<int32>(bodyb.get(), transpose_a ? 1 : -1));
- xla::Tuple(bodyb.get(), {next_i, body_out, body_a, body_b});
- }
- TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
-
- // Construct the While loop and return the result,
- // return while_loop(cond_fun, body_fun, init)[1]
- auto triangular_solve_left_looking_while = xla::While(cond, body, init);
- return xla::GetTupleElement(triangular_solve_left_looking_while, 1);
+ // init = (n-1, output, a, b)
+ std::vector<xla::Shape> tuple_shapes = {
+ // The loop iteration counter is a scalar, incremented each iteration.
+ xla::ShapeUtil::MakeShape(xla::S32, {}),
+ // The output has the shape of b, with one row updated each iteration.
+ b_shape,
+ // The coefficient matrix a is a loop invariant.
+ a_shape,
+ // The right-hand-side matrix b is a loop invariant.
+ b_shape};
+ xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes);
+ auto init_i = xla::ConstantR0<int32>(builder, transpose_a ? 0 : n - 1);
+ auto init = xla::Tuple(builder, {init_i, output, a, b});
+
+ // Construct the loop condition function,
+ // def cond_fun(loop_carry):
+ // i, output, a, b = loop_carry
+ // return i < n if transpose_a else i >= 0
+ std::unique_ptr<xla::XlaBuilder> condb =
+ builder->CreateSubBuilder("TriangularSolveRightLookingWhileCond");
+ {
+ auto i = xla::GetTupleElement(
+ xla::Parameter(condb.get(), 0, tuple_shape,
+ "TriangularSolveRightLookingWhileTuple"),
+ 0);
+ if (transpose_a) {
+ xla::Lt(i, xla::ConstantR0<int32>(condb.get(), n));
+ } else {
+ xla::Ge(i, xla::ConstantR0<int32>(condb.get(), 0));
+ }
+ }
+ TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
+
+ // Construct the loop body function,
+ // def body_fun(loop_carry):
+ // i, output, a, b = loop_carry
+ // if transpose_a:
+ // a_row = np.swapaxes(a[..., :, i:i+1], -1 -2)
+ // else:
+ // a_row = a[..., :, i:i+1]
+ // result_row = b[..., :, i:i+1] - np.matmul(output, a_row)
+ // output[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1]
+ // if transpose_a:
+ // return (i - 1, output, a, b)
+ // else:
+ // return (i + 1, output, a, b)
+ // We have to do some extra FLOPs propagating zeros in the matrix multiply
+ // because we can't have the size of its arguments depend on the loop
+ // counter.
+ std::unique_ptr<xla::XlaBuilder> bodyb =
+ builder->CreateSubBuilder("TriangularSolveRightLookingWhileBody");
+ {
+ auto input_tuple = xla::Parameter(
+ bodyb.get(), 0, tuple_shape, "TriangularSolveRightLookingWhileTuple");
+
+ // i, output, a, b = loop_carry
+ auto i = xla::GetTupleElement(input_tuple, 0);
+ auto body_out = xla::GetTupleElement(input_tuple, 1);
+ auto body_a = xla::GetTupleElement(input_tuple, 2);
+ auto body_b = xla::GetTupleElement(input_tuple, 3);
+ auto zero = xla::ConstantR0<int32>(bodyb.get(), 0);
+
+ // We'd like to implement b[..., :, i:i+1] - np.matmul(output, a[..., :,
+ // i:i+1]) But since we can't have intermediate array sizes depend on the
+ // loop counter, we instead exploit the fact that we initialized the
+ // output to all zeros and use that as zero-padding (doing unnecessary
+ // FLOPs).
+ auto b_update = BatchDot(body_out, body_a,
+ /*transpose_x=*/false,
+ /*transpose_y=*/transpose_a,
+ /*conjugate_x=*/false,
+ /*conjugate_y=*/conjugate_a);
+ // result = b - np.matmul(output, a)
+ auto result = body_b - b_update;
+ // result_row = result[..., :, i:i+1]
+ auto result_row = DynamicSliceInMinorDims(result, {zero, i}, {m, 1});
+
+ // body_out[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1]
+ auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1});
+ auto a_ii_conj = MaybeConjugate(a_ii, conjugate_a);
+ auto div_result = xla::Div(result_row, a_ii_conj);
+ body_out = DynamicUpdateSliceInMinorDims(body_out, div_result, {zero, i});
+
+ // if transpose_a:
+ // return (i + 1, body_out, a, b)
+ // else:
+ // return (i - 1, body_out, a, b)
+ auto next_i = xla::Add(
+ i, xla::ConstantR0<int32>(bodyb.get(), transpose_a ? 1 : -1));
+ xla::Tuple(bodyb.get(), {next_i, body_out, body_a, body_b});
+ }
+ TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
+
+ // Construct the While loop and return the result,
+ // return while_loop(cond_fun, body_fun, init)[1]
+ auto triangular_solve_left_looking_while = xla::While(cond, body, init);
+ return xla::GetTupleElement(triangular_solve_left_looking_while, 1);
+ });
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h
index 540c26b247..80c2bc4c9c 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h
@@ -57,23 +57,15 @@ namespace tensorflow {
//
// Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no
// blocking is used.
-xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder,
- const xla::XlaOp& a, xla::XlaOp b,
- bool left_side, bool lower,
- bool transpose_a, bool conjugate_a,
- int64 block_size = 256);
+xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side,
+ bool lower, bool transpose_a, bool conjugate_a,
+ int64 block_size = 256);
-xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder,
- const xla::XlaOp& a,
- const xla::XlaOp& b,
- bool transpose_a,
- bool conjugate_a);
+xla::XlaOp TriangularSolveLeftLooking(xla::XlaOp a, xla::XlaOp b,
+ bool transpose_a, bool conjugate_a);
-xla::StatusOr<xla::XlaOp> TriangularSolveRightLooking(xla::XlaBuilder* builder,
- const xla::XlaOp& a,
- const xla::XlaOp& b,
- bool transpose_a,
- bool conjugate_a);
+xla::XlaOp TriangularSolveRightLooking(xla::XlaOp a, xla::XlaOp b,
+ bool transpose_a, bool conjugate_a);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
index 87ea4763f7..d5ffc1498e 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
@@ -85,11 +85,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/false, /*lower=*/true,
- /*transpose_a=*/true, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/false, /*lower=*/true,
+ /*transpose_a=*/true, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<float> expected({
{0.5, 0.08333334, 0.04629629, 0.03367003},
@@ -107,11 +106,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/false, /*lower=*/true,
- /*transpose_a=*/false, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/false, /*lower=*/true,
+ /*transpose_a=*/false, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<float> expected({
{-0.16414141, -0.06902357, -0.07070707, 0.36363636},
@@ -129,11 +127,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/false, /*lower=*/false,
- /*transpose_a=*/true, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/false, /*lower=*/false,
+ /*transpose_a=*/true, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<float> expected({
{-0.16414141, -0.06902357, -0.07070707, 0.36363636},
@@ -151,11 +148,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/false, /*lower=*/false,
- /*transpose_a=*/false, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/false, /*lower=*/false,
+ /*transpose_a=*/false, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<float> expected({
{0.5, 0.08333334, 0.04629629, 0.03367003},
@@ -173,11 +169,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/true, /*lower=*/true,
- /*transpose_a=*/true, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/true, /*lower=*/true,
+ /*transpose_a=*/true, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<float> expected({
{-0.89646465, -0.69444444, -0.49242424},
@@ -196,11 +191,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/true, /*lower=*/true,
- /*transpose_a=*/false, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/true, /*lower=*/true,
+ /*transpose_a=*/false, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<float> expected({
{0.5, 1.0, 1.5},
@@ -219,11 +213,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/true, /*lower=*/false,
- /*transpose_a=*/true, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/true, /*lower=*/false,
+ /*transpose_a=*/true, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<float> expected({
{0.5, 1.0, 1.5},
@@ -242,11 +235,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/true, /*lower=*/false,
- /*transpose_a=*/false, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/true, /*lower=*/false,
+ /*transpose_a=*/false, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<float> expected({
{-0.89646465, -0.69444444, -0.49242424},
@@ -267,11 +259,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) {
CreateR2Parameter<complex64>(AValsLowerComplex(), 0, "a", &builder, &a);
auto b_data =
CreateR2Parameter<complex64>(BValsRightComplex(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/false, /*lower=*/true,
- /*transpose_a=*/true, /*conjugate_a=*/true,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/false, /*lower=*/true,
+ /*transpose_a=*/true, /*conjugate_a=*/true,
+ /*block_size=*/2);
xla::Array2D<complex64> expected({
{0.5, complex64(0.08333333, 0.08333333),
@@ -295,11 +286,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) {
CreateR2Parameter<complex64>(AValsUpperComplex(), 0, "a", &builder, &a);
auto b_data =
CreateR2Parameter<complex64>(BValsLeftComplex(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/true, /*lower=*/false,
- /*transpose_a=*/true, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/true, /*lower=*/false,
+ /*transpose_a=*/true, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<complex64> expected({
{0.5, 1., 1.5},
@@ -323,10 +313,9 @@ XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
- auto result = TriangularSolveLeftLooking(&builder, a, b,
- /*transpose_a=*/false,
- /*conjugate_a=*/false);
- TF_ASSERT_OK(result.status());
+ TriangularSolveLeftLooking(a, b,
+ /*transpose_a=*/false,
+ /*conjugate_a=*/false);
xla::Array2D<float> expected({
{0.5, 1.0, 1.5},
@@ -345,10 +334,9 @@ XLA_TEST_F(TriangularSolveLeftLookingTest, NonzeroUpperTriangle) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsFull(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
- auto result = TriangularSolveLeftLooking(&builder, a, b,
- /*transpose_a=*/false,
- /*conjugate_a=*/false);
- TF_ASSERT_OK(result.status());
+ TriangularSolveLeftLooking(a, b,
+ /*transpose_a=*/false,
+ /*conjugate_a=*/false);
xla::Array2D<float> expected({
{0.5, 1.0, 1.5},
diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc
index 11774dde08..6694729495 100644
--- a/tensorflow/compiler/tf2xla/lib/util.cc
+++ b/tensorflow/compiler/tf2xla/lib/util.cc
@@ -111,130 +111,137 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
return xla::ConstantLiteral(builder, literal);
}
-xla::StatusOr<xla::XlaOp> SliceInMinorDims(xla::XlaBuilder* builder,
- const xla::XlaOp& x,
- gtl::ArraySlice<int64> start,
- gtl::ArraySlice<int64> end) {
- TF_RET_CHECK(start.size() == end.size());
- int64 n_minor_dims = start.size();
-
- TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
-
- const int64 n_dims = xla::ShapeUtil::Rank(shape);
- TF_RET_CHECK(n_minor_dims <= n_dims);
- gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape.dimensions()),
- /*pos=*/0,
- /*len=*/n_dims - n_minor_dims);
-
- // Prepends 0s in the major dim
- std::vector<int64> padded_start(n_dims, 0);
- std::copy(start.begin(), start.end(),
- padded_start.begin() + major_dims.size());
-
- // Prepends the shape of the major dims.
- std::vector<int64> padded_end(n_dims);
- std::copy(major_dims.begin(), major_dims.end(), padded_end.begin());
- std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size());
-
- std::vector<int64> strides(n_dims, 1);
- return xla::Slice(x, padded_start, padded_end, strides);
+xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice<int64> start,
+ gtl::ArraySlice<int64> end) {
+ xla::XlaBuilder* builder = x.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_RET_CHECK(start.size() == end.size());
+ int64 n_minor_dims = start.size();
+
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
+ TF_RET_CHECK(n_minor_dims <= n_dims);
+ gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape.dimensions()),
+ /*pos=*/0,
+ /*len=*/n_dims - n_minor_dims);
+
+ // Prepends 0s in the major dim
+ std::vector<int64> padded_start(n_dims, 0);
+ std::copy(start.begin(), start.end(),
+ padded_start.begin() + major_dims.size());
+
+ // Prepends the shape of the major dims.
+ std::vector<int64> padded_end(n_dims);
+ std::copy(major_dims.begin(), major_dims.end(), padded_end.begin());
+ std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size());
+
+ std::vector<int64> strides(n_dims, 1);
+ return xla::Slice(x, padded_start, padded_end, strides);
+ });
}
-std::vector<int64> PrependMajorDims(xla::XlaBuilder* builder,
- const gtl::ArraySlice<int64>& major_dims,
- const gtl::ArraySlice<int64>& indices) {
- std::vector<int64> output(indices.size() + major_dims.size());
- std::copy(major_dims.begin(), major_dims.end(), output.begin());
- std::copy(indices.begin(), indices.end(), output.begin() + major_dims.size());
+std::vector<int64> ConcatVectors(gtl::ArraySlice<int64> xs,
+ gtl::ArraySlice<int64> ys) {
+ std::vector<int64> output(xs.size() + ys.size());
+ std::copy(xs.begin(), xs.end(), output.begin());
+ std::copy(ys.begin(), ys.end(), output.begin() + xs.size());
return output;
}
-xla::StatusOr<xla::XlaOp> DynamicSliceInMinorDims(
- xla::XlaBuilder* builder, const xla::XlaOp& x,
- const std::vector<xla::XlaOp>& starts,
- const gtl::ArraySlice<int64>& sizes) {
- TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(shape);
- int64 n_minor_dims = starts.size();
- TF_RET_CHECK(n_minor_dims == sizes.size());
- TF_RET_CHECK(n_minor_dims <= n_dims);
- gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape.dimensions()),
- /*pos=*/0,
- /*len=*/n_dims - sizes.size());
- TF_ASSIGN_OR_RETURN(auto padded_starts,
- PrependZerosInMajorDims(builder, x, starts));
- auto padded_sizes = PrependMajorDims(builder, major_dims, sizes);
- return xla::DynamicSlice(x, padded_starts, padded_sizes);
+xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x,
+ gtl::ArraySlice<xla::XlaOp> starts,
+ gtl::ArraySlice<int64> sizes) {
+ xla::XlaBuilder* builder = x.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
+ int64 n_minor_dims = starts.size();
+ TF_RET_CHECK(n_minor_dims == sizes.size());
+ TF_RET_CHECK(n_minor_dims <= n_dims);
+ gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape.dimensions()),
+ /*pos=*/0,
+ /*len=*/n_dims - sizes.size());
+ auto padded_starts = PrependZerosInMajorDims(x, starts);
+ auto padded_sizes = ConcatVectors(major_dims, sizes);
+ return xla::DynamicSlice(x, padded_starts, padded_sizes);
+ });
}
-xla::StatusOr<xla::XlaOp> UpdateSlice(xla::XlaBuilder* builder,
- const xla::XlaOp& x,
- const xla::XlaOp& update,
- gtl::ArraySlice<int64> start) {
- // TODO(phawkins): make int64 work on all backends, remove the int32 cast.
- std::vector<int32> start_as_int32(start.begin(), start.end());
- auto start_constant = xla::ConstantR1<int32>(builder, start_as_int32);
- TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(shape);
- TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape,
- builder->GetShape(start_constant));
- const int64 start_length =
- xla::ShapeUtil::GetDimension(start_constant_shape, -1);
- TF_RET_CHECK(start_length == n_dims);
- return xla::DynamicUpdateSlice(x, update, start_constant);
+xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update,
+ gtl::ArraySlice<int64> start) {
+ xla::XlaBuilder* builder = x.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ // TODO(phawkins): make int64 work on all backends, remove the int32 cast.
+ std::vector<int32> start_as_int32(start.begin(), start.end());
+ auto start_constant = xla::ConstantR1<int32>(builder, start_as_int32);
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
+ TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape,
+ builder->GetShape(start_constant));
+ const int64 start_length =
+ xla::ShapeUtil::GetDimension(start_constant_shape, -1);
+ TF_RET_CHECK(start_length == n_dims);
+ return xla::DynamicUpdateSlice(x, update, start_constant);
+ });
}
-xla::StatusOr<xla::XlaOp> UpdateSliceInMinorDims(xla::XlaBuilder* builder,
- const xla::XlaOp& x,
- const xla::XlaOp& update,
- gtl::ArraySlice<int64> start) {
- TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(shape);
- const int64 n_minor_dims = start.size();
- TF_RET_CHECK(n_minor_dims <= n_dims);
- std::vector<int64> padded_start(n_dims, 0);
- std::copy(start.begin(), start.end(),
- padded_start.begin() + (n_dims - n_minor_dims));
- return UpdateSlice(builder, x, update, padded_start);
+xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
+ gtl::ArraySlice<int64> start) {
+ xla::XlaBuilder* builder = x.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
+ const int64 n_minor_dims = start.size();
+ TF_RET_CHECK(n_minor_dims <= n_dims);
+ std::vector<int64> padded_start(n_dims, 0);
+ std::copy(start.begin(), start.end(),
+ padded_start.begin() + (n_dims - n_minor_dims));
+ return UpdateSlice(x, update, padded_start);
+ });
}
-xla::StatusOr<xla::XlaOp> DynamicUpdateSliceInMinorDims(
- xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update,
- const std::vector<xla::XlaOp>& starts) {
- TF_ASSIGN_OR_RETURN(auto padded_starts,
- PrependZerosInMajorDims(builder, x, starts));
+xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
+ gtl::ArraySlice<xla::XlaOp> starts) {
+ auto padded_starts = PrependZerosInMajorDims(x, starts);
return xla::DynamicUpdateSlice(x, update, padded_starts);
}
-xla::StatusOr<xla::XlaOp> PrependZerosInMajorDims(
- xla::XlaBuilder* builder, const xla::XlaOp& x,
- const std::vector<xla::XlaOp>& starts) {
- TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(shape);
- auto zero = xla::Reshape(xla::ConstantR0<int32>(builder, 0), {1});
- std::vector<xla::XlaOp> padded_starts(n_dims, zero);
- for (int i = 0; i < starts.size(); ++i) {
- padded_starts[n_dims - starts.size() + i] = xla::Reshape(starts[i], {1});
- }
- return xla::ConcatInDim(builder, padded_starts, 0);
+xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x,
+ gtl::ArraySlice<xla::XlaOp> starts) {
+ xla::XlaBuilder* builder = x.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
+ auto zero = xla::Reshape(xla::ConstantR0<int32>(builder, 0), {1});
+ std::vector<xla::XlaOp> padded_starts(n_dims, zero);
+ for (int i = 0; i < starts.size(); ++i) {
+ padded_starts[n_dims - starts.size() + i] = xla::Reshape(starts[i], {1});
+ }
+ return xla::ConcatInDim(builder, padded_starts, 0);
+ });
}
-xla::StatusOr<xla::XlaOp> TransposeInMinorDims(xla::XlaBuilder* builder,
- const xla::XlaOp& x) {
- TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(shape);
- TF_RET_CHECK(n_dims >= 2);
- std::vector<int64> permutation(n_dims);
- std::iota(permutation.begin(), permutation.end(), 0);
- std::swap(permutation[n_dims - 1], permutation[n_dims - 2]);
- return xla::Transpose(x, permutation);
+xla::XlaOp TransposeInMinorDims(xla::XlaOp x) {
+ xla::XlaBuilder* builder = x.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
+ TF_RET_CHECK(n_dims >= 2);
+ std::vector<int64> permutation(n_dims);
+ std::iota(permutation.begin(), permutation.end(), 0);
+ std::swap(permutation[n_dims - 1], permutation[n_dims - 2]);
+ return xla::Transpose(x, permutation);
+ });
}
-xla::StatusOr<xla::XlaOp> MaybeConjugate(xla::XlaBuilder* builder,
- const xla::XlaOp& x, bool conjugate) {
- TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
- auto perform_conj = shape.element_type() == xla::C64 && conjugate;
- return perform_conj ? xla::Conj(x) : x;
+xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate) {
+ xla::XlaBuilder* builder = x.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ auto perform_conj = shape.element_type() == xla::C64 && conjugate;
+ return perform_conj ? xla::Conj(x) : x;
+ });
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h
index 3c120a2548..ac5d2940ff 100644
--- a/tensorflow/compiler/tf2xla/lib/util.h
+++ b/tensorflow/compiler/tf2xla/lib/util.h
@@ -33,7 +33,7 @@ xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
// Makes a 1D tensor [0, ..., x, y] from two tensors x and y with zeros
// prepended until the array is length n_dims.
-xla::XlaOp PrependZerosInMajorDims(xla::XlaBuilder* builder,
+xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x,
gtl::ArraySlice<xla::XlaOp> starts);
// Returns a integer scalar constant of 'type' with 'value'.
@@ -41,54 +41,43 @@ xla::XlaOp PrependZerosInMajorDims(xla::XlaBuilder* builder,
xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
int64 value);
-// Builds a vector of zeros of length rank(x) with the last two values being
+// Builds a vector of zeros of length rank(x) with the last values being
// those in `starts`.
-xla::StatusOr<xla::XlaOp> PrependZerosInMajorDims(
- xla::XlaBuilder* builder, const xla::XlaOp& x,
- const std::vector<xla::XlaOp>& starts);
+xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x,
+ gtl::ArraySlice<xla::XlaOp> starts);
// Performs a slice in the minor dimensions of a Tensor.
-xla::StatusOr<xla::XlaOp> SliceInMinorDims(xla::XlaBuilder* builder,
- const xla::XlaOp& x,
- gtl::ArraySlice<int64> start,
- gtl::ArraySlice<int64> end);
+xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice<int64> start,
+ gtl::ArraySlice<int64> end);
-// Builds a 1-d vector out of a concatenation of `major_dims` and `starts`.
-std::vector<int64> PrependMajorDims(xla::XlaBuilder* builder,
- const gtl::ArraySlice<int64>& major_dims,
- const gtl::ArraySlice<int64>& indices);
+// Returns the concatenation of `xs` and `ys`.
+std::vector<int64> ConcatVectors(gtl::ArraySlice<int64> xs,
+ gtl::ArraySlice<int64> ys);
// Performs a dynamic slice in the minor dimensions of a Tensor.
-xla::StatusOr<xla::XlaOp> DynamicSliceInMinorDims(
- xla::XlaBuilder* builder, const xla::XlaOp& x,
- const std::vector<xla::XlaOp>& starts, const gtl::ArraySlice<int64>& sizes);
+xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x,
+ gtl::ArraySlice<xla::XlaOp> starts,
+ gtl::ArraySlice<int64> sizes);
// Updates a slice of 'x', i.e.,
// x[start[0], ..., start[n]] = update
-xla::StatusOr<xla::XlaOp> UpdateSlice(xla::XlaBuilder* builder,
- const xla::XlaOp& x,
- const xla::XlaOp& update,
- gtl::ArraySlice<int64> start);
+xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update,
+ gtl::ArraySlice<int64> start);
// Updates a slice of 'x', where 'start' contains a list of minor dimensions:
// x[..., start[0], ..., start[n]] = update
-xla::StatusOr<xla::XlaOp> UpdateSliceInMinorDims(xla::XlaBuilder* builder,
- const xla::XlaOp& x,
- const xla::XlaOp& update,
- gtl::ArraySlice<int64> start);
+xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
+ gtl::ArraySlice<int64> start);
-xla::StatusOr<xla::XlaOp> DynamicUpdateSliceInMinorDims(
- xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update,
- const std::vector<xla::XlaOp>& starts);
+xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
+ gtl::ArraySlice<xla::XlaOp> starts);
// Transposes a stack of matrices `x` by swapping the last two dimensions.
-xla::StatusOr<xla::XlaOp> TransposeInMinorDims(xla::XlaBuilder* builder,
- const xla::XlaOp& x);
+xla::XlaOp TransposeInMinorDims(xla::XlaOp x);
// Applies a complex conjugation operation if `a` is complex and `conjugate_a`
// is true, otherwise returns its argument.
-xla::StatusOr<xla::XlaOp> MaybeConjugate(xla::XlaBuilder* builder,
- const xla::XlaOp& x, bool conjugate);
+xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc
index 2a332c933f..7d0f2222a9 100644
--- a/tensorflow/compiler/tf2xla/lib/util_test.cc
+++ b/tensorflow/compiler/tf2xla/lib/util_test.cc
@@ -70,8 +70,7 @@ XLA_TEST_F(UtilTest, Simple2dLookup) {
auto a_data = CreateR2Parameter<float>(BValsRight(), 0, "a", &builder, &a);
auto x_data = CreateR0Parameter<int>(2, 1, "x", &builder, &x);
auto y_data = CreateR0Parameter<int>(1, 2, "y", &builder, &y);
- auto result = DynamicSliceInMinorDims(&builder, a, {x, y}, {1, 1});
- TF_ASSERT_OK(result.status());
+ DynamicSliceInMinorDims(a, {x, y}, {1, 1});
ComputeAndCompareR2<float>(&builder, {{10}},
{a_data.get(), x_data.get(), y_data.get()},
@@ -86,10 +85,8 @@ XLA_TEST_F(UtilTest, Simple3dLookup) {
CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
auto index_data = CreateR0Parameter<int>(1, 1, "index", &builder, &index);
- TF_ASSERT_OK(
- DynamicSliceInMinorDims(
- &builder, a, {index, xla::ConstantR0<int32>(&builder, 0)}, {1, 4})
- .status());
+ DynamicSliceInMinorDims(a, {index, xla::ConstantR0<int32>(&builder, 0)},
+ {1, 4});
ComputeAndCompareR3<float>(&builder, {{{3, 6, 0, 1}}, {{24, 61, 82, 48}}},
{a_data.get(), index_data.get()});
@@ -104,8 +101,7 @@ XLA_TEST_F(UtilTest, SimpleSliceUpdate) {
auto x_data = CreateR0Parameter<int>(2, 2, "x", &builder, &x);
auto y_data = CreateR0Parameter<int>(1, 3, "y", &builder, &y);
- auto result = DynamicUpdateSliceInMinorDims(&builder, a, b, {x, y});
- TF_ASSERT_OK(result.status());
+ DynamicUpdateSliceInMinorDims(a, b, {x, y});
xla::Array2D<float> expected(
{{{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 9, 1, -10}, {5, 8, 10, 11}}});
@@ -128,13 +124,9 @@ XLA_TEST_F(UtilTest, RowBatchDot) {
// Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull().
auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index);
- TF_ASSERT_OK_AND_ASSIGN(
- auto l_index,
- DynamicSliceInMinorDims(
- &builder, a, {index, xla::ConstantR0<int32>(&builder, 0)}, {1, n}));
- TF_ASSERT_OK(BatchDot(&builder, l_index, row,
- /*transpose_x=*/false, /*transpose_y=*/true)
- .status());
+ auto l_index = DynamicSliceInMinorDims(
+ a, {index, xla::ConstantR0<int32>(&builder, 0)}, {1, n});
+ BatchDot(l_index, row, /*transpose_x=*/false, /*transpose_y=*/true);
ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
{a_data.get(), row_data.get(), index_data.get()});
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc
index 917ef4037d..81bdf139f5 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.cc
+++ b/tensorflow/compiler/tf2xla/xla_helpers.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/framework/tensor.h"
@@ -72,10 +73,9 @@ Status ArgMinMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
// And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its
// index.
- xla::XlaOp iota;
const int64 axis_size = input_shape.dim_size(axis);
- TF_RETURN_IF_ERROR(XlaHelpers::Iota(builder, output_type, axis_size, &iota));
+ xla::XlaOp iota = xla::Iota(builder, xla_output_type, axis_size);
xla::XlaOp product =
xla::And(full_mask, iota, /*broadcast_dimensions=*/{axis});
@@ -230,31 +230,6 @@ Status XlaHelpers::ArgMin(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
axis, /*is_min=*/true, argmin);
}
-Status XlaHelpers::Iota(xla::XlaBuilder* builder, DataType dtype, int64 size,
- xla::XlaOp* iota) {
- TensorShape linspace_shape({size});
- Tensor linspace;
- switch (dtype) {
- case DT_UINT8:
- linspace = MakeLinspaceTensor<uint8>(linspace_shape, size);
- break;
- case DT_INT32:
- linspace = MakeLinspaceTensor<int32>(linspace_shape, size);
- break;
- case DT_INT64:
- linspace = MakeLinspaceTensor<int64>(linspace_shape, size);
- break;
- default:
- return errors::InvalidArgument("Invalid argument type ",
- DataTypeString(dtype));
- }
- xla::BorrowingLiteral linspace_literal;
- TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal));
-
- *iota = xla::ConstantLiteral(builder, linspace_literal);
- return Status::OK();
-}
-
Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis,
DataType index_type, const TensorShape& indices_shape,
const xla::XlaOp& indices, const xla::XlaOp& on_value,
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h
index c320016998..495bd2b8b6 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.h
+++ b/tensorflow/compiler/tf2xla/xla_helpers.h
@@ -89,10 +89,6 @@ class XlaHelpers {
DataType input_type, DataType output_type, int axis,
xla::XlaOp* argmin);
- // Sets *iota to a rank 1 tensor with values [0, 1, 2, ...] of `dtype`.
- static Status Iota(xla::XlaBuilder* builder, DataType dtype, int64 size,
- xla::XlaOp* iota);
-
// Converts `indices` into a one-hot representation. `depth` is the size
// of the new axis to add. `axis` is the position at which to add the new
// axis. `indices_shape` is the shape of `indices`. `on_value` and
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index c2298b97e1..0eabfb3a52 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
@@ -68,6 +69,20 @@ TensorShape XlaOpKernelContext::InputShape(int index) {
return context_->input(index).shape();
}
+DataType XlaOpKernelContext::input_type(int index) const {
+ return context_->input(index).dtype();
+}
+
+xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) {
+ xla::PrimitiveType type;
+ Status status = DataTypeToPrimitiveType(input_type(index), &type);
+ if (!status.ok()) {
+ SetStatus(status);
+ return xla::PRIMITIVE_TYPE_INVALID;
+ }
+ return type;
+}
+
Status XlaOpKernelContext::ConstantInput(int index,
xla::Literal* constant_literal) {
return ConstantInputReshaped(
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 667dc262ca..2bde2c983d 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/macros.h"
@@ -67,7 +68,12 @@ class XlaOpKernelContext {
int num_inputs() const { return context_->num_inputs(); }
// Returns the type of input 'index'.
- DataType input_type(int index) { return context_->input(index).dtype(); }
+ DataType input_type(int index) const;
+
+ // Returns the type of input 'index' as an xla::PrimitiveType. If the type
+ // is not representable as an XLA type, sets an error status and returns
+ // xla::PRIMITIVE_TYPE_INVALID.
+ xla::PrimitiveType input_xla_type(int index);
// Returns the shape of input 'index'.
TensorShape InputShape(int index);
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index d49d959a6c..273fa17371 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -13,6 +13,12 @@ filegroup(
]),
)
+load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
+load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites")
+
+# Generate test_suites for all backends, named "${backend}_tests".
+generate_backend_suites()
+
cc_library(
name = "arithmetic",
srcs = ["arithmetic.cc"],
@@ -29,6 +35,32 @@ cc_library(
)
cc_library(
+ name = "numeric",
+ srcs = ["numeric.cc"],
+ hdrs = ["numeric.h"],
+ deps = [
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ ],
+)
+
+xla_test(
+ name = "numeric_test",
+ srcs = ["numeric_test.cc"],
+ tags = ["enable_for_xla_interpreter"],
+ deps = [
+ ":numeric",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ ],
+)
+
+cc_library(
name = "testing",
srcs = ["testing.cc"],
hdrs = ["testing.h"],
diff --git a/tensorflow/compiler/xla/client/lib/numeric.cc b/tensorflow/compiler/xla/client/lib/numeric.cc
new file mode 100644
index 0000000000..cbe9e7fdd1
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/numeric.cc
@@ -0,0 +1,71 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
+
+#include <numeric>
+#include <vector>
+
+namespace xla {
+
+namespace {
+
+template <typename T>
+XlaOp MakeIota(XlaBuilder* builder, int64 size) {
+ std::vector<T> values(size);
+ for (int64 i = 0; i < size; ++i) {
+ values[i] = static_cast<T>(i);
+ }
+ return xla::ConstantR1<T>(builder, values);
+}
+
+} // namespace
+
+XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) {
+ switch (type) {
+ case S8:
+ return MakeIota<int8>(builder, size);
+ case S16:
+ return MakeIota<int16>(builder, size);
+ case S32:
+ return MakeIota<int32>(builder, size);
+ case S64:
+ return MakeIota<int64>(builder, size);
+ case U8:
+ return MakeIota<uint8>(builder, size);
+ case U16:
+ return MakeIota<uint16>(builder, size);
+ case U32:
+ return MakeIota<uint32>(builder, size);
+ case U64:
+ return MakeIota<uint64>(builder, size);
+ case BF16:
+ return MakeIota<bfloat16>(builder, size);
+ case F16:
+ return MakeIota<half>(builder, size);
+ case F32:
+ return MakeIota<float>(builder, size);
+ case F64:
+ return MakeIota<double>(builder, size);
+ case C64:
+ return MakeIota<complex64>(builder, size);
+ default:
+ return builder->ReportError(
+ InvalidArgument("Unimplemented type for Iota: %s.",
+ PrimitiveType_Name(type).c_str()));
+ }
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/numeric.h b/tensorflow/compiler/xla/client/lib/numeric.h
new file mode 100644
index 0000000000..2a409ae311
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/numeric.h
@@ -0,0 +1,30 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_
+#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_
+
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+
+// Returns a rank 1 tensor of `type` containing values [0, 1, 2, ...].
+XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_
diff --git a/tensorflow/compiler/xla/client/lib/numeric_test.cc b/tensorflow/compiler/xla/client/lib/numeric_test.cc
new file mode 100644
index 0000000000..bc8a73e9d7
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/numeric_test.cc
@@ -0,0 +1,37 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+namespace {
+
+using NumericTest = ClientLibraryTestBase;
+
+XLA_TEST_F(NumericTest, Iota) {
+ XlaBuilder builder(TestName());
+ Iota(&builder, S32, 10);
+
+ ComputeAndCompareR1<int32>(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, {});
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD
index b0f41ac1d3..ee00a9eada 100644
--- a/tensorflow/compiler/xla/client/xla_client/BUILD
+++ b/tensorflow/compiler/xla/client/xla_client/BUILD
@@ -1,7 +1,5 @@
# Description:
# The new XLA client libraries.
-#
-# This is NOT YET ready to use.
licenses(["notice"]) # Apache 2.0
@@ -41,6 +39,7 @@ cc_library(
name = "xla_builder",
srcs = ["xla_builder.cc"],
hdrs = ["xla_builder.h"],
+ visibility = ["//visibility:public"],
deps = [
":xla_computation",
"//tensorflow/compiler/xla:execution_options_util",
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index 0145f60483..4f683a4115 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -60,36 +60,18 @@ bool CanBeRoot(HloOpcode opcode) {
} // namespace
-XlaOp operator-(const XlaOp& x) { return x.builder()->Neg(x); }
-XlaOp operator+(const XlaOp& x, const XlaOp& y) {
- return x.builder()->Add(x, y);
-}
-XlaOp operator-(const XlaOp& x, const XlaOp& y) {
- return x.builder()->Sub(x, y);
-}
-XlaOp operator*(const XlaOp& x, const XlaOp& y) {
- return x.builder()->Mul(x, y);
-}
-XlaOp operator/(const XlaOp& x, const XlaOp& y) {
- return x.builder()->Div(x, y);
-}
-XlaOp operator%(const XlaOp& x, const XlaOp& y) {
- return x.builder()->Rem(x, y);
-}
-
-XlaOp operator~(const XlaOp& x) { return x.builder()->Not(x); }
-XlaOp operator&(const XlaOp& x, const XlaOp& y) {
- return x.builder()->And(x, y);
-}
-XlaOp operator|(const XlaOp& x, const XlaOp& y) {
- return x.builder()->Or(x, y);
-}
-XlaOp operator^(const XlaOp& x, const XlaOp& y) {
- return x.builder()->Xor(x, y);
-}
-XlaOp operator<<(const XlaOp& x, const XlaOp& y) {
- return x.builder()->ShiftLeft(x, y);
-}
+XlaOp operator-(const XlaOp& x) { return Neg(x); }
+XlaOp operator+(const XlaOp& x, const XlaOp& y) { return Add(x, y); }
+XlaOp operator-(const XlaOp& x, const XlaOp& y) { return Sub(x, y); }
+XlaOp operator*(const XlaOp& x, const XlaOp& y) { return Mul(x, y); }
+XlaOp operator/(const XlaOp& x, const XlaOp& y) { return Div(x, y); }
+XlaOp operator%(const XlaOp& x, const XlaOp& y) { return Rem(x, y); }
+
+XlaOp operator~(const XlaOp& x) { return Not(x); }
+XlaOp operator&(const XlaOp& x, const XlaOp& y) { return And(x, y); }
+XlaOp operator|(const XlaOp& x, const XlaOp& y) { return Or(x, y); }
+XlaOp operator^(const XlaOp& x, const XlaOp& y) { return Xor(x, y); }
+XlaOp operator<<(const XlaOp& x, const XlaOp& y) { return ShiftLeft(x, y); }
XlaOp operator>>(const XlaOp& x, const XlaOp& y) {
XlaBuilder* builder = x.builder();
@@ -101,9 +83,9 @@ XlaOp operator>>(const XlaOp& x, const XlaOp& y) {
ShapeUtil::HumanString(shape).c_str());
}
if (ShapeUtil::ElementIsSigned(shape)) {
- return builder->ShiftRightArithmetic(x, y);
+ return ShiftRightArithmetic(x, y);
} else {
- return builder->ShiftRightLogical(x, y);
+ return ShiftRightLogical(x, y);
}
});
}
@@ -1366,8 +1348,25 @@ XlaOp XlaBuilder::Rev(const XlaOp& operand,
});
}
-XlaOp XlaBuilder::Sort(const XlaOp& operand) {
- return UnaryOp(HloOpcode::kSort, operand);
+XlaOp XlaBuilder::Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ std::vector<const Shape*> operand_shape_ptrs;
+ TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys));
+ operand_shape_ptrs.push_back(&keys_shape);
+ Shape values_shape;
+ if (values.has_value()) {
+ TF_ASSIGN_OR_RETURN(values_shape, GetShape(*values));
+ operand_shape_ptrs.push_back(&values_shape);
+ }
+ TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+ ShapeInference::InferVariadicOpShape(
+ HloOpcode::kSort, operand_shape_ptrs));
+ return values.has_value()
+ ? AddInstruction(std::move(instr), HloOpcode::kSort,
+ {keys, *values})
+ : AddInstruction(std::move(instr), HloOpcode::kSort, {keys});
+ });
}
XlaOp XlaBuilder::SqrtF32(const XlaOp& operand) {
@@ -2538,7 +2537,9 @@ XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
return operand.builder()->Rev(operand, dimensions);
}
-XlaOp Sort(const XlaOp& operand) { return operand.builder()->Sort(operand); }
+XlaOp Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values) {
+ return keys.builder()->Sort(keys, std::move(values));
+}
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) {
return min.builder()->Clamp(min, operand, max);
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
index fe31774b86..ac6ad87349 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
@@ -158,6 +158,93 @@ class XlaBuilder {
die_immediately_on_error_ = enabled;
}
+ // Default dimension numbers used for a 2D convolution.
+ static constexpr int64 kConvBatchDimension = 0;
+ static constexpr int64 kConvFeatureDimension = 1;
+ static constexpr int64 kConvFirstSpatialDimension = 2;
+ static constexpr int64 kConvSecondSpatialDimension = 3;
+ static constexpr int64 kConvKernelOutputDimension = 0;
+ static constexpr int64 kConvKernelInputDimension = 1;
+ static constexpr int64 kConvKernelFirstSpatialDimension = 2;
+ static constexpr int64 kConvKernelSecondSpatialDimension = 3;
+
+ // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
+ // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
+ // the kernel operand
+ // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
+ static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
+ int num_spatial_dims = 2);
+
+ // Returns an error if the convolution dimension numbers have conflicts.
+ static Status Validate(const ConvolutionDimensionNumbers& dnum);
+
+ // Returns a new XlaBuilder whose resultant Computation is used only by this
+ // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error
+ // behavior as the parent.
+ std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name);
+
+ // Builds the computation with the requested operations, or returns a non-ok
+ // status. Note that all ops that have been enqueued will be moved to the
+ // computation being returned.
+ StatusOr<XlaComputation> Build();
+
+ // Builds the computation with the requested operations, or notes an error in
+ // the parent XlaBuilder and returns an empty computation if building failed.
+ // This function is intended to be used where the returned XlaComputation is
+ // only used by the parent XlaBuilder and hence further operation on the
+ // returned XlaComputation will simply be error'ed out if an error occurred
+ // while building this computation. If the built computation is to be used by
+ // a XlaBuilder other than the parent XlaBuilder then Build() should be used
+ // instead.
+ XlaComputation BuildAndNoteError();
+
+ // Returns a subgraph that roots on the given root. If the root is not a
+ // compile-time constant (see `IsConstant`), returns an error.
+ //
+ // This will copy the needed ops/computations to the subgraph.
+ StatusOr<XlaComputation> BuildConstantSubGraph(const XlaOp& root_op) const;
+
+ // Returns the first error that was encountered while building the
+ // computation. When an error is encountered, by default we return a vacuous
+ // XlaOp and inform the user of the error that occurred while
+ // building the computation when they make a final call to Build().
+ //
+ // See also set_die_immediately_on_error().
+ Status first_error() const { return first_error_; }
+
+ // Returns the shape of the given op.
+ StatusOr<Shape> GetShape(const XlaOp& op) const;
+
+ // Returns the (inferred) result for the current computation's shape.
+ StatusOr<ProgramShape> GetProgramShape() const;
+
+ // Reports an error to the builder, by
+ // * storing it internally and capturing a backtrace if it's the first error
+ // (this deferred value will be produced on the call to
+ // Build()/GetShape()/...)
+ // * dying if die_immediately_on_error_ is true.
+ // Returns an XlaOp with an invalid handle but a valid builder. This value can
+ // be returned in place of a value in APIs that return an XlaOp.
+ XlaOp ReportError(const Status& error);
+
+ // A helper function that converts a StatusOr<XlaOp> into an XlaOp.
+ // If the Status was an error, reports the error to builder and returns an
+ // invalid XlaOp handle.
+ XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op);
+
+ // A helper function that runs a function that returns a StatusOr<XlaOp> and
+ // returns an XlaOp.
+ XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator);
+
+ // Returns true if 'operand' is a compile-time constant. A compile-time
+ // constant does not depend on any parameters, or on stateful operators such
+ // as `RngNormal` or `Infeed`.
+ //
+ // This tests whether a computation is a compile-time constant without
+ // evaluating the computation.
+ StatusOr<bool> IsConstant(const XlaOp& operand) const;
+
+ private:
// Enqueues a "retrieve parameter value" instruction for a parameter that was
// passed to the computation.
XlaOp Parameter(int64 parameter_number, const Shape& shape,
@@ -378,26 +465,6 @@ class XlaBuilder {
XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
const DotDimensionNumbers& dimension_numbers);
- // Default dimension numbers used for a 2D convolution.
- static constexpr int64 kConvBatchDimension = 0;
- static constexpr int64 kConvFeatureDimension = 1;
- static constexpr int64 kConvFirstSpatialDimension = 2;
- static constexpr int64 kConvSecondSpatialDimension = 3;
- static constexpr int64 kConvKernelOutputDimension = 0;
- static constexpr int64 kConvKernelInputDimension = 1;
- static constexpr int64 kConvKernelFirstSpatialDimension = 2;
- static constexpr int64 kConvKernelSecondSpatialDimension = 3;
-
- // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
- // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
- // the kernel operand
- // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
- static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
- int num_spatial_dims = 2);
-
- // Returns an error if the convolution dimension numbers have conflicts.
- static Status Validate(const ConvolutionDimensionNumbers& dnum);
-
// Enqueues a convolution instruction onto the computation, which uses the
// default convolution dimension numbers.
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
@@ -717,7 +784,18 @@ class XlaBuilder {
tensorflow::gtl::ArraySlice<int64> dimensions);
// Enqueues a sort (as increasing order) instruction onto the computation.
- XlaOp Sort(const XlaOp& operand);
+ // If only keys are provided:
+ // * The keys must be a rank-1 tensor (i.e. an array).
+ // * The result is a sorted array of keys.
+ //
+ // If both keys and values are provided:
+ // * The keys and the values must be rank-1 tensors with the same dimensions.
+ // The element types of the tensors may be different.
+ // * The result is a tuple that consists of a sorted array of keys as the
+ // first element, and an array with their corresponding values as the second
+ // element.
+ XlaOp Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values =
+ tensorflow::gtl::nullopt);
// Enqueues a clamp instruction onto the computation.
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
@@ -764,14 +842,6 @@ class XlaBuilder {
// be the same as the given shape.
XlaOp Recv(const Shape& shape, const ChannelHandle& handle);
- // Returns true if 'operand' is a compile-time constant. A compile-time
- // constant does not depend on any parameters, or on stateful operators such
- // as `RngNormal` or `Infeed`.
- //
- // This tests whether a computation is a compile-time constant without
- // evaluating the computation.
- StatusOr<bool> IsConstant(const XlaOp& operand) const;
-
// Normalizes operand across spatial and batch dimensions for each feature.
//
// Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
@@ -810,65 +880,6 @@ class XlaBuilder {
const XlaOp& grad_output, float epsilon,
int64 feature_index);
- // Returns a new XlaBuilder whose resultant Computation is used only by this
- // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error
- // behavior as the parent.
- std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name);
-
- // Builds the computation with the requested operations, or returns a non-ok
- // status. Note that all ops that have been enqueued will be moved to the
- // computation being returned.
- StatusOr<XlaComputation> Build();
-
- // Builds the computation with the requested operations, or notes an error in
- // the parent XlaBuilder and returns an empty computation if building failed.
- // This function is intended to be used where the returned XlaComputation is
- // only used by the parent XlaBuilder and hence further operation on the
- // returned XlaComputation will simply be error'ed out if an error occurred
- // while building this computation. If the built computation is to be used by
- // a XlaBuilder other than the parent XlaBuilder then Build() should be used
- // instead.
- XlaComputation BuildAndNoteError();
-
- // Returns a subgraph that roots on the given root. If the root is not a
- // compile-time constant (see `IsConstant`), returns an error.
- //
- // This will copy the needed ops/computations to the subgraph.
- StatusOr<XlaComputation> BuildConstantSubGraph(const XlaOp& root_op) const;
-
- // Returns the first error that was encountered while building the
- // computation. When an error is encountered, by default we return a vacuous
- // XlaOp and inform the user of the error that occurred while
- // building the computation when they make a final call to Build().
- //
- // See also set_die_immediately_on_error().
- Status first_error() const { return first_error_; }
-
- // Returns the shape of the given op.
- StatusOr<Shape> GetShape(const XlaOp& op) const;
-
- // Returns the (inferred) result for the current computation's shape.
- StatusOr<ProgramShape> GetProgramShape() const;
-
- // Reports an error to the builder, by
- // * storing it internally and capturing a backtrace if it's the first error
- // (this deferred value will be produced on the call to
- // Build()/GetShape()/...)
- // * dying if die_immediately_on_error_ is true.
- // Returns an XlaOp with an invalid handle but a valid builder. This value can
- // be returned in place of a value in APIs that return an XlaOp.
- XlaOp ReportError(const Status& error);
-
- // A helper function that converts a StatusOr<XlaOp> into an XlaOp.
- // If the Status was an error, reports the error to builder and returns an
- // invalid XlaOp handle.
- XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op);
-
- // A helper function that runs a function that returns a StatusOr<XlaOp> and
- // returns an XlaOp.
- XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator);
-
- private:
StatusOr<XlaOp> AddInstruction(
HloInstructionProto&& instr, HloOpcode opcode,
tensorflow::gtl::ArraySlice<XlaOp> operands = {});
@@ -971,6 +982,284 @@ class XlaBuilder {
bool die_immediately_on_error_ = false;
XlaBuilder* parent_builder_{nullptr};
+
+ friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number,
+ const Shape& shape, const string& name);
+ friend XlaOp ConstantLiteral(XlaBuilder* builder,
+ const LiteralSlice& literal);
+ template <typename NativeT>
+ friend XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
+ template <typename NativeT>
+ friend XlaOp ConstantR1(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<NativeT> values);
+ friend XlaOp ConstantR1(XlaBuilder* builder,
+ const tensorflow::core::Bitmap& values);
+ template <typename NativeT>
+ friend XlaOp ConstantR2(
+ XlaBuilder* builder,
+ std::initializer_list<std::initializer_list<NativeT>> values);
+ template <typename NativeT>
+ friend XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
+ const Array<NativeT>& values,
+ const Layout& layout);
+ template <typename NativeT>
+ friend XlaOp ConstantFromArray(XlaBuilder* builder,
+ const Array<NativeT>& values);
+ template <typename NativeT>
+ friend XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
+ const Array2D<NativeT>& values,
+ const Layout& layout);
+ template <typename NativeT>
+ friend XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
+ const Array2D<NativeT>& values);
+ template <typename NativeT>
+ friend XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
+ const Array3D<NativeT>& values,
+ const Layout& layout);
+ template <typename NativeT>
+ friend XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
+ const Array3D<NativeT>& values);
+ template <typename NativeT>
+ friend XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
+ const Array4D<NativeT>& values,
+ const Layout& layout);
+ template <typename NativeT>
+ friend XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
+ const Array4D<NativeT>& values);
+
+ template <typename NativeT>
+ friend XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value);
+
+ friend XlaOp Broadcast(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+
+ friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
+ const PaddingConfig& padding_config);
+
+ friend XlaOp Reshape(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<int64> new_sizes);
+
+ friend XlaOp Reshape(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> new_sizes);
+
+ friend XlaOp Collapse(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+
+ friend XlaOp Slice(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices,
+ tensorflow::gtl::ArraySlice<int64> strides);
+
+ friend XlaOp SliceInDim(const XlaOp& operand, int64 start_index,
+ int64 limit_index, int64 stride, int64 dimno);
+
+ friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
+
+ friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
+ const XlaOp& start_indices);
+
+ friend XlaOp ConcatInDim(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ int64 dimension);
+
+ friend void Trace(const string& tag, const XlaOp& operand);
+
+ friend XlaOp Select(const XlaOp& pred, const XlaOp& on_true,
+ const XlaOp& on_false);
+ friend XlaOp Tuple(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> elements);
+ friend XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
+ friend XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs);
+ friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
+ const DotDimensionNumbers& dimension_numbers);
+ friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ Padding padding);
+ friend XlaOp ConvWithGeneralPadding(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+ friend XlaOp ConvWithGeneralDimensions(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+ friend XlaOp ConvGeneral(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+ friend XlaOp ConvGeneralDilated(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ tensorflow::gtl::ArraySlice<int64> lhs_dilation,
+ tensorflow::gtl::ArraySlice<int64> rhs_dilation,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+ friend XlaOp Fft(const XlaOp& operand, FftType fft_type,
+ tensorflow::gtl::ArraySlice<int64> fft_length);
+ friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
+ const string& config);
+ friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
+ const string& outfeed_config);
+ friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<XlaOp> operands);
+ friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const Shape& shape);
+ friend XlaOp HostCompute(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const string& channel_name, int64 cost_estimate_ns,
+ const Shape& shape);
+ friend XlaOp Complex(const XlaOp& real, const XlaOp& imag,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Conj(const XlaOp& operand);
+ friend XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Not(const XlaOp& operand);
+ friend XlaOp ShiftLeft(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp ShiftRightArithmetic(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp ShiftRightLogical(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
+ friend XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation);
+ friend XlaOp ReduceWindow(
+ const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding);
+ friend XlaOp ReduceWindowWithGeneralPadding(
+ const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+ friend XlaOp CrossReplicaSum(
+ const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids);
+ friend XlaOp CrossReplicaSum(
+ const XlaOp& operand, const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids,
+ const tensorflow::gtl::optional<ChannelHandle>& channel_id);
+ friend XlaOp SelectAndScatter(
+ const XlaOp& operand, const XlaComputation& select,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ const XlaOp& source, const XlaOp& init_value,
+ const XlaComputation& scatter);
+ friend XlaOp SelectAndScatterWithGeneralPadding(
+ const XlaOp& operand, const XlaComputation& select,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const XlaOp& source, const XlaOp& init_value,
+ const XlaComputation& scatter);
+ friend XlaOp Abs(const XlaOp& operand);
+ friend XlaOp Atan2(const XlaOp& y, const XlaOp& x,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Exp(const XlaOp& operand);
+ friend XlaOp Expm1(const XlaOp& operand);
+ friend XlaOp Floor(const XlaOp& operand);
+ friend XlaOp Ceil(const XlaOp& operand);
+ friend XlaOp Round(const XlaOp& operand);
+ friend XlaOp Log(const XlaOp& operand);
+ friend XlaOp Log1p(const XlaOp& operand);
+ friend XlaOp Sign(const XlaOp& operand);
+ friend XlaOp Clz(const XlaOp& operand);
+ friend XlaOp Cos(const XlaOp& operand);
+ friend XlaOp Sin(const XlaOp& operand);
+ friend XlaOp Tanh(const XlaOp& operand);
+ friend XlaOp Real(const XlaOp& operand);
+ friend XlaOp Imag(const XlaOp& operand);
+ friend XlaOp SqrtF32(const XlaOp& operand);
+ friend XlaOp SquareF32(const XlaOp& operand);
+ friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp IsFinite(const XlaOp& operand);
+ friend XlaOp ConvertElementType(const XlaOp& operand,
+ PrimitiveType new_element_type);
+ friend XlaOp BitcastConvertType(const XlaOp& operand,
+ PrimitiveType new_element_type);
+ friend XlaOp ReciprocalF32(const XlaOp& operand);
+ friend XlaOp Neg(const XlaOp& operand);
+ friend XlaOp Transpose(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> permutation);
+ friend XlaOp Rev(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+ friend XlaOp Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values);
+ friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
+ friend XlaOp Map(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<XlaOp> static_operands);
+ friend XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma,
+ const Shape& shape);
+ friend XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
+ friend XlaOp While(const XlaComputation& condition,
+ const XlaComputation& body, const XlaOp& init);
+ friend XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
+ const XlaComputation& true_computation,
+ const XlaOp& false_operand,
+ const XlaComputation& false_computation);
+ friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
+ const int mantissa_bits);
+ friend XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
+ const GatherDimensionNumbers& dimension_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds);
+ friend void Send(const XlaOp& operand, const ChannelHandle& handle);
+ friend XlaOp Recv(XlaBuilder* builder, const Shape& shape,
+ const ChannelHandle& handle);
+ friend XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& offset, float epsilon,
+ int64 feature_index);
+ friend XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& offset, const XlaOp& mean,
+ const XlaOp& variance, float epsilon,
+ int64 feature_index);
+ friend XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& batch_mean, const XlaOp& batch_var,
+ const XlaOp& grad_output, float epsilon,
+ int64 feature_index);
};
// RAII-style object: sets the current sharding assignment in builder on
@@ -1548,8 +1837,16 @@ XlaOp Transpose(const XlaOp& operand,
// is moved to index dimension_size - 1 - i).
XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions);
-// Enqueues a sort (as increasing order) instruction onto the computation.
-XlaOp Sort(const XlaOp& operand);
+// * The result is a sorted array of keys.
+//
+// If both keys and values are provided:
+// * The keys and the values must be rank-1 tensors with the same dimensions.
+// The element types of the tensors may be different.
+// * The result is a tuple that consists of a sorted array of keys as the
+// first element, and an array with their corresponding values as the second
+// element.
+XlaOp Sort(XlaOp keys,
+ tensorflow::gtl::optional<XlaOp> values = tensorflow::gtl::nullopt);
// Enqueues a clamp instruction onto the computation.
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 48fd07371d..1ddeb27e40 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -1252,9 +1252,10 @@ bool OutputIsPermutationOfOperandElements(HloInstruction* instruction,
switch (instruction->opcode()) {
case HloOpcode::kReshape:
case HloOpcode::kReverse:
- case HloOpcode::kSort:
case HloOpcode::kTranspose:
return true;
+ case HloOpcode::kSort:
+ return (!ShapeUtil::IsTuple(instruction->shape()));
default:
return false;
}
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index b0ad433d8d..ab3d846403 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -1093,8 +1093,7 @@ void MaybeDumpModule(const string& message, const HloModule& module) {
} // namespace
Status RemoveUnnecessaryCopies(
- const HloOrdering& ordering,
- const tensorflow::gtl::FlatSet<int>& copies_to_exclude, HloModule* module,
+ const HloOrdering& ordering, HloModule* module,
const HloDataflowAnalysis::FusionCanShareBufferFunction&
fusion_can_share_buffer) {
MaybeDumpModule("after adding copies to resolve interference", *module);
@@ -1108,7 +1107,6 @@ Status RemoveUnnecessaryCopies(
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kCopy &&
- !ContainsKey(copies_to_exclude, instruction->unique_id()) &&
instruction->CopyElisionAllowed()) {
TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status());
}
@@ -1152,16 +1150,13 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
"Call graph must be flattened before copy insertion.");
}
- // Gather Ids of existing kCopy instructions in the module. We avoid removing
- // these copies (except via DCE in TupleSimplifier) because they may have been
- // added for reasons not considered by copy insertion (eg, layout assignment).
- // Instruction id is used instead of HloInstruction* because the pointer
- // values may be recycled.
- tensorflow::gtl::FlatSet<int> existing_copies;
- for (HloComputation* computation : module->computations()) {
- for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kCopy) {
- existing_copies.insert(instruction->unique_id());
+ int64 num_existing_copies = 0;
+ if (VLOG_IS_ON(1)) {
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->opcode() == HloOpcode::kCopy) {
+ ++num_existing_copies;
+ }
}
}
}
@@ -1181,8 +1176,7 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
TF_DCHECK_OK(VerifyNoLiveRangeInterference(module));
DependencyHloOrdering ordering(module);
- TF_RETURN_IF_ERROR(
- RemoveUnnecessaryCopies(ordering, existing_copies, module));
+ TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, module));
TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
@@ -1203,7 +1197,7 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
}
}
}
- VLOG(1) << "Num copies before copy-insertion: " << existing_copies.size();
+ VLOG(1) << "Num copies before copy-insertion: " << num_existing_copies;
VLOG(1) << "Num copies after copy-insertion: " << num_total_copies;
}
diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h
index 6d25706089..e1973db928 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.h
+++ b/tensorflow/compiler/xla/service/copy_insertion.h
@@ -21,7 +21,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
@@ -79,11 +78,10 @@ class CopyInsertion : public HloPassInterface {
};
// Try to remove as many copies from the module as possible without introducing
-// live range interference. Copy instructions (identified by their unique id) in
-// the set copies_to_exclude are not considered for removal.
+// live range interference. Only copy instructions that are eligible for
+// copy elision are considered for removal.
Status RemoveUnnecessaryCopies(
- const HloOrdering& ordering,
- const tensorflow::gtl::FlatSet<int>& copies_to_exclude, HloModule* module,
+ const HloOrdering& ordering, HloModule* module,
const HloDataflowAnalysis::FusionCanShareBufferFunction&
fusion_can_share_buffer = nullptr);
diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc
index e7539759ce..7ae8799b61 100644
--- a/tensorflow/compiler/xla/service/copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc
@@ -125,21 +125,27 @@ TEST_F(CopyInsertionTest, SingleConstant) {
}
TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) {
- // Verify that an kCopy instructions which exist in the pass before
+ // Verify that kCopy instructions which change layout and exist before
// copy-insertion remain in the graph after copy-insertion.
auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
- HloInstruction* constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
- HloInstruction* copy_1 = builder.AddInstruction(HloInstruction::CreateUnary(
- constant->shape(), HloOpcode::kCopy, constant));
- HloInstruction* copy_2 = builder.AddInstruction(HloInstruction::CreateUnary(
- constant->shape(), HloOpcode::kCopy, constant));
+ HloInstruction* constant =
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}})));
+ auto minor_to_major = LayoutUtil::MinorToMajor(constant->shape());
+ Layout reversed_layout =
+ LayoutUtil::MakeLayoutFromMajorToMinor(minor_to_major);
+ Shape copy_shape = constant->shape();
+ *copy_shape.mutable_layout() = reversed_layout;
+ HloInstruction* copy_1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant));
+ HloInstruction* copy_2 = builder.AddInstruction(
+ HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant));
HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
constant->shape(), HloOpcode::kAdd, copy_1, copy_2));
- HloInstruction* add_copy = builder.AddInstruction(
- HloInstruction::CreateUnary(constant->shape(), HloOpcode::kCopy, add));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(add->shape(), HloOpcode::kCopy, add));
module->AddEntryComputation(builder.Build());
@@ -147,12 +153,11 @@ TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) {
InsertCopies(module.get());
- EXPECT_EQ(CountCopies(*module), 3);
+ EXPECT_EQ(CountCopies(*module), 2);
- EXPECT_EQ(module->entry_computation()->root_instruction(), add_copy);
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- op::Copy(op::Add(op::Copy(op::Constant()), op::Copy(op::Constant()))));
+ EXPECT_EQ(module->entry_computation()->root_instruction(), add);
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ op::Add(op::Copy(op::Constant()), op::Copy(op::Constant())));
}
TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) {
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
index 7bb8df6581..5343497c03 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
@@ -55,33 +55,28 @@ Status GpuTransferManager::TransferLiteralToInfeed(
return TransferBufferToInfeed(executor, size, literal.untyped_data());
}
- if (ShapeUtil::IsNestedTuple(shape)) {
- return Unimplemented(
- "Infeed with a nested tuple shape is not supported: %s",
- ShapeUtil::HumanString(literal.shape()).c_str());
- }
-
// For a tuple, we transfer each of its elements to the device and
// enqueue the resulting destination device addresses with the
// infeed manager.
std::vector<gpu::InfeedBuffer*> buffers;
- buffers.reserve(ShapeUtil::TupleElementCount(shape));
auto cleanup = tensorflow::gtl::MakeCleanup([buffers]() {
for (gpu::InfeedBuffer* b : buffers) {
b->Done();
}
});
- for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
- const Shape& tuple_element_shape =
- ShapeUtil::GetTupleElementShape(shape, i);
- int64 tuple_element_size = GetByteSizeRequirement(tuple_element_shape);
- TF_ASSIGN_OR_RETURN(
- gpu::InfeedBuffer * buffer,
- TransferBufferToInfeedInternal(executor, tuple_element_size,
- literal.untyped_data({i})));
- buffers.push_back(buffer);
- }
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
+ shape, [&](const Shape& literal_subshape, const ShapeIndex& index) {
+ if (ShapeUtil::IsArray(literal_subshape)) {
+ int64 tuple_element_size = GetByteSizeRequirement(literal_subshape);
+ TF_ASSIGN_OR_RETURN(
+ gpu::InfeedBuffer * buffer,
+ TransferBufferToInfeedInternal(executor, tuple_element_size,
+ literal.untyped_data(index)));
+ buffers.push_back(buffer);
+ }
+ return Status::OK();
+ }));
cleanup.release();
return EnqueueBuffersToInfeed(executor, buffers);
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
index 652b5c7687..ea661b3c2c 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
@@ -113,10 +113,7 @@ bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) {
// We can fuse reduces and loop fusions.
return IsInputFusibleReduction(instr) ||
(instr->opcode() == HloOpcode::kFusion &&
- instr->fusion_kind() == HloInstruction::FusionKind::kLoop &&
- // TODO(b/110202584): bitcasts make nested fusions, GPU has no support
- // for nested fusions.
- instr->fused_expression_root()->opcode() != HloOpcode::kBitcast);
+ instr->fusion_kind() == HloInstruction::FusionKind::kLoop);
}
int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1,
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index deb7f28d84..e65e1af20c 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -1068,6 +1068,19 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
return Status::OK();
}
+Status HloEvaluator::HandleSort(HloInstruction* sort) {
+ if (!ShapeUtil::IsTuple(sort->shape())) {
+ return DefaultAction(sort);
+ }
+ // The key-value version of Sort is a special snowflake, since the output
+ // shape is a tuple, so its element type is not meaningful.
+ //
+ // TODO(mkuper): Do something sane here, so that we can support different key
+ // and value types.
+ return sort->Visit(
+ typed_visitors_.at(sort->operand(0)->shape().element_type()).get());
+}
+
Status HloEvaluator::Preprocess(HloInstruction* hlo) {
VLOG(2) << "About to visit HLO: " << hlo->ToString();
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index 2ad56080d8..b330c30eeb 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -176,6 +176,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
Status HandleAfterAll(HloInstruction* token) override;
+ Status HandleSort(HloInstruction* sort) override;
+
// Returns the already-evaluated literal result for the instruction.
// A Constant instruction is considered evaluated and its literal will be
// returned directly without looking up the cache.
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 8b08756c64..1136178e90 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -1025,83 +1025,47 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
CHECK_EQ(dnums.lhs_batch_dimensions_size(),
dnums.rhs_batch_dimensions_size());
- std::vector<int64> lhs_non_contracting_dims;
+ DimensionVector lhs_index(lhs_rank);
+ DimensionVector rhs_index(rhs_rank);
+
+ // result_index_locations[i] contains one or two pointers to the locations
+ // in lhs_index or rhs_index where the i'th result index should go.
+ tensorflow::gtl::InlinedVector<std::pair<int64*, int64*>, kInlineRank>
+ result_index_locations;
+ result_index_locations.reserve(lhs_rank + rhs_rank - 2);
+
+ // The first components in the output shape are the LHS and RHS batch
+ // dimensions:
+ for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); i++) {
+ result_index_locations.push_back(
+ {&lhs_index[dnums.lhs_batch_dimensions(i)],
+ &rhs_index[dnums.rhs_batch_dimensions(i)]});
+ }
+
+ // Then we have the LHS and RHS non-contracting dimensions, if any:
for (int64 i = 0; i < lhs_rank; i++) {
- if (i != lhs_contracting_dimension) {
- lhs_non_contracting_dims.push_back(i);
+ if (i != lhs_contracting_dimension &&
+ !ArrayContains(AsInt64Slice(dnums.lhs_batch_dimensions()), i)) {
+ result_index_locations.push_back({&lhs_index[i], nullptr});
}
}
-
- std::vector<int64> rhs_non_batch_non_contracting_dims;
- tensorflow::gtl::FlatSet<int64> batch_dims_set(
- dnums.rhs_batch_dimensions().begin(),
- dnums.rhs_batch_dimensions().end());
for (int64 i = 0; i < rhs_rank; i++) {
- if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) {
- rhs_non_batch_non_contracting_dims.push_back(i);
+ if (i != rhs_contracting_dimension &&
+ !ArrayContains(AsInt64Slice(dnums.rhs_batch_dimensions()), i)) {
+ result_index_locations.push_back({&rhs_index[i], nullptr});
}
}
- const int64 batch_dim_size = dnums.lhs_batch_dimensions_size();
- const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size();
-
- DimensionVector lhs_index(lhs_rank);
- DimensionVector rhs_index(rhs_rank);
auto result = MakeUnique<Literal>(dot->shape());
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> result_index) {
ElementwiseT result_val = static_cast<ElementwiseT>(0);
- // Find the corresponding non-contracting indices for lhs and rhs.
- //
- // For `result_index`, its batch dimension, if exists, will be at the
- // same dimension as the batch dimension of lhs and rhs. More
- // specifically:
- // - For lhs, the non-contracting dimensions, including the batch
- // dimension have the same index as the `result_index`.
- // - For rhs, the batch dimension is set seperately from other
- // non-contracting dimensions, since these other non-contracting
- // dimensions in rhs follow the non-contracting dimensions of lhs in
- // the resulting index.
- //
- // As an example, for a resulting index:
- // result_index [result_batch, result_x, result_y]
- // the effecting lhs and rhs indices are:
- // lhs [result_batch, lhs_non_contracting_dim, contracting_dim
- // rhs [result_batch, contracting_dim, rhs_non_contracting_dim]
- // `result_x` is only affected by the lhs_non_contracting_dim and
- // likewise `result_y` only depends on rhs_non_contracting_dim.
- //
- // so we can look up the lhs and rhs indices by:
- //
- // lhs:
- // batch index is the same as `result_batch`.
- // non-contracting dimension is the same as
- // result_index[lhs_non_contracting_dim]
- // rhs:
- // batch index: the same as `result_batch`.
- // non-contracting dimension index: *not* the same as
- // result_index[rhs_non_contractng_dim], since the
- // non-contracting dimensions of lhs are included in the
- // result_index first. Instead, the non_contracting_dim of rhs must
- // be calculated as following:
- // lhs_non_contracting_dimensions_size +
- // (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1
- //
- // Note that (rhs_non_batch_contracting_dim - batch_dim_size) is
- // the index offset to the result_index that only depends on
- // the non_batch and non-contracting dimensions of rhs. -1 at the
- // end translates size to index.
- for (auto i : lhs_non_contracting_dims) {
- lhs_index[i] = result_index[i];
- }
- for (auto i : dnums.rhs_batch_dimensions()) {
- rhs_index[i] = result_index[i];
- }
- for (auto i : rhs_non_batch_non_contracting_dims) {
- const int64 rhs_non_batch_non_contracting_dim =
- lhs_non_contracting_size + (i - batch_dim_size) - 1;
- rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim];
+ for (int64 i = 0; i < result_index.size(); i++) {
+ *result_index_locations[i].first = result_index[i];
+ if (result_index_locations[i].second) {
+ *result_index_locations[i].second = result_index[i];
+ }
}
// Accumulates resulting product along the contracted dimension.
@@ -1402,24 +1366,68 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
!is_complex_t<NativeT>::value &&
!std::is_same<NativeT, bool>::value>::type* = nullptr>
Status HandleSort(HloInstruction* sort) {
- TF_RET_CHECK(ShapeUtil::Rank(sort->shape()) == 1)
+ auto keys = sort->operand(0);
+ TF_RET_CHECK(ShapeUtil::Rank(keys->shape()) == 1)
<< "Sort is only supported for R1 shapes";
- auto arg = sort->operand(0);
- const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg);
- VLOG(3) << "HandleSort arg_literal: " << arg_literal.ToString();
- const auto& arg_data = arg_literal.data<ReturnT>();
+ const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys);
+ VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString();
+ const auto& keys_data = keys_literal.data<ReturnT>();
+
+ if (sort->operand_count() == 1) {
+ std::vector<ReturnT> result_data(keys_data.begin(), keys_data.end());
+ std::sort(result_data.begin(), result_data.end(),
+ [](const ReturnT& a, const ReturnT& b) {
+ return SafeLess<ReturnT>(a, b);
+ });
+ auto result_literal = MakeUnique<Literal>(sort->shape());
+ result_literal->PopulateR1(
+ tensorflow::gtl::ArraySlice<ReturnT>(result_data));
+ VLOG(3) << "HandleSort result_literal: " << result_literal->ToString();
+ parent_->evaluated_[sort] = std::move(result_literal);
+ } else {
+ CHECK_EQ(sort->operand_count(), 2);
+ auto values = sort->operand(1);
+ if (values->shape().element_type() !=
+ primitive_util::NativeToPrimitiveType<ReturnT>()) {
+ return InvalidArgument(
+ "Evaluator requires value and key types for Sort to match");
+ }
- std::vector<ReturnT> return_data(arg_data.begin(), arg_data.end());
- std::sort(return_data.begin(), return_data.end(),
- [](const ReturnT& a, const ReturnT& b) {
- return SafeLess<ReturnT>(a, b);
- });
- auto result_literal = MakeUnique<Literal>(sort->shape());
- result_literal->PopulateR1(
- tensorflow::gtl::ArraySlice<ReturnT>(return_data));
- VLOG(3) << "HandleSort result_literal: " << result_literal->ToString();
- parent_->evaluated_[sort] = std::move(result_literal);
+ // We need to sort and array of keys and an array of values, where the
+ // sorted order of the values is determined by the keys. The simplest(?)
+ // way to do this is to go to an array-of-pairs representation, sort the
+ // array using the keys, and then go back to pair-of-arrays.
+ const Literal& values_literal = parent_->GetEvaluatedLiteralFor(values);
+ VLOG(3) << "HandleSort values_literal: " << values_literal.ToString();
+ const auto& values_data = values_literal.data<ReturnT>();
+ using kv_pair = std::pair<ReturnT, ReturnT>;
+ std::vector<kv_pair> key_value_vector;
+ CHECK_EQ(keys_data.size(), values_data.size());
+ for (int i = 0; i < keys_data.size(); ++i) {
+ key_value_vector.push_back(
+ std::make_pair(keys_data[i], values_data[i]));
+ }
+ std::sort(key_value_vector.begin(), key_value_vector.end(),
+ [](const kv_pair& a, const kv_pair& b) {
+ return SafeLess<ReturnT>(a.first, b.first);
+ });
+ std::vector<ReturnT> result_keys, result_values;
+ for (const auto& key_value : key_value_vector) {
+ result_keys.push_back(key_value.first);
+ result_values.push_back(key_value.second);
+ }
+ auto result_keys_literal = MakeUnique<Literal>(keys->shape());
+ result_keys_literal->PopulateR1(
+ tensorflow::gtl::ArraySlice<ReturnT>(result_keys));
+ auto result_values_literal = MakeUnique<Literal>(values->shape());
+ result_values_literal->PopulateR1(
+ tensorflow::gtl::ArraySlice<ReturnT>(result_values));
+ auto result_tuple = Literal::MakeTuple(
+ {result_keys_literal.get(), result_values_literal.get()});
+ VLOG(3) << "HandleSort result_tuple: " << result_tuple->ToString();
+ parent_->evaluated_[sort] = std::move(result_tuple);
+ }
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 5aaeec802f..e0e3d301be 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -489,7 +489,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
case HloOpcode::kReal:
case HloOpcode::kSign:
case HloOpcode::kSin:
- case HloOpcode::kSort:
case HloOpcode::kTanh:
break;
default:
@@ -908,6 +907,16 @@ HloInstruction::CreateBroadcastSequence(
return MakeUnique<HloTransposeInstruction>(shape, operand, dimensions);
}
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
+ const Shape& shape, HloInstruction* keys, HloInstruction* values) {
+ auto instruction = WrapUnique(new HloInstruction(HloOpcode::kSort, shape));
+ instruction->AppendOperand(keys);
+ if (values) {
+ instruction->AppendOperand(values);
+ }
+ return instruction;
+}
+
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
return MakeUnique<HloFusionInstruction>(shape, fusion_kind, fused_root);
@@ -1122,7 +1131,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kReal:
case HloOpcode::kSign:
case HloOpcode::kSin:
- case HloOpcode::kSort:
case HloOpcode::kTanh:
CHECK_EQ(new_operands.size(), 1);
clone = CreateUnary(shape, opcode_, new_operands[0]);
@@ -1215,6 +1223,14 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kAfterAll:
clone = CreateAfterAll(new_operands);
break;
+ case HloOpcode::kSort:
+ CHECK(new_operands.size() == 1 || new_operands.size() == 2)
+ << "Too many operands for sort: " << new_operands.size();
+ HloInstruction* keys = new_operands[0];
+ HloInstruction* values =
+ new_operands.size() == 2 ? new_operands[1] : nullptr;
+ clone = CreateSort(shape, keys, values);
+ break;
}
SetupDerivedInstruction(clone.get());
clone->set_parent(parent_);
@@ -1491,6 +1507,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
case HloOpcode::kSign:
+ case HloOpcode::kSort:
case HloOpcode::kSin:
case HloOpcode::kSubtract:
case HloOpcode::kTanh:
@@ -1520,10 +1537,6 @@ bool HloInstruction::IdenticalSlowPath(
return eq_computations(true_computation(), other.true_computation()) &&
eq_computations(false_computation(), other.false_computation());
- // These opcodes are not yet supported.
- case HloOpcode::kSort:
- return false;
-
// Ops migrated to subclasses should never come to this line.
// TODO(b/80131774): Remove this switch when migration is complete.
case HloOpcode::kBatchNormTraining:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 59a383218c..0459072127 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -611,6 +611,11 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions);
+ // Creates a sort op, with a keys operand, and an optional values operand.
+ static std::unique_ptr<HloInstruction> CreateSort(
+ const Shape& shape, HloInstruction* keys,
+ HloInstruction* values = nullptr);
+
// Creates a while instruction, given a condition computation, a body
// computation, and the initial value for the input of the computations. For
// example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 57d17064c1..6ffed62a09 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -509,7 +509,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kReal:
case HloOpcode::kSign:
case HloOpcode::kSin:
- case HloOpcode::kSort:
case HloOpcode::kTanh: {
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
@@ -625,6 +624,27 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
builder->AddInstruction(HloInstruction::CreateAfterAll(operands));
break;
}
+ case HloOpcode::kSort: {
+ auto loc = lexer_.GetLoc();
+ if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
+ return false;
+ }
+ switch (operands.size()) {
+ case 1:
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateSort(shape, /*keys=*/operands[0]));
+ break;
+ case 2:
+ instruction = builder->AddInstruction(HloInstruction::CreateSort(
+ shape,
+ /*keys=*/operands[0], /*values=*/operands[1]));
+ break;
+ default:
+ return Error(loc, StrCat("expects either 1 or 2 operands, but has ",
+ operands.size(), " operands"));
+ }
+ break;
+ }
case HloOpcode::kTuple: {
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index da1a34ae3c..504ea3fe7a 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -832,6 +832,31 @@ ENTRY ReducePrecision {
)"
},
+// Sort (Key)
+{
+"SortKey",
+R"(HloModule sort
+
+ENTRY Sort {
+ x = f32[1024]{0} parameter(0)
+ ROOT sorted = f32[1024]{0} sort(x)
+}
+
+)"
+},
+// Sort (Key, Value)
+{
+"SortKeyValue",
+R"(HloModule sort
+
+ENTRY Sort {
+ keys = f32[1024]{0} parameter(0)
+ values = s32[1024]{0} parameter(1)
+ ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values)
+}
+
+)"
+},
// Conditional
{
"Conditional",
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index 62c07d7fac..59a8800a7d 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -1244,7 +1244,7 @@ StatusOr<bool> HloRematerialization::Run(
// TODO(b/80249101): Instead of a separate copy elision pass, use the
// ordering from the HLO schedule directly for copy insertion.
SequentialHloOrdering ordering(module, *sequence);
- TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, {}, module));
+ TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, module));
}
// Compute peak memory usage of all computations in the module called in a
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index fb39c6f085..27c9529b11 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -167,7 +167,16 @@ Status ShapeVerifier::HandleReverse(HloInstruction* reverse) {
}
Status ShapeVerifier::HandleSort(HloInstruction* sort) {
- return CheckUnaryShape(sort);
+ if (sort->operand_count() == 2 &&
+ !ShapeUtil::SameDimensions(sort->operand(0)->shape(),
+ sort->operand(1)->shape())) {
+ return InternalError(
+ "Expected sort to have to have the same dimensions for the keys and "
+ "the values. Keys shape is: %s\n, Values shape is: %s",
+ ShapeUtil::HumanString(sort->operand(0)->shape()).c_str(),
+ ShapeUtil::HumanString(sort->operand(1)->shape()).c_str());
+ }
+ return CheckVariadicShape(sort);
}
Status ShapeVerifier::HandleConstant(HloInstruction* constant) {
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc
index 79b5a442aa..4166ef5baf 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc
@@ -115,39 +115,18 @@ HloInstruction* MultiOutputFusion::Fuse(HloInstruction* instr1,
HloInstruction* fused = instr2;
// Make sure that if only one of the instructions is a fusion, or if only one
// of the instructions is a multi-output fusion, it's what will be fused into.
- //
- // An invariant is that no bitcast nodes will show up in the middle of a
- // fusion node. This invariant must hold in order for us to lower it. Given
- // that, we require that during multi-output fusion, a fusion node ending with
- // bitcast to preserve its structure as a nested fusion instead being
- // merged and flattened.
- if (fused->opcode() == HloOpcode::kFusion &&
- fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) {
+ if (fused->opcode() == HloOpcode::kFusion) {
std::swap(remaining, fused);
}
if (fused->IsMultiOutputFusion()) {
std::swap(remaining, fused);
}
- if (fused->opcode() == HloOpcode::kFusion &&
- fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) {
+ if (fused->opcode() == HloOpcode::kFusion) {
remaining->MergeFusionInstructionIntoMultiOutput(fused);
} else {
- if (remaining->opcode() == HloOpcode::kFusion &&
- remaining->fused_expression_root()->opcode() == HloOpcode::kBitcast) {
- auto parent_computation = remaining->parent();
- // Create a nested fusion node.
- auto remaining_nested_fused =
- parent_computation->AddInstruction(HloInstruction::CreateFusion(
- remaining->shape(), HloInstruction::FusionKind::kLoop,
- remaining));
- TF_CHECK_OK(parent_computation->ReplaceInstruction(
- remaining, remaining_nested_fused));
- remaining = remaining_nested_fused;
- }
remaining->FuseInstructionIntoMultiOutput(fused);
}
-
return remaining;
}
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h
index d23822e33e..0019cd7254 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.h
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.h
@@ -78,6 +78,10 @@ class MultiOutputFusion : public HloPassInterface {
// Test if it's legal to fuse instr1 and instr2 into one fusion instruction.
virtual bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2);
+ // Fuse HloInstrctuion instr1 and instr2 and return the fused instruction.
+ // The other instruction is removed from its parent computation.
+ virtual HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2);
+
// Recompute reachability for the current computation.
void RecomputeReachability();
@@ -101,10 +105,6 @@ class MultiOutputFusion : public HloPassInterface {
virtual bool DoProducerConsumerMultiOutputFusion();
private:
- // Fuse HloInstrctuion instr1 and instr2 and return the fused instruction.
- // The other instruction is removed from its parent computation.
- HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2);
-
// Update the internal data structures after instr1 and instr2 are fused into
// one fusion instruction.
void Update(HloInstruction* instr1, HloInstruction* instr2);
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 096bbde922..d05e995a95 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -239,7 +239,6 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
case HloOpcode::kNegate:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kSign:
- case HloOpcode::kSort:
return shape;
case HloOpcode::kNot:
@@ -962,6 +961,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
return result;
}
+ case HloOpcode::kSort: {
+ if (operand_shapes.size() == 1) {
+ return *operand_shapes[0];
+ } else if (operand_shapes.size() == 2) {
+ return ShapeUtil::MakeTupleShape(
+ {*operand_shapes[0], *operand_shapes[1]});
+ }
+ return InvalidArgument("Unexpected number of operands for sort");
+ }
default:
return InvalidArgument("Unknown operation %s.",
HloOpcodeString(opcode).c_str());
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 5a45e2e610..20b2885e90 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -2040,6 +2040,7 @@ xla_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index 37862fa9cb..5361ae6783 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -373,6 +373,13 @@ class ClientLibraryTestBase : public ::testing::Test {
// The float type used in this test, BF16 or F32 according to use_bfloat16.
PrimitiveType FloatType() const { return use_bfloat16_ ? BF16 : F32; }
+ // Executes the computation and calculates the expected reference value using
+ // the reference client. Returns two literals in the order of (expected,
+ // actual).
+ StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
+ ComputeValueAndReference(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<Literal> arguments);
+
Client* client_;
Client* ref_client_; // To compute reference result.
ExecutionOptions execution_options_;
@@ -390,13 +397,6 @@ class ClientLibraryTestBase : public ::testing::Test {
const string& error_message)>& verify_output,
const Shape* output_with_layout = nullptr);
- // Executes the computation and calculates the expected reference value using
- // the reference client. Returns two literals in the order of (expected,
- // actual).
- StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
- ComputeValueAndReference(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<Literal> arguments);
-
// Whether to run tests with all float-type input/output converted to
// bfloat16.
bool use_bfloat16_ = false;
diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc
index ba22530f1c..1a396b090c 100644
--- a/tensorflow/compiler/xla/tests/compute_constant_test.cc
+++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc
@@ -99,7 +99,7 @@ TEST_F(ComputeConstantTest, ScalarInt32Literal) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
- auto computation = b.ConstantR0<int32>(42);
+ auto computation = ConstantR0<int32>(&b, 42);
EXPECT_TRUE(IsConstant(computation, &b));
auto value = ComputeConstantScalar<int32>(client, computation, &b);
@@ -113,7 +113,7 @@ TEST_F(ComputeConstantTest, ScalarFloatAdd) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto computation =
- b.Add(b.ConstantR0<float>(42.5f), b.ConstantR0<float>(1.5f));
+ Add(ConstantR0<float>(&b, 42.5f), ConstantR0<float>(&b, 1.5f));
EXPECT_TRUE(IsConstant(computation, &b));
auto value = ComputeConstantScalar<float>(client, computation, &b);
@@ -127,8 +127,8 @@ TEST_F(ComputeConstantTest, ScalarRng) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto computation =
- b.RngUniform(b.ConstantR0<float>(1.1f), b.ConstantR0<float>(2.1f),
- ShapeUtil::MakeShape(F32, {}));
+ RngUniform(ConstantR0<float>(&b, 1.1f), ConstantR0<float>(&b, 2.1f),
+ ShapeUtil::MakeShape(F32, {}));
EXPECT_FALSE(IsConstant(computation, &b));
auto value = ComputeConstantScalar<float>(client, computation, &b);
@@ -141,7 +141,7 @@ TEST_F(ComputeConstantTest, DirectParamMissing) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
- auto computation = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param");
+ auto computation = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "param");
EXPECT_FALSE(IsConstant(computation, &b));
auto value = ComputeConstantScalar<float>(client, computation, &b);
@@ -156,8 +156,8 @@ TEST_F(ComputeConstantTest, IndirectParamMissing) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto computation =
- b.Add(b.ConstantR0<float>(1.0f),
- b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"));
+ Add(ConstantR0<float>(&b, 1.0f),
+ Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "param"));
EXPECT_FALSE(IsConstant(computation, &b));
auto value = ComputeConstantScalar<float>(client, computation, &b);
@@ -174,18 +174,18 @@ TEST_F(ComputeConstantTest, UnrelatedParam) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
- auto param_a = b.Parameter(10, ShapeUtil::MakeShape(F32, {}), "param0");
+ auto param_a = Parameter(&b, 10, ShapeUtil::MakeShape(F32, {}), "param0");
auto constant_4 =
- b.Add(b.ConstantR0<float>(2.5f), b.ConstantR0<float>(1.5f));
- auto not_constant_a = b.Add(constant_4, param_a);
+ Add(ConstantR0<float>(&b, 2.5f), ConstantR0<float>(&b, 1.5f));
+ auto not_constant_a = Add(constant_4, param_a);
- auto param_b = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "param1");
+ auto param_b = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "param1");
auto constant_9 =
- b.Mul(b.ConstantR0<float>(2.0f), b.ConstantR0<float>(4.5f));
- auto not_constant_b = b.Add(param_b, constant_9);
+ Mul(ConstantR0<float>(&b, 2.0f), ConstantR0<float>(&b, 4.5f));
+ auto not_constant_b = Add(param_b, constant_9);
- auto constant_13 = b.Add(constant_4, constant_9);
- b.Add(not_constant_b, b.Add(constant_13, not_constant_a));
+ auto constant_13 = Add(constant_4, constant_9);
+ Add(not_constant_b, Add(constant_13, not_constant_a));
EXPECT_TRUE(IsConstant(constant_13, &b));
@@ -201,7 +201,7 @@ TEST_F(ComputeConstantTest, NonScalarAdd) {
XlaBuilder b(TestName());
auto computation =
- b.Add(b.ConstantR1<int32>({1, 2}), b.ConstantR1<int32>({3, 4}));
+ Add(ConstantR1<int32>(&b, {1, 2}), ConstantR1<int32>(&b, {3, 4}));
EXPECT_TRUE(IsConstant(computation, &b));
TF_ASSERT_OK_AND_ASSIGN(auto computed,
@@ -216,7 +216,7 @@ TEST_F(ComputeConstantTest, IntegerDivide) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
- auto computation = b.Div(b.ConstantR0<int32>(15), b.ConstantR0<int32>(3));
+ auto computation = Div(ConstantR0<int32>(&b, 15), ConstantR0<int32>(&b, 3));
EXPECT_TRUE(IsConstant(computation, &b));
TF_ASSERT_OK_AND_ASSIGN(auto computed,
@@ -237,8 +237,8 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
TF_ASSERT_OK_AND_ASSIGN(
auto computed, ComputeConstantLiteral(
client,
- b.Add(b.ConstantR2<int32>({{1, 2}, {3, 4}}),
- b.ConstantR2<int32>({{10, 20}, {30, 40}})),
+ Add(ConstantR2<int32>(&b, {{1, 2}, {3, 4}}),
+ ConstantR2<int32>(&b, {{10, 20}, {30, 40}})),
&b, &layout_proto));
std::unique_ptr<Literal> expected_literal =
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index 33d79aebb1..cf2e645d47 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -853,10 +853,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) {
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
-// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
XLA_TEST_F(DotOperationTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
- DotOfGatherOptimizationWithConstRHSReverseMM)))) {
+
+ DotOfGatherOptimizationWithConstRHSReverseMM) {
std::unique_ptr<Array2D<float>> constant_lhs_array(
new Array2D<float>({{1.0, 2.0, 3.0},
{4.0, 5.0, 6.0},
@@ -883,10 +882,7 @@ XLA_TEST_F(DotOperationTest,
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
-// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
-XLA_TEST_F(DotOperationTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
- DotOfGatherOptimizationWithConstLHSReverseMM)))) {
+XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSReverseMM) {
std::unique_ptr<Array2D<float>> constant_lhs_array(
new Array2D<float>({{1.0, 2.0, 3.0},
{4.0, 5.0, 6.0},
@@ -913,10 +909,7 @@ XLA_TEST_F(DotOperationTest,
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
-// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
-XLA_TEST_F(DotOperationTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
- DotOfGatherOptimizationWithConstRHSRows)))) {
+XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSRows) {
std::unique_ptr<Array2D<float>> constant_lhs_array(
new Array2D<float>({{1.0, 2.0},
{3.0, 4.0},
@@ -948,10 +941,7 @@ XLA_TEST_F(DotOperationTest,
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
-// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
-XLA_TEST_F(DotOperationTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
- DotOfGatherOptimizationWithConstLHSRows)))) {
+XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSRows) {
std::unique_ptr<Array2D<float>> constant_lhs_array(
new Array2D<float>({{1.0, 2.0},
{3.0, 4.0},
@@ -983,10 +973,7 @@ XLA_TEST_F(DotOperationTest,
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
-// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
-XLA_TEST_F(DotOperationTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
- DotOfGatherOptimizationWithConstRHSCols)))) {
+XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSCols) {
std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
{{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
std::unique_ptr<Array2D<float>> constant_rhs_array(
@@ -1010,10 +997,7 @@ XLA_TEST_F(DotOperationTest,
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
-// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
-XLA_TEST_F(DotOperationTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
- DotOfGatherOptimizationWithConstLHSCols)))) {
+XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSCols) {
std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
{{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
std::unique_ptr<Array2D<float>> constant_rhs_array(
@@ -1036,5 +1020,28 @@ XLA_TEST_F(DotOperationTest,
Array2D<float> expected({{168.0}, {168.0}});
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
+
+XLA_TEST_F(DotOperationTest, DotRank2AndRank2NonDefaultContractionDims) {
+ XlaBuilder builder(TestName());
+
+ Array2D<float> lhs_array({{1.0f, 2.0f}, {3.0f, 4.0f}});
+ auto lhs_constant = ConstantR2FromArray2D(&builder, lhs_array);
+
+ Array2D<float> rhs_array({{5.0f, 6.0f}, {7.0f, 8.0f}});
+ auto rhs_constant = ConstantR2FromArray2D(&builder, rhs_array);
+
+ Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ DotGeneral(lhs_constant, rhs_constant, dot_dnums);
+
+ Array2D<float> expected({
+ {26.f, 30.f},
+ {38.f, 44.f},
+ });
+
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc
index 76bf47845c..fd85118849 100644
--- a/tensorflow/compiler/xla/tests/half_test.cc
+++ b/tensorflow/compiler/xla/tests/half_test.cc
@@ -37,8 +37,7 @@ class HalfTestBase : public ClientLibraryTestBase {
static const int kNumElements = 4;
};
-using UnaryBuildFuncTy =
- std::function<void(xla::XlaBuilder*, const xla::XlaOp& src)>;
+using UnaryBuildFuncTy = std::function<void(const xla::XlaOp& src)>;
struct UnaryOpTestParam {
std::function<half(half)> compute_func;
@@ -62,7 +61,7 @@ XLA_TEST_P(UnaryOpTest, Ops) {
}
UnaryBuildFuncTy build_func = GetParam().build_func;
- build_func(&builder, x_opnd);
+ build_func(x_opnd);
ComputeAndCompareR1<half>(&builder, expected, {x_data.get()}, error_spec_);
}
@@ -79,18 +78,17 @@ half round_imp(half value) {
INSTANTIATE_TEST_CASE_P(
half, UnaryOpTest,
::testing::Values(
- UnaryOpTestParam{[](half x) { return abs(x); }, &XlaBuilder::Abs},
- UnaryOpTestParam{[](half x) { return round_imp(x); },
- &XlaBuilder::Round},
- UnaryOpTestParam{[](half x) { return ceil(x); }, &XlaBuilder::Ceil},
- UnaryOpTestParam{[](half x) { return cos(x); }, &XlaBuilder::Cos},
- UnaryOpTestParam{[](half x) { return exp(x); }, &XlaBuilder::Exp},
- UnaryOpTestParam{[](half x) { return floor(x); }, &XlaBuilder::Floor},
- UnaryOpTestParam{[](half x) { return log(x); }, &XlaBuilder::Log},
- UnaryOpTestParam{[](half x) { return -x; }, &XlaBuilder::Neg},
- UnaryOpTestParam{[](half x) { return sign_imp(x); }, &XlaBuilder::Sign},
- UnaryOpTestParam{[](half x) { return sin(x); }, &XlaBuilder::Sin},
- UnaryOpTestParam{[](half x) { return tanh(x); }, &XlaBuilder::Tanh}
+ UnaryOpTestParam{[](half x) { return abs(x); }, &Abs},
+ UnaryOpTestParam{[](half x) { return round_imp(x); }, &Round},
+ UnaryOpTestParam{[](half x) { return ceil(x); }, &Ceil},
+ UnaryOpTestParam{[](half x) { return cos(x); }, &Cos},
+ UnaryOpTestParam{[](half x) { return exp(x); }, &Exp},
+ UnaryOpTestParam{[](half x) { return floor(x); }, &Floor},
+ UnaryOpTestParam{[](half x) { return log(x); }, &Log},
+ UnaryOpTestParam{[](half x) { return -x; }, &Neg},
+ UnaryOpTestParam{[](half x) { return sign_imp(x); }, &Sign},
+ UnaryOpTestParam{[](half x) { return sin(x); }, &Sin},
+ UnaryOpTestParam{[](half x) { return tanh(x); }, &Tanh}
));
@@ -118,19 +116,18 @@ XLA_TEST_P(UnaryPredTest, Ops) {
}
UnaryBuildFuncTy build_func = GetParam().build_func;
- build_func(&builder, x_opnd);
+ build_func(x_opnd);
ComputeAndCompareR1<bool>(&builder, expected, {x_data.get()});
}
INSTANTIATE_TEST_CASE_P(half, UnaryPredTest,
::testing::Values(UnaryPredTestParam{
- [](half x) { return isfinite(x); },
- &XlaBuilder::IsFinite}));
+ [](half x) { return isfinite(x); }, &IsFinite}));
-using BinaryBuildFuncTy = std::function<void(
- xla::XlaBuilder*, const xla::XlaOp& x, const xla::XlaOp& y,
- tensorflow::gtl::ArraySlice<int64>)>;
+using BinaryBuildFuncTy =
+ std::function<void(const xla::XlaOp& x, const xla::XlaOp& y,
+ tensorflow::gtl::ArraySlice<int64>)>;
struct BinaryOpTestParam {
std::function<half(half, half)> compute_func;
@@ -159,7 +156,7 @@ XLA_TEST_P(BinaryOpTest, Ops) {
}
BinaryBuildFuncTy build_func = GetParam().build_func;
- build_func(&builder, x_opnd, y_opnd, {});
+ build_func(x_opnd, y_opnd, {});
ComputeAndCompareR1<half>(&builder, expected, {x_data.get(), y_data.get()},
error_spec_);
@@ -173,22 +170,15 @@ half atan2_imp(half x, half y) {
INSTANTIATE_TEST_CASE_P(
half, BinaryOpTest,
::testing::Values(
- BinaryOpTestParam{[](half x, half y) { return x + y; },
- &XlaBuilder::Add},
+ BinaryOpTestParam{[](half x, half y) { return x + y; }, &Add},
BinaryOpTestParam{[](half x, half y) { return atan2_imp(x, y); },
- &XlaBuilder::Atan2},
- BinaryOpTestParam{[](half x, half y) { return x / y; },
- &XlaBuilder::Div},
- BinaryOpTestParam{[](half x, half y) { return max(x, y); },
- &XlaBuilder::Max},
- BinaryOpTestParam{[](half x, half y) { return min(x, y); },
- &XlaBuilder::Min},
- BinaryOpTestParam{[](half x, half y) { return x * y; },
- &XlaBuilder::Mul},
- BinaryOpTestParam{[](half x, half y) { return pow(x, y); },
- &XlaBuilder::Pow},
- BinaryOpTestParam{[](half x, half y) { return x - y; },
- &XlaBuilder::Sub}
+ &Atan2},
+ BinaryOpTestParam{[](half x, half y) { return x / y; }, &Div},
+ BinaryOpTestParam{[](half x, half y) { return max(x, y); }, &Max},
+ BinaryOpTestParam{[](half x, half y) { return min(x, y); }, &Min},
+ BinaryOpTestParam{[](half x, half y) { return x * y; }, &Mul},
+ BinaryOpTestParam{[](half x, half y) { return pow(x, y); }, &Pow},
+ BinaryOpTestParam{[](half x, half y) { return x - y; }, &Sub}
));
@@ -221,27 +211,22 @@ XLA_TEST_P(BinaryPredTest, Ops) {
}
BinaryBuildFuncTy build_func = GetParam().build_func;
- build_func(&builder, x_opnd, y_opnd, {});
+ build_func(x_opnd, y_opnd, {});
ComputeAndCompareR1<bool>(&builder, expected, {x_data.get(), y_data.get()});
}
INSTANTIATE_TEST_CASE_P(
half, BinaryPredTest,
- ::testing::Values(BinaryPredTestParam{[](half x, half y) { return x == y; },
- &XlaBuilder::Eq},
- BinaryPredTestParam{[](half x, half y) { return x != y; },
- &XlaBuilder::Ne},
- BinaryPredTestParam{[](half x, half y) { return x >= y; },
- &XlaBuilder::Ge},
- BinaryPredTestParam{[](half x, half y) { return x > y; },
- &XlaBuilder::Gt},
- BinaryPredTestParam{[](half x, half y) { return x <= y; },
- &XlaBuilder::Le},
- BinaryPredTestParam{[](half x, half y) { return x < y; },
- &XlaBuilder::Lt}
-
- ));
+ ::testing::Values(
+ BinaryPredTestParam{[](half x, half y) { return x == y; }, &Eq},
+ BinaryPredTestParam{[](half x, half y) { return x != y; }, &Ne},
+ BinaryPredTestParam{[](half x, half y) { return x >= y; }, &Ge},
+ BinaryPredTestParam{[](half x, half y) { return x > y; }, &Gt},
+ BinaryPredTestParam{[](half x, half y) { return x <= y; }, &Le},
+ BinaryPredTestParam{[](half x, half y) { return x < y; }, &Lt}
+
+ ));
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc
index 6154ce671c..5c351b2d11 100644
--- a/tensorflow/compiler/xla/tests/pred_test.cc
+++ b/tensorflow/compiler/xla/tests/pred_test.cc
@@ -29,14 +29,14 @@ namespace {
class PredTest : public ClientLibraryTestBase {
protected:
- void TestCompare(
- bool lhs, bool rhs, bool expected,
- XlaOp (XlaBuilder::*op)(const xla::XlaOp&, const xla::XlaOp&,
- tensorflow::gtl::ArraySlice<int64>)) {
+ void TestCompare(bool lhs, bool rhs, bool expected,
+ std::function<XlaOp(const xla::XlaOp&, const xla::XlaOp&,
+ tensorflow::gtl::ArraySlice<int64>)>
+ op) {
XlaBuilder builder(TestName());
XlaOp lhs_op = ConstantR0<bool>(&builder, lhs);
XlaOp rhs_op = ConstantR0<bool>(&builder, rhs);
- (builder.*op)(lhs_op, rhs_op, {});
+ op(lhs_op, rhs_op, {});
ComputeAndCompareR0<bool>(&builder, expected, {});
}
};
@@ -54,27 +54,27 @@ TEST_F(PredTest, ConstantR0PredFalse) {
}
TEST_F(PredTest, ConstantR0PredCompareEq) {
- TestCompare(true, false, false, &XlaBuilder::Eq);
+ TestCompare(true, false, false, &Eq);
}
TEST_F(PredTest, ConstantR0PredCompareNe) {
- TestCompare(true, false, true, &XlaBuilder::Ne);
+ TestCompare(true, false, true, &Ne);
}
TEST_F(PredTest, ConstantR0PredCompareLe) {
- TestCompare(true, false, false, &XlaBuilder::Le);
+ TestCompare(true, false, false, &Le);
}
TEST_F(PredTest, ConstantR0PredCompareLt) {
- TestCompare(true, false, false, &XlaBuilder::Lt);
+ TestCompare(true, false, false, &Lt);
}
TEST_F(PredTest, ConstantR0PredCompareGe) {
- TestCompare(true, false, true, &XlaBuilder::Ge);
+ TestCompare(true, false, true, &Ge);
}
TEST_F(PredTest, ConstantR0PredCompareGt) {
- TestCompare(true, false, true, &XlaBuilder::Gt);
+ TestCompare(true, false, true, &Gt);
}
TEST_F(PredTest, ConstantR1Pred) {
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
index d0ebb108ae..bc994315c3 100644
--- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc
+++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
@@ -44,25 +44,26 @@ class ScalarComputationsTest : public ClientLibraryTestBase {
protected:
// A template for building and running a binary comparison test.
template <typename NativeT>
- void TestCompare(
- NativeT lhs, NativeT rhs, bool expected,
- XlaOp (XlaBuilder::*op)(const XlaOp&, const XlaOp&,
- tensorflow::gtl::ArraySlice<int64>)) {
+ void TestCompare(NativeT lhs, NativeT rhs, bool expected,
+ std::function<XlaOp(const XlaOp&, const XlaOp&,
+ tensorflow::gtl::ArraySlice<int64>)>
+ op) {
XlaBuilder builder(TestName());
XlaOp lhs_op = ConstantR0<NativeT>(&builder, lhs);
XlaOp rhs_op = ConstantR0<NativeT>(&builder, rhs);
- (builder.*op)(lhs_op, rhs_op, {});
+ op(lhs_op, rhs_op, {});
ComputeAndCompareR0<bool>(&builder, expected, {});
}
template <typename NativeT>
void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected,
- XlaOp (XlaBuilder::*op)(const XlaOp&, const XlaOp&,
- tensorflow::gtl::ArraySlice<int64>)) {
+ std::function<XlaOp(const XlaOp&, const XlaOp&,
+ tensorflow::gtl::ArraySlice<int64>)>
+ op) {
XlaBuilder builder(TestName());
XlaOp lhs_op = ConstantR0<NativeT>(&builder, lhs);
XlaOp rhs_op = ConstantR0<NativeT>(&builder, rhs);
- (builder.*op)(lhs_op, rhs_op, {});
+ op(lhs_op, rhs_op, {});
ComputeAndCompareR0<NativeT>(&builder, expected, {});
}
};
@@ -583,117 +584,116 @@ XLA_TEST_F(ScalarComputationsTest, CompareGtScalar) {
// S32 comparisons.
XLA_TEST_F(ScalarComputationsTest, CompareEqS32Greater) {
- TestCompare<int32>(2, 1, false, &XlaBuilder::Eq);
+ TestCompare<int32>(2, 1, false, &Eq);
}
XLA_TEST_F(ScalarComputationsTest, CompareEqS32Equal) {
- TestCompare<int32>(3, 3, true, &XlaBuilder::Eq);
+ TestCompare<int32>(3, 3, true, &Eq);
}
XLA_TEST_F(ScalarComputationsTest, CompareNeS32) {
- TestCompare<int32>(2, 1, true, &XlaBuilder::Ne);
+ TestCompare<int32>(2, 1, true, &Ne);
}
XLA_TEST_F(ScalarComputationsTest, CompareGeS32) {
- TestCompare<int32>(2, 1, true, &XlaBuilder::Ge);
+ TestCompare<int32>(2, 1, true, &Ge);
}
XLA_TEST_F(ScalarComputationsTest, CompareGtS32) {
- TestCompare<int32>(1, 5, false, &XlaBuilder::Gt);
+ TestCompare<int32>(1, 5, false, &Gt);
}
XLA_TEST_F(ScalarComputationsTest, CompareLeS32) {
- TestCompare<int32>(2, 1, false, &XlaBuilder::Le);
+ TestCompare<int32>(2, 1, false, &Le);
}
XLA_TEST_F(ScalarComputationsTest, CompareLtS32) {
- TestCompare<int32>(9, 7, false, &XlaBuilder::Lt);
+ TestCompare<int32>(9, 7, false, &Lt);
TestCompare<int32>(std::numeric_limits<int32>::min(),
- std::numeric_limits<int32>::max(), true, &XlaBuilder::Lt);
+ std::numeric_limits<int32>::max(), true, &Lt);
}
// U32 comparisons.
XLA_TEST_F(ScalarComputationsTest, CompareEqU32False) {
- TestCompare<uint32>(2, 1, false, &XlaBuilder::Eq);
+ TestCompare<uint32>(2, 1, false, &Eq);
}
XLA_TEST_F(ScalarComputationsTest, CompareNeU32) {
- TestCompare<uint32>(2, 1, true, &XlaBuilder::Ne);
+ TestCompare<uint32>(2, 1, true, &Ne);
}
XLA_TEST_F(ScalarComputationsTest, CompareGeU32Greater) {
- TestCompare<uint32>(2, 1, true, &XlaBuilder::Ge);
+ TestCompare<uint32>(2, 1, true, &Ge);
}
XLA_TEST_F(ScalarComputationsTest, CompareGeU32Equal) {
- TestCompare<uint32>(3, 3, true, &XlaBuilder::Ge);
+ TestCompare<uint32>(3, 3, true, &Ge);
}
XLA_TEST_F(ScalarComputationsTest, CompareGtU32) {
- TestCompare<uint32>(1, 5, false, &XlaBuilder::Gt);
- TestCompare<uint32>(5, 5, false, &XlaBuilder::Gt);
- TestCompare<uint32>(5, 1, true, &XlaBuilder::Gt);
+ TestCompare<uint32>(1, 5, false, &Gt);
+ TestCompare<uint32>(5, 5, false, &Gt);
+ TestCompare<uint32>(5, 1, true, &Gt);
}
XLA_TEST_F(ScalarComputationsTest, CompareLeU32) {
- TestCompare<uint32>(2, 1, false, &XlaBuilder::Le);
+ TestCompare<uint32>(2, 1, false, &Le);
}
XLA_TEST_F(ScalarComputationsTest, CompareLtU32) {
- TestCompare<uint32>(9, 7, false, &XlaBuilder::Lt);
- TestCompare<uint32>(0, std::numeric_limits<uint32>::max(), true,
- &XlaBuilder::Lt);
+ TestCompare<uint32>(9, 7, false, &Lt);
+ TestCompare<uint32>(0, std::numeric_limits<uint32>::max(), true, &Lt);
}
// F32 comparisons.
XLA_TEST_F(ScalarComputationsTest, CompareEqF32False) {
- TestCompare<float>(2.0, 1.3, false, &XlaBuilder::Eq);
+ TestCompare<float>(2.0, 1.3, false, &Eq);
}
XLA_TEST_F(ScalarComputationsTest, CompareNeF32) {
- TestCompare<float>(2.0, 1.3, true, &XlaBuilder::Ne);
+ TestCompare<float>(2.0, 1.3, true, &Ne);
}
XLA_TEST_F(ScalarComputationsTest, CompareGeF32Greater) {
- TestCompare<float>(2.0, 1.9, true, &XlaBuilder::Ge);
+ TestCompare<float>(2.0, 1.9, true, &Ge);
}
XLA_TEST_F(ScalarComputationsTest, CompareGeF32Equal) {
- TestCompare<float>(3.5, 3.5, true, &XlaBuilder::Ge);
+ TestCompare<float>(3.5, 3.5, true, &Ge);
}
XLA_TEST_F(ScalarComputationsTest, CompareGtF32) {
- TestCompare<float>(1.0, 5.2, false, &XlaBuilder::Gt);
+ TestCompare<float>(1.0, 5.2, false, &Gt);
}
XLA_TEST_F(ScalarComputationsTest, CompareLeF32) {
- TestCompare<float>(2.0, 1.2, false, &XlaBuilder::Le);
+ TestCompare<float>(2.0, 1.2, false, &Le);
}
XLA_TEST_F(ScalarComputationsTest, CompareLtF32) {
- TestCompare<float>(9.0, 7.2, false, &XlaBuilder::Lt);
+ TestCompare<float>(9.0, 7.2, false, &Lt);
}
// F32 comparisons with exceptional values. The test names encode the
// left/right operands at the end, and use Minf and Mzero for -inf and -0.0.
XLA_TEST_F(ScalarComputationsTest, CompareLtF32MinfMzero) {
- TestCompare<float>(-INFINITY, -0.0, true, &XlaBuilder::Lt);
+ TestCompare<float>(-INFINITY, -0.0, true, &Lt);
}
XLA_TEST_F(ScalarComputationsTest, CompareLtF32MzeroZero) {
// Comparisons of 0.0 to -0.0 consider them equal in IEEE 754.
- TestCompare<float>(-0.0, 0.0, false, &XlaBuilder::Lt);
+ TestCompare<float>(-0.0, 0.0, false, &Lt);
}
XLA_TEST_F(ScalarComputationsTest, CompareLtF32ZeroInf) {
- TestCompare<float>(0.0, INFINITY, true, &XlaBuilder::Lt);
+ TestCompare<float>(0.0, INFINITY, true, &Lt);
}
XLA_TEST_F(ScalarComputationsTest, CompareGeF32MinfMzero) {
- TestCompare<float>(-INFINITY, -0.0, false, &XlaBuilder::Ge);
+ TestCompare<float>(-INFINITY, -0.0, false, &Ge);
}
XLA_TEST_F(ScalarComputationsTest, CompareGeF32MzeroZero) {
// Comparisons of 0.0 to -0.0 consider them equal in IEEE 754.
- TestCompare<float>(-0.0, 0.0, true, &XlaBuilder::Ge);
+ TestCompare<float>(-0.0, 0.0, true, &Ge);
}
XLA_TEST_F(ScalarComputationsTest, CompareGeF32ZeroInf) {
- TestCompare<float>(0.0, INFINITY, false, &XlaBuilder::Ge);
+ TestCompare<float>(0.0, INFINITY, false, &Ge);
}
XLA_TEST_F(ScalarComputationsTest, ExpScalar) {
@@ -813,65 +813,65 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarLowF32) {
}
XLA_TEST_F(ScalarComputationsTest, MinS32Above) {
- TestMinMax<int32>(10, 3, 3, &XlaBuilder::Min);
+ TestMinMax<int32>(10, 3, 3, &Min);
}
XLA_TEST_F(ScalarComputationsTest, MinS32Below) {
- TestMinMax<int32>(-100, 3, -100, &XlaBuilder::Min);
+ TestMinMax<int32>(-100, 3, -100, &Min);
}
XLA_TEST_F(ScalarComputationsTest, MaxS32Above) {
- TestMinMax<int32>(10, 3, 10, &XlaBuilder::Max);
+ TestMinMax<int32>(10, 3, 10, &Max);
}
XLA_TEST_F(ScalarComputationsTest, MaxS32Below) {
- TestMinMax<int32>(-100, 3, 3, &XlaBuilder::Max);
+ TestMinMax<int32>(-100, 3, 3, &Max);
}
XLA_TEST_F(ScalarComputationsTest, MinU32Above) {
const uint32 large = std::numeric_limits<int32>::max();
- TestMinMax<uint32>(large, 3, 3, &XlaBuilder::Min);
+ TestMinMax<uint32>(large, 3, 3, &Min);
}
XLA_TEST_F(ScalarComputationsTest, MinU32Below) {
- TestMinMax<uint32>(0, 5, 0, &XlaBuilder::Min);
+ TestMinMax<uint32>(0, 5, 0, &Min);
}
XLA_TEST_F(ScalarComputationsTest, MaxU32Above) {
const uint32 large = std::numeric_limits<int32>::max();
- TestMinMax<uint32>(large, 3, large, &XlaBuilder::Max);
+ TestMinMax<uint32>(large, 3, large, &Max);
}
XLA_TEST_F(ScalarComputationsTest, MaxU32Below) {
- TestMinMax<uint32>(0, 5, 5, &XlaBuilder::Max);
+ TestMinMax<uint32>(0, 5, 5, &Max);
}
XLA_TEST_F(ScalarComputationsTest, MinF32Above) {
- TestMinMax<float>(10.1f, 3.1f, 3.1f, &XlaBuilder::Min);
+ TestMinMax<float>(10.1f, 3.1f, 3.1f, &Min);
}
XLA_TEST_F(ScalarComputationsTest, MinF32Below) {
- TestMinMax<float>(-100.1f, 3.1f, -100.1f, &XlaBuilder::Min);
+ TestMinMax<float>(-100.1f, 3.1f, -100.1f, &Min);
}
XLA_TEST_F(ScalarComputationsTest, MinPropagatesNan) {
SetFastMathDisabled(true);
- TestMinMax<float>(NAN, 3.1f, NAN, &XlaBuilder::Min);
- TestMinMax<float>(-3.1f, NAN, NAN, &XlaBuilder::Min);
+ TestMinMax<float>(NAN, 3.1f, NAN, &Min);
+ TestMinMax<float>(-3.1f, NAN, NAN, &Min);
}
XLA_TEST_F(ScalarComputationsTest, MaxF32Above) {
- TestMinMax<float>(10.1f, 3.1f, 10.1f, &XlaBuilder::Max);
+ TestMinMax<float>(10.1f, 3.1f, 10.1f, &Max);
}
XLA_TEST_F(ScalarComputationsTest, MaxF32Below) {
- TestMinMax<float>(-100.1f, 3.1f, 3.1f, &XlaBuilder::Max);
+ TestMinMax<float>(-100.1f, 3.1f, 3.1f, &Max);
}
XLA_TEST_F(ScalarComputationsTest, MaxPropagatesNan) {
SetFastMathDisabled(true);
- TestMinMax<float>(NAN, 3.1f, NAN, &XlaBuilder::Max);
- TestMinMax<float>(-3.1f, NAN, NAN, &XlaBuilder::Max);
+ TestMinMax<float>(NAN, 3.1f, NAN, &Max);
+ TestMinMax<float>(-3.1f, NAN, NAN, &Max);
}
XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) {
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 000535a982..20c7c30878 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -161,6 +161,9 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
}));
break;
}
+ // Token requires no data.
+ case TOKEN:
+ break;
default:
return Unimplemented("Unsupported type for fake literal generation: %s",
ShapeUtil::HumanString(shape).c_str());
diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc
index e8f2fb44d8..8f424ae81f 100644
--- a/tensorflow/compiler/xla/tests/test_utils_test.cc
+++ b/tensorflow/compiler/xla/tests/test_utils_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
@@ -53,5 +54,23 @@ XLA_TEST_F(TestUtilsTest, UnusedParam) {
TF_ASSERT_OK(MakeFakeArguments(&module).status());
}
+XLA_TEST_F(TestUtilsTest, Token) {
+ auto module = ParseHloString(
+ R"(HloModule outfeed_module
+
+ ENTRY InfeedToOutfeed {
+ token = token[] parameter(0)
+ infeed = ((u32[3]{0}, pred[]), token[]) infeed(token)
+ infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0
+ outfeed = token[] outfeed(infeed.data, token)
+ ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token)
+ infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0
+ infeed.1.token = token[] get-tuple-element(infeed.1), index=1
+ outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token)
+ })")
+ .ValueOrDie();
+ TF_ASSERT_OK(MakeFakeArguments(module.get()).status());
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 2d7916c8b1..229b0c481f 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -25,6 +25,7 @@ py_library(
"//tensorflow/contrib/all_reduce",
"//tensorflow/contrib/batching:batch_py",
"//tensorflow/contrib/bayesflow:bayesflow_py",
+ "//tensorflow/contrib/bigtable",
"//tensorflow/contrib/boosted_trees:init_py",
"//tensorflow/contrib/checkpoint/python:checkpoint",
"//tensorflow/contrib/cloud:cloud_py",
diff --git a/tensorflow/contrib/autograph/pyct/BUILD b/tensorflow/contrib/autograph/pyct/BUILD
index 8f09689fe9..a49a4ed05c 100644
--- a/tensorflow/contrib/autograph/pyct/BUILD
+++ b/tensorflow/contrib/autograph/pyct/BUILD
@@ -22,6 +22,7 @@ py_library(
"__init__.py",
"anno.py",
"ast_util.py",
+ "cfg.py",
"compiler.py",
"inspect_utils.py",
"parser.py",
@@ -64,6 +65,17 @@ py_test(
)
py_test(
+ name = "cfg_test",
+ srcs = ["cfg_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":pyct",
+ "//tensorflow/python:client_testlib",
+ "@gast_archive//:gast",
+ ],
+)
+
+py_test(
name = "compiler_test",
srcs = ["compiler_test.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/autograph/pyct/cfg.py b/tensorflow/contrib/autograph/pyct/cfg.py
new file mode 100644
index 0000000000..666328781f
--- /dev/null
+++ b/tensorflow/contrib/autograph/pyct/cfg.py
@@ -0,0 +1,733 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Control flow graph (CFG) structure for Python AST representation.
+
+The CFG is a digraph with edges representing valid control flow. Each
+node is associated with exactly one AST node, but not all AST nodes may have
+a corresponding CFG counterpart.
+
+Once built, the CFG itself is immutable, but the values it holds need not be;
+they are usually annotated with information extracted by walking the graph.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+from enum import Enum
+
+# pylint:disable=g-bad-import-order
+import gast
+# pylint:enable=g-bad-import-order
+
+from tensorflow.contrib.autograph.pyct import compiler
+
+
+class Node(object):
+ """A node in the CFG.
+
+ Although new instances of this class are mutable, the objects that a user
+ finds in the CFG are typically not.
+
+ The nodes represent edges in the CFG graph, and maintain pointers to allow
+ efficient walking in both forward and reverse order. The following property
+ holds for all nodes: "child in node.next" iff "node in child.prev".
+
+ Attributes:
+ next: FrozenSet[Node, ...], the nodes that follow this node, in control
+ flow order
+ prev: FrozenSet[Node, ...], the nodes that precede this node, in reverse
+ control flow order
+ ast_node: ast.AST, the AST node corresponding to this CFG node
+ """
+
+ def __init__(self, next_, prev, ast_node):
+ self.next = next_
+ self.prev = prev
+ self.ast_node = ast_node
+
+ def freeze(self):
+ self.next = frozenset(self.next)
+ self.prev = frozenset(self.prev)
+
+ def __repr__(self):
+ return compiler.ast_to_source(self.ast_node).strip()
+
+
+class Graph(
+ collections.namedtuple('Graph', ['entry', 'exit', 'error', 'index'])):
+ """A Control Flow Graph.
+
+ The CFG maintains an index to allow looking up a CFG node by the AST node to
+ which it is associated. The index can also be enumerated in top-down, depth
+ first order.
+
+ Walking the graph in forward or reverse order is supported by double
+ parent-child links.
+
+ Note: the error nodes are not wired to their corresponding finally guards,
+ because these are shared, and wiring them would create a reverse path from
+ normal control flow into the error nodes, which we want to avoid.
+
+ Attributes:
+ entry: Node, the entry node
+ exit: FrozenSet[Node, ...], the exit nodes
+ error: FrozenSet[Node, ...], nodes that exit due to an explicitly raised
+ error (errors propagated from function calls are not accounted)
+ index: Dict[ast.Node, Node], mapping AST nodes to the respective CFG
+ node
+ """
+
+ def __repr__(self):
+ result = 'digraph CFG {\n'
+ for node in self.index.values():
+ result += ' %s [label="%s"];\n' % (id(node), node)
+ for node in self.index.values():
+ if node.next:
+ result += ' %s -> {%s};\n' % (id(node), ', '.join(
+ repr(id(n)) for n in node.next))
+ result += '}'
+ return result
+
+
+class _WalkMode(Enum):
+ FORWARD = 1
+ REVERSE = 2
+
+
+class GraphVisitor(object):
+ """Base class for a CFG visitors.
+
+ This implementation is not thread safe.
+
+ The visitor has some facilities to simplify dataflow analyses. In particular,
+ it allows revisiting the nodes at the decision of the subclass. This can be
+ used to visit the graph until the state reaches a fixed point.
+
+ For more details on dataflow analysis, see
+ https://www.seas.harvard.edu/courses/cs252/2011sp/slides/Lec02-Dataflow.pdf
+
+ Note: the literature generally suggests visiting successor nodes only when the
+ state of the current node changed, regardless of whether that successor has
+ ever been visited. This implementation visits every successor at least once.
+
+ Attributes:
+ graph: Graph
+ in_: Dict[Node, Any], stores node-keyed state during a visit
+ out: Dict[Node, Any], stores node-keyed state during a visit
+ """
+
+ def reset(self):
+ self.in_ = {
+ node: self.init_state(node) for node in self.graph.index.values()
+ }
+ self.out = {
+ node: self.init_state(node) for node in self.graph.index.values()
+ }
+
+ def init_state(self, node):
+ """State initialization function. Optional to overload.
+
+ An in/out state slot will be created for each node in the graph. Subclasses
+ may overload this to control what that is initialized to.
+
+ Args:
+ node: Node
+ """
+ del node
+ return None
+
+ def visit_node(self, node):
+ """Visitor function.
+
+ Args:
+ node: Node
+ Returns:
+ bool, whether the node should be revisited; subclasses can visit every
+ reachable node exactly once by always returning False
+ """
+ raise NotImplementedError('Subclasses must implement this.')
+
+ def _visit_internal(self, mode):
+ """Visits the CFG, depth-first."""
+ assert mode in (_WalkMode.FORWARD, _WalkMode.REVERSE)
+ if mode == _WalkMode.FORWARD:
+ open_ = [self.graph.entry]
+ elif mode == _WalkMode.REVERSE:
+ open_ = list(self.graph.exit)
+ closed = set()
+ self.reset()
+
+ while open_:
+ node = open_.pop(0)
+ closed.add(node)
+
+ should_revisit = self.visit_node(node)
+
+ if mode == _WalkMode.FORWARD:
+ children = node.next
+ elif mode == _WalkMode.REVERSE:
+ children = node.prev
+
+ for next_ in children:
+ if should_revisit or next_ not in closed:
+ open_.append(next_)
+
+ def visit_forward(self, graph):
+ self.graph = graph
+ self._visit_internal(_WalkMode.FORWARD)
+
+ def visit_reverse(self, graph):
+ self.graph = graph
+ self._visit_internal(_WalkMode.REVERSE)
+
+
+class GraphBuilder(object):
+ """Builder that constructs a CFG from a given AST.
+
+ This GraphBuilder facilitates constructing the DAG that forms the CFG when
+ nodes
+ are supplied in lexical order (i.e., top-down, depth first). Under these
+ conditions, it supports building patterns found in typical structured
+ programs.
+
+ This builder ignores the flow generated by exceptions, which are assumed to
+ always be catastrophic and present purely for diagnostic purposes (e.g. to
+ print debug information). Statements like raise and try/catch sections are
+ allowed and will generate control flow edges, but ordinaty statements are
+ assumed not to raise exceptions.
+
+ Finally sections are also correctly interleaved between break/continue/return
+ nodes and their subsequent statements.
+
+ Important concepts:
+ * nodes - nodes refer refer to CFG nodes; AST nodes are qualified explicitly
+ * leaf set - since the graph is constructed gradually, a leaf set maintains
+ the CFG nodes that will precede the node that the builder expects to
+ receive next; when an ordinary node is added, it is connected to the
+ existing leaves and it in turn becomes the new leaf
+ * jump nodes - nodes that should generate edges other than what
+ ordinary nodes would; these correspond to break, continue and return
+ statements
+ * sections - logical delimiters for subgraphs that require special
+ edges; there are various types of nodes, each admitting various
+ types of jump nodes; sections are identified by their corresponding AST
+ node
+ """
+
+ # TODO(mdan): Perhaps detail this in a markdown doc.
+ # TODO(mdan): Add exception support.
+
+ def __init__(self, parent_ast_node):
+ self.reset()
+ self.parent = parent_ast_node
+
+ def reset(self):
+ """Resets the state of this factory."""
+ self.head = None
+ self.errors = set()
+ self.node_index = collections.OrderedDict()
+
+ # TODO(mdan): Too many primitives. Use classes.
+ self.leaves = set()
+
+ self.finally_sections = {}
+ self.finally_section_subgraphs = {} # Values are [begin_node, exit_nodes]
+ # Whether the guard section can be reached from the statement that precedes
+ # it.
+ self.finally_section_has_direct_flow = {}
+ # Finally sections that await their first node.
+ self.pending_finally_sections = set()
+
+ # Exit jumps keyed by the section they affect.
+ self.exits = {}
+
+ # The entry of loop sections, keyed by the section.
+ self.section_entry = {}
+ # Continue jumps keyed by the section they affect.
+ self.continues = {}
+
+ # The entry of conditional sections, keyed by the section.
+ self.cond_entry = {}
+ # Lists of leaf nodes corresponding to each branch in the section.
+ self.cond_leaves = {}
+
+ def _connect_nodes(self, first, second):
+ """Connects nodes to signify that control flows from first to second.
+
+ Args:
+ first: Union[Set[Node, ...], Node]
+ second: Node
+ """
+ if isinstance(first, Node):
+ first.next.add(second)
+ second.prev.add(first)
+ else:
+ for node in first:
+ self._connect_nodes(node, second)
+
+ def _add_new_node(self, ast_node):
+ """Grows the graph by adding a CFG node following the current leaves."""
+ if ast_node is self.node_index:
+ raise ValueError('%s added twice' % ast_node)
+ node = Node(next_=set(), prev=set(), ast_node=ast_node)
+ self.node_index[ast_node] = node
+
+ if self.head is None:
+ self.head = node
+
+ for leaf in self.leaves:
+ self._connect_nodes(leaf, node)
+
+ # If any finally section awaits its first node, populate it.
+ for section_id in self.pending_finally_sections:
+ self.finally_section_subgraphs[section_id][0] = node
+ self.pending_finally_sections = set()
+
+ return node
+
+ def add_ordinary_node(self, ast_node):
+ """Grows the graph by adding an ordinary CFG node.
+
+ Ordinary nodes are followed by the next node, in lexical order, that is,
+ they become the new leaf set.
+
+ Args:
+ ast_node: ast.AST
+ Returns:
+ Node
+ """
+ node = self._add_new_node(ast_node)
+ self.leaves = set((node,))
+ return node
+
+ def _add_jump_node(self, ast_node, guards):
+ """Grows the graph by adding a jump node.
+
+ Jump nodes are added to the current leaf set, and the leaf set becomes
+ empty. If the jump node is the last in a cond section, then it may be added
+ back to the leaf set by a separate mechanism.
+
+ Args:
+ ast_node: ast.AST
+ guards: Tuple[ast.AST, ...], the finally sections active for this node
+ Returns:
+ Node
+ """
+ node = self._add_new_node(ast_node)
+ self.leaves = set()
+ # The guards themselves may not yet be complete, and will be wired later.
+ self.finally_sections[node] = guards
+ return node
+
+ def _connect_jump_to_finally_sections(self, node):
+ """Connects a jump node to the finally sections protecting it."""
+ cursor = set((node,))
+ for guard_section_id in self.finally_sections[node]:
+ guard_begin, guard_ends = self.finally_section_subgraphs[guard_section_id]
+ self._connect_nodes(cursor, guard_begin)
+ cursor = guard_ends
+ del self.finally_sections[node]
+ # TODO(mdan): Should garbage-collect finally_section_subgraphs.
+ return cursor
+
+ def add_exit_node(self, ast_node, section_id, guards):
+ """Grows the graph by adding an exit node.
+
+ This node becomes an exit for the current section.
+
+ Args:
+ ast_node: ast.AST
+ section_id: Hashable, the node for which ast_node should be considered
+ to be an exit node
+ guards: Tuple[ast.AST, ...], the finally sections that guard ast_node
+ """
+ node = self._add_jump_node(ast_node, guards)
+ self.exits[section_id].add(node)
+
+ def add_continue_node(self, ast_node, section_id, guards):
+ """Grows the graph by adding a reentry node.
+
+ This node causes control flow to go back to the loop section's entry.
+
+ Args:
+ ast_node: ast.AST
+ section_id: Hashable, the node for which ast_node should be considered
+ to be an exit node
+ guards: Tuple[ast.AST, ...], the finally sections that guard ast_node
+ """
+ node = self._add_jump_node(ast_node, guards)
+ self.continues[section_id].add(node)
+
+ def add_error_node(self, ast_node, guards):
+ """Grows the graph by adding an error node.
+
+ This node becomes an exit for the entire graph.
+
+ Args:
+ ast_node: ast.AST
+ guards: Tuple[ast.AST, ...], the finally sections that guard ast_node
+ """
+ node = self._add_jump_node(ast_node, guards)
+ self.errors.add(node)
+ self.leaves = set()
+
+ def enter_section(self, section_id):
+ """Enters a regular section.
+
+ Regular sections admit exit jumps, which end the section.
+
+ Args:
+ section_id: Hashable, the same node that will be used in calls to the
+ ast_node arg passed to add_exit_node
+ """
+ assert section_id not in self.exits
+ self.exits[section_id] = set()
+
+ def exit_section(self, section_id):
+ """Exits a regular section."""
+
+ # Exits are jump nodes, which may be protected.
+ for exit_ in self.exits[section_id]:
+ self.leaves |= self._connect_jump_to_finally_sections(exit_)
+
+ del self.exits[section_id]
+
+ def enter_loop_section(self, section_id, entry_node):
+ """Enters a loop section.
+
+ Loop sections define an entry node. The end of the section always flows back
+ to the entry node. These admit continue jump nodes which also flow to the
+ entry node.
+
+ Args:
+ section_id: Hashable, the same node that will be used in calls to the
+ ast_node arg passed to add_continue_node
+ entry_node: ast.AST, the entry node into the loop (e.g. the test node
+ for while loops)
+ """
+ assert section_id not in self.section_entry
+ assert section_id not in self.continues
+ self.continues[section_id] = set()
+ node = self.add_ordinary_node(entry_node)
+ self.section_entry[section_id] = node
+
+ def exit_loop_section(self, section_id):
+ """Exits a loop section."""
+ self._connect_nodes(self.leaves, self.section_entry[section_id])
+
+ # continues are jump nodes, which may be protected.
+ for reentry in self.continues[section_id]:
+ guard_ends = self._connect_jump_to_finally_sections(reentry)
+ self._connect_nodes(guard_ends, self.section_entry[section_id])
+
+ # Loop nodes always loop back.
+ self.leaves = set((self.section_entry[section_id],))
+
+ del self.continues[section_id]
+ del self.section_entry[section_id]
+
+ def enter_cond_section(self, section_id):
+ """Enters a conditional section.
+
+ Conditional sections define an entry node, and one or more branches.
+
+ Args:
+ section_id: Hashable, the same node that will be used in calls to the
+ section_id arg passed to new_cond_branch
+ """
+
+ assert section_id not in self.cond_entry
+ assert section_id not in self.cond_leaves
+ self.cond_leaves[section_id] = []
+
+ def new_cond_branch(self, section_id):
+ """Begins a new branch in a cond section."""
+ assert section_id in self.cond_leaves
+
+ if section_id in self.cond_entry:
+ # Subsequent splits move back to the split point, and memorize the
+ # current leaves.
+ self.cond_leaves[section_id].append(self.leaves)
+ self.leaves = self.cond_entry[section_id]
+ else:
+ # If this is the first time we split a section, just remember the split
+ # point.
+ self.cond_entry[section_id] = self.leaves
+
+ def exit_cond_section(self, section_id):
+ """Exits a conditional section."""
+ for split in self.cond_leaves[section_id]:
+ self.leaves |= split
+ del self.cond_entry[section_id]
+ del self.cond_leaves[section_id]
+
+ def enter_finally_section(self, section_id):
+ """Enters a finally section."""
+ # TODO(mdan): This, not the caller, should track the active sections.
+ self.finally_section_subgraphs[section_id] = [None, None]
+ if self.leaves:
+ self.finally_section_has_direct_flow[section_id] = True
+ else:
+ self.finally_section_has_direct_flow[section_id] = False
+ self.pending_finally_sections.add(section_id)
+
+ def exit_finally_section(self, section_id):
+ """Exits a finally section."""
+ assert section_id not in self.pending_finally_sections, 'Empty finally?'
+ self.finally_section_subgraphs[section_id][1] = self.leaves
+ # If the guard can only be reached by a jump, then it will not flow
+ # into the statement that follows it.
+ if not self.finally_section_has_direct_flow[section_id]:
+ self.leaves = set()
+ del self.finally_section_has_direct_flow[section_id]
+
+ def build(self):
+ """Returns the CFG accumulated so far and resets the builder.
+
+ Returns:
+ Graph
+ """
+ # Freeze the nodes.
+ for node in self.node_index.values():
+ node.freeze()
+
+ result = Graph(
+ entry=self.head,
+ exit=self.leaves,
+ error=self.errors,
+ index=self.node_index)
+
+ # Reset the state.
+ self.reset()
+
+ return result
+
+
+class AstToCfg(gast.NodeVisitor):
+ """Converts an AST to CFGs.
+
+ A separate CFG will be constructed for each function.
+ """
+
+ # TODO(mdan): Figure out how to deal with closures.
+
+ def __init__(self):
+ super(AstToCfg, self).__init__()
+
+ self.builder_stack = []
+ self.builder = None
+ self.cfgs = {}
+
+ self.lexical_scopes = []
+
+ def _enter_lexical_scope(self, node):
+ self.lexical_scopes.append(node)
+
+ def _exit_lexical_scope(self, node):
+ leaving_node = self.lexical_scopes.pop()
+ assert node == leaving_node
+
+ def _get_enclosing_scopes(self, include, stop_at):
+ included = []
+ for node in reversed(self.lexical_scopes):
+ if isinstance(node, include):
+ included.append(node)
+ if isinstance(node, stop_at):
+ return node, included
+ return None, included
+
+ def _process_basic_statement(self, node):
+ self.generic_visit(node)
+ self.builder.add_ordinary_node(node)
+
+ def _process_exit_statement(self, node, *exits_nodes_of_type):
+ # Note: this is safe because we process functions separately.
+ try_node, guards = self._get_enclosing_scopes(
+ include=(gast.Try,),
+ stop_at=tuple(exits_nodes_of_type),
+ )
+ if try_node is None:
+ raise ValueError(
+ '%s that is not enclosed by any of %s' % (node, exits_nodes_of_type))
+ self.builder.add_exit_node(node, try_node, guards)
+
+ def _process_continue_statement(self, node, *loops_to_nodes_of_type):
+ # Note: this is safe because we process functions separately.
+ try_node, guards = self._get_enclosing_scopes(
+ include=(gast.Try,),
+ stop_at=tuple(loops_to_nodes_of_type),
+ )
+ if try_node is None:
+ raise ValueError('%s that is not enclosed by any of %s' %
+ (node, loops_to_nodes_of_type))
+ self.builder.add_continue_node(node, try_node, guards)
+
+ def visit_FunctionDef(self, node):
+ self.builder_stack.append(self.builder)
+ self.builder = GraphBuilder(node)
+
+ self._enter_lexical_scope(node)
+ self.builder.enter_section(node)
+
+ self._process_basic_statement(node.args)
+ for stmt in node.body:
+ self.visit(stmt)
+
+ self.builder.exit_section(node)
+ self._exit_lexical_scope(node)
+
+ self.cfgs[node] = self.builder.build()
+ self.builder = self.builder_stack.pop()
+
+ def visit_Lambda(self, node):
+ # TODO(mdan): Treat like FunctionDef? That would be a separate CFG.
+ raise NotImplementedError()
+
+ def visit_Return(self, node):
+ self._process_exit_statement(node, gast.FunctionDef)
+
+ def visit_Expr(self, node):
+ self._process_basic_statement(node)
+
+ def visit_Assign(self, node):
+ self._process_basic_statement(node)
+
+ def visit_AnnAssign(self, node):
+ self._process_basic_statement(node)
+
+ def visit_AugAssign(self, node):
+ self._process_basic_statement(node)
+
+ def visit_Print(self, node):
+ self._process_basic_statement(node)
+
+ def visit_Raise(self, node):
+ try_node, guards = self._get_enclosing_scopes(
+ include=(gast.Try,),
+ stop_at=(gast.FunctionDef,),
+ )
+ if try_node is None:
+ raise ValueError('%s that is not enclosed by any FunctionDef' % node)
+ self.builder.add_error_node(node, try_node, guards)
+
+ def visit_Assert(self, node):
+ # Ignoring the effect of exceptions.
+ self._process_basic_statement(node)
+
+ def visit_Delete(self, node):
+ self._process_basic_statement(node)
+
+ def visit_If(self, node):
+ # No need to track ifs as lexical scopes, for now.
+ # Lexical scopes are generally tracked in order to be able to resolve the
+ # targets of jump statements like break/continue/etc. Since there is no
+ # statement that can interrupt a conditional, we don't need to track their
+ # lexical scope. That may change in the future.
+
+ self.builder.enter_cond_section(node)
+ self._process_basic_statement(node.test)
+
+ self.builder.new_cond_branch(node)
+ for stmt in node.body:
+ self.visit(stmt)
+
+ self.builder.new_cond_branch(node)
+ for stmt in node.orelse:
+ self.visit(stmt)
+
+ self.builder.exit_cond_section(node)
+
+ def visit_While(self, node):
+ self._enter_lexical_scope(node)
+
+ self.builder.enter_section(node)
+
+ self.builder.enter_loop_section(node, node.test)
+ for stmt in node.body:
+ self.visit(stmt)
+ self.builder.exit_loop_section(node)
+
+ # Note: although the orelse is technically part of the loop node,
+ # the statements inside it don't affect the loop itself. For example, a
+ # break in the loop's orelse will not affect the loop itself.
+ self._exit_lexical_scope(node)
+
+ for stmt in node.orelse:
+ self.visit(stmt)
+
+ self.builder.exit_section(node)
+
+ def visit_For(self, node):
+ self._enter_lexical_scope(node)
+
+ self.builder.enter_section(node)
+
+ # TODO(mdan): Strictly speaking, this should be node.target + node.iter.
+ # A blind dataflow analysis would have to process both node.target and
+ # node.iter to properly process read and write access.
+ self.builder.enter_loop_section(node, node.iter)
+ for stmt in node.body:
+ self.visit(stmt)
+ self.builder.exit_loop_section(node)
+
+ # Note: although the orelse is technically part of the loop node,
+ # they don't count as loop bodies. For example, a break in the loop's
+ # orelse will affect the parent loop, not the current one.
+ self._exit_lexical_scope(node)
+
+ for stmt in node.orelse:
+ self.visit(stmt)
+
+ self.builder.exit_section(node)
+
+ def visit_Break(self, node):
+ self._process_exit_statement(node, gast.While, gast.For)
+
+ def visit_Continue(self, node):
+ self._process_continue_statement(node, gast.While, gast.For)
+
+ def visit_Try(self, node):
+ self._enter_lexical_scope(node)
+
+ for stmt in node.body:
+ self.visit(stmt)
+ # Unlike loops, the orelse is a simple continuation of the body.
+ for stmt in node.orelse:
+ self.visit(stmt)
+
+ if node.handlers:
+ # TODO(mdan): Should we still support bare try/except? Might be confusing.
+ raise NotImplementedError('exceptions are not yet supported')
+
+ self._exit_lexical_scope(node)
+
+ self.builder.enter_finally_section(node)
+ for stmt in node.finalbody:
+ self.visit(stmt)
+ self.builder.exit_finally_section(node)
+
+ def visit_With(self, node):
+ # TODO(mdan): Mark the context manager's exit call as exit guard.
+ self._process_basic_statement(node.items)
+ for stmt in node.body:
+ self.visit(stmt)
+
+
+def build(node):
+ builder = AstToCfg()
+ builder.visit(node)
+ return builder.cfgs
diff --git a/tensorflow/contrib/autograph/pyct/cfg_test.py b/tensorflow/contrib/autograph/pyct/cfg_test.py
new file mode 100644
index 0000000000..00afadd521
--- /dev/null
+++ b/tensorflow/contrib/autograph/pyct/cfg_test.py
@@ -0,0 +1,790 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for cfg module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph.pyct import cfg
+from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.platform import test
+
+
+class CountingVisitor(cfg.GraphVisitor):
+
+ def __init__(self):
+ self.counts = {}
+
+ def visit_node(self, node):
+ self.counts[node.ast_node] = self.counts.get(node.ast_node, 0) + 1
+ return False # visit only once
+
+
+class GraphVisitorTest(test.TestCase):
+
+ def _build_cfg(self, fn):
+ node, _ = parser.parse_entity(fn)
+ cfgs = cfg.build(node)
+ return cfgs, node
+
+ def test_basic_coverage_forward(self):
+
+ def test_fn(a):
+ while a > 0:
+ a = 1
+ break
+ return a # pylint:disable=unreachable
+ a = 2
+
+ graphs, node = self._build_cfg(test_fn)
+ graph, = graphs.values()
+ visitor = CountingVisitor()
+ visitor.visit_forward(graph)
+ fn_node = node.body[0]
+
+ self.assertEqual(visitor.counts[fn_node.args], 1)
+ self.assertEqual(visitor.counts[fn_node.body[0].test], 1)
+ self.assertEqual(visitor.counts[fn_node.body[0].body[0]], 1)
+ self.assertEqual(visitor.counts[fn_node.body[0].body[1]], 1)
+ # The return node should be unreachable in forward direction.
+ self.assertTrue(fn_node.body[0].body[2] not in visitor.counts)
+ self.assertEqual(visitor.counts[fn_node.body[1]], 1)
+
+ def test_basic_coverage_reverse(self):
+
+ def test_fn(a):
+ while a > 0:
+ a = 1
+ break
+ return a # pylint:disable=unreachable
+ a = 2
+
+ graphs, node = self._build_cfg(test_fn)
+ graph, = graphs.values()
+ visitor = CountingVisitor()
+ visitor.visit_reverse(graph)
+ fn_node = node.body[0]
+
+ self.assertEqual(visitor.counts[fn_node.args], 1)
+ self.assertEqual(visitor.counts[fn_node.body[0].test], 1)
+ self.assertEqual(visitor.counts[fn_node.body[0].body[0]], 1)
+ self.assertEqual(visitor.counts[fn_node.body[0].body[1]], 1)
+ self.assertTrue(visitor.counts[fn_node.body[0].body[2]], 1)
+ self.assertEqual(visitor.counts[fn_node.body[1]], 1)
+
+
+class AstToCfgTest(test.TestCase):
+
+ def _build_cfg(self, fn):
+ node, _ = parser.parse_entity(fn)
+ cfgs = cfg.build(node)
+ return cfgs
+
+ def _repr_set(self, node_set):
+ return set(repr(n) for n in node_set)
+
+ def _as_set(self, elements):
+ if elements is None:
+ return frozenset()
+ elif isinstance(elements, str):
+ return frozenset((elements,))
+ else:
+ return frozenset(elements)
+
+ def assertGraphMatches(self, graph, edges):
+ """Tests whether the CFG contains the specified edges."""
+ for prev, node_repr, next_ in edges:
+ matched = False
+ for cfg_node in graph.index.values():
+ if repr(cfg_node) == node_repr:
+ if (self._as_set(prev) == set(map(repr, cfg_node.prev)) and
+ self._as_set(next_) == set(map(repr, cfg_node.next))):
+ matched = True
+ break
+ if not matched:
+ self.fail(
+ 'match failed for node "%s" in graph:\n%s' % (node_repr, graph))
+
+ def test_straightline(self):
+
+ def test_fn(a):
+ a += 1
+ a = 2
+ a = 3
+ return
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (None, 'a', 'a += 1'),
+ ('a += 1', 'a = 2', 'a = 3'),
+ ('a = 2', 'a = 3', 'return'),
+ ('a = 3', 'return', None),
+ ),
+ )
+
+ def test_straightline_no_return(self):
+
+ def test_fn(a, b):
+ a = b + 1
+ a += max(a)
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (None, 'a, b', 'a = b + 1'),
+ ('a = b + 1', 'a += max(a)', None),
+ ),
+ )
+
+ def test_unreachable_code(self):
+
+ def test_fn(a):
+ return
+ a += 1 # pylint:disable=unreachable
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (None, 'a', 'return'),
+ ('a', 'return', None),
+ (None, 'a += 1', None),
+ ),
+ )
+
+ def test_branch_straightline(self):
+
+ def test_fn(a):
+ if a > 0:
+ a = 1
+ else:
+ a += -1
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (None, 'a', '(a > 0)'),
+ ('(a > 0)', 'a = 1', None),
+ ('(a > 0)', 'a += -1', None),
+ ),
+ )
+
+ def test_branch_nested(self):
+
+ def test_fn(a):
+ if a > 0:
+ if a > 1:
+ a = 1
+ else:
+ a = 2
+ else:
+ if a > 2:
+ a = 3
+ else:
+ a = 4
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (None, 'a', '(a > 0)'),
+ ('a', '(a > 0)', ('(a > 1)', '(a > 2)')),
+ ('(a > 0)', '(a > 1)', ('a = 1', 'a = 2')),
+ ('(a > 1)', 'a = 1', None),
+ ('(a > 1)', 'a = 2', None),
+ ('(a > 0)', '(a > 2)', ('a = 3', 'a = 4')),
+ ('(a > 2)', 'a = 3', None),
+ ('(a > 2)', 'a = 4', None),
+ ),
+ )
+
+ def test_branch_straightline_semi(self):
+
+ def test_fn(a):
+ if a > 0:
+ a = 1
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (None, 'a', '(a > 0)'),
+ ('a', '(a > 0)', 'a = 1'),
+ ('(a > 0)', 'a = 1', None),
+ ),
+ )
+
+ def test_branch_return(self):
+
+ def test_fn(a):
+ if a > 0:
+ return
+ else:
+ a = 1
+ a = 2
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ ('a', '(a > 0)', ('return', 'a = 1')),
+ ('(a > 0)', 'a = 1', 'a = 2'),
+ ('(a > 0)', 'return', None),
+ ('a = 1', 'a = 2', None),
+ ),
+ )
+
+ def test_branch_return_minimal(self):
+
+ def test_fn(a):
+ if a > 0:
+ return
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ ('a', '(a > 0)', 'return'),
+ ('(a > 0)', 'return', None),
+ ),
+ )
+
+ def test_while_straightline(self):
+
+ def test_fn(a):
+ while a > 0:
+ a = 1
+ a = 2
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), '(a > 0)', ('a = 1', 'a = 2')),
+ ('(a > 0)', 'a = 1', '(a > 0)'),
+ ('(a > 0)', 'a = 2', None),
+ ),
+ )
+
+ def test_while_else_straightline(self):
+
+ def test_fn(a):
+ while a > 0:
+ a = 1
+ else: # pylint:disable=useless-else-on-loop
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), '(a > 0)', ('a = 1', 'a = 2')),
+ ('(a > 0)', 'a = 1', '(a > 0)'),
+ ('(a > 0)', 'a = 2', 'a = 3'),
+ ('a = 2', 'a = 3', None),
+ ),
+ )
+
+ def test_while_else_continue(self):
+
+ def test_fn(a):
+ while a > 0:
+ if a > 1:
+ continue
+ else:
+ a = 0
+ a = 1
+ else: # pylint:disable=useless-else-on-loop
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'continue', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')),
+ ('(a > 0)', '(a > 1)', ('continue', 'a = 0')),
+ ('(a > 1)', 'continue', '(a > 0)'),
+ ('a = 0', 'a = 1', '(a > 0)'),
+ ('(a > 0)', 'a = 2', 'a = 3'),
+ ('a = 2', 'a = 3', None),
+ ),
+ )
+
+ def test_while_else_break(self):
+
+ def test_fn(a):
+ while a > 0:
+ if a > 1:
+ break
+ a = 1
+ else:
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')),
+ ('(a > 0)', '(a > 1)', ('break', 'a = 1')),
+ ('(a > 1)', 'break', 'a = 3'),
+ ('(a > 1)', 'a = 1', '(a > 0)'),
+ ('(a > 0)', 'a = 2', 'a = 3'),
+ (('break', 'a = 2'), 'a = 3', None),
+ ),
+ )
+
+ def test_while_else_return(self):
+
+ def test_fn(a):
+ while a > 0:
+ if a > 1:
+ return
+ a = 1
+ else: # pylint:disable=useless-else-on-loop
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')),
+ ('(a > 0)', '(a > 1)', ('return', 'a = 1')),
+ ('(a > 1)', 'return', None),
+ ('(a > 1)', 'a = 1', '(a > 0)'),
+ ('(a > 0)', 'a = 2', 'a = 3'),
+ ('a = 2', 'a = 3', None),
+ ),
+ )
+
+ def test_while_nested_straightline(self):
+
+ def test_fn(a):
+ while a > 0:
+ while a > 1:
+ a = 1
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')),
+ (('(a > 0)', 'a = 1'), '(a > 1)', ('a = 1', 'a = 2')),
+ ('(a > 1)', 'a = 1', '(a > 1)'),
+ ('(a > 1)', 'a = 2', '(a > 0)'),
+ ('(a > 0)', 'a = 3', None),
+ ),
+ )
+
+ def test_while_nested_continue(self):
+
+ def test_fn(a):
+ while a > 0:
+ while a > 1:
+ if a > 3:
+ continue
+ a = 1
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')),
+ (('(a > 0)', 'continue', 'a = 1'), '(a > 1)', ('(a > 3)', 'a = 2')),
+ ('(a > 1)', '(a > 3)', ('continue', 'a = 1')),
+ ('(a > 3)', 'continue', '(a > 1)'),
+ ('(a > 3)', 'a = 1', '(a > 1)'),
+ ('(a > 1)', 'a = 2', '(a > 0)'),
+ ('(a > 0)', 'a = 3', None),
+ ),
+ )
+
+ def test_while_nested_break(self):
+
+ def test_fn(a):
+ while a > 0:
+ while a > 1:
+ if a > 2:
+ break
+ a = 1
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')),
+ (('(a > 0)', 'a = 1'), '(a > 1)', ('(a > 2)', 'a = 2')),
+ ('(a > 1)', '(a > 2)', ('break', 'a = 1')),
+ ('(a > 2)', 'break', 'a = 2'),
+ ('(a > 2)', 'a = 1', '(a > 1)'),
+ (('(a > 1)', 'break'), 'a = 2', '(a > 0)'),
+ ('(a > 0)', 'a = 3', None),
+ ),
+ )
+
+ def test_for_straightline(self):
+
+ def test_fn(a):
+ for a in range(0, a):
+ a = 1
+ a = 2
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), 'range(0, a)', ('a = 1', 'a = 2')),
+ ('range(0, a)', 'a = 1', 'range(0, a)'),
+ ('range(0, a)', 'a = 2', None),
+ ),
+ )
+
+ def test_for_else_straightline(self):
+
+ def test_fn(a):
+ for a in range(0, a):
+ a = 1
+ else: # pylint:disable=useless-else-on-loop
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), 'range(0, a)', ('a = 1', 'a = 2')),
+ ('range(0, a)', 'a = 1', 'range(0, a)'),
+ ('range(0, a)', 'a = 2', 'a = 3'),
+ ('a = 2', 'a = 3', None),
+ ),
+ )
+
+ def test_for_else_continue(self):
+
+ def test_fn(a):
+ for a in range(0, a):
+ if a > 1:
+ continue
+ else:
+ a = 0
+ a = 1
+ else: # pylint:disable=useless-else-on-loop
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'continue', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')),
+ ('range(0, a)', '(a > 1)', ('continue', 'a = 0')),
+ ('(a > 1)', 'continue', 'range(0, a)'),
+ ('(a > 1)', 'a = 0', 'a = 1'),
+ ('a = 0', 'a = 1', 'range(0, a)'),
+ ('range(0, a)', 'a = 2', 'a = 3'),
+ ('a = 2', 'a = 3', None),
+ ),
+ )
+
+ def test_for_else_break(self):
+
+ def test_fn(a):
+ for a in range(0, a):
+ if a > 1:
+ break
+ a = 1
+ else:
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')),
+ ('range(0, a)', '(a > 1)', ('break', 'a = 1')),
+ ('(a > 1)', 'break', 'a = 3'),
+ ('(a > 1)', 'a = 1', 'range(0, a)'),
+ ('range(0, a)', 'a = 2', 'a = 3'),
+ (('break', 'a = 2'), 'a = 3', None),
+ ),
+ )
+
+ def test_for_else_return(self):
+
+ def test_fn(a):
+ for a in range(0, a):
+ if a > 1:
+ return
+ a = 1
+ else: # pylint:disable=useless-else-on-loop
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')),
+ ('range(0, a)', '(a > 1)', ('return', 'a = 1')),
+ ('(a > 1)', 'return', None),
+ ('(a > 1)', 'a = 1', 'range(0, a)'),
+ ('range(0, a)', 'a = 2', 'a = 3'),
+ ('a = 2', 'a = 3', None),
+ ),
+ )
+
+ def test_for_nested_straightline(self):
+
+ def test_fn(a):
+ for a in range(0, a):
+ for b in range(1, a):
+ b += 1
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')),
+ (('range(0, a)', 'b += 1'), 'range(1, a)', ('b += 1', 'a = 2')),
+ ('range(1, a)', 'b += 1', 'range(1, a)'),
+ ('range(1, a)', 'a = 2', 'range(0, a)'),
+ ('range(0, a)', 'a = 3', None),
+ ),
+ )
+
+ def test_for_nested_continue(self):
+
+ def test_fn(a):
+ for a in range(0, a):
+ for b in range(1, a):
+ if a > 3:
+ continue
+ b += 1
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')),
+ (('range(0, a)', 'continue', 'b += 1'), 'range(1, a)',
+ ('(a > 3)', 'a = 2')),
+ ('range(1, a)', '(a > 3)', ('continue', 'b += 1')),
+ ('(a > 3)', 'continue', 'range(1, a)'),
+ ('(a > 3)', 'b += 1', 'range(1, a)'),
+ ('range(1, a)', 'a = 2', 'range(0, a)'),
+ ('range(0, a)', 'a = 3', None),
+ ),
+ )
+
+ def test_for_nested_break(self):
+
+ def test_fn(a):
+ for a in range(0, a):
+ for b in range(1, a):
+ if a > 2:
+ break
+ b += 1
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')),
+ (('range(0, a)', 'b += 1'), 'range(1, a)', ('(a > 2)', 'a = 2')),
+ ('range(1, a)', '(a > 2)', ('break', 'b += 1')),
+ ('(a > 2)', 'break', 'a = 2'),
+ ('(a > 2)', 'b += 1', 'range(1, a)'),
+ (('range(1, a)', 'break'), 'a = 2', 'range(0, a)'),
+ ('range(0, a)', 'a = 3', None),
+ ),
+ )
+
+ def test_complex(self):
+
+ def test_fn(a):
+ b = 0
+ while a > 0:
+ for b in range(0, a):
+ if a > 2:
+ break
+ if a > 3:
+ if a > 4:
+ continue
+ else:
+ max(a)
+ break
+ b += 1
+ else: # for b in range(0, a):
+ return a
+ a = 2
+ for a in range(1, a):
+ return b
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('b = 0', 'a = 2'), '(a > 0)', ('range(0, a)', 'range(1, a)')),
+ (
+ ('(a > 0)', 'continue', 'b += 1'),
+ 'range(0, a)',
+ ('(a > 2)', 'return a'),
+ ),
+ ('range(0, a)', '(a > 2)', ('(a > 3)', 'break')),
+ ('(a > 2)', 'break', 'a = 2'),
+ ('(a > 2)', '(a > 3)', ('(a > 4)', 'b += 1')),
+ ('(a > 3)', '(a > 4)', ('continue', 'max(a)')),
+ ('(a > 4)', 'max(a)', 'break'),
+ ('max(a)', 'break', 'a = 2'),
+ ('(a > 4)', 'continue', 'range(0, a)'),
+ ('(a > 3)', 'b += 1', 'range(0, a)'),
+ ('range(0, a)', 'return a', None),
+ ('break', 'a = 2', '(a > 0)'),
+ ('(a > 0)', 'range(1, a)', ('return b', 'a = 3')),
+ ('range(1, a)', 'return b', None),
+ ('range(1, a)', 'a = 3', None),
+ ),
+ )
+
+ def test_finally_straightline(self):
+
+ def test_fn(a):
+ try:
+ a += 1
+ finally:
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ ('a', 'a += 1', 'a = 2'),
+ ('a += 1', 'a = 2', 'a = 3'),
+ ('a = 2', 'a = 3', None),
+ ),
+ )
+
+ def test_return_finally(self):
+
+ def test_fn(a):
+ try:
+ return a
+ finally:
+ a = 1
+ a = 2
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ ('a', 'return a', 'a = 1'),
+ ('return a', 'a = 1', None),
+ (None, 'a = 2', None),
+ ),
+ )
+
+ def test_break_finally(self):
+
+ def test_fn(a):
+ while a > 0:
+ try:
+ break
+ finally:
+ a = 1
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ ('a', '(a > 0)', 'break'),
+ ('(a > 0)', 'break', 'a = 1'),
+ ('break', 'a = 1', None),
+ ),
+ )
+
+ def test_continue_finally(self):
+
+ def test_fn(a):
+ while a > 0:
+ try:
+ continue
+ finally:
+ a = 1
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), '(a > 0)', 'continue'),
+ ('(a > 0)', 'continue', 'a = 1'),
+ ('continue', 'a = 1', '(a > 0)'),
+ ),
+ )
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py
index 39eca6e444..4acc4ed66a 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py
@@ -286,7 +286,7 @@ class Forward(object):
# TODO(alexbw): see if we can simplify by visiting breadth-first
def visit(self, node):
- """Depth-first walking the CFG, applying dataflow information propagation."""
+ """Depth-first walking the CFG, applying dataflow info propagation."""
# node.value is None only for the exit CfgNode.
if not node.value:
return
diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
index 032b859d46..68ead2f760 100644
--- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
+++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
@@ -192,7 +192,7 @@ def _logspace_mean(log_values):
def expectation(f, samples, log_prob=None, use_reparametrization=True,
axis=0, keep_dims=False, name=None):
- """Computes the Monte-Carlo approximation of \\(E_p[f(X)]\\).
+ r"""Computes the Monte-Carlo approximation of \\(E_p[f(X)]\\).
This function computes the Monte-Carlo approximation of an expectation, i.e.,
diff --git a/tensorflow/contrib/bigtable/BUILD b/tensorflow/contrib/bigtable/BUILD
new file mode 100644
index 0000000000..5c15d21e35
--- /dev/null
+++ b/tensorflow/contrib/bigtable/BUILD
@@ -0,0 +1,196 @@
+# Cloud Bigtable client for TensorFlow
+
+package(
+ default_visibility = ["//tensorflow:internal"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_copts",
+ "tf_custom_op_library",
+ "tf_gen_op_libs",
+ "tf_gen_op_wrapper_py",
+ "tf_kernel_library",
+ "tf_cc_test",
+ "tf_py_test",
+)
+
+tf_custom_op_py_library(
+ name = "bigtable",
+ srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
+ dso = [
+ ":python/ops/_bigtable.so",
+ ],
+ kernels = [
+ ":bigtable_kernels",
+ ":bigtable_ops_op_lib",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":bigtable_ops",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data",
+ ],
+)
+
+tf_custom_op_library(
+ name = "python/ops/_bigtable.so",
+ srcs = [
+ "kernels/bigtable_kernels.cc",
+ "kernels/bigtable_lookup_dataset_op.cc",
+ "kernels/bigtable_prefix_key_dataset_op.cc",
+ "kernels/bigtable_range_key_dataset_op.cc",
+ "kernels/bigtable_scan_dataset_op.cc",
+ "ops/bigtable_ops.cc",
+ ],
+ deps = [
+ ":bigtable_lib_cc",
+ "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "bigtable_ops",
+ deps = [":bigtable_ops_op_lib"],
+)
+
+tf_gen_op_libs(
+ op_lib_names = [
+ "bigtable_ops",
+ "bigtable_test_ops",
+ ],
+)
+
+tf_kernel_library(
+ name = "bigtable_kernels",
+ srcs = [
+ "kernels/bigtable_kernels.cc",
+ "kernels/bigtable_lookup_dataset_op.cc",
+ "kernels/bigtable_prefix_key_dataset_op.cc",
+ "kernels/bigtable_range_key_dataset_op.cc",
+ "kernels/bigtable_scan_dataset_op.cc",
+ ],
+ deps = [
+ ":bigtable_lib_cc",
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client",
+ ],
+)
+
+# A library for use in the bigtable kernels.
+cc_library(
+ name = "bigtable_lib_cc",
+ srcs = ["kernels/bigtable_lib.cc"],
+ hdrs = ["kernels/bigtable_lib.h"],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client",
+ ],
+)
+
+cc_library(
+ name = "bigtable_test_client",
+ srcs = ["kernels/test_kernels/bigtable_test_client.cc"],
+ hdrs = ["kernels/test_kernels/bigtable_test_client.h"],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "@com_github_googleapis_googleapis//:bigtable_protos",
+ "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client",
+ "@com_googlesource_code_re2//:re2",
+ ],
+)
+
+tf_cc_test(
+ name = "bigtable_test_client_test",
+ srcs = ["kernels/test_kernels/bigtable_test_client_test.cc"],
+ tags = ["manual"],
+ deps = [
+ ":bigtable_test_client",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "bigtable_test_ops",
+ deps = [":bigtable_test_ops_op_lib"],
+)
+
+tf_custom_op_library(
+ name = "python/kernel_tests/_bigtable_test.so",
+ srcs = [
+ "kernels/test_kernels/bigtable_test_client_op.cc",
+ "ops/bigtable_test_ops.cc",
+ ],
+ deps = [
+ ":bigtable_lib_cc",
+ ":bigtable_test_client",
+ "@com_googlesource_code_re2//:re2",
+ ],
+)
+
+# Don't use tf_kernel_library because it prevents access to strings/stringprintf.h
+cc_library(
+ name = "bigtable_test_kernels",
+ srcs = [
+ "kernels/test_kernels/bigtable_test_client_op.cc",
+ ],
+ copts = tf_copts(),
+ linkstatic = 1,
+ deps = [
+ ":bigtable_lib_cc",
+ ":bigtable_test_client",
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ "@com_googlesource_code_re2//:re2",
+ ],
+ alwayslink = 1,
+)
+
+tf_custom_op_py_library(
+ name = "bigtable_test_py",
+ dso = [
+ ":python/kernel_tests/_bigtable_test.so",
+ ],
+ kernels = [
+ ":bigtable_test_kernels",
+ ":bigtable_test_ops_op_lib",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":bigtable_test_ops",
+ # "//tensorflow/contrib/util:util_py",
+ # "//tensorflow/python:framework_for_generated_wrappers",
+ # "//tensorflow/python:platform",
+ # "//tensorflow/python:util",
+ # "//tensorflow/python/data",
+ ],
+)
+
+tf_py_test(
+ name = "bigtable_ops_test",
+ size = "small",
+ srcs = ["python/kernel_tests/bigtable_ops_test.py"],
+ additional_deps = [
+ ":bigtable",
+ ":bigtable_test_py",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:util",
+ ],
+ tags = ["manual"],
+)
diff --git a/tensorflow/contrib/bigtable/README.md b/tensorflow/contrib/bigtable/README.md
new file mode 100644
index 0000000000..ef3c60069e
--- /dev/null
+++ b/tensorflow/contrib/bigtable/README.md
@@ -0,0 +1,10 @@
+# Bigtable #
+
+[Google Cloud Bigtable](https://cloud.google.com/bigtable/) is a high
+performance storage system that can store and serve training data. This contrib
+package contains an experimental integration with TensorFlow.
+
+> **Status: Highly experimental.** The current implementation is very much in
+> flux. Please use at your own risk! :-)
+
+<!-- TODO(saeta): Document usage / methods / etc. -->
diff --git a/tensorflow/contrib/bigtable/__init__.py b/tensorflow/contrib/bigtable/__init__.py
new file mode 100644
index 0000000000..7df054637c
--- /dev/null
+++ b/tensorflow/contrib/bigtable/__init__.py
@@ -0,0 +1,39 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Cloud Bigtable Client for TensorFlow.
+
+This contrib package allows TensorFlow to interface directly with Cloud Bigtable
+for high-speed data loading.
+
+@@BigtableClient
+@@BigTable
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigTable
+from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigtableClient
+
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ 'BigTable',
+ 'BigtableClient',
+]
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc
new file mode 100644
index 0000000000..0c81951d56
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc
@@ -0,0 +1,313 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+
+namespace tensorflow {
+
+namespace {
+
+class BigtableClientOp : public OpKernel {
+ public:
+ explicit BigtableClientOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("project_id", &project_id_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("instance_id", &instance_id_));
+ OP_REQUIRES(ctx, !project_id_.empty(),
+ errors::InvalidArgument("project_id must be non-empty"));
+ OP_REQUIRES(ctx, !instance_id_.empty(),
+ errors::InvalidArgument("instance_id must be non-empty"));
+ }
+
+ ~BigtableClientOp() override {
+ if (cinfo_.resource_is_private_to_kernel()) {
+ if (!cinfo_.resource_manager()
+ ->Delete<BigtableClientResource>(cinfo_.container(),
+ cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ }
+ }
+ }
+
+ void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ if (!initialized_) {
+ ResourceMgr* mgr = ctx->resource_manager();
+ OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def()));
+ BigtableClientResource* resource;
+ OP_REQUIRES_OK(
+ ctx, mgr->LookupOrCreate<BigtableClientResource>(
+ cinfo_.container(), cinfo_.name(), &resource,
+ [this, ctx](BigtableClientResource** ret)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ std::shared_ptr<bigtable::DataClient> client =
+ bigtable::CreateDefaultDataClient(
+ project_id_, instance_id_,
+ bigtable::ClientOptions());
+ *ret = new BigtableClientResource(
+ project_id_, instance_id_, std::move(client));
+ return Status::OK();
+ }));
+ core::ScopedUnref resource_cleanup(resource);
+ initialized_ = true;
+ }
+ OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
+ ctx, 0, cinfo_.container(), cinfo_.name(),
+ MakeTypeIndex<BigtableClientResource>()));
+ }
+
+ private:
+ string project_id_;
+ string instance_id_;
+
+ mutex mu_;
+ ContainerInfo cinfo_ GUARDED_BY(mu_);
+ bool initialized_ GUARDED_BY(mu_) = false;
+};
+
+REGISTER_KERNEL_BUILDER(Name("BigtableClient").Device(DEVICE_CPU),
+ BigtableClientOp);
+
+class BigtableTableOp : public OpKernel {
+ public:
+ explicit BigtableTableOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("table_name", &table_));
+ OP_REQUIRES(ctx, !table_.empty(),
+ errors::InvalidArgument("table_name must be non-empty"));
+ }
+
+ ~BigtableTableOp() override {
+ if (cinfo_.resource_is_private_to_kernel()) {
+ if (!cinfo_.resource_manager()
+ ->Delete<BigtableTableResource>(cinfo_.container(),
+ cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ }
+ }
+ }
+
+ void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ if (!initialized_) {
+ ResourceMgr* mgr = ctx->resource_manager();
+ OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def()));
+
+ BigtableClientResource* client_resource;
+ OP_REQUIRES_OK(
+ ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client_resource));
+ core::ScopedUnref unref_client(client_resource);
+
+ BigtableTableResource* resource;
+ OP_REQUIRES_OK(
+ ctx, mgr->LookupOrCreate<BigtableTableResource>(
+ cinfo_.container(), cinfo_.name(), &resource,
+ [this, client_resource](BigtableTableResource** ret) {
+ *ret = new BigtableTableResource(client_resource, table_);
+ return Status::OK();
+ }));
+ initialized_ = true;
+ }
+ OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
+ ctx, 0, cinfo_.container(), cinfo_.name(),
+ MakeTypeIndex<BigtableTableResource>()));
+ }
+
+ private:
+ string table_; // Note: this is const after construction.
+
+ mutex mu_;
+ ContainerInfo cinfo_ GUARDED_BY(mu_);
+ bool initialized_ GUARDED_BY(mu_) = false;
+};
+
+REGISTER_KERNEL_BUILDER(Name("BigtableTable").Device(DEVICE_CPU),
+ BigtableTableOp);
+
+class ToBigtableOp : public AsyncOpKernel {
+ public:
+ explicit ToBigtableOp(OpKernelConstruction* ctx)
+ : AsyncOpKernel(ctx),
+ thread_pool_(new thread::ThreadPool(
+ ctx->env(), ThreadOptions(),
+ strings::StrCat("to_bigtable_op_", SanitizeThreadSuffix(name())),
+ /* num_threads = */ 1, /* low_latency_hint = */ false)) {}
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ // The call to `iterator->GetNext()` may block and depend on an
+ // inter-op thread pool thread, so we issue the call from the
+ // owned thread pool.
+ thread_pool_->Schedule([this, ctx, done]() {
+ const Tensor* column_families_tensor;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->input("column_families", &column_families_tensor), done);
+ OP_REQUIRES_ASYNC(
+ ctx, column_families_tensor->dims() == 1,
+ errors::InvalidArgument("`column_families` must be a vector."), done);
+
+ const Tensor* columns_tensor;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input("columns", &columns_tensor), done);
+ OP_REQUIRES_ASYNC(ctx, columns_tensor->dims() == 1,
+ errors::InvalidArgument("`columns` must be a vector."),
+ done);
+ OP_REQUIRES_ASYNC(
+ ctx,
+ columns_tensor->NumElements() ==
+ column_families_tensor->NumElements(),
+ errors::InvalidArgument("len(column_families) != len(columns)"),
+ done);
+
+ std::vector<string> column_families;
+ column_families.reserve(column_families_tensor->NumElements());
+ std::vector<string> columns;
+ columns.reserve(column_families_tensor->NumElements());
+ for (uint64 i = 0; i < column_families_tensor->NumElements(); ++i) {
+ column_families.push_back(column_families_tensor->flat<string>()(i));
+ columns.push_back(columns_tensor->flat<string>()(i));
+ }
+
+ DatasetBase* dataset;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, GetDatasetFromVariantTensor(ctx->input(1), &dataset), done);
+
+ IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
+ std::unique_ptr<IteratorBase> iterator;
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ dataset->MakeIterator(&iter_ctx, "ToBigtableOpIterator", &iterator),
+ done);
+
+ int64 timestamp_int;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ParseScalarArgument<int64>(ctx, "timestamp", &timestamp_int),
+ done);
+ OP_REQUIRES_ASYNC(ctx, timestamp_int >= -1,
+ errors::InvalidArgument("timestamp must be >= -1"),
+ done);
+ std::chrono::milliseconds timestamp(timestamp_int);
+
+ BigtableTableResource* resource;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource), done);
+ core::ScopedUnref resource_cleanup(resource);
+
+ std::vector<Tensor> components;
+ components.reserve(dataset->output_dtypes().size());
+ bool end_of_sequence = false;
+ do {
+ ::bigtable::BulkMutation mutation;
+ // TODO(saeta): Make # of mutations configurable.
+ for (uint64 i = 0; i < 100 && !end_of_sequence; ++i) {
+ OP_REQUIRES_OK_ASYNC(
+ ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence),
+ done);
+ if (!end_of_sequence) {
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ CreateMutation(std::move(components), column_families, columns,
+ timestamp, &mutation),
+ done);
+ }
+ components.clear();
+ }
+ grpc::Status mutation_status;
+ std::vector<::bigtable::FailedMutation> failures =
+ resource->table().BulkApply(std::move(mutation), mutation_status);
+ if (!failures.empty()) {
+ for (const auto& failure : failures) {
+ LOG(ERROR) << "Failure applying mutation on row ("
+ << failure.original_index()
+ << "): " << failure.mutation().row_key()
+ << " - error: " << failure.status().error_message()
+ << " (Details: " << failure.status().error_details()
+ << ").";
+ }
+ }
+ OP_REQUIRES_ASYNC(
+ ctx, failures.empty() && mutation_status.ok(),
+ errors::Unknown("Failure while writing to BigTable: ",
+ mutation_status.error_code(), " - ",
+ mutation_status.error_message(), " (",
+ mutation_status.error_details(),
+ "), # of mutation failures: ", failures.size(),
+ ". See the log for the specific error details."),
+ done);
+ } while (!end_of_sequence);
+ done();
+ });
+ }
+
+ private:
+ static string SanitizeThreadSuffix(string suffix) {
+ string clean;
+ for (int i = 0; i < suffix.size(); ++i) {
+ const char ch = suffix[i];
+ if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') ||
+ (ch >= '0' && ch <= '9') || ch == '_' || ch == '-') {
+ clean += ch;
+ } else {
+ clean += '_';
+ }
+ }
+ return clean;
+ }
+
+ Status CreateMutation(std::vector<Tensor> tensors,
+ const std::vector<string>& column_families,
+ const std::vector<string>& columns,
+ std::chrono::milliseconds timestamp,
+ ::bigtable::BulkMutation* bulk_mutation) {
+ if (tensors.size() != column_families.size() + 1) {
+ return errors::InvalidArgument(
+ "Iterator produced a set of Tensors shorter than expected");
+ }
+ ::bigtable::SingleRowMutation mutation(
+ std::move(tensors[0].scalar<string>()()));
+ for (size_t i = 1; i < tensors.size(); ++i) {
+ if (!TensorShapeUtils::IsScalar(tensors[i].shape())) {
+ return errors::Internal("Output tensor ", i, " was not a scalar");
+ }
+ mutation.emplace_back(
+ ::bigtable::SetCell(column_families[i - 1], columns[i - 1], timestamp,
+ std::move(tensors[i].scalar<string>()())));
+ }
+ bulk_mutation->emplace_back(std::move(mutation));
+ return Status::OK();
+ }
+
+ template <typename T>
+ Status ParseScalarArgument(OpKernelContext* ctx,
+ const StringPiece& argument_name, T* output) {
+ const Tensor* argument_t;
+ TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
+ if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
+ return errors::InvalidArgument(argument_name, " must be a scalar");
+ }
+ *output = argument_t->scalar<T>()();
+ return Status::OK();
+ }
+
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("DatasetToBigtable").Device(DEVICE_CPU),
+ ToBigtableOp);
+
+} // namespace
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc
new file mode 100644
index 0000000000..2514575f30
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc
@@ -0,0 +1,45 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
+
+namespace tensorflow {
+
+Status GrpcStatusToTfStatus(const ::grpc::Status& status) {
+ if (status.ok()) {
+ return Status::OK();
+ }
+ auto grpc_code = status.error_code();
+ if (status.error_code() == ::grpc::StatusCode::ABORTED ||
+ status.error_code() == ::grpc::StatusCode::UNAVAILABLE ||
+ status.error_code() == ::grpc::StatusCode::OUT_OF_RANGE) {
+ grpc_code = ::grpc::StatusCode::INTERNAL;
+ }
+ return Status(
+ static_cast<::tensorflow::error::Code>(status.error_code()),
+ strings::StrCat("Error reading from BigTable: ", status.error_message(),
+ " (Details: ", status.error_details(), ")"));
+}
+
+string RegexFromStringSet(const std::vector<string>& strs) {
+ CHECK(!strs.empty()) << "The list of strings to turn into a regex was empty.";
+ std::unordered_set<string> uniq(strs.begin(), strs.end());
+ if (uniq.size() == 1) {
+ return *uniq.begin();
+ }
+ return str_util::Join(uniq, "|");
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h
new file mode 100644
index 0000000000..54303cdc5e
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h
@@ -0,0 +1,138 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_
+#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_
+
+// Note: we use bigtable/client/internal/table.h as this is the no-exception API
+
+#include "google/cloud/bigtable/data_client.h"
+#include "google/cloud/bigtable/internal/table.h"
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+
+namespace tensorflow {
+
+Status GrpcStatusToTfStatus(const ::grpc::Status& status);
+
+string RegexFromStringSet(const std::vector<string>& strs);
+
+class BigtableClientResource : public ResourceBase {
+ public:
+ BigtableClientResource(string project_id, string instance_id,
+ std::shared_ptr<bigtable::DataClient> client)
+ : project_id_(std::move(project_id)),
+ instance_id_(std::move(instance_id)),
+ client_(std::move(client)) {}
+
+ std::shared_ptr<bigtable::DataClient> get_client() { return client_; }
+
+ string DebugString() override {
+ return strings::StrCat("BigtableClientResource(project_id: ", project_id_,
+ ", instance_id: ", instance_id_, ")");
+ }
+
+ private:
+ const string project_id_;
+ const string instance_id_;
+ std::shared_ptr<bigtable::DataClient> client_;
+};
+
+class BigtableTableResource : public ResourceBase {
+ public:
+ BigtableTableResource(BigtableClientResource* client, string table_name)
+ : client_(client),
+ table_name_(std::move(table_name)),
+ table_(client->get_client(), table_name_) {
+ client_->Ref();
+ }
+
+ ~BigtableTableResource() override { client_->Unref(); }
+
+ ::bigtable::noex::Table& table() { return table_; }
+
+ string DebugString() override {
+ return strings::StrCat(
+ "BigtableTableResource(client: ", client_->DebugString(),
+ ", table: ", table_name_, ")");
+ }
+
+ private:
+ BigtableClientResource* client_; // Ownes one ref.
+ const string table_name_;
+ ::bigtable::noex::Table table_;
+};
+
+// BigtableReaderDatasetIterator is an abstract class for iterators from
+// datasets that are "readers" (source datasets, not transformation datasets)
+// that read from Bigtable.
+template <typename Dataset>
+class BigtableReaderDatasetIterator : public DatasetIterator<Dataset> {
+ public:
+ explicit BigtableReaderDatasetIterator(
+ const typename DatasetIterator<Dataset>::Params& params)
+ : DatasetIterator<Dataset>(params), iterator_(nullptr, false) {}
+
+ Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(EnsureIteratorInitialized());
+ if (iterator_ == reader_->end()) {
+ grpc::Status status = reader_->Finish();
+ if (status.ok()) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ return GrpcStatusToTfStatus(status);
+ }
+ *end_of_sequence = false;
+ bigtable::Row& row = *iterator_;
+ Status s = ParseRow(ctx, row, out_tensors);
+ // Ensure we always advance.
+ ++iterator_;
+ return s;
+ }
+
+ protected:
+ virtual ::bigtable::RowRange MakeRowRange() = 0;
+ virtual ::bigtable::Filter MakeFilter() = 0;
+ virtual Status ParseRow(IteratorContext* ctx, const ::bigtable::Row& row,
+ std::vector<Tensor>* out_tensors) = 0;
+
+ private:
+ Status EnsureIteratorInitialized() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (reader_) {
+ return Status::OK();
+ }
+
+ auto rows = MakeRowRange();
+ auto filter = MakeFilter();
+
+ // Note: the this in `this->dataset()` below is necessary due to namespace
+ // name conflicts.
+ reader_.reset(new ::bigtable::RowReader(
+ this->dataset()->table()->table().ReadRows(rows, filter)));
+ iterator_ = reader_->begin();
+ return Status::OK();
+ }
+
+ mutex mu_;
+ std::unique_ptr<::bigtable::RowReader> reader_ GUARDED_BY(mu_);
+ ::bigtable::RowReader::iterator iterator_ GUARDED_BY(mu_);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
new file mode 100644
index 0000000000..4b6d55a2d3
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
@@ -0,0 +1,220 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ using UnaryDatasetOpKernel::UnaryDatasetOpKernel;
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ BigtableTableResource* table;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &table));
+
+ std::vector<string> column_families;
+ std::vector<string> columns;
+ OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "column_families",
+ &column_families));
+ OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "columns", &columns));
+ OP_REQUIRES(
+ ctx, column_families.size() == columns.size(),
+ errors::InvalidArgument("len(columns) != len(column_families)"));
+
+ const uint64 num_outputs = columns.size() + 1;
+ std::vector<PartialTensorShape> output_shapes;
+ output_shapes.reserve(num_outputs);
+ DataTypeVector output_types;
+ output_types.reserve(num_outputs);
+ for (uint64 i = 0; i < num_outputs; ++i) {
+ output_shapes.push_back({});
+ output_types.push_back(DT_STRING);
+ }
+
+ *output =
+ new Dataset(ctx, input, table, std::move(column_families),
+ std::move(columns), output_types, std::move(output_shapes));
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ BigtableTableResource* table,
+ std::vector<string> column_families,
+ std::vector<string> columns,
+ const DataTypeVector& output_types,
+ std::vector<PartialTensorShape> output_shapes)
+ : GraphDatasetBase(ctx),
+ input_(input),
+ table_(table),
+ column_families_(std::move(column_families)),
+ columns_(std::move(columns)),
+ output_types_(output_types),
+ output_shapes_(std::move(output_shapes)),
+ filter_(MakeFilter(column_families_, columns_)) {
+ table_->Ref();
+ input_->Ref();
+ }
+
+ ~Dataset() override {
+ table_->Unref();
+ input_->Unref();
+ }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::BigtableLookupDataset")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return output_types_;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override {
+ return "BigtableLookupDatasetOp::Dataset";
+ }
+
+ private:
+ static ::bigtable::Filter MakeFilter(
+ const std::vector<string>& column_families,
+ const std::vector<string>& columns) {
+ string column_family_regex = RegexFromStringSet(column_families);
+ string column_regex = RegexFromStringSet(columns);
+
+ return ::bigtable::Filter::Chain(
+ ::bigtable::Filter::Latest(1),
+ ::bigtable::Filter::FamilyRegex(column_family_regex),
+ ::bigtable::Filter::ColumnRegex(column_regex));
+ }
+
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_); // Sequence requests.
+ std::vector<Tensor> input_tensors;
+ TF_RETURN_IF_ERROR(
+ input_impl_->GetNext(ctx, &input_tensors, end_of_sequence));
+ if (*end_of_sequence) {
+ return Status::OK();
+ }
+ if (input_tensors.size() != 1) {
+ return errors::InvalidArgument(
+ "Upstream iterator (", dataset()->input_->DebugString(),
+ ") did not produce a single `tf.string` `tf.Tensor`. It "
+ "produced ",
+ input_tensors.size(), " tensors.");
+ }
+ if (input_tensors[0].NumElements() == 0) {
+ return errors::InvalidArgument("Upstream iterator (",
+ dataset()->input_->DebugString(),
+ ") return an empty set of keys.");
+ }
+ if (input_tensors[0].NumElements() == 1) {
+ // Single key lookup.
+ ::grpc::Status status;
+ auto pair = dataset()->table_->table().ReadRow(
+ input_tensors[0].scalar<string>()(), dataset()->filter_, status);
+ if (!status.ok()) {
+ return GrpcStatusToTfStatus(status);
+ }
+ if (!pair.first) {
+ return errors::DataLoss("Row key '",
+ input_tensors[0].scalar<string>()(),
+ "' not found.");
+ }
+ TF_RETURN_IF_ERROR(ParseRow(ctx, pair.second, out_tensors));
+ } else {
+ // Batched get.
+ return errors::Unimplemented(
+ "BigtableLookupDataset doesn't yet support batched retrieval.");
+ }
+ return Status::OK();
+ }
+
+ private:
+ Status ParseRow(IteratorContext* ctx, const ::bigtable::Row& row,
+ std::vector<Tensor>* out_tensors) {
+ out_tensors->reserve(dataset()->columns_.size() + 1);
+ Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {});
+ row_key_tensor.scalar<string>()() = string(row.row_key());
+ out_tensors->emplace_back(std::move(row_key_tensor));
+
+ if (row.cells().size() > 2 * dataset()->columns_.size()) {
+ LOG(WARNING) << "An excessive number of columns ("
+ << row.cells().size()
+ << ") were retrieved when reading row: "
+ << row.row_key();
+ }
+
+ for (uint64 i = 0; i < dataset()->columns_.size(); ++i) {
+ Tensor col_tensor(ctx->allocator({}), DT_STRING, {});
+ bool found_column = false;
+ for (auto cell_itr = row.cells().begin();
+ !found_column && cell_itr != row.cells().end(); ++cell_itr) {
+ if (cell_itr->family_name() == dataset()->column_families_[i] &&
+ string(cell_itr->column_qualifier()) ==
+ dataset()->columns_[i]) {
+ col_tensor.scalar<string>()() = string(cell_itr->value());
+ found_column = true;
+ }
+ }
+ if (!found_column) {
+ return errors::DataLoss("Column ", dataset()->column_families_[i],
+ ":", dataset()->columns_[i],
+ " not found in row: ", row.row_key());
+ }
+ out_tensors->emplace_back(std::move(col_tensor));
+ }
+ return Status::OK();
+ }
+
+ mutex mu_;
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ };
+
+ const DatasetBase* const input_;
+ BigtableTableResource* table_;
+ const std::vector<string> column_families_;
+ const std::vector<string> columns_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ const ::bigtable::Filter filter_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("BigtableLookupDataset").Device(DEVICE_CPU),
+ BigtableLookupDatasetOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
new file mode 100644
index 0000000000..3d5c3cfdaa
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
@@ -0,0 +1,103 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class BigtablePrefixKeyDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ string prefix;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "prefix", &prefix));
+
+ BigtableTableResource* resource;
+ OP_REQUIRES_OK(ctx,
+ LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+
+ *output = new Dataset(ctx, resource, std::move(prefix));
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
+ string prefix)
+ : GraphDatasetBase(ctx), table_(table), prefix_(std::move(prefix)) {
+ table_->Ref();
+ }
+
+ ~Dataset() override { table_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::BigtablePrefixKeyDataset")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
+ return *dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}});
+ return *shapes;
+ }
+
+ string DebugString() const override {
+ return "BigtablePrefixKeyDatasetOp::Dataset";
+ }
+
+ BigtableTableResource* table() const { return table_; }
+
+ private:
+ class Iterator : public BigtableReaderDatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : BigtableReaderDatasetIterator<Dataset>(params) {}
+
+ ::bigtable::RowRange MakeRowRange() override {
+ return ::bigtable::RowRange::Prefix(dataset()->prefix_);
+ }
+ ::bigtable::Filter MakeFilter() override {
+ return ::bigtable::Filter::Chain(
+ ::bigtable::Filter::CellsRowLimit(1),
+ ::bigtable::Filter::StripValueTransformer());
+ }
+ Status ParseRow(IteratorContext* ctx, const ::bigtable::Row& row,
+ std::vector<Tensor>* out_tensors) override {
+ Tensor output_tensor(ctx->allocator({}), DT_STRING, {});
+ output_tensor.scalar<string>()() = string(row.row_key());
+ out_tensors->emplace_back(std::move(output_tensor));
+ return Status::OK();
+ }
+ };
+
+ BigtableTableResource* const table_;
+ const string prefix_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("BigtablePrefixKeyDataset").Device(DEVICE_CPU),
+ BigtablePrefixKeyDatasetOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
new file mode 100644
index 0000000000..7fa06052c5
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
@@ -0,0 +1,111 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class BigtableRangeKeyDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ string start_key;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<string>(ctx, "start_key", &start_key));
+ string end_key;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key));
+
+ BigtableTableResource* resource;
+ OP_REQUIRES_OK(ctx,
+ LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+
+ *output =
+ new Dataset(ctx, resource, std::move(start_key), std::move(end_key));
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
+ string start_key, string end_key)
+ : GraphDatasetBase(ctx),
+ table_(table),
+ start_key_(std::move(start_key)),
+ end_key_(std::move(end_key)) {
+ table_->Ref();
+ }
+
+ ~Dataset() override { table_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::BigtableRangeKeyDataset")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
+ return *dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}});
+ return *shapes;
+ }
+
+ string DebugString() const override {
+ return "BigtableRangeKeyDatasetOp::Dataset";
+ }
+
+ BigtableTableResource* table() const { return table_; }
+
+ private:
+ class Iterator : public BigtableReaderDatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : BigtableReaderDatasetIterator<Dataset>(params) {}
+
+ ::bigtable::RowRange MakeRowRange() override {
+ return ::bigtable::RowRange::Range(dataset()->start_key_,
+ dataset()->end_key_);
+ }
+ ::bigtable::Filter MakeFilter() override {
+ return ::bigtable::Filter::Chain(
+ ::bigtable::Filter::CellsRowLimit(1),
+ ::bigtable::Filter::StripValueTransformer());
+ }
+ Status ParseRow(IteratorContext* ctx, const ::bigtable::Row& row,
+ std::vector<Tensor>* out_tensors) override {
+ Tensor output_tensor(ctx->allocator({}), DT_STRING, {});
+ output_tensor.scalar<string>()() = string(row.row_key());
+ out_tensors->emplace_back(std::move(output_tensor));
+ return Status::OK();
+ }
+ };
+
+ BigtableTableResource* const table_;
+ const string start_key_;
+ const string end_key_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("BigtableRangeKeyDataset").Device(DEVICE_CPU),
+ BigtableRangeKeyDatasetOp);
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
new file mode 100644
index 0000000000..11b9bd2bdc
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
@@ -0,0 +1,214 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class BigtableScanDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ string prefix;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "prefix", &prefix));
+ string start_key;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<string>(ctx, "start_key", &start_key));
+ string end_key;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key));
+
+ OP_REQUIRES(ctx, !(prefix.empty() && start_key.empty()),
+ errors::InvalidArgument(
+ "Either prefix or start_key must be specified"));
+ OP_REQUIRES(ctx, prefix.empty() || start_key.empty(),
+ errors::InvalidArgument(
+ "Only one of prefix and start_key can be provided"));
+ if (!prefix.empty()) {
+ OP_REQUIRES(ctx, end_key.empty(),
+ errors::InvalidArgument(
+ "If prefix is specified, end_key must be empty."));
+ }
+
+ std::vector<string> column_families;
+ std::vector<string> columns;
+ OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "column_families",
+ &column_families));
+ OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "columns", &columns));
+ OP_REQUIRES(
+ ctx, column_families.size() == columns.size(),
+ errors::InvalidArgument("len(columns) != len(column_families)"));
+ OP_REQUIRES(ctx, !column_families.empty(),
+ errors::InvalidArgument("`column_families` is empty"));
+
+ float probability = 0;
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<float>(ctx, "probability", &probability));
+ OP_REQUIRES(
+ ctx, probability > 0 && probability <= 1,
+ errors::InvalidArgument(
+ "Probability outside the range of (0, 1]. Got: ", probability));
+
+ BigtableTableResource* resource;
+ OP_REQUIRES_OK(ctx,
+ LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+
+ const uint64 num_outputs = columns.size() + 1;
+ std::vector<PartialTensorShape> output_shapes;
+ output_shapes.reserve(num_outputs);
+ DataTypeVector output_types;
+ output_types.reserve(num_outputs);
+ for (uint64 i = 0; i < num_outputs; ++i) {
+ output_shapes.push_back({});
+ output_types.push_back(DT_STRING);
+ }
+
+ *output = new Dataset(ctx, resource, std::move(prefix),
+ std::move(start_key), std::move(end_key),
+ std::move(column_families), std::move(columns),
+ probability, output_types, std::move(output_shapes));
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
+ string prefix, string start_key, string end_key,
+ std::vector<string> column_families,
+ std::vector<string> columns, float probability,
+ const DataTypeVector& output_types,
+ std::vector<PartialTensorShape> output_shapes)
+ : GraphDatasetBase(ctx),
+ table_(table),
+ prefix_(std::move(prefix)),
+ start_key_(std::move(start_key)),
+ end_key_(std::move(end_key)),
+ column_families_(std::move(column_families)),
+ columns_(std::move(columns)),
+ column_family_regex_(RegexFromStringSet(column_families_)),
+ column_regex_(RegexFromStringSet(columns_)),
+ probability_(probability),
+ output_types_(output_types),
+ output_shapes_(std::move(output_shapes)) {
+ table_->Ref();
+ }
+
+ ~Dataset() override { table_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::BigtableScanDataset")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return output_types_;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override {
+ return "BigtableScanDatasetOp::Dataset";
+ }
+
+ BigtableTableResource* table() const { return table_; }
+
+ private:
+ class Iterator : public BigtableReaderDatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : BigtableReaderDatasetIterator<Dataset>(params) {}
+
+ ::bigtable::RowRange MakeRowRange() override {
+ if (!dataset()->prefix_.empty()) {
+ DCHECK(dataset()->start_key_.empty());
+ return ::bigtable::RowRange::Prefix(dataset()->prefix_);
+ } else {
+ DCHECK(!dataset()->start_key_.empty())
+ << "Both prefix and start_key were empty!";
+ return ::bigtable::RowRange::Range(dataset()->start_key_,
+ dataset()->end_key_);
+ }
+ }
+ ::bigtable::Filter MakeFilter() override {
+ // TODO(saeta): Investigate optimal ordering here.
+ return ::bigtable::Filter::Chain(
+ ::bigtable::Filter::Latest(1),
+ ::bigtable::Filter::FamilyRegex(dataset()->column_family_regex_),
+ ::bigtable::Filter::ColumnRegex(dataset()->column_regex_),
+ dataset()->probability_ != 1.0
+ ? ::bigtable::Filter::RowSample(dataset()->probability_)
+ : ::bigtable::Filter::PassAllFilter());
+ }
+ Status ParseRow(IteratorContext* ctx, const ::bigtable::Row& row,
+ std::vector<Tensor>* out_tensors) override {
+ out_tensors->reserve(dataset()->columns_.size() + 1);
+ Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {});
+ row_key_tensor.scalar<string>()() = string(row.row_key());
+ out_tensors->emplace_back(std::move(row_key_tensor));
+
+ if (row.cells().size() > 2 * dataset()->columns_.size()) {
+ LOG(WARNING) << "An excessive number of columns ("
+ << row.cells().size()
+ << ") were retrieved when reading row: "
+ << row.row_key();
+ }
+
+ for (uint64 i = 0; i < dataset()->columns_.size(); ++i) {
+ Tensor col_tensor(ctx->allocator({}), DT_STRING, {});
+ bool found_column = false;
+ for (auto cell_itr = row.cells().begin();
+ !found_column && cell_itr != row.cells().end(); ++cell_itr) {
+ if (cell_itr->family_name() == dataset()->column_families_[i] &&
+ string(cell_itr->column_qualifier()) ==
+ dataset()->columns_[i]) {
+ col_tensor.scalar<string>()() = string(cell_itr->value());
+ found_column = true;
+ }
+ }
+ if (!found_column) {
+ return errors::InvalidArgument(
+ "Column ", dataset()->column_families_[i], ":",
+ dataset()->columns_[i], " not found in row: ", row.row_key());
+ }
+ out_tensors->emplace_back(std::move(col_tensor));
+ }
+ return Status::OK();
+ }
+ };
+
+ BigtableTableResource* table_;
+ const string prefix_;
+ const string start_key_;
+ const string end_key_;
+ const std::vector<string> column_families_;
+ const std::vector<string> columns_;
+ const string column_family_regex_;
+ const string column_regex_;
+ const float probability_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("BigtableScanDataset").Device(DEVICE_CPU),
+ BigtableScanDatasetOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc
new file mode 100644
index 0000000000..0f107f169c
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc
@@ -0,0 +1,367 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h"
+
+#include "google/bigtable/v2/data.pb.h"
+#include "google/protobuf/wrappers.pb.h"
+#include "re2/re2.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/util/ptr_util.h"
+// #include "util/task/codes.pb.h"
+
+namespace tensorflow {
+namespace {
+
+void UpdateRow(const ::google::bigtable::v2::Mutation& mut,
+ std::map<string, string>* row) {
+ if (mut.has_set_cell()) {
+ auto col =
+ strings::Printf("%s:%s", mut.set_cell().family_name().c_str(),
+ string(mut.set_cell().column_qualifier()).c_str());
+ (*row)[col] = string(mut.set_cell().value());
+ } else if (mut.has_delete_from_column()) {
+ auto col = strings::Printf(
+ "%s:%s", mut.delete_from_column().family_name().c_str(),
+ string(mut.delete_from_column().column_qualifier()).c_str());
+ row->erase(col);
+ } else if (mut.has_delete_from_family()) {
+ auto itr = row->lower_bound(mut.delete_from_family().family_name());
+ auto prefix =
+ strings::Printf("%s:", mut.delete_from_family().family_name().c_str());
+ while (itr != row->end() && itr->first.substr(0, prefix.size()) == prefix) {
+ row->erase(itr);
+ }
+ } else if (mut.has_delete_from_row()) {
+ row->clear();
+ } else {
+ LOG(ERROR) << "Unknown mutation: " << mut.ShortDebugString();
+ }
+}
+
+} // namespace
+
+class SampleRowKeysResponse : public grpc::ClientReaderInterface<
+ google::bigtable::v2::SampleRowKeysResponse> {
+ public:
+ explicit SampleRowKeysResponse(BigtableTestClient* client)
+ : client_(client) {}
+
+ bool NextMessageSize(uint32_t* sz) override {
+ mutex_lock l(mu_);
+ if (sent_first_message_) {
+ return false;
+ }
+ *sz = 10000; // A sufficiently high enough value to not worry about.
+ return true;
+ }
+
+ bool Read(google::bigtable::v2::SampleRowKeysResponse* resp) override {
+ mutex_lock l(mu_);
+ if (sent_first_message_) {
+ return false;
+ }
+ sent_first_message_ = true;
+
+ mutex_lock l2(client_->mu_);
+ *resp = google::bigtable::v2::SampleRowKeysResponse();
+ resp->set_row_key(client_->table_.rows.begin()->first);
+ resp->set_offset_bytes(0);
+ return true;
+ }
+
+ grpc::Status Finish() override { return grpc::Status::OK; }
+
+ void WaitForInitialMetadata() override {} // Do nothing.
+
+ private:
+ mutex mu_;
+ bool sent_first_message_ GUARDED_BY(mu_) = false;
+ BigtableTestClient* client_; // Not owned.
+};
+
+class ReadRowsResponse : public grpc::ClientReaderInterface<
+ google::bigtable::v2::ReadRowsResponse> {
+ public:
+ ReadRowsResponse(BigtableTestClient* client,
+ google::bigtable::v2::ReadRowsRequest const& request)
+ : client_(client), request_(request) {}
+
+ bool NextMessageSize(uint32_t* sz) override {
+ mutex_lock l(mu_);
+ if (sent_first_message_) {
+ return false;
+ }
+ *sz = 10000000; // A sufficiently high enough value to not worry about.
+ return true;
+ }
+
+ bool Read(google::bigtable::v2::ReadRowsResponse* resp) override {
+ mutex_lock l(mu_);
+ if (sent_first_message_) {
+ return false;
+ }
+ sent_first_message_ = true;
+ RowFilter filter = MakeRowFilter();
+
+ mutex_lock l2(client_->mu_);
+ *resp = google::bigtable::v2::ReadRowsResponse();
+ // Send all contents in first response.
+ for (auto itr = client_->table_.rows.begin();
+ itr != client_->table_.rows.end(); ++itr) {
+ if (filter.AllowRow(itr->first)) {
+ ::google::bigtable::v2::ReadRowsResponse_CellChunk* chunk = nullptr;
+ bool sent_first = false;
+ for (auto col_itr = itr->second.columns.begin();
+ col_itr != itr->second.columns.end(); ++col_itr) {
+ if (filter.AllowColumn(col_itr->first)) {
+ chunk = resp->add_chunks();
+ if (!sent_first) {
+ sent_first = true;
+ chunk->set_row_key(itr->first);
+ }
+ auto colon_idx = col_itr->first.find(":");
+ CHECK(colon_idx != string::npos)
+ << "No ':' found in: " << col_itr->first;
+ chunk->mutable_family_name()->set_value(
+ string(col_itr->first, 0, colon_idx));
+ chunk->mutable_qualifier()->set_value(
+ string(col_itr->first, ++colon_idx));
+ if (!filter.strip_values) {
+ chunk->set_value(col_itr->second);
+ }
+ if (filter.only_one_column) {
+ break;
+ }
+ }
+ }
+ if (sent_first) {
+ // We are sending this row, so set the commit flag on the last chunk.
+ chunk->set_commit_row(true);
+ }
+ }
+ }
+ return true;
+ }
+
+ grpc::Status Finish() override { return grpc::Status::OK; }
+
+ void WaitForInitialMetadata() override {} // Do nothing.
+
+ private:
+ struct RowFilter {
+ std::set<string> row_set;
+ std::vector<std::pair<string, string>> row_ranges;
+ double row_sample = 0.0; // Note: currently ignored.
+ std::unique_ptr<RE2> col_filter;
+ bool strip_values = false;
+ bool only_one_column = false;
+
+ bool AllowRow(const string& row) {
+ if (row_set.find(row) != row_set.end()) {
+ return true;
+ }
+ for (const auto& range : row_ranges) {
+ if (range.first <= row && range.second > row) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ bool AllowColumn(const string& col) {
+ if (col_filter) {
+ return RE2::FullMatch(col, *col_filter);
+ } else {
+ return true;
+ }
+ }
+ };
+
+ RowFilter MakeRowFilter() {
+ RowFilter filter;
+ for (auto i = request_.rows().row_keys().begin();
+ i != request_.rows().row_keys().end(); ++i) {
+ filter.row_set.insert(string(*i));
+ }
+ for (auto i = request_.rows().row_ranges().begin();
+ i != request_.rows().row_ranges().end(); ++i) {
+ if (i->start_key_case() !=
+ google::bigtable::v2::RowRange::kStartKeyClosed ||
+ i->end_key_case() != google::bigtable::v2::RowRange::kEndKeyOpen) {
+ LOG(WARNING) << "Skipping row range that cannot be processed: "
+ << i->ShortDebugString();
+ continue;
+ }
+ filter.row_ranges.emplace_back(std::make_pair(
+ string(i->start_key_closed()), string(i->end_key_open())));
+ }
+ if (request_.filter().has_chain()) {
+ string family_filter;
+ string qualifier_filter;
+ for (auto i = request_.filter().chain().filters().begin();
+ i != request_.filter().chain().filters().end(); ++i) {
+ switch (i->filter_case()) {
+ case google::bigtable::v2::RowFilter::kFamilyNameRegexFilter:
+ family_filter = i->family_name_regex_filter();
+ break;
+ case google::bigtable::v2::RowFilter::kColumnQualifierRegexFilter:
+ qualifier_filter = i->column_qualifier_regex_filter();
+ break;
+ case google::bigtable::v2::RowFilter::kCellsPerColumnLimitFilter:
+ if (i->cells_per_column_limit_filter() != 1) {
+ LOG(ERROR) << "Unexpected cells_per_column_limit_filter: "
+ << i->cells_per_column_limit_filter();
+ }
+ break;
+ case google::bigtable::v2::RowFilter::kStripValueTransformer:
+ filter.strip_values = i->strip_value_transformer();
+ break;
+ case google::bigtable::v2::RowFilter::kRowSampleFilter:
+ LOG(INFO) << "Ignoring row sample directive.";
+ break;
+ case google::bigtable::v2::RowFilter::kPassAllFilter:
+ break;
+ case google::bigtable::v2::RowFilter::kCellsPerRowLimitFilter:
+ filter.only_one_column = true;
+ break;
+ default:
+ LOG(WARNING) << "Ignoring unknown filter type: "
+ << i->ShortDebugString();
+ }
+ }
+ if (family_filter.empty() || qualifier_filter.empty()) {
+ LOG(WARNING) << "Missing regex!";
+ } else {
+ string regex = strings::Printf("%s:%s", family_filter.c_str(),
+ qualifier_filter.c_str());
+ filter.col_filter.reset(new RE2(regex));
+ }
+ } else {
+ LOG(WARNING) << "Read request did not have a filter chain specified: "
+ << request_.filter().DebugString();
+ }
+ return filter;
+ }
+
+ mutex mu_;
+ bool sent_first_message_ GUARDED_BY(mu_) = false;
+ BigtableTestClient* client_; // Not owned.
+ const google::bigtable::v2::ReadRowsRequest request_;
+};
+
+class MutateRowsResponse : public grpc::ClientReaderInterface<
+ google::bigtable::v2::MutateRowsResponse> {
+ public:
+ explicit MutateRowsResponse(size_t num_successes)
+ : num_successes_(num_successes) {}
+
+ bool NextMessageSize(uint32_t* sz) override {
+ mutex_lock l(mu_);
+ if (sent_first_message_) {
+ return false;
+ }
+ *sz = 10000000; // A sufficiently high enough value to not worry about.
+ return true;
+ }
+
+ bool Read(google::bigtable::v2::MutateRowsResponse* resp) override {
+ mutex_lock l(mu_);
+ if (sent_first_message_) {
+ return false;
+ }
+ sent_first_message_ = true;
+ *resp = google::bigtable::v2::MutateRowsResponse();
+ for (size_t i = 0; i < num_successes_; ++i) {
+ auto entry = resp->add_entries();
+ entry->set_index(i);
+ }
+ return true;
+ }
+
+ grpc::Status Finish() override { return grpc::Status::OK; }
+
+ void WaitForInitialMetadata() override {} // Do nothing.
+
+ private:
+ const size_t num_successes_;
+
+ mutex mu_;
+ bool sent_first_message_ = false;
+};
+
+grpc::Status BigtableTestClient::MutateRow(
+ grpc::ClientContext* context,
+ google::bigtable::v2::MutateRowRequest const& request,
+ google::bigtable::v2::MutateRowResponse* response) {
+ mutex_lock l(mu_);
+ auto* row = &table_.rows[string(request.row_key())];
+ for (int i = 0; i < request.mutations_size(); ++i) {
+ UpdateRow(request.mutations(i), &row->columns);
+ }
+ *response = google::bigtable::v2::MutateRowResponse();
+ return grpc::Status::OK;
+}
+grpc::Status BigtableTestClient::CheckAndMutateRow(
+ grpc::ClientContext* context,
+ google::bigtable::v2::CheckAndMutateRowRequest const& request,
+ google::bigtable::v2::CheckAndMutateRowResponse* response) {
+ return grpc::Status(grpc::StatusCode::UNIMPLEMENTED,
+ "CheckAndMutateRow not implemented.");
+}
+grpc::Status BigtableTestClient::ReadModifyWriteRow(
+ grpc::ClientContext* context,
+ google::bigtable::v2::ReadModifyWriteRowRequest const& request,
+ google::bigtable::v2::ReadModifyWriteRowResponse* response) {
+ return grpc::Status(grpc::StatusCode::UNIMPLEMENTED,
+ "ReadModifyWriteRow not implemented.");
+}
+std::unique_ptr<
+ grpc::ClientReaderInterface<google::bigtable::v2::ReadRowsResponse>>
+BigtableTestClient::ReadRows(
+ grpc::ClientContext* context,
+ google::bigtable::v2::ReadRowsRequest const& request) {
+ return MakeUnique<ReadRowsResponse>(this, request);
+}
+
+std::unique_ptr<
+ grpc::ClientReaderInterface<google::bigtable::v2::SampleRowKeysResponse>>
+BigtableTestClient::SampleRowKeys(
+ grpc::ClientContext* context,
+ google::bigtable::v2::SampleRowKeysRequest const& request) {
+ return MakeUnique<SampleRowKeysResponse>(this);
+}
+std::unique_ptr<
+ grpc::ClientReaderInterface<google::bigtable::v2::MutateRowsResponse>>
+BigtableTestClient::MutateRows(
+ grpc::ClientContext* context,
+ google::bigtable::v2::MutateRowsRequest const& request) {
+ mutex_lock l(mu_);
+ for (auto i = request.entries().begin(); i != request.entries().end(); ++i) {
+ auto* row = &table_.rows[string(i->row_key())];
+ for (auto mut = i->mutations().begin(); mut != i->mutations().end();
+ ++mut) {
+ UpdateRow(*mut, &row->columns);
+ }
+ }
+ return MakeUnique<MutateRowsResponse>(request.entries_size());
+}
+
+std::shared_ptr<grpc::Channel> BigtableTestClient::Channel() {
+ LOG(WARNING) << "Call to InMemoryDataClient::Channel(); this will likely "
+ "cause a crash!";
+ return nullptr;
+}
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h
new file mode 100644
index 0000000000..dcce6a33a7
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h
@@ -0,0 +1,87 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_
+#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_
+
+#include "google/cloud/bigtable/data_client.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+class BigtableTestClient : public ::bigtable::DataClient {
+ public:
+ std::string const& project_id() const override { return project_id_; }
+ std::string const& instance_id() const override { return instance_id_; }
+ void reset() override {
+ mutex_lock l(mu_);
+ table_ = Table();
+ }
+
+ grpc::Status MutateRow(
+ grpc::ClientContext* context,
+ google::bigtable::v2::MutateRowRequest const& request,
+ google::bigtable::v2::MutateRowResponse* response) override;
+
+ grpc::Status CheckAndMutateRow(
+ grpc::ClientContext* context,
+ google::bigtable::v2::CheckAndMutateRowRequest const& request,
+ google::bigtable::v2::CheckAndMutateRowResponse* response) override;
+
+ grpc::Status ReadModifyWriteRow(
+ grpc::ClientContext* context,
+ google::bigtable::v2::ReadModifyWriteRowRequest const& request,
+ google::bigtable::v2::ReadModifyWriteRowResponse* response) override;
+
+ std::unique_ptr<
+ grpc::ClientReaderInterface<google::bigtable::v2::ReadRowsResponse>>
+ ReadRows(grpc::ClientContext* context,
+ google::bigtable::v2::ReadRowsRequest const& request) override;
+ std::unique_ptr<
+ grpc::ClientReaderInterface<google::bigtable::v2::SampleRowKeysResponse>>
+ SampleRowKeys(
+ grpc::ClientContext* context,
+ google::bigtable::v2::SampleRowKeysRequest const& request) override;
+
+ std::unique_ptr<
+ grpc::ClientReaderInterface<google::bigtable::v2::MutateRowsResponse>>
+ MutateRows(grpc::ClientContext* context,
+ google::bigtable::v2::MutateRowsRequest const& request) override;
+
+ std::shared_ptr<grpc::Channel> Channel() override;
+
+ private:
+ friend class SampleRowKeysResponse;
+ friend class ReadRowsResponse;
+ friend class MutateRowsResponse;
+
+ struct Row {
+ string row_key;
+ std::map<string, string> columns;
+ };
+ struct Table {
+ std::map<string, Row> rows;
+ };
+
+ mutex mu_;
+ const std::string project_id_ = "testproject";
+ const std::string instance_id_ = "testinstance";
+ Table table_ GUARDED_BY(mu_);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_
diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_op.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_op.cc
new file mode 100644
index 0000000000..f9be9ec6e2
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_op.cc
@@ -0,0 +1,77 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
+#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+
+namespace tensorflow {
+
+namespace {
+
+class BigtableTestClientOp : public OpKernel {
+ public:
+ explicit BigtableTestClientOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ ~BigtableTestClientOp() override {
+ if (cinfo_.resource_is_private_to_kernel()) {
+ if (!cinfo_.resource_manager()
+ ->Delete<BigtableClientResource>(cinfo_.container(),
+ cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ }
+ }
+ }
+ void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ if (!initialized_) {
+ ResourceMgr* mgr = ctx->resource_manager();
+ OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def()));
+ BigtableClientResource* resource;
+ OP_REQUIRES_OK(ctx,
+ mgr->LookupOrCreate<BigtableClientResource>(
+ cinfo_.container(), cinfo_.name(), &resource,
+ [this, ctx](BigtableClientResource** ret)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ std::shared_ptr<bigtable::DataClient> client(
+ new BigtableTestClient());
+ // Note: must make explicit copies to sequence
+ // them before the move of client.
+ string project_id = client->project_id();
+ string instance_id = client->instance_id();
+ *ret = new BigtableClientResource(
+ std::move(project_id),
+ std::move(instance_id), std::move(client));
+ return Status::OK();
+ }));
+ initialized_ = true;
+ }
+ OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
+ ctx, 0, cinfo_.container(), cinfo_.name(),
+ MakeTypeIndex<BigtableClientResource>()));
+ }
+
+ private:
+ mutex mu_;
+ ContainerInfo cinfo_ GUARDED_BY(mu_);
+ bool initialized_ GUARDED_BY(mu_) = false;
+};
+
+REGISTER_KERNEL_BUILDER(Name("BigtableTestClient").Device(DEVICE_CPU),
+ BigtableTestClientOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc
new file mode 100644
index 0000000000..bd362f7de5
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc
@@ -0,0 +1,279 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h"
+#include "google/cloud/bigtable/internal/table.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+void WriteCell(const string& row, const string& family, const string& column,
+ const string& value, ::bigtable::noex::Table* table) {
+ ::bigtable::SingleRowMutation mut(row);
+ mut.emplace_back(::bigtable::SetCell(family, column, value));
+ table->Apply(std::move(mut));
+}
+
+TEST(BigtableTestClientTest, EmptyRowRead) {
+ std::shared_ptr<::bigtable::DataClient> client_ptr =
+ std::make_shared<BigtableTestClient>();
+ ::bigtable::noex::Table table(client_ptr, "test_table");
+
+ ::bigtable::RowSet rowset;
+ rowset.Append("r1");
+ auto filter = ::bigtable::Filter::Chain(::bigtable::Filter::Latest(1));
+ auto rows = table.ReadRows(std::move(rowset), filter);
+ EXPECT_EQ(rows.begin(), rows.end()) << "Some rows were returned in response!";
+ EXPECT_TRUE(rows.Finish().ok()) << "Error reading rows.";
+}
+
+TEST(BigtableTestClientTest, SingleRowWriteAndRead) {
+ std::shared_ptr<::bigtable::DataClient> client_ptr =
+ std::make_shared<BigtableTestClient>();
+ ::bigtable::noex::Table table(client_ptr, "test_table");
+
+ WriteCell("r1", "f1", "c1", "v1", &table);
+
+ ::bigtable::RowSet rowset("r1");
+ auto filter = ::bigtable::Filter::Chain(::bigtable::Filter::Latest(1));
+ auto rows = table.ReadRows(std::move(rowset), filter);
+ auto itr = rows.begin();
+ EXPECT_NE(itr, rows.end()) << "No rows were returned in response!";
+ EXPECT_EQ(itr->row_key(), "r1");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "v1");
+
+ ++itr;
+ EXPECT_EQ(itr, rows.end());
+ EXPECT_TRUE(rows.Finish().ok());
+}
+
+TEST(BigtableTestClientTest, MultiRowWriteAndSingleRowRead) {
+ std::shared_ptr<::bigtable::DataClient> client_ptr =
+ std::make_shared<BigtableTestClient>();
+ ::bigtable::noex::Table table(client_ptr, "test_table");
+
+ WriteCell("r1", "f1", "c1", "v1", &table);
+ WriteCell("r2", "f1", "c1", "v2", &table);
+ WriteCell("r3", "f1", "c1", "v3", &table);
+
+ ::bigtable::RowSet rowset("r1");
+ auto filter = ::bigtable::Filter::Chain(::bigtable::Filter::Latest(1));
+ auto rows = table.ReadRows(std::move(rowset), filter);
+ auto itr = rows.begin();
+
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r1");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "v1");
+
+ ++itr;
+ EXPECT_EQ(itr, rows.end()) << "Extra rows in the response.";
+ EXPECT_TRUE(rows.Finish().ok());
+}
+
+TEST(BigtableTestClientTest, MultiRowWriteAndRead) {
+ std::shared_ptr<::bigtable::DataClient> client_ptr =
+ std::make_shared<BigtableTestClient>();
+ ::bigtable::noex::Table table(client_ptr, "test_table");
+
+ WriteCell("r1", "f1", "c1", "v1", &table);
+ WriteCell("r2", "f1", "c1", "v2", &table);
+ WriteCell("r3", "f1", "c1", "v3", &table);
+
+ ::bigtable::RowSet rowset("r1", "r2", "r3");
+ auto filter = ::bigtable::Filter::Chain(::bigtable::Filter::Latest(1));
+ auto rows = table.ReadRows(std::move(rowset), filter);
+ auto itr = rows.begin();
+
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r1");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "v1");
+
+ ++itr;
+
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r2");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "v2");
+
+ ++itr;
+
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r3");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "v3");
+
+ ++itr;
+ EXPECT_EQ(itr, rows.end()) << "Extra rows in the response.";
+ EXPECT_TRUE(rows.Finish().ok());
+}
+
+TEST(BigtableTestClientTest, MultiRowWriteAndPrefixRead) {
+ std::shared_ptr<::bigtable::DataClient> client_ptr =
+ std::make_shared<BigtableTestClient>();
+ ::bigtable::noex::Table table(client_ptr, "test_table");
+
+ WriteCell("r1", "f1", "c1", "v1", &table);
+ WriteCell("r2", "f1", "c1", "v2", &table);
+ WriteCell("r3", "f1", "c1", "v3", &table);
+
+ auto filter = ::bigtable::Filter::Chain(::bigtable::Filter::Latest(1));
+ auto rows = table.ReadRows(::bigtable::RowRange::Prefix("r"), filter);
+ auto itr = rows.begin();
+
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r1");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "v1");
+
+ ++itr;
+
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r2");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "v2");
+
+ ++itr;
+
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r3");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "v3");
+
+ ++itr;
+ EXPECT_EQ(itr, rows.end()) << "Extra rows in the response.";
+ EXPECT_TRUE(rows.Finish().ok());
+}
+
+TEST(BigtableTestClientTest, ColumnFiltering) {
+ std::shared_ptr<::bigtable::DataClient> client_ptr =
+ std::make_shared<BigtableTestClient>();
+ ::bigtable::noex::Table table(client_ptr, "test_table");
+
+ WriteCell("r1", "f1", "c1", "v1", &table);
+ WriteCell("r2", "f1", "c1", "v2", &table);
+ WriteCell("r3", "f1", "c1", "v3", &table);
+
+ // Extra cells
+ WriteCell("r1", "f2", "c1", "v1", &table);
+ WriteCell("r2", "f2", "c1", "v2", &table);
+ WriteCell("r3", "f1", "c2", "v3", &table);
+
+ auto filter = ::bigtable::Filter::Chain(
+ ::bigtable::Filter::Latest(1), ::bigtable::Filter::FamilyRegex("f1"),
+ ::bigtable::Filter::ColumnRegex("c1"));
+ auto rows = table.ReadRows(::bigtable::RowRange::Prefix("r"), filter);
+ auto itr = rows.begin();
+
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r1");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "v1");
+
+ ++itr;
+
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r2");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "v2");
+
+ ++itr;
+
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r3");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "v3");
+
+ ++itr;
+ EXPECT_EQ(itr, rows.end()) << "Extra rows in the response.";
+ EXPECT_TRUE(rows.Finish().ok());
+}
+
+TEST(BigtableTestClientTest, RowKeys) {
+ std::shared_ptr<::bigtable::DataClient> client_ptr =
+ std::make_shared<BigtableTestClient>();
+ ::bigtable::noex::Table table(client_ptr, "test_table");
+
+ WriteCell("r1", "f1", "c1", "v1", &table);
+ WriteCell("r2", "f1", "c1", "v2", &table);
+ WriteCell("r3", "f1", "c1", "v3", &table);
+
+ // Extra cells
+ WriteCell("r1", "f2", "c1", "v1", &table);
+ WriteCell("r2", "f2", "c1", "v2", &table);
+ WriteCell("r3", "f1", "c2", "v3", &table);
+
+ auto filter = ::bigtable::Filter::Chain(
+ ::bigtable::Filter::Latest(1), ::bigtable::Filter::CellsRowLimit(1),
+ ::bigtable::Filter::StripValueTransformer());
+ auto rows = table.ReadRows(::bigtable::RowRange::Prefix("r"), filter);
+ auto itr = rows.begin();
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r1");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "");
+
+ ++itr;
+
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r2");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "");
+
+ ++itr;
+
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r3");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "");
+
+ ++itr;
+ EXPECT_EQ(itr, rows.end()) << "Extra rows in the response.";
+ EXPECT_TRUE(rows.Finish().ok());
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/ops/bigtable_ops.cc b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc
new file mode 100644
index 0000000000..17ecc3dcb2
--- /dev/null
+++ b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc
@@ -0,0 +1,88 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+// TODO(saeta): Add support for setting ClientOptions values.
+REGISTER_OP("BigtableClient")
+ .Attr("project_id: string")
+ .Attr("instance_id: string")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Output("client: resource")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+// TODO(saeta): Add support for Application Profiles.
+// See https://cloud.google.com/bigtable/docs/app-profiles for more info.
+REGISTER_OP("BigtableTable")
+ .Input("client: resource")
+ .Attr("table_name: string")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Output("table: resource")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("DatasetToBigtable")
+ .Input("table: resource")
+ .Input("input_dataset: variant")
+ .Input("column_families: string")
+ .Input("columns: string")
+ .Input("timestamp: int64")
+ .SetShapeFn(shape_inference::NoOutputs);
+
+REGISTER_OP("BigtableLookupDataset")
+ .Input("keys_dataset: variant")
+ .Input("table: resource")
+ .Input("column_families: string")
+ .Input("columns: string")
+ .Output("handle: variant")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("BigtablePrefixKeyDataset")
+ .Input("table: resource")
+ .Input("prefix: string")
+ .Output("handle: variant")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("BigtableRangeKeyDataset")
+ .Input("table: resource")
+ .Input("start_key: string")
+ .Input("end_key: string")
+ .Output("handle: variant")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn(shape_inference::ScalarShape);
+
+// TODO(saeta): Support continuing despite bad data (e.g. empty string, or
+// skip incomplete row.)
+REGISTER_OP("BigtableScanDataset")
+ .Input("table: resource")
+ .Input("prefix: string")
+ .Input("start_key: string")
+ .Input("end_key: string")
+ .Input("column_families: string")
+ .Input("columns: string")
+ .Input("probability: float")
+ .Output("handle: variant")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn(shape_inference::ScalarShape);
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/ops/bigtable_test_ops.cc b/tensorflow/contrib/bigtable/ops/bigtable_test_ops.cc
new file mode 100644
index 0000000000..f7d02458f6
--- /dev/null
+++ b/tensorflow/contrib/bigtable/ops/bigtable_test_ops.cc
@@ -0,0 +1,27 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("BigtableTestClient")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Output("client: resource")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/__init__.py b/tensorflow/contrib/bigtable/python/kernel_tests/__init__.py
new file mode 100644
index 0000000000..292d8f4e51
--- /dev/null
+++ b/tensorflow/contrib/bigtable/python/kernel_tests/__init__.py
@@ -0,0 +1,20 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""This module contains tests for the bigtable integration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py
new file mode 100644
index 0000000000..d33a66f2df
--- /dev/null
+++ b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py
@@ -0,0 +1,132 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Bigtable Ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib import bigtable
+from tensorflow.contrib.bigtable.ops import gen_bigtable_ops
+from tensorflow.contrib.bigtable.ops import gen_bigtable_test_ops
+from tensorflow.contrib.util import loader
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+_bigtable_so = loader.load_op_library(
+ resource_loader.get_path_to_datafile("_bigtable_test.so"))
+
+
+class BigtableOpsTest(test.TestCase):
+ COMMON_ROW_KEYS = ["r1", "r2", "r3"]
+ COMMON_VALUES = ["v1", "v2", "v3"]
+
+ def setUp(self):
+ self._client = gen_bigtable_test_ops.bigtable_test_client()
+ table = gen_bigtable_ops.bigtable_table(self._client, "testtable")
+ self._table = bigtable.BigTable("testtable", None, table)
+
+ def _makeSimpleDataset(self):
+ output_rows = dataset_ops.Dataset.from_tensor_slices(self.COMMON_ROW_KEYS)
+ output_values = dataset_ops.Dataset.from_tensor_slices(self.COMMON_VALUES)
+ return dataset_ops.Dataset.zip((output_rows, output_values))
+
+ def _writeCommonValues(self, sess):
+ output_ds = self._makeSimpleDataset()
+ write_op = self._table.write(output_ds, ["cf1"], ["c1"])
+ sess.run(write_op)
+
+ def runReadKeyTest(self, read_ds):
+ itr = read_ds.make_initializable_iterator()
+ n = itr.get_next()
+ expected = list(self.COMMON_ROW_KEYS)
+ expected.reverse()
+ with self.test_session() as sess:
+ self._writeCommonValues(sess)
+ sess.run(itr.initializer)
+ for i in range(3):
+ output = sess.run(n)
+ want = expected.pop()
+ self.assertEqual(
+ compat.as_bytes(want), compat.as_bytes(output),
+ "Unequal at step %d: want: %s, got: %s" % (i, want, output))
+
+ def testReadPrefixKeys(self):
+ self.runReadKeyTest(self._table.keys_by_prefix_dataset("r"))
+
+ def testReadRangeKeys(self):
+ self.runReadKeyTest(self._table.keys_by_range_dataset("r1", "r4"))
+
+ def runScanTest(self, read_ds):
+ itr = read_ds.make_initializable_iterator()
+ n = itr.get_next()
+ expected_keys = list(self.COMMON_ROW_KEYS)
+ expected_keys.reverse()
+ expected_values = list(self.COMMON_VALUES)
+ expected_values.reverse()
+ with self.test_session() as sess:
+ self._writeCommonValues(sess)
+ sess.run(itr.initializer)
+ for i in range(3):
+ output = sess.run(n)
+ want = expected_keys.pop()
+ self.assertEqual(
+ compat.as_bytes(want), compat.as_bytes(output[0]),
+ "Unequal keys at step %d: want: %s, got: %s" % (i, want, output[0]))
+ want = expected_values.pop()
+ self.assertEqual(
+ compat.as_bytes(want), compat.as_bytes(output[1]),
+ "Unequal values at step: %d: want: %s, got: %s" % (i, want,
+ output[1]))
+
+ def testScanPrefixStringCol(self):
+ self.runScanTest(self._table.scan_prefix("r", cf1="c1"))
+
+ def testScanPrefixListCol(self):
+ self.runScanTest(self._table.scan_prefix("r", cf1=["c1"]))
+
+ def testScanRangeStringCol(self):
+ self.runScanTest(self._table.scan_range("r1", "r4", cf1="c1"))
+
+ def testScanRangeListCol(self):
+ self.runScanTest(self._table.scan_range("r1", "r4", cf1=["c1"]))
+
+ def testLookup(self):
+ ds = self._table.keys_by_prefix_dataset("r")
+ ds = ds.apply(self._table.lookup_columns(cf1="c1"))
+ itr = ds.make_initializable_iterator()
+ n = itr.get_next()
+ expected_keys = list(self.COMMON_ROW_KEYS)
+ expected_values = list(self.COMMON_VALUES)
+ expected_tuples = zip(expected_keys, expected_values)
+ with self.test_session() as sess:
+ self._writeCommonValues(sess)
+ sess.run(itr.initializer)
+ for i, elem in enumerate(expected_tuples):
+ output = sess.run(n)
+ self.assertEqual(
+ compat.as_bytes(elem[0]), compat.as_bytes(output[0]),
+ "Unequal keys at step %d: want: %s, got: %s" %
+ (i, compat.as_bytes(elem[0]), compat.as_bytes(output[0])))
+ self.assertEqual(
+ compat.as_bytes(elem[1]), compat.as_bytes(output[1]),
+ "Unequal values at step %d: want: %s, got: %s" %
+ (i, compat.as_bytes(elem[1]), compat.as_bytes(output[1])))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/bigtable/python/ops/__init__.py b/tensorflow/contrib/bigtable/python/ops/__init__.py
new file mode 100644
index 0000000000..36d75b0d70
--- /dev/null
+++ b/tensorflow/contrib/bigtable/python/ops/__init__.py
@@ -0,0 +1,20 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""This module contains the Python API for the Cloud Bigtable integration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
new file mode 100644
index 0000000000..a54e020ed7
--- /dev/null
+++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
@@ -0,0 +1,480 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The Python API for TensorFlow's Bigtable integration.
+
+TensorFlow has support for reading from and writing to Cloud Bigtable. To use
+the Bigtable TensorFlow integration, first create a BigtableClient (which
+configures your connection to Cloud Bigtable), and then open a Table. The Table
+object then allows you to create numerous @{tf.data.Dataset}s to read data, or
+write a @{tf.data.Dataset} object to the underlying Bigtable Table.
+
+For background on Google Cloud Bigtable, see: https://cloud.google.com/bigtable.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from six import iteritems
+
+from tensorflow.contrib.bigtable.ops import gen_bigtable_ops
+from tensorflow.contrib.util import loader
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.platform import resource_loader
+
+_bigtable_so = loader.load_op_library(
+ resource_loader.get_path_to_datafile("_bigtable.so"))
+
+
+class BigtableClient(object):
+ """BigtableClient is the entrypoint for interacting with Cloud Bigtable in TF.
+
+ BigtableClient encapsulates a connection to Cloud Bigtable, and exposes the
+ `table` method to open a Bigtable Table.
+ """
+
+ def __init__(self, project_id, instance_id):
+ """Creates a BigtableClient that can be used to open connections to tables.
+
+ Args:
+ project_id: A string representing the GCP project id to connect to.
+ instance_id: A string representing the Bigtable instance to connect to.
+ """
+ self._project_id = project_id
+ self._instance_id = instance_id
+ self._resource = gen_bigtable_ops.bigtable_client(project_id, instance_id)
+
+ def table(self, name, snapshot=None):
+ """Opens a table and returns a `BigTable` object.
+
+ Args:
+ name: A `tf.string` `tf.Tensor` name of the table to open.
+ snapshot: Either a `tf.string` `tf.Tensor` snapshot id, or `True` to
+ request the creation of a snapshot. (Note: currently unimplemented.)
+
+ Returns:
+ A `BigTable` python object representing the operations available on the
+ table.
+ """
+ # TODO(saeta): Implement snapshot functionality.
+ table = gen_bigtable_ops.bigtable_table(self._resource, name)
+ return BigTable(name, snapshot, table)
+
+
+class BigTable(object):
+ """BigTable is the entrypoint for reading and writing data in Cloud Bigtable.
+
+ This BigTable class is the python representation of the Cloud Bigtable table
+ within TensorFlow. Methods on this class allow data to be read from and
+ written to the Cloud Bigtable service in flexible and high performance
+ manners.
+ """
+
+ # TODO(saeta): Investigate implementing tf.contrib.lookup.LookupInterface.
+ # TODO(saeta): Consider variant tensors instead of resources (while supporting
+ # connection pooling).
+
+ def __init__(self, name, snapshot, resource):
+ self._name = name
+ self._snapshot = snapshot
+ self._resource = resource
+
+ def lookup_columns(self, *args, **kwargs):
+ """Retrieves the values of columns for a dataset of keys.
+
+ Example usage:
+ ```
+ table = bigtable_client.table("my_table")
+ key_dataset = table.get_keys_prefix("imagenet")
+ images = key_dataset.apply(table.lookup_columns(("cf1", "image"),
+ ("cf2", "label"),
+ ("cf2", "boundingbox")))
+ training_data = images.map(parse_and_crop, num_parallel_calls=64).batch(128)
+ ```
+
+ Alternatively, you can use keyword arguments to specify the columns to
+ capture. Example (same as above, rewritten):
+ ```
+ table = bigtable_client.table("my_table")
+ key_dataset = table.get_keys_prefix("imagenet")
+ images = key_dataset.apply(table.lookup_columns(
+ cf1="image", cf2=("label", "boundingbox")))
+ training_data = images.map(parse_and_crop, num_parallel_calls=64).batch(128)
+ ```
+
+ Note: certain kwargs keys are reserved, and thus some column families cannot
+ be identified using the kwargs syntax. Instead, please use the args syntax.
+ This list includes:
+ - 'name'
+ This list can change at any time.
+
+ Args:
+ *args: A list of tuples containing (column family, column name) pairs.
+ **kwargs: Column families and
+
+ Returns:
+ A function that can be passed to `tf.data.Dataset.apply` to retrieve the
+ values of columns for the rows.
+ """
+ table = self # Capture self
+ normalized = args
+ if normalized is None:
+ normalized = []
+ if isinstance(normalized, tuple):
+ normalized = list(normalized)
+ for key, value in iteritems(kwargs):
+ if key == "name":
+ continue
+ if isinstance(value, str):
+ normalized.append((key, value))
+ continue
+ for col in value:
+ normalized.append((key, col))
+
+ def _apply_fn(dataset):
+ # TODO(saeta): Verify dataset's types are correct!
+ return _BigtableLookupDataset(dataset, table, normalized)
+
+ return _apply_fn
+
+ def keys_by_range_dataset(self, start, end):
+ """Retrieves all row keys between start and end.
+
+ Note: it does NOT retrieve the values of columns.
+
+ Args:
+ start: The start row key. The row keys for rows after start (inclusive)
+ will be retrieved.
+ end: (Optional.) The end row key. Rows up to (but not including) end will
+ be retrieved. If end is None, all subsequent row keys will be retrieved.
+
+ Returns:
+ A @{tf.data.Dataset} containing `tf.string` Tensors corresponding to all
+ of the row keys between `start` and `end`.
+ """
+ # TODO(saeta): Make inclusive / exclusive configurable?
+ if end is None:
+ end = ""
+ return _BigtableRangeKeyDataset(self, start, end)
+
+ def keys_by_prefix_dataset(self, prefix):
+ """Retrieves the row keys matching a given prefix.
+
+ Args:
+ prefix: All row keys that begin with `prefix` in the table will be
+ retrieved.
+
+ Returns:
+ A @{tf.data.Dataset}. containing `tf.string` Tensors corresponding to all
+ of the row keys matching that prefix.
+ """
+ return _BigtablePrefixKeyDataset(self, prefix)
+
+ def scan_prefix(self, prefix, probability=None, columns=None, **kwargs):
+ """Retrieves row (including values) from the Bigtable service.
+
+ Rows with row-key prefixed by `prefix` will be retrieved.
+
+ Specifying the columns to retrieve for each row is done by either using
+ kwargs or in the columns parameter. To retrieve values of the columns "c1",
+ and "c2" from the column family "cfa", and the value of the column "c3"
+ from column family "cfb", the following datasets (`ds1`, and `ds2`) are
+ equivalent:
+
+ ```
+ table = # ...
+ ds1 = table.scan_prefix("row_prefix", columns=[("cfa", "c1"),
+ ("cfa", "c2"),
+ ("cfb", "c3")])
+ ds2 = table.scan_prefix("row_prefix", cfa=["c1", "c2"], cfb="c3")
+ ```
+
+ Note: only the latest value of a cell will be retrieved.
+
+ Args:
+ prefix: The prefix all row keys muat match to be retrieved for prefix-
+ based scans.
+ probability: Probabilistically sample rows.
+ columns: The columns to read. Note: most commonly, they are expressed as
+ kwargs. Use the columns value if you are using column families that are
+ reserved. The value of columns and kwargs are merged. Columns is a list
+ of tuples of strings ("column_family", "column_qualifier").
+ **kwargs: The column families and columns to read. Keys are treated as
+ column_families, and values can be either lists of strings, or strings
+ that are treated as the column qualifier (column name).
+
+ Returns:
+ A @{tf.data.Dataset} returning the row keys and the cell contents.
+
+ Raises:
+ ValueError: If the configured probability is unexpected.
+ """
+ if probability is None:
+ probability = 1.0
+ if isinstance(probability, float) and (probability <= 0.0 or
+ probability > 1.0):
+ raise ValueError("probability must be in the range (0, 1].")
+
+ normalized = columns
+ if normalized is None:
+ normalized = []
+ if isinstance(normalized, tuple):
+ normalized = list(normalized)
+ for key, value in iteritems(kwargs):
+ if key == "name":
+ continue
+ if isinstance(value, str):
+ normalized.append((key, value))
+ continue
+ for col in value:
+ normalized.append((key, col))
+
+ return _BigtableScanDataset(self, prefix, "", "", normalized, probability)
+
+ def scan_range(self, start, end, probability=None, columns=None, **kwargs):
+ """Retrieves rows (including values) from the Bigtable service.
+
+ Rows with row-keys between `start` and `end` will be retrieved.
+
+ Specifying the columns to retrieve for each row is done by either using
+ kwargs or in the columns parameter. To retrieve values of the columns "c1",
+ and "c2" from the column family "cfa", and the value of the column "c3"
+ from column family "cfb", the following datasets (`ds1`, and `ds2`) are
+ equivalent:
+
+ ```
+ table = # ...
+ ds1 = table.scan_range("row_start", "row_end", columns=[("cfa", "c1"),
+ ("cfa", "c2"),
+ ("cfb", "c3")])
+ ds2 = table.scan_range("row_start", "row_end", cfa=["c1", "c2"], cfb="c3")
+ ```
+
+ Note: only the latest value of a cell will be retrieved.
+
+ Args:
+ start: The start of the range when scanning by range.
+ end: (Optional.) The end of the range when scanning by range.
+ probability: Probabilistically sample rows.
+ columns: The columns to read. Note: most commonly, they are expressed as
+ kwargs. Use the columns value if you are using column families that are
+ reserved. The value of columns and kwargs are merged. Columns is a list
+ of tuples of strings ("column_family", "column_qualifier").
+ **kwargs: The column families and columns to read. Keys are treated as
+ column_families, and values can be either lists of strings, or strings
+ that are treated as the column qualifier (column name).
+
+ Returns:
+ A @{tf.data.Dataset} returning the row keys and the cell contents.
+
+ Raises:
+ ValueError: If the configured probability is unexpected.
+ """
+ if probability is None:
+ probability = 1.0
+ if isinstance(probability, float) and (probability <= 0.0 or
+ probability > 1.0):
+ raise ValueError("probability must be in the range (0, 1].")
+
+ normalized = columns
+ if normalized is None:
+ normalized = []
+ if isinstance(normalized, tuple):
+ normalized = list(normalized)
+ for key, value in iteritems(kwargs):
+ if key == "name":
+ continue
+ if isinstance(value, str):
+ normalized.append((key, value))
+ continue
+ for col in value:
+ normalized.append((key, col))
+
+ return _BigtableScanDataset(self, "", start, end, normalized, probability)
+
+ def write(self, dataset, column_families, columns, timestamp=None):
+ """Writes a dataset to the table.
+
+ Args:
+ dataset: A @{tf.data.Dataset} to be written to this table. It must produce
+ a list of number-of-columns+1 elements, all of which must be strings.
+ The first value will be used as the row key, and subsequent values will
+ be used as cell values for the corresponding columns from the
+ corresponding column_families and columns entries.
+ column_families: A @{tf.Tensor} of `tf.string`s corresponding to the
+ column names to store the dataset's elements into.
+ columns: A `tf.Tensor` of `tf.string`s corresponding to the column names
+ to store the dataset's elements into.
+ timestamp: (Optional.) An int64 timestamp to write all the values at.
+ Leave as None to use server-provided timestamps.
+
+ Returns:
+ A @{tf.Operation} that can be run to perform the write.
+
+ Raises:
+ ValueError: If there are unexpected or incompatible types, or if the
+ number of columns and column_families does not match the output of
+ `dataset`.
+ """
+ if timestamp is None:
+ timestamp = -1 # Bigtable server provided timestamp.
+ for tensor_type in nest.flatten(dataset.output_types):
+ if tensor_type != dtypes.string:
+ raise ValueError("Not all elements of the dataset were `tf.string`")
+ for shape in nest.flatten(dataset.output_shapes):
+ if not shape.is_compatible_with(tensor_shape.scalar()):
+ raise ValueError("Not all elements of the dataset were scalars")
+ if len(column_families) != len(columns):
+ raise ValueError("len(column_families) != len(columns)")
+ if len(nest.flatten(dataset.output_types)) != len(columns) + 1:
+ raise ValueError("A column name must be specified for every component of "
+ "the dataset elements. (e.g.: len(columns) != "
+ "len(dataset.output_types))")
+ return gen_bigtable_ops.dataset_to_bigtable(
+ self._resource,
+ dataset._as_variant_tensor(), # pylint: disable=protected-access
+ column_families,
+ columns,
+ timestamp)
+
+
+class _BigtableKeyDataset(dataset_ops.Dataset):
+ """_BigtableKeyDataset is an abstract class representing the keys of a table.
+ """
+
+ def __init__(self, table):
+ """Constructs a _BigtableKeyDataset.
+
+ Args:
+ table: a Bigtable class.
+ """
+ super(_BigtableKeyDataset, self).__init__()
+ self._table = table
+
+ @property
+ def output_classes(self):
+ return ops.Tensor
+
+ @property
+ def output_shapes(self):
+ return tensor_shape.TensorShape([])
+
+ @property
+ def output_types(self):
+ return dtypes.string
+
+
+class _BigtablePrefixKeyDataset(_BigtableKeyDataset):
+ """_BigtablePrefixKeyDataset represents looking up keys by prefix.
+ """
+
+ def __init__(self, table, prefix):
+ super(_BigtablePrefixKeyDataset, self).__init__(table)
+ self._prefix = prefix
+
+ def _as_variant_tensor(self):
+ return gen_bigtable_ops.bigtable_prefix_key_dataset(
+ table=self._table._resource, # pylint: disable=protected-access
+ prefix=self._prefix)
+
+
+class _BigtableRangeKeyDataset(_BigtableKeyDataset):
+ """_BigtableRangeKeyDataset represents looking up keys by range.
+ """
+
+ def __init__(self, table, start, end):
+ super(_BigtableRangeKeyDataset, self).__init__(table)
+ self._start = start
+ self._end = end
+
+ def _as_variant_tensor(self):
+ return gen_bigtable_ops.bigtable_range_key_dataset(
+ table=self._table._resource, # pylint: disable=protected-access
+ start_key=self._start,
+ end_key=self._end)
+
+
+class _BigtableLookupDataset(dataset_ops.Dataset):
+ """_BigtableLookupDataset represents a dataset that retrieves values for keys.
+ """
+
+ def __init__(self, dataset, table, normalized):
+ self._num_outputs = len(normalized) + 1 # 1 for row key
+ self._dataset = dataset
+ self._table = table
+ self._normalized = normalized
+ self._column_families = [i[0] for i in normalized]
+ self._columns = [i[1] for i in normalized]
+
+ @property
+ def output_classes(self):
+ return tuple([ops.Tensor] * self._num_outputs)
+
+ @property
+ def output_shapes(self):
+ return tuple([tensor_shape.TensorShape([])] * self._num_outputs)
+
+ @property
+ def output_types(self):
+ return tuple([dtypes.string] * self._num_outputs)
+
+ def _as_variant_tensor(self):
+ # pylint: disable=protected-access
+ return gen_bigtable_ops.bigtable_lookup_dataset(
+ keys_dataset=self._dataset._as_variant_tensor(),
+ table=self._table._resource,
+ column_families=self._column_families,
+ columns=self._columns)
+
+
+class _BigtableScanDataset(dataset_ops.Dataset):
+ """_BigtableScanDataset represents a dataset that retrieves keys and values.
+ """
+
+ def __init__(self, table, prefix, start, end, normalized, probability):
+ self._table = table
+ self._prefix = prefix
+ self._start = start
+ self._end = end
+ self._column_families = [i[0] for i in normalized]
+ self._columns = [i[1] for i in normalized]
+ self._probability = probability
+ self._num_outputs = len(normalized) + 1 # 1 for row key
+
+ @property
+ def output_classes(self):
+ return tuple([ops.Tensor] * self._num_outputs)
+
+ @property
+ def output_shapes(self):
+ return tuple([tensor_shape.TensorShape([])] * self._num_outputs)
+
+ @property
+ def output_types(self):
+ return tuple([dtypes.string] * self._num_outputs)
+
+ def _as_variant_tensor(self):
+ return gen_bigtable_ops.bigtable_scan_dataset(
+ table=self._table._resource, # pylint: disable=protected-access
+ prefix=self._prefix,
+ start_key=self._start,
+ end_key=self._end,
+ column_families=self._column_families,
+ columns=self._columns,
+ probability=self._probability)
diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD
index 1a7a3759ba..523a9efcf0 100644
--- a/tensorflow/contrib/cloud/BUILD
+++ b/tensorflow/contrib/cloud/BUILD
@@ -50,6 +50,7 @@ py_library(
deps = [
":gen_bigquery_reader_ops",
":gen_gcs_config_ops",
+ "//tensorflow/contrib/bigtable",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:io_ops",
"//tensorflow/python:util",
diff --git a/tensorflow/contrib/cloud/README.md b/tensorflow/contrib/cloud/README.md
new file mode 100644
index 0000000000..134ce057f4
--- /dev/null
+++ b/tensorflow/contrib/cloud/README.md
@@ -0,0 +1,18 @@
+# Cloud #
+
+## BigTable ##
+
+[Google Cloud BigTable](https://cloud.google.com/bigtable/) is a high
+performance storage system that can store and serve training data. This contrib
+package contains an experimental integration with TensorFlow.
+
+> **Status: Highly experimental.** The current implementation is very much in
+> flux. Please use at your own risk! :-)
+
+<!-- TODO(saeta): Document usage / methods / etc. -->
+
+## Cloud Storage (GCS) ##
+
+The Google Cloud Storage ops allow the user to configure the GCS File System.
+
+<!-- TODO(saeta): Document usage / methods / etc. -->
diff --git a/tensorflow/contrib/cloud/__init__.py b/tensorflow/contrib/cloud/__init__.py
index ef7aa7624c..af81106a68 100644
--- a/tensorflow/contrib/cloud/__init__.py
+++ b/tensorflow/contrib/cloud/__init__.py
@@ -18,15 +18,24 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-# pylint: disable=line-too-long,wildcard-import
+import os
+
+# pylint: disable=line-too-long,wildcard-import,g-import-not-at-top
from tensorflow.contrib.cloud.python.ops.bigquery_reader_ops import *
from tensorflow.contrib.cloud.python.ops.gcs_config_ops import *
-# pylint: enable=line-too-long,wildcard-import
+
+if os.name != 'nt':
+ from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigTable
+ from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigtableClient
+
+del os
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'BigQueryReader',
+ 'BigTable',
+ 'BigtableClient',
'BlockCacheParams',
'configure_colab_session',
'configure_gcs',
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index d530572e91..8ff6ebedab 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -86,6 +86,8 @@ tensorflow/contrib/batching/python/ops
tensorflow/contrib/bayesflow
tensorflow/contrib/bayesflow/python
tensorflow/contrib/bayesflow/python/ops
+# tensorflow/contrib/bigtable/python
+# tensorflow/contrib/bigtable/python/ops
tensorflow/contrib/boosted_trees
tensorflow/contrib/boosted_trees/estimator_batch
tensorflow/contrib/boosted_trees/kernels
diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake
index 067c299a71..872b016d2b 100644
--- a/tensorflow/contrib/cmake/tf_core_framework.cmake
+++ b/tensorflow/contrib/cmake/tf_core_framework.cmake
@@ -49,43 +49,48 @@ function(RELATIVE_PROTOBUF_GENERATE_CPP SRCS HDRS ROOT_DIR)
set(${HDRS} ${${HDRS}} PARENT_SCOPE)
endfunction()
-if(NOT WIN32)
- function(RELATIVE_PROTOBUF_GENERATE_GRPC_CPP SRCS HDRS ROOT_DIR)
- if(NOT ARGN)
- message(SEND_ERROR "Error: RELATIVE_PROTOBUF_GENERATE_GRPC_CPP() called without any proto files")
- return()
+function(RELATIVE_PROTOBUF_GENERATE_GRPC_CPP SRCS HDRS ROOT_DIR)
+ if(NOT ARGN)
+ message(SEND_ERROR "Error: RELATIVE_PROTOBUF_GENERATE_GRPC_CPP() called without any proto files")
+ return()
+ endif()
+
+ set(${SRCS})
+ set(${HDRS})
+ foreach(FIL ${ARGN})
+ set(ABS_FIL ${ROOT_DIR}/${FIL})
+ get_filename_component(FIL_WE ${FIL} NAME_WE)
+ get_filename_component(FIL_DIR ${ABS_FIL} PATH)
+ file(RELATIVE_PATH REL_DIR ${ROOT_DIR} ${FIL_DIR})
+
+ list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.cc")
+ list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.h")
+ list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc")
+ list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h")
+
+ # We adust the path of the gRPC code generation accordingly.
+ if(WIN32)
+ set(GRPC_PROTOC_PLUGIN_PATH ${GRPC_BUILD}/Release/grpc_cpp_plugin.exe)
+ else()
+ set(GRPC_PROTOC_PLUGIN_PATH ${GRPC_BUILD}/grpc_cpp_plugin)
endif()
- set(${SRCS})
- set(${HDRS})
- foreach(FIL ${ARGN})
- set(ABS_FIL ${ROOT_DIR}/${FIL})
- get_filename_component(FIL_WE ${FIL} NAME_WE)
- get_filename_component(FIL_DIR ${ABS_FIL} PATH)
- file(RELATIVE_PATH REL_DIR ${ROOT_DIR} ${FIL_DIR})
-
- list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.cc")
- list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.h")
- list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc")
- list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h")
-
- add_custom_command(
- OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.cc"
- "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.h"
- "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc"
- "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h"
- COMMAND ${PROTOBUF_PROTOC_EXECUTABLE}
- ARGS --grpc_out ${CMAKE_CURRENT_BINARY_DIR} --cpp_out ${CMAKE_CURRENT_BINARY_DIR} --plugin protoc-gen-grpc=${GRPC_BUILD}/grpc_cpp_plugin -I ${ROOT_DIR} ${ABS_FIL} -I ${PROTOBUF_INCLUDE_DIRS}
- DEPENDS ${ABS_FIL} protobuf grpc
- COMMENT "Running C++ protocol buffer grpc compiler on ${FIL}"
- VERBATIM )
- endforeach()
-
- set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE)
- set(${SRCS} ${${SRCS}} PARENT_SCOPE)
- set(${HDRS} ${${HDRS}} PARENT_SCOPE)
- endfunction()
-endif()
+ add_custom_command(
+ OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.cc"
+ "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.grpc.pb.h"
+ "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.cc"
+ "${CMAKE_CURRENT_BINARY_DIR}/${REL_DIR}/${FIL_WE}.pb.h"
+ COMMAND ${PROTOBUF_PROTOC_EXECUTABLE}
+ ARGS --grpc_out ${CMAKE_CURRENT_BINARY_DIR} --cpp_out ${CMAKE_CURRENT_BINARY_DIR} --plugin=protoc-gen-grpc=${GRPC_PROTOC_PLUGIN_PATH} -I ${ROOT_DIR} ${ABS_FIL} -I ${PROTOBUF_INCLUDE_DIRS}
+ DEPENDS ${ABS_FIL} protobuf grpc
+ COMMENT "Running C++ protocol buffer grpc compiler on ${FIL}"
+ VERBATIM )
+ endforeach()
+
+ set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE)
+ set(${SRCS} ${${SRCS}} PARENT_SCOPE)
+ set(${HDRS} ${${HDRS}} PARENT_SCOPE)
+endfunction()
function(RELATIVE_PROTOBUF_TEXT_GENERATE_CPP SRCS HDRS ROOT_DIR)
if(NOT ARGN)
@@ -175,17 +180,14 @@ RELATIVE_PROTOBUF_TEXT_GENERATE_CPP(PROTO_TEXT_SRCS PROTO_TEXT_HDRS
${tensorflow_source_dir} ${tf_proto_text_srcs}
)
-if(WIN32)
- add_library(tf_protos_cc ${PROTO_SRCS} ${PROTO_HDRS})
-else()
- file(GLOB_RECURSE tf_protos_grpc_cc_srcs RELATIVE ${tensorflow_source_dir}
- "${tensorflow_source_dir}/tensorflow/core/debug/*.proto"
- )
- RELATIVE_PROTOBUF_GENERATE_GRPC_CPP(PROTO_GRPC_SRCS PROTO_GRPC_HDRS
- ${tensorflow_source_dir} ${tf_protos_grpc_cc_srcs}
- )
- add_library(tf_protos_cc ${PROTO_GRPC_SRCS} ${PROTO_GRPC_HDRS} ${PROTO_SRCS} ${PROTO_HDRS})
-endif()
+file(GLOB_RECURSE tf_protos_grpc_cc_srcs RELATIVE ${tensorflow_source_dir}
+ "${tensorflow_source_dir}/tensorflow/core/debug/*.proto"
+ "${tensorflow_source_dir}/tensorflow/core/protobuf/master_service.proto"
+)
+RELATIVE_PROTOBUF_GENERATE_GRPC_CPP(PROTO_GRPC_SRCS PROTO_GRPC_HDRS
+ ${tensorflow_source_dir} ${tf_protos_grpc_cc_srcs}
+)
+add_library(tf_protos_cc ${PROTO_GRPC_SRCS} ${PROTO_GRPC_HDRS} ${PROTO_SRCS} ${PROTO_HDRS})
########################################################
# tf_core_lib library
diff --git a/tensorflow/contrib/eager/python/examples/revnet/README.md b/tensorflow/contrib/eager/python/examples/revnet/README.md
new file mode 100644
index 0000000000..21fc44febc
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/revnet/README.md
@@ -0,0 +1,45 @@
+# RevNet with TensorFlow eager execution
+
+This folder contains an TensorFlow eager implementation of the [Reversible Residual Network](https://arxiv.org/pdf/1707.04585.pdf) adapted from the released implementation by the authors. The presented implementation can be ran both in eager and graph mode. The code is considerably simplified with `tf.GradientTape`. Moreover, we reduce the step of reconstructing the outputs. This saves us from using `tf.stop_gradient` and makes the model run faster.
+
+## Content
+
+- `revnet.py`: The RevNet model.
+- `blocks.py`: The relevant reversible blocks.
+- `cifar_tfrecords.py`: Script to generate the TFRecords for both CIFAR-10 and CIFAR-100.
+- `cifar_input.py`: Script to read from TFRecords and generate dataset objects with the `tf.data` API.
+- `config.py`: Configuration file for network architectures and training hyperparameters.
+- `main.py`: Main training and evaluation script.
+- `ops.py`: Auxiliary downsampling operation.
+
+## To run
+- Make sure you have installed TensorFlow 1.9+ or the latest `tf-nightly`
+or `tf-nightly-gpu` pip package in order to access the eager execution feature.
+
+- First run
+
+```bash
+python cifar_tfrecords.py --data_dir ${PWD}/cifar
+```
+to download the cifar dataset and convert them
+to TFRecords. This produces TFRecord files for both CIFAR-10 and CIFAR-100.
+
+- To train a model run
+
+```bash
+python main.py --data_dir ${PWD}/cifar
+```
+
+- Optional arguments for `main.py` include
+ - `train_dir`: Directory to store eventfiles and checkpoints.
+ - `restore`: Restore the latest checkpoint.
+ - `validate`: Use validation set for training monitoring.
+ - `manual_grad`: Use the manually defined gradient map given by the authors.
+ - `dataset`: Use either `cifar-10` or `cifar-100`
+
+## Performance
+- With the current implementation, RevNet-38 achieves >92% on CIFAR-10 and >71% on CIFAR-100.
+
+## Reference
+The Reversible Residual Network: Backpropagation Without Storing Activations.
+Aidan N. Gomez, Mengye Ren, Raquel Urtasun, Roger B. Grosse. Neural Information Processing Systems (NIPS), 2017.
diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java
index bfb4a0a04b..580206943b 100644
--- a/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java
@@ -25,6 +25,8 @@ import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
@@ -54,6 +56,14 @@ public class TFLiteObjectDetectionAPIModel implements Classifier {
private static final float H_SCALE = 5.0f;
private static final float W_SCALE = 5.0f;
+ // Float model
+ private static final float IMAGE_MEAN = 128.0f;
+ private static final float IMAGE_STD = 128.0f;
+
+ //Number of threads in the java app
+ private static final int NUM_THREADS = 4;
+
+
// Config values.
private int inputSize;
@@ -65,7 +75,7 @@ public class TFLiteObjectDetectionAPIModel implements Classifier {
private float[][][] outputLocations;
private float[][][] outputClasses;
- float[][][][] img;
+ private ByteBuffer imgData = null;
private Interpreter tfLite;
@@ -176,9 +186,12 @@ public class TFLiteObjectDetectionAPIModel implements Classifier {
}
// Pre-allocate buffers.
- d.img = new float[1][inputSize][inputSize][3];
-
+ int numBytesPerChannel = 4; // Floating point
+ d.imgData = ByteBuffer.allocateDirect(1 * d.inputSize * d.inputSize * 3 * numBytesPerChannel);
+ d.imgData.order(ByteOrder.nativeOrder());
d.intValues = new int[d.inputSize * d.inputSize];
+
+ d.tfLite.setNumThreads(NUM_THREADS);
d.outputLocations = new float[1][NUM_RESULTS][4];
d.outputClasses = new float[1][NUM_RESULTS][NUM_CLASSES];
return d;
@@ -198,10 +211,11 @@ public class TFLiteObjectDetectionAPIModel implements Classifier {
for (int i = 0; i < inputSize; ++i) {
for (int j = 0; j < inputSize; ++j) {
- int pixel = intValues[j * inputSize + i];
- img[0][j][i][2] = (float) (pixel & 0xFF) / 128.0f - 1.0f;
- img[0][j][i][1] = (float) ((pixel >> 8) & 0xFF) / 128.0f - 1.0f;
- img[0][j][i][0] = (float) ((pixel >> 16) & 0xFF) / 128.0f - 1.0f;
+ int pixelValue = intValues[i * inputSize + j];
+ // Float model
+ imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
}
}
Trace.endSection(); // preprocessBitmap
@@ -211,7 +225,7 @@ public class TFLiteObjectDetectionAPIModel implements Classifier {
outputLocations = new float[1][NUM_RESULTS][4];
outputClasses = new float[1][NUM_RESULTS][NUM_CLASSES];
- Object[] inputArray = {img};
+ Object[] inputArray = {imgData};
Map<Integer, Object> outputMap = new HashMap<>();
outputMap.put(0, outputLocations);
outputMap.put(1, outputClasses);
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 1b8a7205e6..8597707b24 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -59,6 +59,7 @@ using reference_ops::Mean;
using reference_ops::RankOneSelect;
using reference_ops::Relu1;
using reference_ops::Relu6;
+using reference_ops::ReluX;
using reference_ops::Select;
using reference_ops::SpaceToBatchND;
using reference_ops::StridedSlice;
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 16901a3e53..9357e7407e 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -951,6 +951,19 @@ inline void Relu6(const float* input_data, const RuntimeShape& input_shape,
}
}
+inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data,
+ const RuntimeShape& input_shape, uint8* output_data,
+ const RuntimeShape& output_shape) {
+ gemmlowp::ScopedProfilingLabel label("Quantized ReluX (not fused)");
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
+ for (int i = 0; i < flat_size; ++i) {
+ const uint8 val = input_data[i];
+ const uint8 clamped =
+ val > max_value ? max_value : val < min_value ? min_value : val;
+ output_data[i] = clamped;
+ }
+}
+
template <FusedActivationFunctionType Ac>
void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
float* output_data, const RuntimeShape& output_shape) {
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index f54db3af87..c448fb71db 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -991,7 +991,7 @@ TfLiteStatus InterpreterBuilder::operator()(
variables.push_back(i);
}
}
- (**interpreter).SetVariables(variables);
+ (**interpreter).SetVariables(std::move(variables));
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
index e7343cb388..681448be20 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
@@ -20,8 +20,8 @@ limitations under the License.
#include <vector>
// Place `<locale>` before <Python.h> to avoid build failures in macOS.
-#include <locale>
#include <Python.h>
+#include <locale>
// We forward declare TFLite classes here to avoid exposing them to SWIG.
namespace tflite {
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index a4229f91f5..29a1487c1f 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -132,7 +132,7 @@ class TocoConverter(object):
Args:
- graph_def: TensorFlow GraphDef.
+ graph_def: Frozen TensorFlow GraphDef.
input_tensors: List of input tensors. Type and shape are computed using
`foo.get_shape()` and `foo.dtype`.
output_tensors: List of output tensors (only .name is used from this).
@@ -178,7 +178,7 @@ class TocoConverter(object):
"""Creates a TocoConverter class from a file containing a frozen GraphDef.
Args:
- graph_def_file: Full filepath of file containing TensorFlow GraphDef.
+ graph_def_file: Full filepath of file containing frozen GraphDef.
input_arrays: List of input tensors to freeze graph with.
output_arrays: List of output tensors to freeze graph with.
input_shapes: Dict of strings representing input tensor names to list of
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
index 0a60477c6d..9bd1f4f76e 100644
--- a/tensorflow/contrib/lite/python/tflite_convert.py
+++ b/tensorflow/contrib/lite/python/tflite_convert.py
@@ -225,7 +225,7 @@ def run_main(_):
input_file_group.add_argument(
"--graph_def_file",
type=str,
- help="Full filepath of file containing TensorFlow GraphDef.")
+ help="Full filepath of file containing frozen TensorFlow GraphDef.")
input_file_group.add_argument(
"--saved_model_dir",
type=str,
diff --git a/tensorflow/contrib/lite/toco/README.md b/tensorflow/contrib/lite/toco/README.md
index ee83c7a6e3..2db6a627ab 100644
--- a/tensorflow/contrib/lite/toco/README.md
+++ b/tensorflow/contrib/lite/toco/README.md
@@ -17,11 +17,12 @@ Usage information is given in these documents:
Once an application developer has a trained TensorFlow model, TOCO will accept
that model and generate a TensorFlow Lite
[FlatBuffer](https://google.github.io/flatbuffers/) file. TOCO currently supports
-[SavedModels](https://www.tensorflow.org/guide/saved_model#using_savedmodel_with_estimators)
-and frozen graphs (models generated via
-[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)).
-The TensorFlow Lite FlatBuffer file can be shipped to client devices, generally
-mobile devices, where the TensorFlow Lite interpreter handles them on-device.
-This flow is represented in the diagram below.
+[SavedModels](https://www.tensorflow.org/guide/saved_model#using_savedmodel_with_estimators),
+frozen graphs (models generated via
+[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)),
+and `tf.Keras` model files. The TensorFlow Lite FlatBuffer file can be shipped
+to client devices, generally mobile devices, where the TensorFlow Lite
+interpreter handles them on-device. This flow is represented in the diagram
+below.
![drawing](g3doc/toco_landscape.svg)
diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
index 0ab024c618..18b7848db8 100644
--- a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
+++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
@@ -11,8 +11,10 @@ Table of contents:
* [Command-line tools](#tools)
* [Converting models prior to TensorFlow 1.9.](#pre-tensorflow-1.9)
-* [Convert a TensorFlow GraphDef](#graphdef)
-* [Convert a TensorFlow SavedModel](#savedmodel)
+* [Basic examples](#basic)
+ * [Convert a TensorFlow GraphDef](#graphdef)
+ * [Convert a TensorFlow SavedModel](#savedmodel)
+ * [Convert a tf.keras model](#keras)
* [Quantization](#quantization)
* [Convert a TensorFlow GraphDef for quantized inference](#graphdef-quant)
* [Use "dummy-quantization" to try out quantized inference on a float
@@ -51,7 +53,12 @@ API](python_api.md#pre-tensorflow-1.9). If a command line tool is desired, the
Terminal for additional details on the command-line flags available. There were
no command line tools in TensorFlow 1.8.
-## Convert a TensorFlow GraphDef <a name="graphdef"></a>
+## Basic examples <a name="basic"></a>
+
+The following section shows examples of how to convert a basic float-point model
+from each of the supported data formats into a TensorFlow Lite FlatBuffers.
+
+### Convert a TensorFlow GraphDef <a name="graphdef"></a>
The follow example converts a basic TensorFlow GraphDef (frozen by
[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py))
@@ -70,7 +77,7 @@ tflite_convert \
The value for `input_shapes` is automatically determined whenever possible.
-## Convert a TensorFlow SavedModel <a name="savedmodel"></a>
+### Convert a TensorFlow SavedModel <a name="savedmodel"></a>
The follow example converts a basic TensorFlow SavedModel into a Tensorflow Lite
FlatBuffer to perform floating-point inference.
@@ -95,6 +102,17 @@ There is currently no support for MetaGraphDefs without a SignatureDef or for
MetaGraphDefs that use the [`assets/`
directory](https://www.tensorflow.org/guide/saved_model#structure_of_a_savedmodel_directory).
+### Convert a tf.Keras model <a name="keras"></a>
+
+The following example converts a `tf.keras` model into a TensorFlow Lite
+Flatbuffer. The `tf.keras` file must contain both the model and the weights.
+
+```
+tflite_convert \
+ --output_file=/tmp/foo.tflite \
+ --keras_model_file=/tmp/keras_model.h5
+```
+
## Quantization
### Convert a TensorFlow GraphDef for quantized inference <a name="graphdef-quant"></a>
diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
index 2d44b871c6..decc8a45a4 100644
--- a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
+++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
@@ -19,7 +19,7 @@ Table of contents:
The following high level flags specify the details of the input and output
files. The flag `--output_file` is always required. Additionally, either
-`--graph_def_file` or `--saved_model_dir` is required.
+`--graph_def_file`, `--saved_model_dir` or `--keras_model_file` is required.
* `--output_file`. Type: string. Specifies the full path of the output file.
* `--graph_def_file`. Type: string. Specifies the full path of the input
@@ -27,6 +27,8 @@ files. The flag `--output_file` is always required. Additionally, either
[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py).
* `--saved_model_dir`. Type: string. Specifies the full path to the directory
containing the SavedModel.
+* `--keras_model_file`. Type: string. Specifies the full path of the HDF5 file
+ containing the tf.keras model.
* `--output_format`. Type: string. Default: `TFLITE`. Specifies the format of
the output file. Allowed values:
* `TFLITE`: TensorFlow Lite FlatBuffer format.
diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md
index b04d166f89..3799eac0a1 100644
--- a/tensorflow/contrib/lite/toco/g3doc/python_api.md
+++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md
@@ -41,9 +41,11 @@ is `tf.contrib.lite.TocoConverter`. The API for calling the Python intepreter is
`TocoConverter` provides class methods based on the original format of the
model. `TocoConverter.from_session()` is available for GraphDefs.
-`TocoConverter.from_saved_model()` is available for SavedModels. Example usages
-for simple float-point models are shown in [Basic Examples](#basic). Examples
-usages for more complex models is shown in [Complex Examples](#complex).
+`TocoConverter.from_saved_model()` is available for SavedModels.
+`TocoConverter.from_keras_model_file()` is available for `tf.Keras` files.
+Example usages for simple float-point models are shown in [Basic
+Examples](#basic). Examples usages for more complex models is shown in [Complex
+Examples](#complex).
**NOTE**: Currently, `TocoConverter` will cause a fatal error to the Python
interpreter when the conversion fails. This will be remedied as soon as
@@ -117,7 +119,7 @@ available by running `help(tf.contrib.lite.TocoConverter)`.
### Exporting a tf.keras File <a name="basic-keras-file"></a>
-The following example shows how to convert a tf.keras model into a TensorFlow
+The following example shows how to convert a `tf.keras` model into a TensorFlow
Lite FlatBuffer.
```python
@@ -128,7 +130,7 @@ tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
```
-The tf.keras file must contain both the model and the weights. A comprehensive
+The `tf.keras` file must contain both the model and the weights. A comprehensive
example including model construction can be seen below.
```python
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
index 38699a62b5..58885b4950 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -59,7 +59,8 @@ bool SupportsQuantization(const Operator& op) {
type == OperatorType::kGreater ||
type == OperatorType::kGreaterEqual || type == OperatorType::kLess ||
type == OperatorType::kLessEqual || type == OperatorType::kSelect ||
- type == OperatorType::kArgMax;
+ type == OperatorType::kArgMax || type == OperatorType::kRelu ||
+ type == OperatorType::kRelu1 || type == OperatorType::kRelu6;
}
const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) {
@@ -325,12 +326,13 @@ bool ChooseQuantizationForOperatorOutput(
output, OperatorTypeName(op.type));
return true;
}
- if ((op.type == OperatorType::kDepthToSpace) ||
- (op.type == OperatorType::kSpaceToDepth) ||
- (op.type == OperatorType::kReshape) ||
- (op.type == OperatorType::kSplit) ||
- (op.type == OperatorType::kConcatenation &&
- model->flags.change_concat_input_ranges())) {
+ if ((op.type == OperatorType::kConcatenation &&
+ model->flags.change_concat_input_ranges()) ||
+ op.type == OperatorType::kDepthToSpace ||
+ op.type == OperatorType::kSpaceToDepth ||
+ op.type == OperatorType::kReshape || op.type == OperatorType::kSplit ||
+ op.type == OperatorType::kRelu || op.type == OperatorType::kRelu1 ||
+ op.type == OperatorType::kRelu6) {
int data_input_index = 0;
if (op.type == OperatorType::kSplit) {
data_input_index = 1;
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index 89db9ee279..6e7423f85e 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -92,6 +92,7 @@ tensorflow/core/kernels/reduction_ops_common.cc
tensorflow/core/kernels/reduction_ops_any.cc
tensorflow/core/kernels/reduction_ops_all.cc
tensorflow/core/kernels/roll_op.cc
+tensorflow/core/kernels/queue_op.cc
tensorflow/core/kernels/queue_ops.cc
tensorflow/core/kernels/queue_base.cc
tensorflow/core/kernels/pooling_ops_common.cc
diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py
index 157ed6a278..3e63e99030 100644
--- a/tensorflow/contrib/opt/__init__.py
+++ b/tensorflow/contrib/opt/__init__.py
@@ -22,17 +22,18 @@ from __future__ import print_function
from tensorflow.contrib.opt.python.training.adamax import *
from tensorflow.contrib.opt.python.training.addsign import *
from tensorflow.contrib.opt.python.training.drop_stale_gradient_optimizer import *
+from tensorflow.contrib.opt.python.training.elastic_average_optimizer import *
from tensorflow.contrib.opt.python.training.external_optimizer import *
+from tensorflow.contrib.opt.python.training.ggt import *
from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import *
+from tensorflow.contrib.opt.python.training.model_average_optimizer import *
from tensorflow.contrib.opt.python.training.moving_average_optimizer import *
from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import *
from tensorflow.contrib.opt.python.training.nadam_optimizer import *
from tensorflow.contrib.opt.python.training.weight_decay_optimizers import *
from tensorflow.contrib.opt.python.training.powersign import *
from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import *
-from tensorflow.contrib.opt.python.training.elastic_average_optimizer import *
-from tensorflow.contrib.opt.python.training.model_average_optimizer import *
-from tensorflow.contrib.opt.python.training.ggt import *
+from tensorflow.contrib.opt.python.training.weight_decay_optimizers import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py
index 8aa40aeb45..b9cf40eb7b 100644
--- a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py
+++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py
@@ -19,13 +19,13 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
-from tensorflow.python.training import optimizer
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.training import adam
from tensorflow.python.training import momentum as momentum_opt
+from tensorflow.python.training import optimizer
from tensorflow.python.util.tf_export import tf_export
-from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import resource_variable_ops
class DecoupledWeightDecayExtension(object):
@@ -65,7 +65,7 @@ class DecoupledWeightDecayExtension(object):
Args:
weight_decay: A `Tensor` or a floating point value, the factor by which
a variable is decayed in the update step.
- decay_var_list: Optional list or tuple or set of `Variable` objects to
+ **kwargs: Optional list or tuple or set of `Variable` objects to
decay.
"""
self._decay_var_list = None # is set in minimize or apply_gradients
@@ -85,6 +85,28 @@ class DecoupledWeightDecayExtension(object):
If decay_var_list is None, all variables in var_list are decayed.
For more information see the documentation of Optimizer.minimize.
+
+ Args:
+ loss: A `Tensor` containing the value to minimize.
+ global_step: Optional `Variable` to increment by one after the
+ variables have been updated.
+ var_list: Optional list or tuple of `Variable` objects to update to
+ minimize `loss`. Defaults to the list of variables collected in
+ the graph under the key `GraphKeys.TRAINABLE_VARIABLES`.
+ gate_gradients: How to gate the computation of gradients. Can be
+ `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
+ aggregation_method: Specifies the method used to combine gradient terms.
+ Valid values are defined in the class `AggregationMethod`.
+ colocate_gradients_with_ops: If True, try colocating gradients with
+ the corresponding op.
+ name: Optional name for the returned operation.
+ grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
+ decay_var_list: Optional list of decay variables.
+
+ Returns:
+ An Operation that updates the variables in `var_list`. If `global_step`
+ was not `None`, that operation also increments `global_step`.
+
"""
self._decay_var_list = set(decay_var_list) if decay_var_list else False
return super(DecoupledWeightDecayExtension, self).minimize(
@@ -103,6 +125,19 @@ class DecoupledWeightDecayExtension(object):
are decayed.
For more information see the documentation of Optimizer.apply_gradients.
+
+ Args:
+ grads_and_vars: List of (gradient, variable) pairs as returned by
+ `compute_gradients()`.
+ global_step: Optional `Variable` to increment by one after the
+ variables have been updated.
+ name: Optional name for the returned operation. Default to the
+ name passed to the `Optimizer` constructor.
+ decay_var_list: Optional list of decay variables.
+
+ Returns:
+ An `Operation` that applies the specified gradients. If `global_step`
+ was not None, that operation also increments `global_step`.
"""
self._decay_var_list = set(decay_var_list) if decay_var_list else False
return super(DecoupledWeightDecayExtension, self).apply_gradients(
@@ -197,6 +232,7 @@ def extend_with_decoupled_weight_decay(base_optimizer):
A new optimizer class that inherits from DecoupledWeightDecayExtension
and base_optimizer.
"""
+
class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension,
base_optimizer):
"""Base_optimizer with decoupled weight decay.
diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py
index 74d1cdbbda..76d8a5697a 100644
--- a/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py
+++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.contrib.opt.python.training import weight_decay_optimizers
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -29,7 +30,6 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import adam
-from tensorflow.contrib.opt.python.training import weight_decay_optimizers
WEIGHT_DECAY = 0.01
@@ -91,7 +91,6 @@ class WeightDecayOptimizerTest(test.TestCase):
opt = optimizer()
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
-
if not context.executing_eagerly():
with ops.Graph().as_default():
# Shouldn't return non-slot variables from other graphs.
@@ -171,9 +170,9 @@ class ExtendWithWeightDecayTest(WeightDecayOptimizerTest):
@staticmethod
def get_optimizer():
- AdamW = weight_decay_optimizers.extend_with_decoupled_weight_decay(
+ adamw = weight_decay_optimizers.extend_with_decoupled_weight_decay(
adam.AdamOptimizer)
- return AdamW(WEIGHT_DECAY)
+ return adamw(WEIGHT_DECAY)
def testBasic(self):
self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m",
@@ -185,6 +184,5 @@ class ExtendWithWeightDecayTest(WeightDecayOptimizerTest):
use_resource=True)
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py
index e69725ff8a..f58268eff5 100644
--- a/tensorflow/contrib/seq2seq/python/ops/decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import abc
import six
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -182,19 +183,20 @@ def dynamic_decode(decoder,
raise TypeError("Expected decoder to be type Decoder, but saw: %s" %
type(decoder))
- def _is_xla_tensor(tensor):
- try:
- op = tensor.op
- except AttributeError:
- return False
- if control_flow_util.IsInXLAContext(op):
- return True
- return False
-
with variable_scope.variable_scope(scope, "decoder") as varscope:
- # Properly cache variable values inside the while_loop
- if varscope.caching_device is None:
- varscope.set_caching_device(lambda op: op.device)
+ # Determine context types.
+ ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
+ is_xla = control_flow_util.GetContainingXLAContext(ctxt) is not None
+ in_while_loop = (
+ control_flow_util.GetContainingWhileContext(ctxt) is not None)
+ # Properly cache variable values inside the while_loop.
+ # Don't set a caching device when running in a loop, since it is possible
+ # that train steps could be wrapped in a tf.while_loop. In that scenario
+ # caching prevents forward computations in loop iterations from re-reading
+ # the updated weights.
+ if not context.executing_eagerly() and not in_while_loop:
+ if varscope.caching_device is None:
+ varscope.set_caching_device(lambda op: op.device)
if maximum_iterations is not None:
maximum_iterations = ops.convert_to_tensor(
@@ -208,9 +210,6 @@ def dynamic_decode(decoder,
decoder.output_dtype,
decoder.batch_size)
- is_xla = False
- if any([_is_xla_tensor(i) for i in nest.flatten(initial_inputs)]):
- is_xla = True
if is_xla and maximum_iterations is None:
raise ValueError("maximum_iterations is required for XLA compilation.")
if maximum_iterations is not None:
diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py
index 3d0308aaf3..2c97834523 100644
--- a/tensorflow/contrib/slim/python/slim/evaluation_test.py
+++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py
@@ -33,7 +33,6 @@ from tensorflow.python.debug.lib import debug_data
from tensorflow.python.debug.wrappers import hooks
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics
@@ -242,7 +241,7 @@ class SingleEvaluationTest(test.TestCase):
checkpoint_path = os.path.join(self.get_temp_dir(),
'this_file_doesnt_exist')
log_dir = os.path.join(self.get_temp_dir(), 'error_raised')
- with self.assertRaises(errors.NotFoundError):
+ with self.assertRaises(ValueError):
evaluation.evaluate_once('', checkpoint_path, log_dir)
def _prepareCheckpoint(self, checkpoint_path):
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index 13986127ba..4dc1c551cc 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -142,7 +142,7 @@ tensorflow::Status ConvertCalibGraphToInferGraph(
auto n = infer_graph->mutable_node(i);
if (n->op() == "TRTEngineOp") {
VLOG(1) << "Processing " << n->name();
- const string& container_name = n->attr().at("segment_funcdef_name").s();
+ string container_name = n->attr().at("segment_funcdef_name").s();
TRTCalibrationResource* cres = nullptr;
auto status = calib_rm->Lookup(container_name, "Calibrator", &cres);
if (!status.ok()) {
@@ -168,7 +168,6 @@ tensorflow::Status ConvertCalibGraphToInferGraph(
"Can't get TRTCalibrator from resource manager!");
}
cres->Unref();
- calib_rm->Cleanup(container_name);
}
}
return tensorflow::Status::OK();
@@ -823,8 +822,8 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) {
} else {
// Graph is not modified.
LOG(WARNING) << "Engine creation for segment " << i << ", composed of "
- << converted_segments.at(i).first.size() << " nodes failed: "
- << status << ". Skipping...";
+ << converted_segments.at(i).first.size()
+ << " nodes failed: " << status << ". Skipping...";
}
}
cudaSetDevice(old_cuda_device);
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
index 7684d8d4a2..1a4c0e755d 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
@@ -46,8 +46,8 @@ const int INT8MODE = 2;
struct EngineConnection {
EngineConnection(const string& outside, int out_id, int out_port,
- const string& inside, int in_id, int in_port,
- bool input_edge, int port)
+ const string& inside, int in_id, int in_port,
+ bool input_edge, int port)
: outside_node_name(outside),
outside_id(out_id),
outside_port(out_port),
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index 75e32559bb..8a17eb02f1 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -319,7 +319,7 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
default:
LOG(ERROR) << "Unknown TRT data type: " << int(dtype);
ctx->SetStatus(tensorflow::errors::InvalidArgument(
- "Unknown ouput TRT data type! ", static_cast<int>(dtype)));
+ "Unknown output TRT data type! ", static_cast<int>(dtype)));
return;
}
}
@@ -327,8 +327,8 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
for (int i = 0; i < ctx->num_outputs(); i++) {
// Create an output tensor
const string output_name = StrCat(kOutputPHName, i);
- const size_t binding_index = trt_engine_ptr->getBindingIndex(
- output_name.c_str());
+ const size_t binding_index =
+ trt_engine_ptr->getBindingIndex(output_name.c_str());
Tensor* output_tensor = nullptr;
TensorShape output_shape;
@@ -371,7 +371,7 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
default:
LOG(ERROR) << "Unknown TRT data type: " << static_cast<int>(dtype);
ctx->SetStatus(tensorflow::errors::InvalidArgument(
- "Unsupported output data type! ", int(dtype)));
+ "Unsupported output data type! ", static_cast<int>(dtype)));
return;
}
}
@@ -420,10 +420,10 @@ nvinfer1::IGpuAllocator* TRTEngineOp::GetAllocator(OpKernelContext* ctx) {
}
TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size,
- OpKernelContext* ctx) {
+ OpKernelContext* ctx) {
static EngineCtxPair null_pair = {
- TrtUniquePtrType<nvinfer1::ICudaEngine>(nullptr),
- TrtUniquePtrType<nvinfer1::IExecutionContext>(nullptr)};
+ TrtUniquePtrType<nvinfer1::ICudaEngine>(nullptr),
+ TrtUniquePtrType<nvinfer1::IExecutionContext>(nullptr)};
// TODO(sami): This method needs to be re-written to use resource manager and
// with LRU mechanism option.
tensorflow::mutex_lock lock(engine_mutex_);
@@ -450,9 +450,9 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size,
auto raw_static_engine = static_engine.get();
const auto max_batch_size = raw_static_engine->getMaxBatchSize();
engine_map_[max_batch_size] = {
- std::move(static_engine),
- TrtUniquePtrType<nvinfer1::IExecutionContext>(
- raw_static_engine->createExecutionContext())};
+ std::move(static_engine),
+ TrtUniquePtrType<nvinfer1::IExecutionContext>(
+ raw_static_engine->createExecutionContext())};
// Runtime is safe to delete after engine creation
serialized_segment_.clear();
if (max_batch_size < batch_size) return null_pair;
diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i
index d6628cd1eb..d51a0b59e2 100644
--- a/tensorflow/contrib/tensorrt/trt_conversion.i
+++ b/tensorflow/contrib/tensorrt/trt_conversion.i
@@ -221,26 +221,22 @@ std::pair<string, string> calib_convert(
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
}
-version_struct get_linked_tensorrt_version() {
+version_struct get_linked_tensorrt_version(){
// Return the version at the link time.
- version_struct s;
-#if GOOGLE_CUDA && GOOGLE_TENSORRT
const auto &lv = tensorflow::tensorrt::convert::GetLinkedTensorRTVersion();
+ version_struct s;
s.vmajor = lv[0];
s.vminor = lv[1];
s.vpatch = lv[2];
-#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
return s;
}
version_struct get_loaded_tensorrt_version(){
// Return the version from the loaded library.
- version_struct s;
-#if GOOGLE_CUDA && GOOGLE_TENSORRT
const auto &lv = tensorflow::tensorrt::convert::GetLoadedTensorRTVersion();
+ version_struct s;
s.vmajor = lv[0];
s.vminor = lv[1];
s.vpatch = lv[2];
-#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
return s;
}
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 5210139336..14e025973e 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -81,12 +81,17 @@ _TPU_ESTIMATOR = 'tpu_estimator'
_ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop'
_BATCH_SIZE_KEY = 'batch_size'
_CTX_KEY = 'context'
+_USE_TPU_KEY = 'use_tpu'
_CROSS_REPLICA_SUM_OP = 'CrossReplicaSum'
_ONE_GIGABYTE = 1024 * 1024 * 1024
_TPU_ENQUEUE_OPS = '_tpu_enqueue_ops'
_TPU_TRAIN_OP = '_tpu_train_op'
_REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference'
+# Ideally _USE_TPU_KEY should be reserved as well. However there are already
+# models that make use of this key, thus it can not be reserved now to prevent
+# breakage. In the long run, we would like to mitigate this by migrating models
+# off of using _USE_TPU_KEY.
_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY]
@@ -1414,8 +1419,11 @@ class _ModelFnWrapper(object):
if batch_size_for_model_fn is not None:
_add_item_to_params(params, _BATCH_SIZE_KEY, batch_size_for_model_fn)
+ running_on_cpu = self._ctx.is_running_on_cpu(is_export_mode)
+ _add_item_to_params(params, _USE_TPU_KEY, not running_on_cpu)
+
estimator_spec = self._model_fn(features=features, **kwargs)
- if (self._ctx.is_running_on_cpu(is_export_mode) and
+ if (running_on_cpu and
isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)): # pylint: disable=protected-access
# The estimator_spec will be passed to `Estimator` directly, which expects
# type `EstimatorSpec`.
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index c1efc9c0c6..0e6bc03c0b 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1923,6 +1923,7 @@ tf_proto_library_cc(
srcs = ["protobuf/master_service.proto"],
has_services = 1,
cc_api_version = 2,
+ cc_grpc_version = 1,
cc_stubby_versions = ["2"],
protodeps = [":master_proto"],
visibility = [
diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc
index 477a0b670e..6149e5fca8 100644
--- a/tensorflow/core/api_def/api_test.cc
+++ b/tensorflow/core/api_def/api_test.cc
@@ -171,7 +171,7 @@ TEST_F(BaseApiTest, AllOpsAreInApiDef) {
if (excluded_ops->find(op.name()) != excluded_ops->end()) {
continue;
}
- ASSERT_TRUE(api_defs_map_.find(op.name()) != api_defs_map_.end())
+ EXPECT_TRUE(api_defs_map_.find(op.name()) != api_defs_map_.end())
<< op.name() << " op does not have api_def_*.pbtxt file. "
<< "Please add api_def_" << op.name() << ".pbtxt file "
<< "under tensorflow/core/api_def/base_api/ directory.";
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 87ba609dd7..f903faf1bd 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -1626,15 +1626,6 @@ Status DirectSession::MakeCallable(const CallableOptions& callable_options,
TF_RETURN_IF_ERROR(CheckNotClosed());
TF_RETURN_IF_ERROR(CheckGraphCreated("MakeCallable()"));
- if (!callable_options.run_options()
- .debug_options()
- .debug_tensor_watch_opts()
- .empty()) {
- return errors::Unimplemented(
- "Debug options are not currently supported via the C++ MakeCallable "
- "interface.");
- }
-
std::unique_ptr<ExecutorsAndKeys> ek;
std::unique_ptr<FunctionInfo> func_info;
RunStateArgs run_state_args(callable_options.run_options().debug_options());
diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
index b4bf1c408f..0b096a14a3 100644
--- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
@@ -106,24 +106,24 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
EXPECT_EQ(1, shape.dim(1).size());
if (node->name() == y->name()) {
#ifdef INTEL_MKL
- // if MKL is used, it goes through various additional
- // graph rewrite pass. In TF, everytime a graph pass
+ // if MKL is used, it goes through various additional
+ // graph rewrite pass. In TF, everytime a graph pass
// happens, "constant" nodes are allocated
// and deallocated. Each allocation calls the
// (FindChunkPtr of BFCAllocator),
- // which increments the value of AllocationId.
- // Thus AllocationId becomes more than TF if MKL
- // is used. Now IDs for MKL are 8 more than TF.
+ // which increments the value of AllocationId.
+ // Thus AllocationId becomes more than TF if MKL
+ // is used. Now IDs for MKL are 8 more than TF.
EXPECT_EQ(29, cm->AllocationId(node, 0));
#else
EXPECT_EQ(21, cm->AllocationId(node, 0));
-#endif
+#endif
} else {
#ifdef INTEL_MKL
EXPECT_EQ(30, cm->AllocationId(node, 0));
#else
EXPECT_EQ(22, cm->AllocationId(node, 0));
-#endif
+#endif
}
}
EXPECT_LE(0, cm->MaxExecutionTime(node));
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index 0abef01a9a..75f8a19e9c 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -636,12 +636,12 @@ tf_cuda_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:master_proto_cc",
+ "//tensorflow/core:master_service_proto_cc",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
- "//tensorflow/core/distributed_runtime/rpc:grpc_master_service_impl",
"//tensorflow/core/distributed_runtime/rpc:grpc_testlib",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
diff --git a/tensorflow/core/distributed_runtime/master_test.cc b/tensorflow/core/distributed_runtime/master_test.cc
index 62b18a45b1..09e96cbd40 100644
--- a/tensorflow/core/distributed_runtime/master_test.cc
+++ b/tensorflow/core/distributed_runtime/master_test.cc
@@ -21,7 +21,6 @@ limitations under the License.
#include "grpcpp/grpcpp.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/framework/allocator.h"
@@ -38,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/master.pb.h"
+#include "tensorflow/core/protobuf/master_service.grpc.pb.h"
namespace tensorflow {
diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD
index 4a10d99a60..d6c493c022 100644
--- a/tensorflow/core/distributed_runtime/rpc/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/BUILD
@@ -201,11 +201,11 @@ cc_library(
srcs = ["grpc_remote_master.cc"],
hdrs = ["grpc_remote_master.h"],
deps = [
- ":grpc_master_service_impl",
":grpc_util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:master_proto_cc",
+ "//tensorflow/core:master_service_proto_cc",
"//tensorflow/core/distributed_runtime:call_options",
"//tensorflow/core/distributed_runtime:master_interface",
],
@@ -219,28 +219,18 @@ cc_library(
deps = [
":async_service_interface",
":grpc_call",
- ":grpc_master_service_impl",
":grpc_util",
"//tensorflow:grpc++",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:master_proto_cc",
+ "//tensorflow/core:master_service_proto_cc",
"//tensorflow/core/distributed_runtime:master",
],
alwayslink = 1,
)
cc_library(
- name = "grpc_master_service_impl",
- srcs = ["grpc_master_service_impl.cc"],
- hdrs = ["grpc_master_service_impl.h"],
- deps = [
- "//tensorflow:grpc++",
- "//tensorflow/core:master_proto_cc",
- ],
-)
-
-cc_library(
name = "rpc_rendezvous_mgr",
srcs = ["rpc_rendezvous_mgr.cc"],
hdrs = ["rpc_rendezvous_mgr.h"],
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
index 127dea2882..2c2c1d484a 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
@@ -36,12 +36,12 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/master.h"
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/protobuf/master.pb.h"
+#include "tensorflow/core/protobuf/master_service.grpc.pb.h"
namespace tensorflow {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc
deleted file mode 100644
index 770a0fcf14..0000000000
--- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc
+++ /dev/null
@@ -1,164 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
-
-#include "grpcpp/impl/codegen/async_stream.h"
-#include "grpcpp/impl/codegen/async_unary_call.h"
-#include "grpcpp/impl/codegen/channel_interface.h"
-#include "grpcpp/impl/codegen/client_unary_call.h"
-#include "grpcpp/impl/codegen/method_handler_impl.h"
-#include "grpcpp/impl/codegen/rpc_service_method.h"
-#include "grpcpp/impl/codegen/service_type.h"
-#include "grpcpp/impl/codegen/sync_stream.h"
-
-namespace tensorflow {
-
-namespace grpc {
-
-static const char* grpcMasterService_method_names[] = {
- "/tensorflow.MasterService/CreateSession",
- "/tensorflow.MasterService/ExtendSession",
- "/tensorflow.MasterService/PartialRunSetup",
- "/tensorflow.MasterService/RunStep",
- "/tensorflow.MasterService/CloseSession",
- "/tensorflow.MasterService/ListDevices",
- "/tensorflow.MasterService/Reset",
- "/tensorflow.MasterService/MakeCallable",
- "/tensorflow.MasterService/RunCallable",
- "/tensorflow.MasterService/ReleaseCallable",
-};
-
-std::unique_ptr<MasterService::Stub> MasterService::NewStub(
- const std::shared_ptr< ::grpc::ChannelInterface>& channel,
- const ::grpc::StubOptions& options) {
- std::unique_ptr<MasterService::Stub> stub(new MasterService::Stub(channel));
- return stub;
-}
-
-MasterService::Stub::Stub(
- const std::shared_ptr< ::grpc::ChannelInterface>& channel)
- : channel_(channel),
- rpcmethod_CreateSession_(grpcMasterService_method_names[0],
- ::grpc::internal::RpcMethod::NORMAL_RPC,
- channel),
- rpcmethod_ExtendSession_(grpcMasterService_method_names[1],
- ::grpc::internal::RpcMethod::NORMAL_RPC,
- channel),
- rpcmethod_PartialRunSetup_(grpcMasterService_method_names[2],
- ::grpc::internal::RpcMethod::NORMAL_RPC,
- channel),
- rpcmethod_RunStep_(grpcMasterService_method_names[3],
- ::grpc::internal::RpcMethod::NORMAL_RPC, channel),
- rpcmethod_CloseSession_(grpcMasterService_method_names[4],
- ::grpc::internal::RpcMethod::NORMAL_RPC, channel),
- rpcmethod_ListDevices_(grpcMasterService_method_names[5],
- ::grpc::internal::RpcMethod::NORMAL_RPC, channel),
- rpcmethod_Reset_(grpcMasterService_method_names[6],
- ::grpc::internal::RpcMethod::NORMAL_RPC, channel),
- rpcmethod_MakeCallable_(grpcMasterService_method_names[7],
- ::grpc::internal::RpcMethod::NORMAL_RPC, channel),
- rpcmethod_RunCallable_(grpcMasterService_method_names[8],
- ::grpc::internal::RpcMethod::NORMAL_RPC, channel),
- rpcmethod_ReleaseCallable_(grpcMasterService_method_names[9],
- ::grpc::internal::RpcMethod::NORMAL_RPC,
- channel) {}
-
-::grpc::Status MasterService::Stub::CreateSession(
- ::grpc::ClientContext* context, const CreateSessionRequest& request,
- CreateSessionResponse* response) {
- return ::grpc::internal::BlockingUnaryCall(
- channel_.get(), rpcmethod_CreateSession_, context, request, response);
-}
-
-::grpc::Status MasterService::Stub::ExtendSession(
- ::grpc::ClientContext* context, const ExtendSessionRequest& request,
- ExtendSessionResponse* response) {
- return ::grpc::internal::BlockingUnaryCall(
- channel_.get(), rpcmethod_ExtendSession_, context, request, response);
-}
-
-::grpc::Status MasterService::Stub::PartialRunSetup(
- ::grpc::ClientContext* context, const PartialRunSetupRequest& request,
- PartialRunSetupResponse* response) {
- return ::grpc::internal::BlockingUnaryCall(
- channel_.get(), rpcmethod_PartialRunSetup_, context, request, response);
-}
-
-::grpc::Status MasterService::Stub::RunStep(::grpc::ClientContext* context,
- const RunStepRequest& request,
- RunStepResponse* response) {
- return ::grpc::internal::BlockingUnaryCall(channel_.get(), rpcmethod_RunStep_,
- context, request, response);
-}
-
-::grpc::Status MasterService::Stub::CloseSession(
- ::grpc::ClientContext* context, const CloseSessionRequest& request,
- CloseSessionResponse* response) {
- return ::grpc::internal::BlockingUnaryCall(
- channel_.get(), rpcmethod_CloseSession_, context, request, response);
-}
-
-::grpc::Status MasterService::Stub::ListDevices(
- ::grpc::ClientContext* context, const ListDevicesRequest& request,
- ListDevicesResponse* response) {
- return ::grpc::internal::BlockingUnaryCall(
- channel_.get(), rpcmethod_ListDevices_, context, request, response);
-}
-
-::grpc::Status MasterService::Stub::Reset(::grpc::ClientContext* context,
- const ResetRequest& request,
- ResetResponse* response) {
- return ::grpc::internal::BlockingUnaryCall(channel_.get(), rpcmethod_Reset_,
- context, request, response);
-}
-
-::grpc::Status MasterService::Stub::MakeCallable(
- ::grpc::ClientContext* context, const MakeCallableRequest& request,
- MakeCallableResponse* response) {
- return ::grpc::internal::BlockingUnaryCall(
- channel_.get(), rpcmethod_MakeCallable_, context, request, response);
-}
-
-::grpc::Status MasterService::Stub::RunCallable(
- ::grpc::ClientContext* context, const RunCallableRequest& request,
- RunCallableResponse* response) {
- return ::grpc::internal::BlockingUnaryCall(
- channel_.get(), rpcmethod_RunCallable_, context, request, response);
-}
-
-::grpc::Status MasterService::Stub::ReleaseCallable(
- ::grpc::ClientContext* context, const ReleaseCallableRequest& request,
- ReleaseCallableResponse* response) {
- return ::grpc::internal::BlockingUnaryCall(
- channel_.get(), rpcmethod_ReleaseCallable_, context, request, response);
-}
-
-MasterService::AsyncService::AsyncService() {
- int method_len = sizeof(grpcMasterService_method_names) /
- sizeof(grpcMasterService_method_names[0]);
- for (int i = 0; i < method_len; ++i) {
- AddMethod(new ::grpc::internal::RpcServiceMethod(
- grpcMasterService_method_names[i],
- ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
- ::grpc::Service::MarkMethodAsync(i);
- }
-}
-
-MasterService::AsyncService::~AsyncService() {}
-
-} // namespace grpc
-
-} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h
deleted file mode 100644
index 751f2633e7..0000000000
--- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h
+++ /dev/null
@@ -1,224 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_
-#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_
-
-#include "grpcpp/impl/codegen/async_stream.h"
-#include "grpcpp/impl/codegen/async_unary_call.h"
-#include "grpcpp/impl/codegen/proto_utils.h"
-#include "grpcpp/impl/codegen/rpc_method.h"
-#include "grpcpp/impl/codegen/service_type.h"
-#include "grpcpp/impl/codegen/status.h"
-#include "grpcpp/impl/codegen/stub_options.h"
-#include "grpcpp/impl/codegen/sync_stream.h"
-
-#include "tensorflow/core/protobuf/master.pb.h"
-
-namespace grpc {
-class CompletionQueue;
-class Channel;
-class RpcService;
-class ServerCompletionQueue;
-class ServerContext;
-} // namespace grpc
-
-namespace tensorflow {
-
-namespace grpc {
-
-// Implementation of `tensorflow.MasterService`, based on the
-// definition in "//tensorflow/core/protobuf/master_service.proto",
-// and the gRPC generated stub and service classes.
-// See that file for the definition of methods and messages.
-class MasterService final {
- public:
- class StubInterface {
- public:
- virtual ~StubInterface() {}
- virtual ::grpc::Status CreateSession(::grpc::ClientContext* context,
- const CreateSessionRequest& request,
- CreateSessionResponse* response) = 0;
- virtual ::grpc::Status ExtendSession(::grpc::ClientContext* context,
- const ExtendSessionRequest& request,
- ExtendSessionResponse* response) = 0;
- virtual ::grpc::Status PartialRunSetup(
- ::grpc::ClientContext* context, const PartialRunSetupRequest& request,
- PartialRunSetupResponse* response) = 0;
- virtual ::grpc::Status RunStep(::grpc::ClientContext* context,
- const RunStepRequest& request,
- RunStepResponse* response) = 0;
- virtual ::grpc::Status CloseSession(::grpc::ClientContext* context,
- const CloseSessionRequest& request,
- CloseSessionResponse* response) = 0;
- virtual ::grpc::Status ListDevices(::grpc::ClientContext* context,
- const ListDevicesRequest& request,
- ListDevicesResponse* response) = 0;
- virtual ::grpc::Status Reset(::grpc::ClientContext* context,
- const ResetRequest& request,
- ResetResponse* response) = 0;
- virtual ::grpc::Status MakeCallable(::grpc::ClientContext* context,
- const MakeCallableRequest& request,
- MakeCallableResponse* response) = 0;
- virtual ::grpc::Status RunCallable(::grpc::ClientContext* context,
- const RunCallableRequest& request,
- RunCallableResponse* response) = 0;
- virtual ::grpc::Status ReleaseCallable(
- ::grpc::ClientContext* context, const ReleaseCallableRequest& request,
- ReleaseCallableResponse* response) = 0;
- };
- class Stub final : public StubInterface {
- public:
- Stub(const std::shared_ptr< ::grpc::ChannelInterface>& channel);
- ::grpc::Status CreateSession(::grpc::ClientContext* context,
- const CreateSessionRequest& request,
- CreateSessionResponse* response) override;
- ::grpc::Status ExtendSession(::grpc::ClientContext* context,
- const ExtendSessionRequest& request,
- ExtendSessionResponse* response) override;
- ::grpc::Status PartialRunSetup(::grpc::ClientContext* context,
- const PartialRunSetupRequest& request,
- PartialRunSetupResponse* response) override;
- ::grpc::Status RunStep(::grpc::ClientContext* context,
- const RunStepRequest& request,
- RunStepResponse* response) override;
- ::grpc::Status CloseSession(::grpc::ClientContext* context,
- const CloseSessionRequest& request,
- CloseSessionResponse* response) override;
- ::grpc::Status ListDevices(::grpc::ClientContext* context,
- const ListDevicesRequest& request,
- ListDevicesResponse* response) override;
- ::grpc::Status Reset(::grpc::ClientContext* context,
- const ResetRequest& request,
- ResetResponse* response) override;
- ::grpc::Status MakeCallable(::grpc::ClientContext* context,
- const MakeCallableRequest& request,
- MakeCallableResponse* response) override;
- ::grpc::Status RunCallable(::grpc::ClientContext* context,
- const RunCallableRequest& request,
- RunCallableResponse* response) override;
- ::grpc::Status ReleaseCallable(::grpc::ClientContext* context,
- const ReleaseCallableRequest& request,
- ReleaseCallableResponse* response) override;
-
- private:
- std::shared_ptr< ::grpc::ChannelInterface> channel_;
- const ::grpc::internal::RpcMethod rpcmethod_CreateSession_;
- const ::grpc::internal::RpcMethod rpcmethod_ExtendSession_;
- const ::grpc::internal::RpcMethod rpcmethod_PartialRunSetup_;
- const ::grpc::internal::RpcMethod rpcmethod_RunStep_;
- const ::grpc::internal::RpcMethod rpcmethod_CloseSession_;
- const ::grpc::internal::RpcMethod rpcmethod_ListDevices_;
- const ::grpc::internal::RpcMethod rpcmethod_Reset_;
- const ::grpc::internal::RpcMethod rpcmethod_MakeCallable_;
- const ::grpc::internal::RpcMethod rpcmethod_RunCallable_;
- const ::grpc::internal::RpcMethod rpcmethod_ReleaseCallable_;
- };
- static std::unique_ptr<Stub> NewStub(
- const std::shared_ptr< ::grpc::ChannelInterface>& channel,
- const ::grpc::StubOptions& options = ::grpc::StubOptions());
-
- class AsyncService : public ::grpc::Service {
- public:
- AsyncService();
- virtual ~AsyncService();
- void RequestCreateSession(
- ::grpc::ServerContext* context, CreateSessionRequest* request,
- ::grpc::ServerAsyncResponseWriter<CreateSessionResponse>* response,
- ::grpc::CompletionQueue* new_call_cq,
- ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
- ::grpc::Service::RequestAsyncUnary(0, context, request, response,
- new_call_cq, notification_cq, tag);
- }
- void RequestExtendSession(
- ::grpc::ServerContext* context, ExtendSessionRequest* request,
- ::grpc::ServerAsyncResponseWriter<ExtendSessionResponse>* response,
- ::grpc::CompletionQueue* new_call_cq,
- ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
- ::grpc::Service::RequestAsyncUnary(1, context, request, response,
- new_call_cq, notification_cq, tag);
- }
- void RequestPartialRunSetup(
- ::grpc::ServerContext* context, PartialRunSetupRequest* request,
- ::grpc::ServerAsyncResponseWriter<PartialRunSetupResponse>* response,
- ::grpc::CompletionQueue* new_call_cq,
- ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
- ::grpc::Service::RequestAsyncUnary(2, context, request, response,
- new_call_cq, notification_cq, tag);
- }
- void RequestRunStep(
- ::grpc::ServerContext* context, RunStepRequest* request,
- ::grpc::ServerAsyncResponseWriter<RunStepResponse>* response,
- ::grpc::CompletionQueue* new_call_cq,
- ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
- ::grpc::Service::RequestAsyncUnary(3, context, request, response,
- new_call_cq, notification_cq, tag);
- }
- void RequestCloseSession(
- ::grpc::ServerContext* context, CloseSessionRequest* request,
- ::grpc::ServerAsyncResponseWriter<CloseSessionResponse>* response,
- ::grpc::CompletionQueue* new_call_cq,
- ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
- ::grpc::Service::RequestAsyncUnary(4, context, request, response,
- new_call_cq, notification_cq, tag);
- }
- void RequestListDevices(
- ::grpc::ServerContext* context, ListDevicesRequest* request,
- ::grpc::ServerAsyncResponseWriter<ListDevicesResponse>* response,
- ::grpc::CompletionQueue* new_call_cq,
- ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
- ::grpc::Service::RequestAsyncUnary(5, context, request, response,
- new_call_cq, notification_cq, tag);
- }
- void RequestReset(
- ::grpc::ServerContext* context, ResetRequest* request,
- ::grpc::ServerAsyncResponseWriter<ResetResponse>* response,
- ::grpc::CompletionQueue* new_call_cq,
- ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
- ::grpc::Service::RequestAsyncUnary(6, context, request, response,
- new_call_cq, notification_cq, tag);
- }
- void RequestMakeCallable(
- ::grpc::ServerContext* context, MakeCallableRequest* request,
- ::grpc::ServerAsyncResponseWriter<MakeCallableResponse>* response,
- ::grpc::CompletionQueue* new_call_cq,
- ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
- ::grpc::Service::RequestAsyncUnary(7, context, request, response,
- new_call_cq, notification_cq, tag);
- }
- void RequestRunCallable(
- ::grpc::ServerContext* context, RunCallableRequest* request,
- ::grpc::ServerAsyncResponseWriter<RunCallableResponse>* response,
- ::grpc::CompletionQueue* new_call_cq,
- ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
- ::grpc::Service::RequestAsyncUnary(8, context, request, response,
- new_call_cq, notification_cq, tag);
- }
- void RequestReleaseCallable(
- ::grpc::ServerContext* context, ReleaseCallableRequest* request,
- ::grpc::ServerAsyncResponseWriter<ReleaseCallableResponse>* response,
- ::grpc::CompletionQueue* new_call_cq,
- ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
- ::grpc::Service::RequestAsyncUnary(9, context, request, response,
- new_call_cq, notification_cq, tag);
- }
- };
-};
-
-} // namespace grpc
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc
index b832a2115c..6c2940553c 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc
@@ -19,13 +19,13 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/master_interface.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/protobuf/master.pb.h"
+#include "tensorflow/core/protobuf/master_service.grpc.pb.h"
namespace tensorflow {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
index ff64d78b79..2c833d11a9 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
@@ -289,12 +289,10 @@ Status GrpcServer::Init(
nullptr);
}
-
Status GrpcServer::Init(
ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func) {
- return Init(std::move(service_func), rendezvous_mgr_func, nullptr,
- nullptr);
+ return Init(std::move(service_func), rendezvous_mgr_func, nullptr, nullptr);
}
Status GrpcServer::Init() { return Init(nullptr, nullptr, nullptr, nullptr); }
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
index 115148b84e..b01cfb6426 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
@@ -100,6 +100,9 @@ class GrpcServer : public ServerInterface {
Status Init(ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func);
+ Status Init(ServiceInitFunction service_func,
+ const RendezvousMgrCreationFunction& rendezvous_mgr_func);
+
Status Init();
// A subclass can override this method to support secure credentials.
diff --git a/tensorflow/core/framework/resource_op_kernel.h b/tensorflow/core/framework/resource_op_kernel.h
index 813ec6eed5..0a8da8b3bf 100644
--- a/tensorflow/core/framework/resource_op_kernel.h
+++ b/tensorflow/core/framework/resource_op_kernel.h
@@ -43,9 +43,15 @@ template <typename T>
class ResourceOpKernel : public OpKernel {
public:
explicit ResourceOpKernel(OpKernelConstruction* context) : OpKernel(context) {
- OP_REQUIRES_OK(context,
- context->allocate_persistent(DT_STRING, TensorShape({2}),
- &handle_, nullptr));
+ has_resource_type_ = (context->output_type(0) == DT_RESOURCE);
+ if (!has_resource_type_) {
+ // The resource variant of the op may be placed on non-CPU devices, but
+ // this allocation is always on the host. Fortunately we don't need it in
+ // the resource case.
+ OP_REQUIRES_OK(context,
+ context->allocate_persistent(DT_STRING, TensorShape({2}),
+ &handle_, nullptr));
+ }
}
// The resource is deleted from the resource manager only when it is private
@@ -89,12 +95,14 @@ class ResourceOpKernel : public OpKernel {
return;
}
- auto h = handle_.AccessTensor(context)->template flat<string>();
- h(0) = cinfo_.container();
- h(1) = cinfo_.name();
+ if (!has_resource_type_) {
+ auto h = handle_.AccessTensor(context)->template flat<string>();
+ h(0) = cinfo_.container();
+ h(1) = cinfo_.name();
+ }
resource_ = resource;
}
- if (context->expected_output_dtype(0) == DT_RESOURCE) {
+ if (has_resource_type_) {
OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
context, 0, cinfo_.container(), cinfo_.name(),
MakeTypeIndex<T>()));
@@ -122,6 +130,9 @@ class ResourceOpKernel : public OpKernel {
virtual Status VerifyResource(T* resource) { return Status::OK(); }
PersistentTensor handle_ GUARDED_BY(mu_);
+
+ // Is the output of the operator of type DT_RESOURCE?
+ bool has_resource_type_;
};
} // namespace tensorflow
diff --git a/tensorflow/core/framework/stats_aggregator.h b/tensorflow/core/framework/stats_aggregator.h
index 8002d9291c..4a18efc940 100644
--- a/tensorflow/core/framework/stats_aggregator.h
+++ b/tensorflow/core/framework/stats_aggregator.h
@@ -57,6 +57,10 @@ class StatsAggregator {
// interface. It is possible that not all implementations will support
// encoding their state as a protocol buffer.
virtual void EncodeToProto(Summary* out_summary) = 0;
+
+ // Increment the `label` cell of metrics mapped with `name` by given `value`.
+ virtual void IncrementCounter(const string& name, const string& label,
+ int64 val) = 0;
};
// A `StatsAggregatorResource` wraps a shareable `StatsAggregator` as a resource
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index bdeb5c66fc..653b088b1d 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -161,6 +161,8 @@ bool IsExit(const NodeDef& node) {
return op == "Exit" || op == "RefExit";
}
+bool IsExp(const NodeDef& node) { return node.op() == "Exp"; }
+
bool IsFill(const NodeDef& node) { return node.op() == "Fill"; }
bool IsFloorDiv(const NodeDef& node) { return node.op() == "FloorDiv"; }
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 2de7d8cc9a..94439265c9 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -60,6 +60,7 @@ bool IsEluGrad(const NodeDef& node);
bool IsEnter(const NodeDef& node);
bool IsEqual(const NodeDef& node);
bool IsExit(const NodeDef& node);
+bool IsExp(const NodeDef& node);
bool IsFill(const NodeDef& node);
bool IsFloorDiv(const NodeDef& node);
bool IsFloorMod(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index d8c5d09c4d..72ca3c3fa2 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -178,6 +178,42 @@ NodeDef* GetTailOfIdempotentChain(
is_idempotent_non_branching);
}
+// GetElementUnexhaustive tries to get the value of an element in a tensor and
+// turn it into complex128 type. It only check for a limited number of data
+// types, so it's unexhaustive.
+bool GetElementUnexhaustive(const Tensor& t, int i, const std::set<int>& dtypes,
+ complex128* element) {
+ if (dtypes.find(t.dtype()) == dtypes.end()) return false;
+ switch (t.dtype()) {
+ case DT_BFLOAT16:
+ *element = complex128(t.flat<bfloat16>()(i));
+ return true;
+ case DT_HALF:
+ *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0);
+ return true;
+ case DT_INT32:
+ *element = complex128(t.flat<int32>()(i));
+ return true;
+ case DT_INT64:
+ *element = complex128(t.flat<int64>()(i));
+ return true;
+ case DT_FLOAT:
+ *element = complex128(t.flat<float>()(i));
+ return true;
+ case DT_DOUBLE:
+ *element = complex128(t.flat<double>()(i));
+ return true;
+ case DT_COMPLEX64:
+ *element = complex128(t.flat<complex64>()(i));
+ return true;
+ case DT_COMPLEX128:
+ *element = t.flat<complex128>()(i);
+ return true;
+ default:
+ return false;
+ }
+}
+
// Graph optimizer context extension specific to ArithmeticOptimizer.
struct ArithmeticOptimizerContext {
explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify)
@@ -2361,7 +2397,13 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
complex128 prev, curr;
for (int i = 0; i < pow.NumElements(); ++i) {
- TF_RETURN_IF_ERROR(GetElement(pow, i, &curr));
+ if (!GetElementUnexhaustive(pow, i,
+ {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
+ DT_COMPLEX64, DT_COMPLEX128},
+ &curr)) {
+ // input data type is not supported by Pow. Skip.
+ return Status::OK();
+ }
if (i != 0 && curr != prev) {
// pow has different values on different elements. Skip.
return Status::OK();
@@ -2432,31 +2474,6 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
}
private:
- Status GetElement(const Tensor& t, int i, complex128* element) {
- switch (t.dtype()) {
- case DT_INT32:
- *element = complex128(t.flat<int32>()(i));
- return Status::OK();
- case DT_INT64:
- *element = complex128(t.flat<int64>()(i));
- return Status::OK();
- case DT_FLOAT:
- *element = complex128(t.flat<float>()(i));
- return Status::OK();
- case DT_DOUBLE:
- *element = complex128(t.flat<double>()(i));
- return Status::OK();
- case DT_COMPLEX64:
- *element = complex128(t.flat<complex64>()(i));
- return Status::OK();
- case DT_COMPLEX128:
- *element = t.flat<complex128>()(i);
- return Status::OK();
- default:
- return errors::InvalidArgument("Invalid data type: ", t.dtype());
- }
- }
-
Status SetElementToOne(int i, Tensor* t) {
switch (t->dtype()) {
case DT_INT32:
@@ -2544,7 +2561,10 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
}
complex128 element;
for (int k = 0; k < constant.NumElements(); ++k) {
- if (!GetElement(constant, k, &element)) {
+ if (!GetElementUnexhaustive(constant, k,
+ {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
+ DT_COMPLEX64, DT_COMPLEX128},
+ &element)) {
// input data type is not supported by log1p. Skip.
return Status::OK();
}
@@ -2569,30 +2589,81 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
}
return Status::OK();
}
+};
- bool GetElement(const Tensor& t, int i, complex128* element) {
- switch (t.dtype()) {
- case DT_BFLOAT16:
- *element = complex128(t.flat<bfloat16>()(i));
- return true;
- case DT_HALF:
- *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0);
- return true;
- case DT_FLOAT:
- *element = complex128(t.flat<float>()(i));
- return true;
- case DT_DOUBLE:
- *element = complex128(t.flat<double>()(i));
- return true;
- case DT_COMPLEX64:
- *element = complex128(t.flat<complex64>()(i));
- return true;
- case DT_COMPLEX128:
- *element = t.flat<complex128>()(i);
- return true;
- default:
- return false;
+class ConvertExpm1Stage : public ArithmeticOptimizerStage {
+ public:
+ explicit ConvertExpm1Stage(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("ConvertExpm1", ctx, ctx_ext) {}
+ ~ConvertExpm1Stage() override = default;
+
+ bool IsSupported(const NodeDef* node) const override { return IsExp(*node); }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ NodeDef* input;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
+ if (!IsSub(*input)) {
+ return Status::OK();
}
+
+ if (ctx().graph_properties->GetInputProperties(input->name()).size() < 2) {
+ return Status::OK();
+ }
+
+ const auto& t =
+ ctx().graph_properties->GetInputProperties(input->name())[0];
+ const auto& c =
+ ctx().graph_properties->GetInputProperties(input->name())[1];
+ for (int k = 0; k < c.shape().dim_size(); ++k) {
+ // Skip if c shape is not fully determined.
+ if (c.shape().dim(k).size() < 0) {
+ return Status::OK();
+ }
+ }
+ TensorShapeProto broadcast_shape;
+ if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
+ return Status::OK();
+ }
+ if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
+ // skip if the non-constant tensor doesn't have the same shape after
+ // broadcast.
+ return Status::OK();
+ }
+ if (TensorShape::IsValid(c.shape()) && c.has_value()) {
+ Tensor constant(c.dtype(), c.shape());
+ if (!constant.FromProto(c.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ c.value().DebugString());
+ }
+ complex128 element;
+ for (int k = 0; k < constant.NumElements(); ++k) {
+ if (!GetElementUnexhaustive(constant, k,
+ {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
+ DT_COMPLEX64, DT_COMPLEX128},
+ &element)) {
+ // input data type is not supported by expm1. Skip.
+ return Status::OK();
+ }
+ if (element != complex128(1)) {
+ // current element is not 1. Skip.
+ return Status::OK();
+ }
+ }
+ NodeDef *x, *y;
+ TF_RETURN_IF_ERROR(GetInputNode(input->input(0), &x));
+ TF_RETURN_IF_ERROR(GetInputNode(input->input(1), &y));
+ node->set_op("Expm1");
+ node->set_input(0, input->input(0));
+ node->add_input(AsControlDependency(y->name()));
+ ForwardControlDependencies(node, {input});
+
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(input);
+ AddToOptimizationQueue(x);
+ AddToOptimizationQueue(y);
+ }
+ return Status::OK();
}
};
@@ -2928,6 +2999,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext);
if (options_.optimize_max_or_min_of_monotonic)
pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext);
+ if (options_.convert_expm1)
+ pipeline.AddStage<ConvertExpm1Stage>(ctx, ctx_ext);
VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
<< str_util::Join(pipeline.StageNames(), ", ");
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 824ef35ef6..45a5f65b81 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -77,6 +77,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool simplify_aggregation = true;
bool convert_pow = true;
bool convert_log1p = true;
+ bool convert_expm1 = true;
// Choose which arithmetic optimizer stages will be enabled for a given
// optimization level by default.
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index d0e6b04679..3f6c04a5b5 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -274,6 +274,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
DisableAllStages(optimizer);
optimizer->options_.optimize_max_or_min_of_monotonic = true;
}
+
+ void EnableOnlyExpm1(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.convert_expm1 = true;
+ }
};
TEST_F(ArithmeticOptimizerTest, NoOp) {
@@ -2533,6 +2538,43 @@ TEST_F(ArithmeticOptimizerTest, Log1p) {
CompareGraphs(want, got);
}
+TEST_F(ArithmeticOptimizerTest, Expm1) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ auto x1 = ops::Const(s.WithOpName("x1"), {2.0f, 2.0f}, {1, 2});
+ auto x2 = ops::Const(s.WithOpName("x2"), {1.0f, 1.0f}, {1, 2});
+ auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2});
+ auto s12 = ops::Sub(s.WithOpName("s12").WithControlDependencies(x3), x1, x2);
+ auto s23 = ops::Sub(s.WithOpName("s23"), x2, x3);
+ Output out1 = ops::Exp(s.WithOpName("out1"), s12);
+ Output out2 = ops::Exp(s.WithOpName("out2"), s23);
+
+ GrapplerItem item;
+ item.fetch = {"out1", "out2"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(2, tensors_expected.size());
+
+ GraphDef got;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyExpm1(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &got);
+ auto tensors = EvaluateNodes(got, item.fetch);
+ EXPECT_EQ(2, tensors.size());
+
+ GraphDef want;
+ AddNode("x1", "Const", {}, {}, &want);
+ AddNode("x2", "Const", {}, {}, &want);
+ AddNode("x3", "Const", {}, {}, &want);
+ AddNode("s23", "Sub", {"x2", "x3"}, {}, &want);
+ AddNode("out1", "Expm1",
+ {"x1", AsControlDependency("x2"), AsControlDependency("x3")}, {},
+ &want);
+ AddNode("out2", "Exp", {"s23"}, {}, &want);
+
+ CompareGraphs(want, got);
+}
+
TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
index 00f66c9bc1..bc717d5eeb 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
@@ -23,9 +23,7 @@ namespace grappler {
namespace graph_utils {
namespace {
-class GraphUtilsTest : public ::testing::Test {};
-
-TEST_F(GraphUtilsTest, AddScalarConstNodeBool) {
+TEST(GraphUtilsTest, AddScalarConstNodeBool) {
GraphDef graph;
NodeDef* bool_node;
TF_EXPECT_OK(AddScalarConstNode<bool>(true, &graph, &bool_node));
@@ -33,7 +31,7 @@ TEST_F(GraphUtilsTest, AddScalarConstNodeBool) {
EXPECT_EQ(bool_node->attr().at("value").tensor().bool_val(0), true);
}
-TEST_F(GraphUtilsTest, AddScalarConstNodeDouble) {
+TEST(GraphUtilsTest, AddScalarConstNodeDouble) {
GraphDef graph;
NodeDef* double_node;
TF_EXPECT_OK(AddScalarConstNode<double>(3.14, &graph, &double_node));
@@ -41,7 +39,7 @@ TEST_F(GraphUtilsTest, AddScalarConstNodeDouble) {
EXPECT_FLOAT_EQ(double_node->attr().at("value").tensor().double_val(0), 3.14);
}
-TEST_F(GraphUtilsTest, AddScalarConstNodeFloat) {
+TEST(GraphUtilsTest, AddScalarConstNodeFloat) {
GraphDef graph;
NodeDef* float_node;
TF_EXPECT_OK(AddScalarConstNode<float>(3.14, &graph, &float_node));
@@ -49,7 +47,7 @@ TEST_F(GraphUtilsTest, AddScalarConstNodeFloat) {
EXPECT_FLOAT_EQ(float_node->attr().at("value").tensor().float_val(0), 3.14);
}
-TEST_F(GraphUtilsTest, AddScalarConstNodeInt) {
+TEST(GraphUtilsTest, AddScalarConstNodeInt) {
GraphDef graph;
NodeDef* int_node;
TF_EXPECT_OK(AddScalarConstNode<int>(42, &graph, &int_node));
@@ -57,7 +55,7 @@ TEST_F(GraphUtilsTest, AddScalarConstNodeInt) {
EXPECT_EQ(int_node->attr().at("value").tensor().int_val(0), 42);
}
-TEST_F(GraphUtilsTest, AddScalarConstNodeInt64) {
+TEST(GraphUtilsTest, AddScalarConstNodeInt64) {
GraphDef graph;
NodeDef* int64_node;
TF_EXPECT_OK(AddScalarConstNode<int64>(42, &graph, &int64_node));
@@ -65,7 +63,7 @@ TEST_F(GraphUtilsTest, AddScalarConstNodeInt64) {
EXPECT_EQ(int64_node->attr().at("value").tensor().int64_val(0), 42);
}
-TEST_F(GraphUtilsTest, AddScalarConstNodeString) {
+TEST(GraphUtilsTest, AddScalarConstNodeString) {
GraphDef graph;
NodeDef* string_node;
TF_EXPECT_OK(AddScalarConstNode<StringPiece>("hello", &graph, &string_node));
@@ -73,7 +71,7 @@ TEST_F(GraphUtilsTest, AddScalarConstNodeString) {
EXPECT_EQ(string_node->attr().at("value").tensor().string_val(0), "hello");
}
-TEST_F(GraphUtilsTest, Compare) {
+TEST(GraphUtilsTest, Compare) {
GraphDef graphA;
GraphDef graphB;
EXPECT_TRUE(Compare(graphA, graphB));
@@ -88,7 +86,7 @@ TEST_F(GraphUtilsTest, Compare) {
EXPECT_TRUE(Compare(graphA, graphB));
}
-TEST_F(GraphUtilsTest, ContainsNodeWithName) {
+TEST(GraphUtilsTest, ContainsNodeWithName) {
GraphDef graph;
EXPECT_TRUE(!ContainsNodeWithName("A", graph));
@@ -100,7 +98,7 @@ TEST_F(GraphUtilsTest, ContainsNodeWithName) {
EXPECT_TRUE(!ContainsNodeWithName("A", graph));
}
-TEST_F(GraphUtilsTest, ContainsNodeWithOp) {
+TEST(GraphUtilsTest, ContainsNodeWithOp) {
GraphDef graph;
EXPECT_TRUE(!ContainsNodeWithOp("OpA", graph));
@@ -112,7 +110,7 @@ TEST_F(GraphUtilsTest, ContainsNodeWithOp) {
EXPECT_TRUE(!ContainsNodeWithOp("OpA", graph));
}
-TEST_F(GraphUtilsTest, FindNodeWithName) {
+TEST(GraphUtilsTest, FindNodeWithName) {
GraphDef graph;
EXPECT_EQ(FindNodeWithName("A", graph), -1);
@@ -124,7 +122,7 @@ TEST_F(GraphUtilsTest, FindNodeWithName) {
EXPECT_EQ(FindNodeWithName("A", graph), -1);
}
-TEST_F(GraphUtilsTest, FindNodeWithOp) {
+TEST(GraphUtilsTest, FindNodeWithOp) {
GraphDef graph;
EXPECT_EQ(FindNodeWithOp("OpA", graph), -1);
@@ -136,7 +134,7 @@ TEST_F(GraphUtilsTest, FindNodeWithOp) {
EXPECT_EQ(FindNodeWithOp("OpA", graph), -1);
}
-TEST_F(GraphUtilsTest, SetUniqueName) {
+TEST(GraphUtilsTest, SetUniqueName) {
GraphDef graph;
NodeDef* node1;
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index d3710a4b5c..3e66d6412a 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -368,6 +368,7 @@ cc_library(
cc_library(
name = "queue_op",
+ srcs = ["queue_op.cc"],
hdrs = ["queue_op.h"],
deps = [
":queue_base",
@@ -1885,9 +1886,10 @@ cc_library(
name = "fifo_queue",
srcs = ["fifo_queue.cc"],
hdrs = ["fifo_queue.h"],
- visibility = ["//visibility:private"],
+ visibility = [":friends"],
deps = [
":queue_base",
+ ":queue_op",
":typed_queue",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -3919,6 +3921,7 @@ tf_cc_test(
cc_library(
name = "sparse",
deps = [
+ ":deserialize_sparse_variant_op",
":serialize_sparse_op",
":sparse_add_grad_op",
":sparse_add_op",
@@ -4073,6 +4076,15 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "deserialize_sparse_variant_op",
+ prefix = "deserialize_sparse_variant_op",
+ deps = SPARSE_DEPS + [
+ ":reshape_util",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+tf_kernel_library(
name = "sparse_tensors_map_ops",
prefix = "sparse_tensors_map_ops",
deps = SPARSE_DEPS,
@@ -5083,6 +5095,7 @@ filegroup(
"padding_fifo_queue.cc",
"padding_fifo_queue_op.cc",
"queue_base.cc",
+ "queue_op.cc",
"queue_ops.cc",
"random_op.cc",
"reduction_ops_all.cc",
diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc
index fe1a1ba5a3..a888422d49 100644
--- a/tensorflow/core/kernels/constant_op.cc
+++ b/tensorflow/core/kernels/constant_op.cc
@@ -297,7 +297,8 @@ class ZerosLikeOp : public OpKernel {
errors::InvalidArgument("ZerosLike non-scalar Tensor with "
"dtype=DT_VARIANT is not supported."));
const Variant& v = input.scalar<Variant>()();
- Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({}));
+ Tensor out(ctx->device()->GetAllocator(AllocatorAttributes()), DT_VARIANT,
+ TensorShape({}));
Variant* out_v = &(out.scalar<Variant>()());
OP_REQUIRES_OK(ctx, UnaryOpVariant<Device>(
ctx, ZEROS_LIKE_VARIANT_UNARY_OP, v, out_v));
diff --git a/tensorflow/core/kernels/data/slide_dataset_op.cc b/tensorflow/core/kernels/data/slide_dataset_op.cc
index c17e9343ea..07cc91f9d5 100644
--- a/tensorflow/core/kernels/data/slide_dataset_op.cc
+++ b/tensorflow/core/kernels/data/slide_dataset_op.cc
@@ -40,9 +40,8 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES(
ctx, window_size > 0,
errors::InvalidArgument("Window size must be greater than zero."));
- OP_REQUIRES(
- ctx, stride > 0,
- errors::InvalidArgument("Stride must be greater than zero."));
+ OP_REQUIRES(ctx, stride > 0,
+ errors::InvalidArgument("Stride must be greater than zero."));
if (stride == window_size) {
LOG(WARNING) << "stride: " << stride
<< " is equal to window_size: " << window_size
diff --git a/tensorflow/core/kernels/data/stats_aggregator_ops.cc b/tensorflow/core/kernels/data/stats_aggregator_ops.cc
index 33a56b2eb5..b133cfab54 100644
--- a/tensorflow/core/kernels/data/stats_aggregator_ops.cc
+++ b/tensorflow/core/kernels/data/stats_aggregator_ops.cc
@@ -20,11 +20,25 @@ limitations under the License.
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/lib/histogram/histogram.h"
+#include "tensorflow/core/lib/monitoring/counter.h"
+#include "tensorflow/core/lib/monitoring/gauge.h"
+#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
namespace {
+static mutex* get_counters_map_lock() {
+ static mutex counters_map_lock(LINKER_INITIALIZED);
+ return &counters_map_lock;
+}
+
+static std::unordered_map<string, monitoring::Counter<1>*>* get_counters_map() {
+ static std::unordered_map<string, monitoring::Counter<1>*>* counters_map =
+ new std::unordered_map<string, monitoring::Counter<1>*>;
+ return counters_map;
+}
+
class StatsAggregatorImpl : public StatsAggregator {
public:
StatsAggregatorImpl() {}
@@ -61,6 +75,21 @@ class StatsAggregatorImpl : public StatsAggregator {
}
}
+ void IncrementCounter(const string& name, const string& label,
+ int64 val) override {
+ mutex_lock l(*get_counters_map_lock());
+ auto counters_map = get_counters_map();
+ if (counters_map->find(name) == counters_map->end()) {
+ counters_map->emplace(
+ name, monitoring::Counter<1>::New(
+ /*streamz name*/ "/tensorflow/" + name,
+ /*streamz description*/
+ name + " generated or consumed by the component.",
+ /*streamz label name*/ "component_descriptor"));
+ }
+ counters_map->at(name)->GetCell(label)->IncrementBy(val);
+ }
+
private:
mutex mu_;
std::unordered_map<string, histogram::Histogram> histograms_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/kernels/data/stats_dataset_ops.cc b/tensorflow/core/kernels/data/stats_dataset_ops.cc
index 3e0a6ae049..a537e7e68f 100644
--- a/tensorflow/core/kernels/data/stats_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/stats_dataset_ops.cc
@@ -316,10 +316,14 @@ class FeatureStatsDatasetOp : public UnaryDatasetOpKernel {
// changes to parse_example() where it returns stats as well.
for (int i = 0; i < record_t.size(); ++i) {
if (example.ParseFromString(record_t(i))) {
+ stats_aggregator->IncrementCounter("examples_count", "trainer",
+ 1);
AddStatsFeatures(example, stats_aggregator);
} else {
SequenceExample sequence_example;
if (sequence_example.ParseFromString(record_t(i))) {
+ stats_aggregator->IncrementCounter("sequence_examples_count",
+ "trainer", 1);
AddStatsFeatures(sequence_example, stats_aggregator);
}
}
@@ -360,8 +364,11 @@ class FeatureStatsDatasetOp : public UnaryDatasetOpKernel {
int feature_values_list_size_sum = 0;
for (const auto& feature : example.features().feature()) {
+ stats_aggregator->IncrementCounter("features_count", "trainer", 1);
feature_values_list_size_sum += AddStatsFeatureValues(feature.second);
}
+ stats_aggregator->IncrementCounter("feature_values_count", "trainer",
+ feature_values_list_size_sum);
stats_aggregator->AddToHistogram(
strings::StrCat(dataset()->tag_, ":feature-values"),
{static_cast<double>(feature_values_list_size_sum)});
@@ -378,16 +385,20 @@ class FeatureStatsDatasetOp : public UnaryDatasetOpKernel {
int feature_values_list_size_sum = 0;
for (const auto& feature : example.context().feature()) {
+ stats_aggregator->IncrementCounter("features_count", "trainer", 1);
feature_values_list_size_sum += AddStatsFeatureValues(feature.second);
}
for (const auto& feature_list :
example.feature_lists().feature_list()) {
+ stats_aggregator->IncrementCounter("feature_lists_count", "reainer",
+ 1);
for (const auto& feature : feature_list.second.feature()) {
feature_values_list_size_sum += AddStatsFeatureValues(feature);
}
}
-
+ stats_aggregator->IncrementCounter("feature_values_count", "trainer",
+ feature_values_list_size_sum);
stats_aggregator->AddToHistogram(
strings::StrCat(dataset()->tag_, ":feature-values"),
{static_cast<double>(feature_values_list_size_sum)});
diff --git a/tensorflow/core/kernels/deserialize_sparse_variant_op.cc b/tensorflow/core/kernels/deserialize_sparse_variant_op.cc
new file mode 100644
index 0000000000..fce3029e4e
--- /dev/null
+++ b/tensorflow/core/kernels/deserialize_sparse_variant_op.cc
@@ -0,0 +1,372 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/variant.h"
+#include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+
+namespace tensorflow {
+
+namespace {
+
+class DeserializeSparseOp : public OpKernel {
+ public:
+ explicit DeserializeSparseOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+
+ OP_REQUIRES(
+ context, input.dims() > 0,
+ errors::InvalidArgument("Serialized sparse should have non-zero rank ",
+ input.shape().DebugString()));
+ OP_REQUIRES(context, input.shape().dim_size(input.dims() - 1) == 3,
+ errors::InvalidArgument(
+ "Serialized sparse should have 3 as the last dimension ",
+ input.shape().DebugString()));
+
+ // `input_dims_to_stack` is the number of dimensions that will be added to
+ // each of the elements before they are concatenated into the output.
+ const int64 input_dims_to_stack = input.dims() - 1;
+ int num_sparse_tensors = 1;
+ for (int i = 0; i < input_dims_to_stack; ++i) {
+ num_sparse_tensors *= input.shape().dim_size(i);
+ }
+
+ if (num_sparse_tensors == 1 && input_dims_to_stack == 0) {
+ // Special case with a single sparse tensor, and no dimensions to add
+ // to the output indices. We can return the boxed tensors directly (after
+ // validating them).
+ const Tensor* output_indices;
+ const Tensor* output_values;
+ const Tensor* output_shape;
+ const auto& input_as_vec = input.vec<Variant>();
+ int64 total_non_zeros;
+ OP_REQUIRES_OK(context, GetAndValidateSparseTensorShape(
+ input_as_vec(1), input_as_vec(2), 0,
+ &output_shape, &total_non_zeros));
+ OP_REQUIRES_OK(context, GetAndValidateSparseTensorIndicesAndValues(
+ input_as_vec(0), input_as_vec(1), 0,
+ output_shape->NumElements(), &output_indices,
+ &output_values));
+ context->set_output(0, *output_indices);
+ context->set_output(1, *output_values);
+ context->set_output(2, *output_shape);
+ return;
+ }
+
+ OP_REQUIRES(
+ context, num_sparse_tensors > 0,
+ errors::InvalidArgument(
+ "Serialized sparse should have at least 1 serialized tensor, "
+ "but has a zero dimension ",
+ input.shape().DebugString()));
+
+ const auto& input_as_matrix = input.flat_inner_dims<Variant, 2>();
+
+ // Compute the output "dense shape" of and number of non-zero elements in
+ // the stacked sparse tensors. Given an input of shape (S_0, ...,
+ // S_{input_dims_to_stack-1}, 3), and an element of dense shape (E_0, ...
+ // E_n), the output dense shape will be (S_0, ...,
+ // S_{input_dims_to_stack-1}, E_0, ..., E_n).
+ Tensor* output_shape;
+ int64 total_non_zeros = 0;
+
+ // Allocate and build the initial output shape based on the element shape of
+ // the 0th sparse tensor in the input.
+ //
+ // NOTE(mrry): We define `element_shape` as a `const Tensor*` rather than a
+ // `Tensor` to avoid the overhead of allocating and deallocating a `Tensor`
+ // on the stack. While the per-`Tensor` cost is small, this op can unbox a
+ // large number of tensors (3 per batch element) and these fixed overheads
+ // dominate when the number of non-zeros per element is small.
+ const Tensor* element_shape;
+ OP_REQUIRES_OK(context, GetAndValidateSparseTensorShape(
+ input_as_matrix(0, 1), input_as_matrix(0, 2), 0,
+ &element_shape, &total_non_zeros));
+ OP_REQUIRES_OK(context,
+ context->allocate_output(
+ 2, {input_dims_to_stack + element_shape->NumElements()},
+ &output_shape));
+ const auto element_shape_vec = element_shape->vec<int64>();
+ auto output_shape_vec = output_shape->vec<int64>();
+ output_shape_vec(0) = num_sparse_tensors;
+ for (int64 j = 0; j < input_dims_to_stack; ++j) {
+ output_shape_vec(j) = input.dim_size(j);
+ }
+ for (int64 j = 0; j < element_shape->NumElements(); ++j) {
+ output_shape_vec(j + input_dims_to_stack) = element_shape_vec(j);
+ }
+
+ // Accumulate the number of non-zero elements from the remaining sparse
+ // tensors, and validate that they have compatible dense shapes.
+ //
+ // NOTE(mrry): For compatibility with the implementations of
+ // DeserializeManySparse, and many ops that generate SparseTensors to batch
+ // that do not have a fixed dense_shape (e.g. `tf.parse_single_example()`),
+ // we compute the maximum in each dimension to find the smallest dense_shape
+ // that bounds all of the input SparseTensors.
+ for (int i = 1; i < num_sparse_tensors; ++i) {
+ int64 num_non_zeros;
+ OP_REQUIRES_OK(context, GetAndValidateSparseTensorShape(
+ input_as_matrix(i, 1), input_as_matrix(i, 2),
+ i, &element_shape, &num_non_zeros));
+ total_non_zeros += num_non_zeros;
+ OP_REQUIRES(
+ context,
+ output_shape->NumElements() - input_dims_to_stack ==
+ element_shape->NumElements(),
+ errors::InvalidArgument(
+ "Inconsistent shape across SparseTensors: rank prior to "
+ "SparseTensor[",
+ i, "] was: ", output_shape->NumElements() - input_dims_to_stack,
+ " but rank of SparseTensor[", i,
+ "] is: ", element_shape->NumElements()));
+ const auto element_shape_vec = element_shape->vec<int64>();
+ for (int j = 0; j < element_shape->NumElements(); ++j) {
+ output_shape_vec(j + input_dims_to_stack) = std::max(
+ output_shape_vec(j + input_dims_to_stack), element_shape_vec(j));
+ }
+ }
+
+ // Compute the output "indices" matrix and "values" vector.
+ Tensor* output_indices;
+ Tensor* output_values;
+
+ const int output_rank = output_shape->NumElements();
+ OP_REQUIRES_OK(context,
+ context->allocate_output(
+ 0, {static_cast<int64>(total_non_zeros), output_rank},
+ &output_indices));
+ OP_REQUIRES_OK(
+ context, context->allocate_output(
+ 1, {static_cast<int64>(total_non_zeros)}, &output_values));
+
+ // The bulk of the work in this method involves building the output indices
+ // in a tight loop. For cache friendliness, we generate the indices in the
+ // order that they will be laid out in memory. We use raw pointers instead
+ // of Eigen element/slice indexing methods, to access the underlying index
+ // buffer to minimize the amount of work in that tight loop.
+ int64* output_indices_data = output_indices->matrix<int64>().data();
+ size_t current_row = 0;
+
+ for (int i = 0; i < num_sparse_tensors; ++i) {
+ const Tensor* element_indices;
+ const Tensor* element_values;
+ OP_REQUIRES_OK(context, this->GetAndValidateSparseTensorIndicesAndValues(
+ input_as_matrix(i, 0), input_as_matrix(i, 1),
+ i, output_rank - input_dims_to_stack,
+ &element_indices, &element_values));
+
+ const size_t num_index_rows = element_values->NumElements();
+
+ // An empty sparse tensor in the input will generate no data
+ // in the output. We short-circuit the rest of the iteration to avoid
+ // triggering assertions in the Eigen when manipulating empty tensors (or
+ // slices of tensors).
+ if (num_index_rows == 0) continue;
+
+ const size_t start_row = current_row;
+ const size_t next_start_row = current_row + num_index_rows;
+
+ // NOTE(mrry): If the element is a scalar SparseTensor,
+ // `element_indices` will be an empty tensor, and this pointer will not
+ // be valid. However, we will not dereference the pointer in that case,
+ // because `input_dims_to_stack == output_rank`.
+ const int64* element_indices_data =
+ element_indices->matrix<int64>().data();
+
+ // Build the submatrix of `output_indices` for the i^th sparse tensor
+ // in the input.
+ //
+ // Each row of `output_indices` comprises `input_dims_to_stack` indices
+ // based on the position of the i^th sparse tensor in the input tensor,
+ // followed by the indices from the corresponding row in
+ // `element_indices`.
+ if (input_dims_to_stack == 1 && output_rank == 2) {
+ // We specialize this case because the compiler can generate
+ // more efficient code when the number of indices for each element is
+ // known statically. Since the most common use of this op is to
+ // serialize batches of SparseTensors, and the most common source of
+ // SparseTensors is the `tf.parse_single_example()` op, which generates
+ // 1-D SparseTensors, we statically unroll the loop for the rank 2
+ // output case.
+ for (; current_row < next_start_row; ++current_row) {
+ *output_indices_data++ = i;
+ *output_indices_data++ = *element_indices_data++;
+ }
+ } else {
+ // `sparse_tensor_index` is the tuple of indices that correspond to
+ // mapping the flat element index (`i`) back onto the stacked
+ // coordinates implied by the position of the i^th sparse tensor in the
+ // input tensor.
+ //
+ // We build `sparse_tensor_index` in reverse (innermost/minor dimension
+ // to outermost/major dimension). The `cumulative_product` represents
+ // the size of the inner subtensor for which `sparse_tensor_index` has
+ // already been built.
+ gtl::InlinedVector<int64, 4> sparse_tensor_index(input_dims_to_stack);
+ int cumulative_product = 1;
+ for (size_t j = 0; j < sparse_tensor_index.size(); ++j) {
+ size_t reverse_index = sparse_tensor_index.size() - j - 1;
+ sparse_tensor_index[reverse_index] =
+ (i / cumulative_product) % input.dim_size(reverse_index);
+ cumulative_product *= input.dim_size(reverse_index);
+ }
+ for (; current_row < next_start_row; ++current_row) {
+ for (int64 sparse_tensor_index_component : sparse_tensor_index) {
+ *output_indices_data++ = sparse_tensor_index_component;
+ }
+ for (size_t k = input_dims_to_stack; k < output_rank; ++k) {
+ *output_indices_data++ = *element_indices_data++;
+ }
+ }
+ }
+
+ // Build the subvector of `output_values` for the i^th sparse tensor
+ // in the input.
+ //
+ // NOTE(mrry): There is a potential optimization here where we use a T*
+ // to represent the current position in `output_values`, but it would
+ // require some rejigging of the template parameters.
+ // NOTE(mrry): Another potential optimization: if we know that this
+ // operation consumes its input, we could std::move non-primitive elements
+ // into the output and avoid a copy.
+ Eigen::DSizes<Eigen::DenseIndex, 1> values_start(start_row);
+ Eigen::DSizes<Eigen::DenseIndex, 1> values_sizes(num_index_rows);
+
+#define HANDLE_TYPE(T) \
+ case DataTypeToEnum<T>::value: { \
+ output_values->vec<T>().slice(values_start, values_sizes) = \
+ element_values->vec<T>(); \
+ break; \
+ }
+ switch (dtype_) {
+ TF_CALL_ALL_TYPES(HANDLE_TYPE);
+ TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
+#undef HANDLE_TYPE
+ default:
+ OP_REQUIRES_OK(
+ context, errors::Unimplemented(
+ "DeserializeSparse Unhandled data type: ", dtype_));
+ }
+ }
+ }
+
+ private:
+ Status GetAndValidateSparseTensorShape(const Variant& serialized_values,
+ const Variant& serialized_shape,
+ int index, const Tensor** output_shape,
+ int64* output_num_non_zeros) {
+ // Deserialize and validate the shape.
+ *output_shape = serialized_shape.get<Tensor>();
+ if (*output_shape == nullptr) {
+ return errors::InvalidArgument(
+ "Could not get a tensor from serialized_sparse[", index, ", 2]");
+ }
+ if ((*output_shape)->dtype() != DT_INT64) {
+ return errors::InvalidArgument(
+ "Expected serialized_sparse[", index,
+ ", 2] to be a vector of DT_INT64 but received dtype ",
+ DataTypeString((*output_shape)->dtype()));
+ }
+ if (!TensorShapeUtils::IsVector((*output_shape)->shape())) {
+ return errors::InvalidArgument(
+ "Expected serialized_sparse[", index,
+ ", 2] to be a shape vector but its shape is ",
+ (*output_shape)->shape().DebugString());
+ }
+ *output_num_non_zeros = serialized_values.get<Tensor>()->NumElements();
+ return Status::OK();
+ }
+
+ Status GetAndValidateSparseTensorIndicesAndValues(
+ const Variant& serialized_indices, const Variant& serialized_values,
+ int index, int expected_rank, const Tensor** output_indices,
+ const Tensor** output_values) {
+ // Deserialize and validate the indices.
+ *output_indices = serialized_indices.get<Tensor>();
+ if (*output_indices == nullptr) {
+ return errors::InvalidArgument(
+ "Could not get a tensor from serialized_sparse[", index, ", 0]");
+ }
+ if ((*output_indices)->dtype() != DT_INT64) {
+ return errors::InvalidArgument(
+ "Expected serialized_sparse[", index,
+ ", 0] to be a matrix of DT_INT64 but received dtype ",
+ DataTypeString((*output_indices)->dtype()));
+ }
+ if (!TensorShapeUtils::IsMatrix((*output_indices)->shape())) {
+ return errors::InvalidArgument(
+ "Expected serialized_sparse[", index,
+ ", 0] to represent an index matrix but received shape ",
+ (*output_indices)->shape().DebugString());
+ }
+ int64 num_entries = (*output_indices)->dim_size(0);
+ int rank = (*output_indices)->dim_size(1);
+ if (rank != expected_rank) {
+ return errors::InvalidArgument(
+ "Expected column counts of SparseTensor[", index,
+ "].indices to match size of SparseTensor[", index,
+ "].shape but they do not: ", rank, " vs. ", expected_rank);
+ }
+
+ // Deserialize and validate the values.
+ *output_values = serialized_values.get<Tensor>();
+ if (*output_values == nullptr) {
+ return errors::InvalidArgument(
+ "Could not get a tensor from serialized_sparse[", index, ", 1]");
+ }
+ if (!TensorShapeUtils::IsVector((*output_values)->shape())) {
+ return errors::InvalidArgument(
+ "Expected serialized_sparse[", index,
+ ", 1] to represent a values vector but received shape ",
+ (*output_values)->shape().DebugString());
+ }
+ if (dtype_ != (*output_values)->dtype()) {
+ return errors::InvalidArgument(
+ "Requested SparseTensor of type ", DataTypeString(dtype_),
+ " but SparseTensor[", index,
+ "].values.dtype() == ", DataTypeString((*output_values)->dtype()));
+ }
+ if (num_entries != (*output_values)->dim_size(0)) {
+ return errors::InvalidArgument(
+ "Expected row counts of SparseTensor[", index,
+ "].indices and SparseTensor[", index,
+ "].values to match but they do not: ", num_entries, " vs. ",
+ (*output_values)->dim_size(0));
+ }
+
+ return Status::OK();
+ }
+
+ DataType dtype_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("DeserializeSparse")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<Variant>("Tserialized"),
+ DeserializeSparseOp)
+
+} // namespace
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/fifo_queue.cc b/tensorflow/core/kernels/fifo_queue.cc
index a23478af5b..d6e859f1aa 100644
--- a/tensorflow/core/kernels/fifo_queue.cc
+++ b/tensorflow/core/kernels/fifo_queue.cc
@@ -366,4 +366,19 @@ Status FIFOQueue::MatchesNodeDef(const NodeDef& node_def) {
return Status::OK();
}
+// Defines a FIFOQueueOp, which produces a Queue (specifically, one
+// backed by FIFOQueue) that persists across different graph
+// executions, and sessions. Running this op produces a single-element
+// tensor of handles to Queues in the corresponding device.
+FIFOQueueOp::FIFOQueueOp(OpKernelConstruction* context)
+ : TypedQueueOp(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
+}
+
+Status FIFOQueueOp::CreateResource(QueueInterface** ret) {
+ FIFOQueue* queue = new FIFOQueue(capacity_, component_types_,
+ component_shapes_, cinfo_.name());
+ return CreateTypedQueue(queue, ret);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/fifo_queue.h b/tensorflow/core/kernels/fifo_queue.h
index f01d70924d..697ee81c39 100644
--- a/tensorflow/core/kernels/fifo_queue.h
+++ b/tensorflow/core/kernels/fifo_queue.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_FIFO_QUEUE_H_
-#define TENSORFLOW_KERNELS_FIFO_QUEUE_H_
+#ifndef TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_
+#define TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_
#include <deque>
#include <vector>
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/queue_op.h"
#include "tensorflow/core/kernels/typed_queue.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
@@ -69,6 +70,22 @@ class FIFOQueue : public TypedQueue<std::deque<PersistentTensor> > {
TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueue);
};
+// Defines a FIFOQueueOp, which produces a Queue (specifically, one
+// backed by FIFOQueue) that persists across different graph
+// executions, and sessions. Running this op produces a single-element
+// tensor of handles to Queues in the corresponding device.
+class FIFOQueueOp : public TypedQueueOp {
+ public:
+ explicit FIFOQueueOp(OpKernelConstruction* context);
+
+ private:
+ Status CreateResource(QueueInterface** ret) override
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ std::vector<TensorShape> component_shapes_;
+ TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueueOp);
+};
+
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_FIFO_QUEUE_H_
+#endif // TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_
diff --git a/tensorflow/core/kernels/fifo_queue_op.cc b/tensorflow/core/kernels/fifo_queue_op.cc
index b35bdbb2f0..80869768f1 100644
--- a/tensorflow/core/kernels/fifo_queue_op.cc
+++ b/tensorflow/core/kernels/fifo_queue_op.cc
@@ -13,50 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// See docs in ../ops/data_flow_ops.cc.
-
-#include <deque>
-#include <vector>
-
#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/resource_mgr.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/fifo_queue.h"
-#include "tensorflow/core/kernels/queue_base.h"
-#include "tensorflow/core/kernels/queue_op.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/thread_annotations.h"
-#include "tensorflow/core/platform/types.h"
namespace tensorflow {
-// Defines a FIFOQueueOp, which produces a Queue (specifically, one
-// backed by FIFOQueue) that persists across different graph
-// executions, and sessions. Running this op produces a single-element
-// tensor of handles to Queues in the corresponding device.
-class FIFOQueueOp : public TypedQueueOp {
- public:
- explicit FIFOQueueOp(OpKernelConstruction* context) : TypedQueueOp(context) {
- OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
- }
-
- private:
- Status CreateResource(QueueInterface** ret) override
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- FIFOQueue* queue = new FIFOQueue(capacity_, component_types_,
- component_shapes_, cinfo_.name());
- return CreateTypedQueue(queue, ret);
- }
-
- std::vector<TensorShape> component_shapes_;
- TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueueOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("FIFOQueue").Device(DEVICE_CPU), FIFOQueueOp);
REGISTER_KERNEL_BUILDER(Name("FIFOQueueV2").Device(DEVICE_CPU), FIFOQueueOp);
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index cede0b9dd6..1d0edb10b3 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -70,23 +70,25 @@ struct MklConvFwdParams {
memory::dims padding_left;
memory::dims padding_right;
- MklConvFwdParams(memory::dims src_dims,
- memory::dims filter_dims, memory::dims bias_dims,
- memory::dims dst_dims, memory::dims strides,
- memory::dims dilations, memory::dims padding_left,
- memory::dims padding_right) :
- src_dims(src_dims), filter_dims(filter_dims),
- bias_dims(bias_dims), dst_dims(dst_dims),
- strides(strides), dilations(dilations),
- padding_left(padding_left), padding_right(padding_right) {
- }
+ MklConvFwdParams(memory::dims src_dims, memory::dims filter_dims,
+ memory::dims bias_dims, memory::dims dst_dims,
+ memory::dims strides, memory::dims dilations,
+ memory::dims padding_left, memory::dims padding_right)
+ : src_dims(src_dims),
+ filter_dims(filter_dims),
+ bias_dims(bias_dims),
+ dst_dims(dst_dims),
+ strides(strides),
+ dilations(dilations),
+ padding_left(padding_left),
+ padding_right(padding_right) {}
};
template <typename T>
-class MklConv2DFwdPrimitive: public MklPrimitive {
+class MklConv2DFwdPrimitive : public MklPrimitive {
public:
- explicit MklConv2DFwdPrimitive(const MklConvFwdParams& convFwdDims) :
- cpu_engine_(engine::cpu, 0) {
+ explicit MklConv2DFwdPrimitive(const MklConvFwdParams& convFwdDims)
+ : cpu_engine_(engine::cpu, 0) {
context_.fwd_stream.reset(new stream(stream::kind::eager));
// create conv primitive
if (context_.conv_fwd == nullptr) {
@@ -101,8 +103,8 @@ class MklConv2DFwdPrimitive: public MklPrimitive {
// filter_data: input data buffer of filter (weights)
// bias_data: input data buffer of bias
// dst_data: output data buffer of dst
- void Execute(const T* src_data, const T* filter_data,
- const T* bias_data, const T* dst_data) {
+ void Execute(const T* src_data, const T* filter_data, const T* bias_data,
+ const T* dst_data) {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.filter_mem->set_data_handle(
@@ -126,8 +128,7 @@ class MklConv2DFwdPrimitive: public MklPrimitive {
// src_data: input data buffer of src
// filter_data: input data buffer of filter (weights)
// dst_data: output data buffer of dst
- void Execute(const T* src_data, const T* filter_data,
- const T* dst_data) {
+ void Execute(const T* src_data, const T* filter_data, const T* dst_data) {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.filter_mem->set_data_handle(
@@ -142,13 +143,9 @@ class MklConv2DFwdPrimitive: public MklPrimitive {
context_.dst_mem->set_data_handle(DummyData);
}
- memory::format GetSrcMemoryFormat() const {
- return context_.src_fmt;
- }
+ memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
- memory::format GetFilterMemoryFormat() const {
- return context_.filter_fmt;
- }
+ memory::format GetFilterMemoryFormat() const { return context_.filter_fmt; }
std::shared_ptr<mkldnn::convolution_forward::primitive_desc>
GetPrimitiveDesc() const {
@@ -184,43 +181,50 @@ class MklConv2DFwdPrimitive: public MklPrimitive {
std::shared_ptr<mkldnn::stream> fwd_stream;
std::vector<mkldnn::primitive> fwd_primitives;
- ConvFwdContext() :
- src_fmt(memory::format::any), filter_fmt(memory::format::any),
- src_mem(nullptr), filter_mem(nullptr), bias_mem(nullptr),
- dst_mem(nullptr), fwd_desc(nullptr),
- src_md(nullptr), filter_md(nullptr), bias_md(nullptr),
- fwd_pd(nullptr), conv_fwd(nullptr), fwd_stream(nullptr) {
- }
+ ConvFwdContext()
+ : src_fmt(memory::format::any),
+ filter_fmt(memory::format::any),
+ src_mem(nullptr),
+ filter_mem(nullptr),
+ bias_mem(nullptr),
+ dst_mem(nullptr),
+ fwd_desc(nullptr),
+ src_md(nullptr),
+ filter_md(nullptr),
+ bias_md(nullptr),
+ fwd_pd(nullptr),
+ conv_fwd(nullptr),
+ fwd_stream(nullptr) {}
};
void Setup(const MklConvFwdParams& convFwdDims) {
// create memory descriptors for convolution data w/ no specified format
- context_.src_md.reset(new memory::desc({convFwdDims.src_dims},
- MklDnnType<T>(), memory::format::any));
+ context_.src_md.reset(new memory::desc(
+ {convFwdDims.src_dims}, MklDnnType<T>(), memory::format::any));
- context_.filter_md.reset(new memory::desc({convFwdDims.filter_dims},
- MklDnnType<T>(), memory::format::any));
+ context_.filter_md.reset(new memory::desc(
+ {convFwdDims.filter_dims}, MklDnnType<T>(), memory::format::any));
- context_.dst_md.reset(new memory::desc({convFwdDims.dst_dims},
- MklDnnType<T>(), memory::format::any));
+ context_.dst_md.reset(new memory::desc(
+ {convFwdDims.dst_dims}, MklDnnType<T>(), memory::format::any));
if (!convFwdDims.bias_dims.empty())
- context_.bias_md.reset(new memory::desc({convFwdDims.bias_dims},
- MklDnnType<T>(), memory::format::any));
+ context_.bias_md.reset(new memory::desc(
+ {convFwdDims.bias_dims}, MklDnnType<T>(), memory::format::any));
// create a convolution
if (!convFwdDims.bias_dims.empty()) {
- context_.fwd_desc.reset(new convolution_forward::desc(prop_kind::forward,
- convolution_direct, *context_.src_md, *context_.filter_md,
- *context_.bias_md, *context_.dst_md,
+ context_.fwd_desc.reset(new convolution_forward::desc(
+ prop_kind::forward, convolution_direct, *context_.src_md,
+ *context_.filter_md, *context_.bias_md, *context_.dst_md,
convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left,
convFwdDims.padding_right, padding_kind::zero));
} else {
- context_.fwd_desc.reset(new convolution_forward::desc(prop_kind::forward,
- convolution_direct, *context_.src_md, *context_.filter_md,
- *context_.dst_md, convFwdDims.strides, convFwdDims.dilations,
- convFwdDims.padding_left, convFwdDims.padding_right,
- padding_kind::zero));
+ context_.fwd_desc.reset(new convolution_forward::desc(
+ prop_kind::forward, convolution_direct, *context_.src_md,
+ *context_.filter_md, *context_.dst_md, convFwdDims.strides,
+ convFwdDims.dilations, convFwdDims.padding_left,
+ convFwdDims.padding_right, padding_kind::zero));
}
context_.fwd_pd.reset(new convolution_forward::primitive_desc(
@@ -234,24 +238,26 @@ class MklConv2DFwdPrimitive: public MklPrimitive {
context_.fwd_pd.get()->weights_primitive_desc().desc().data.format);
// create memory primitive based on dummy data
- context_.src_mem.reset(new memory(
- context_.fwd_pd.get()->src_primitive_desc(), DummyData));
- context_.filter_mem.reset(new memory(
- context_.fwd_pd.get()->weights_primitive_desc(), DummyData));
- context_.dst_mem.reset(new memory(
- context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
+ context_.src_mem.reset(
+ new memory(context_.fwd_pd.get()->src_primitive_desc(), DummyData));
+ context_.filter_mem.reset(
+ new memory(context_.fwd_pd.get()->weights_primitive_desc(), DummyData));
+ context_.dst_mem.reset(
+ new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
// create convolution primitive and add it to net
if (!convFwdDims.bias_dims.empty()) {
- context_.bias_mem.reset(new memory({{{convFwdDims.bias_dims},
- MklDnnType<T>(), memory::format::x}, cpu_engine_}, DummyData));
- context_.conv_fwd.reset(new convolution_forward(
- *context_.fwd_pd, *context_.src_mem, *context_.filter_mem,
- *context_.bias_mem, *context_.dst_mem));
+ context_.bias_mem.reset(new memory(
+ {{{convFwdDims.bias_dims}, MklDnnType<T>(), memory::format::x},
+ cpu_engine_},
+ DummyData));
+ context_.conv_fwd.reset(new convolution_forward(
+ *context_.fwd_pd, *context_.src_mem, *context_.filter_mem,
+ *context_.bias_mem, *context_.dst_mem));
} else {
- context_.conv_fwd.reset(new convolution_forward(
- *context_.fwd_pd, *context_.src_mem,
- *context_.filter_mem, *context_.dst_mem));
+ context_.conv_fwd.reset(
+ new convolution_forward(*context_.fwd_pd, *context_.src_mem,
+ *context_.filter_mem, *context_.dst_mem));
}
context_.fwd_primitives.push_back(*context_.conv_fwd);
@@ -266,19 +272,19 @@ template <typename T>
class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
static MklConv2DFwdPrimitive<T>* Get(const MklConvFwdParams& convFwdDims) {
- MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr;
-
- // try to find a suitable one in pool
- conv2d_fwd = dynamic_cast<MklConv2DFwdPrimitive<T>*> (
- MklConv2DFwdPrimitiveFactory<T>::GetInstance().GetConv2DFwd(
- convFwdDims));
-
- if (conv2d_fwd == nullptr) {
- conv2d_fwd = new MklConv2DFwdPrimitive<T>(convFwdDims);
- MklConv2DFwdPrimitiveFactory<T>::GetInstance().SetConv2DFwd(
- convFwdDims, conv2d_fwd);
- }
- return conv2d_fwd;
+ MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr;
+
+ // try to find a suitable one in pool
+ conv2d_fwd = dynamic_cast<MklConv2DFwdPrimitive<T>*>(
+ MklConv2DFwdPrimitiveFactory<T>::GetInstance().GetConv2DFwd(
+ convFwdDims));
+
+ if (conv2d_fwd == nullptr) {
+ conv2d_fwd = new MklConv2DFwdPrimitive<T>(convFwdDims);
+ MklConv2DFwdPrimitiveFactory<T>::GetInstance().SetConv2DFwd(convFwdDims,
+ conv2d_fwd);
+ }
+ return conv2d_fwd;
}
private:
@@ -312,7 +318,7 @@ class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
return this->GetOp(key);
}
- void SetConv2DFwd(const MklConvFwdParams& convFwdDims, MklPrimitive *op) {
+ void SetConv2DFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) {
std::string key = CreateKey(convFwdDims);
this->SetOp(key, op);
}
@@ -865,22 +871,24 @@ class MklConv2DOp : public OpKernel {
dilations[kDilationW] -= 1;
// get a conv2d fwd from primitive pool
- MklConv2DFwdPrimitive<T> *conv2d_fwd = nullptr;
+ MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr;
if (biasEnabled) {
memory::dims bias_dims = {};
conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_dims);
MklConvFwdParams convFwdDims(src_dims, filter_dims, bias_dims,
- dst_dims_mkl_order, strides, dilations, padding_left, padding_right);
+ dst_dims_mkl_order, strides, dilations,
+ padding_left, padding_right);
conv2d_fwd = MklConv2DFwdPrimitiveFactory<T>::Get(convFwdDims);
} else {
MklConvFwdParams convFwdDims(src_dims, filter_dims, NONE_DIMS,
- dst_dims_mkl_order, strides, dilations, padding_left, padding_right);
+ dst_dims_mkl_order, strides, dilations,
+ padding_left, padding_right);
conv2d_fwd = MklConv2DFwdPrimitiveFactory<T>::Get(convFwdDims);
}
// allocate output tensors output_tensor and filter_out_tensor
- std::shared_ptr<mkldnn::convolution_forward::primitive_desc>
- conv_fwd_pd = conv2d_fwd->GetPrimitiveDesc();
+ std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_fwd_pd =
+ conv2d_fwd->GetPrimitiveDesc();
AllocateOutputTensor(context, *conv_fwd_pd,
dst_dims_mkl_order, tf_fmt, &dst_tensor);
Tensor* filter_out_tensor = nullptr;
@@ -892,26 +900,24 @@ class MklConv2DOp : public OpKernel {
// check whether src/filter need reorder
std::vector<primitive> net;
- T *src_data = nullptr;
+ T* src_data = nullptr;
if (src_md.data.format != conv2d_fwd->GetSrcMemoryFormat()) {
src.SetUsrMem(src_md, &src_tensor);
- src.CheckReorderToOpMem(
- conv_fwd_pd.get()->src_primitive_desc(), &net);
+ src.CheckReorderToOpMem(conv_fwd_pd.get()->src_primitive_desc(), &net);
src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
} else {
- src_data = static_cast<T*>(const_cast<T*>(
- src_tensor.flat<T>().data()));
+ src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
}
- T *filter_data = nullptr;
+ T* filter_data = nullptr;
if (filter_md.data.format != conv2d_fwd->GetFilterMemoryFormat()) {
filter.SetUsrMem(filter_md, &filter_tensor);
- filter.CheckReorderToOpMem(
- conv_fwd_pd.get()->weights_primitive_desc(),
- filter.GetTensorBuffer(filter_out_tensor), &net);
+ filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_primitive_desc(),
+ filter.GetTensorBuffer(filter_out_tensor),
+ &net);
filter_data = static_cast<T*>(filter.GetOpMem().get_data_handle());
} else {
- filter_data = static_cast<T*>(const_cast<T*>(
- filter_tensor.flat<T>().data()));
+ filter_data =
+ static_cast<T*>(const_cast<T*>(filter_tensor.flat<T>().data()));
}
stream(stream::kind::eager).submit(net).wait();
diff --git a/tensorflow/core/kernels/queue_op.cc b/tensorflow/core/kernels/queue_op.cc
new file mode 100644
index 0000000000..53f431ef3c
--- /dev/null
+++ b/tensorflow/core/kernels/queue_op.cc
@@ -0,0 +1,367 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/queue_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/queue_interface.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+QueueOp::QueueOp(OpKernelConstruction* context) : ResourceOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_));
+ if (capacity_ < 0) {
+ capacity_ = QueueBase::kUnbounded;
+ }
+ OP_REQUIRES_OK(context,
+ context->GetAttr("component_types", &component_types_));
+}
+
+void QueueOp::Compute(OpKernelContext* context) {
+ ResourceOpKernel<QueueInterface>::Compute(context);
+ mutex_lock l(mu_);
+ if (resource_ && context->track_allocations()) {
+ context->record_persistent_memory_allocation(resource_->MemoryUsed());
+ }
+}
+
+Status QueueOp::VerifyResource(QueueInterface* queue) {
+ return queue->MatchesNodeDef(def());
+}
+
+
+QueueOpKernel::QueueOpKernel(OpKernelConstruction* context)
+ : AsyncOpKernel(context) {}
+
+void QueueOpKernel::ComputeAsync(OpKernelContext* ctx, DoneCallback callback) {
+ QueueInterface* queue;
+ if (ctx->input_dtype(0) == DT_RESOURCE) {
+ OP_REQUIRES_OK_ASYNC(
+ ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &queue), callback);
+ } else {
+ OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue),
+ callback);
+ }
+ ComputeAsync(ctx, queue, [callback, queue]() {
+ queue->Unref();
+ callback();
+ });
+}
+
+QueueAccessOpKernel::QueueAccessOpKernel(OpKernelConstruction* context)
+ : QueueOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_));
+ // TODO(keveman): Enable timeout.
+ OP_REQUIRES(context, timeout_ == -1,
+ errors::InvalidArgument("Timeout not supported yet."));
+}
+
+// Defines an EnqueueOp, the execution of which enqueues a tuple of
+// tensors in the given Queue.
+//
+// The op has 1 + k inputs, where k is the number of components in the
+// tuples stored in the given Queue:
+// - Input 0: queue handle.
+// - Input 1: 0th element of the tuple.
+// - ...
+// - Input (1+k): kth element of the tuple.
+EnqueueOp::EnqueueOp(OpKernelConstruction* context)
+ : QueueAccessOpKernel(context) {}
+
+void EnqueueOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) {
+ DataTypeVector expected_inputs;
+ if (ctx->input_dtype(0) == DT_RESOURCE) {
+ expected_inputs.push_back(DT_RESOURCE);
+ } else {
+ expected_inputs.push_back(DT_STRING_REF);
+ }
+ for (DataType dt : queue->component_dtypes()) {
+ expected_inputs.push_back(dt);
+ }
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), callback);
+
+ QueueInterface::Tuple tuple;
+ OpInputList components;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
+ callback);
+ for (const Tensor& Tcomponent : components) {
+ tuple.push_back(Tcomponent);
+ }
+
+ OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateTuple(tuple), callback);
+ queue->TryEnqueue(tuple, ctx, callback);
+}
+
+// Defines an EnqueueManyOp, the execution of which slices each
+// component of a tuple of tensors along the 0th dimension, and
+// enqueues tuples of slices in the given Queue.
+//
+// The op has 1 + k inputs, where k is the number of components in the
+// tuples stored in the given Queue:
+// - Input 0: queue handle.
+// - Input 1: 0th element of the tuple.
+// - ...
+// - Input (1+k): kth element of the tuple.
+//
+// N.B. All tuple components must have the same size in the 0th
+// dimension.
+EnqueueManyOp::EnqueueManyOp(OpKernelConstruction* context)
+ : QueueAccessOpKernel(context) {}
+
+void EnqueueManyOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) {
+ DataTypeVector expected_inputs;
+ if (ctx->input_dtype(0) == DT_RESOURCE) {
+ expected_inputs.push_back(DT_RESOURCE);
+ } else {
+ expected_inputs.push_back(DT_STRING_REF);
+ }
+ for (DataType dt : queue->component_dtypes()) {
+ expected_inputs.push_back(dt);
+ }
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), callback);
+
+ QueueInterface::Tuple tuple;
+ OpInputList components;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
+ callback);
+ for (const Tensor& Tcomponent : components) {
+ tuple.push_back(Tcomponent);
+ }
+
+ OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateManyTuple(tuple), callback);
+ queue->TryEnqueueMany(tuple, ctx, callback);
+}
+
+EnqueueManyOp::~EnqueueManyOp() = default;
+
+// Defines a DequeueOp, the execution of which dequeues a tuple of
+// tensors from the given Queue.
+//
+// The op has one input, which is the handle of the appropriate
+// Queue. The op has k outputs, where k is the number of components in
+// the tuples stored in the given Queue, and output i is the ith
+// component of the dequeued tuple.
+DequeueOp::DequeueOp(OpKernelConstruction* context)
+ : QueueAccessOpKernel(context) {}
+
+void DequeueOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) {
+ if (ctx->input_dtype(0) == DT_RESOURCE) {
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->MatchSignature({DT_RESOURCE}, queue->component_dtypes()),
+ callback);
+ } else {
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()),
+ callback);
+ }
+
+ queue->TryDequeue(ctx, [ctx, callback](const QueueInterface::Tuple& tuple) {
+ if (!ctx->status().ok()) {
+ callback();
+ return;
+ }
+ OpOutputList output_components;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->output_list("components", &output_components), callback);
+ for (int i = 0; i < ctx->num_outputs(); ++i) {
+ output_components.set(i, tuple[i]);
+ }
+ callback();
+ });
+}
+
+DequeueOp::~DequeueOp() = default;
+
+// Defines a DequeueManyOp, the execution of which concatenates the
+// requested number of elements from the given Queue along the 0th
+// dimension, and emits the result as a single tuple of tensors.
+//
+// The op has two inputs:
+// - Input 0: the handle to a queue.
+// - Input 1: the number of elements to dequeue.
+//
+// The op has k outputs, where k is the number of components in the
+// tuples stored in the given Queue, and output i is the ith component
+// of the dequeued tuple.
+DequeueManyOp::DequeueManyOp(OpKernelConstruction* context)
+ : QueueAccessOpKernel(context) {}
+
+void DequeueManyOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) {
+ const Tensor& Tnum_elements = ctx->input(1);
+ int32 num_elements = Tnum_elements.flat<int32>()(0);
+
+ OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
+ errors::InvalidArgument("DequeueManyOp requested ",
+ num_elements, " < 0 elements"),
+ callback);
+
+ if (ctx->input_dtype(0) == DT_RESOURCE) {
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ ctx->MatchSignature({DT_RESOURCE, DT_INT32}, queue->component_dtypes()),
+ callback);
+ } else {
+ OP_REQUIRES_OK_ASYNC(ctx,
+ ctx->MatchSignature({DT_STRING_REF, DT_INT32},
+ queue->component_dtypes()),
+ callback);
+ }
+
+ queue->TryDequeueMany(
+ num_elements, ctx, false /* allow_small_batch */,
+ [ctx, callback](const QueueInterface::Tuple& tuple) {
+ if (!ctx->status().ok()) {
+ callback();
+ return;
+ }
+ OpOutputList output_components;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->output_list("components", &output_components), callback);
+ for (int i = 0; i < ctx->num_outputs(); ++i) {
+ output_components.set(i, tuple[i]);
+ }
+ callback();
+ });
+}
+
+DequeueManyOp::~DequeueManyOp() = default;
+
+// Defines a DequeueUpToOp, the execution of which concatenates the
+// requested number of elements from the given Queue along the 0th
+// dimension, and emits the result as a single tuple of tensors.
+//
+// The difference between this op and DequeueMany is the handling when
+// the Queue is closed. While the DequeueMany op will return if there
+// an error when there are less than num_elements elements left in the
+// closed queue, this op will return between 1 and
+// min(num_elements, elements_remaining_in_queue), and will not block.
+// If there are no elements left, then the standard DequeueMany error
+// is returned.
+//
+// This op only works if the underlying Queue implementation accepts
+// the allow_small_batch = true parameter to TryDequeueMany.
+// If it does not, an errors::Unimplemented exception is returned.
+//
+// The op has two inputs:
+// - Input 0: the handle to a queue.
+// - Input 1: the number of elements to dequeue.
+//
+// The op has k outputs, where k is the number of components in the
+// tuples stored in the given Queue, and output i is the ith component
+// of the dequeued tuple.
+//
+// The op has one attribute: allow_small_batch. If the Queue supports
+// it, setting this to true causes the queue to return smaller
+// (possibly zero length) batches when it is closed, up to however
+// many elements are available when the op executes. In this case,
+// the Queue does not block when closed.
+DequeueUpToOp::DequeueUpToOp(OpKernelConstruction* context)
+ : QueueAccessOpKernel(context) {}
+
+void DequeueUpToOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) {
+ const Tensor& Tnum_elements = ctx->input(1);
+ int32 num_elements = Tnum_elements.flat<int32>()(0);
+
+ OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
+ errors::InvalidArgument("DequeueUpToOp requested ",
+ num_elements, " < 0 elements"),
+ callback);
+
+ if (ctx->input_dtype(0) == DT_RESOURCE) {
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ ctx->MatchSignature({DT_RESOURCE, DT_INT32}, queue->component_dtypes()),
+ callback);
+ } else {
+ OP_REQUIRES_OK_ASYNC(ctx,
+ ctx->MatchSignature({DT_STRING_REF, DT_INT32},
+ queue->component_dtypes()),
+ callback);
+ }
+
+ queue->TryDequeueMany(
+ num_elements, ctx, true /* allow_small_batch */,
+ [ctx, callback](const QueueInterface::Tuple& tuple) {
+ if (!ctx->status().ok()) {
+ callback();
+ return;
+ }
+ OpOutputList output_components;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->output_list("components", &output_components), callback);
+ for (int i = 0; i < ctx->num_outputs(); ++i) {
+ output_components.set(i, tuple[i]);
+ }
+ callback();
+ });
+}
+
+DequeueUpToOp::~DequeueUpToOp() = default;
+
+// Defines a QueueCloseOp, which closes the given Queue. Closing a
+// Queue signals that no more elements will be enqueued in it.
+//
+// The op has one input, which is the handle of the appropriate Queue.
+QueueCloseOp::QueueCloseOp(OpKernelConstruction* context)
+ : QueueOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues",
+ &cancel_pending_enqueues_));
+}
+
+void QueueCloseOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) {
+ queue->Close(ctx, cancel_pending_enqueues_, callback);
+}
+
+// Defines a QueueSizeOp, which computes the number of elements in the
+// given Queue, and emits it as an output tensor.
+//
+// The op has one input, which is the handle of the appropriate Queue;
+// and one output, which is a single-element tensor containing the current
+// size of that Queue.
+QueueSizeOp::QueueSizeOp(OpKernelConstruction* context)
+ : QueueOpKernel(context) {}
+
+void QueueSizeOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) {
+ Tensor* Tqueue_size = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &Tqueue_size));
+ Tqueue_size->flat<int32>().setConstant(queue->size());
+ callback();
+}
+
+QueueIsClosedOp::QueueIsClosedOp(OpKernelConstruction* context)
+ : QueueOpKernel(context) {}
+
+void QueueIsClosedOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) {
+ Tensor* Tqueue_is_closed = nullptr;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, TensorShape({}), &Tqueue_is_closed));
+ Tqueue_is_closed->flat<bool>().setConstant(queue->is_closed());
+ callback();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/queue_op.h b/tensorflow/core/kernels/queue_op.h
index 6c19f9841c..2efd838a5f 100644
--- a/tensorflow/core/kernels/queue_op.h
+++ b/tensorflow/core/kernels/queue_op.h
@@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_QUEUE_OP_H_
-#define TENSORFLOW_KERNELS_QUEUE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_
#include <deque>
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/queue_interface.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
@@ -32,22 +33,9 @@ namespace tensorflow {
// Defines a QueueOp, an abstract class for Queue construction ops.
class QueueOp : public ResourceOpKernel<QueueInterface> {
public:
- QueueOp(OpKernelConstruction* context) : ResourceOpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_));
- if (capacity_ < 0) {
- capacity_ = QueueBase::kUnbounded;
- }
- OP_REQUIRES_OK(context,
- context->GetAttr("component_types", &component_types_));
- }
+ QueueOp(OpKernelConstruction* context);
- void Compute(OpKernelContext* context) override {
- ResourceOpKernel<QueueInterface>::Compute(context);
- mutex_lock l(mu_);
- if (resource_ && context->track_allocations()) {
- context->record_persistent_memory_allocation(resource_->MemoryUsed());
- }
- }
+ void Compute(OpKernelContext* context) override;
protected:
// Variables accessible by subclasses
@@ -55,9 +43,7 @@ class QueueOp : public ResourceOpKernel<QueueInterface> {
DataTypeVector component_types_;
private:
- Status VerifyResource(QueueInterface* queue) override {
- return queue->MatchesNodeDef(def());
- }
+ Status VerifyResource(QueueInterface* queue) override;
};
class TypedQueueOp : public QueueOp {
@@ -75,6 +61,211 @@ class TypedQueueOp : public QueueOp {
}
};
+// Queue manipulator kernels
+
+class QueueOpKernel : public AsyncOpKernel {
+ public:
+ explicit QueueOpKernel(OpKernelConstruction* context);
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final;
+
+ protected:
+ virtual void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) = 0;
+};
+
+class QueueAccessOpKernel : public QueueOpKernel {
+ public:
+ explicit QueueAccessOpKernel(OpKernelConstruction* context);
+
+ protected:
+ int64 timeout_;
+};
+
+// Defines an EnqueueOp, the execution of which enqueues a tuple of
+// tensors in the given Queue.
+//
+// The op has 1 + k inputs, where k is the number of components in the
+// tuples stored in the given Queue:
+// - Input 0: queue handle.
+// - Input 1: 0th element of the tuple.
+// - ...
+// - Input (1+k): kth element of the tuple.
+class EnqueueOp : public QueueAccessOpKernel {
+ public:
+ explicit EnqueueOp(OpKernelConstruction* context);
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(EnqueueOp);
+};
+
+// Defines an EnqueueManyOp, the execution of which slices each
+// component of a tuple of tensors along the 0th dimension, and
+// enqueues tuples of slices in the given Queue.
+//
+// The op has 1 + k inputs, where k is the number of components in the
+// tuples stored in the given Queue:
+// - Input 0: queue handle.
+// - Input 1: 0th element of the tuple.
+// - ...
+// - Input (1+k): kth element of the tuple.
+//
+// N.B. All tuple components must have the same size in the 0th
+// dimension.
+class EnqueueManyOp : public QueueAccessOpKernel {
+ public:
+ explicit EnqueueManyOp(OpKernelConstruction* context);
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override;
+
+ ~EnqueueManyOp() override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(EnqueueManyOp);
+};
+
+// Defines a DequeueOp, the execution of which dequeues a tuple of
+// tensors from the given Queue.
+//
+// The op has one input, which is the handle of the appropriate
+// Queue. The op has k outputs, where k is the number of components in
+// the tuples stored in the given Queue, and output i is the ith
+// component of the dequeued tuple.
+class DequeueOp : public QueueAccessOpKernel {
+ public:
+ explicit DequeueOp(OpKernelConstruction* context);
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override;
+
+ ~DequeueOp() override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(DequeueOp);
+};
+
+// Defines a DequeueManyOp, the execution of which concatenates the
+// requested number of elements from the given Queue along the 0th
+// dimension, and emits the result as a single tuple of tensors.
+//
+// The op has two inputs:
+// - Input 0: the handle to a queue.
+// - Input 1: the number of elements to dequeue.
+//
+// The op has k outputs, where k is the number of components in the
+// tuples stored in the given Queue, and output i is the ith component
+// of the dequeued tuple.
+class DequeueManyOp : public QueueAccessOpKernel {
+ public:
+ explicit DequeueManyOp(OpKernelConstruction* context);
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override;
+
+ ~DequeueManyOp() override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(DequeueManyOp);
+};
+
+// Defines a DequeueUpToOp, the execution of which concatenates the
+// requested number of elements from the given Queue along the 0th
+// dimension, and emits the result as a single tuple of tensors.
+//
+// The difference between this op and DequeueMany is the handling when
+// the Queue is closed. While the DequeueMany op will return if there
+// an error when there are less than num_elements elements left in the
+// closed queue, this op will return between 1 and
+// min(num_elements, elements_remaining_in_queue), and will not block.
+// If there are no elements left, then the standard DequeueMany error
+// is returned.
+//
+// This op only works if the underlying Queue implementation accepts
+// the allow_small_batch = true parameter to TryDequeueMany.
+// If it does not, an errors::Unimplemented exception is returned.
+//
+// The op has two inputs:
+// - Input 0: the handle to a queue.
+// - Input 1: the number of elements to dequeue.
+//
+// The op has k outputs, where k is the number of components in the
+// tuples stored in the given Queue, and output i is the ith component
+// of the dequeued tuple.
+//
+// The op has one attribute: allow_small_batch. If the Queue supports
+// it, setting this to true causes the queue to return smaller
+// (possibly zero length) batches when it is closed, up to however
+// many elements are available when the op executes. In this case,
+// the Queue does not block when closed.
+class DequeueUpToOp : public QueueAccessOpKernel {
+ public:
+ explicit DequeueUpToOp(OpKernelConstruction* context);
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override;
+
+ ~DequeueUpToOp() override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(DequeueUpToOp);
+};
+
+// Defines a QueueCloseOp, which closes the given Queue. Closing a
+// Queue signals that no more elements will be enqueued in it.
+//
+// The op has one input, which is the handle of the appropriate Queue.
+class QueueCloseOp : public QueueOpKernel {
+ public:
+ explicit QueueCloseOp(OpKernelConstruction* context);
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override;
+
+ private:
+ bool cancel_pending_enqueues_;
+ TF_DISALLOW_COPY_AND_ASSIGN(QueueCloseOp);
+};
+
+// Defines a QueueSizeOp, which computes the number of elements in the
+// given Queue, and emits it as an output tensor.
+//
+// The op has one input, which is the handle of the appropriate Queue;
+// and one output, which is a single-element tensor containing the current
+// size of that Queue.
+class QueueSizeOp : public QueueOpKernel {
+ public:
+ explicit QueueSizeOp(OpKernelConstruction* context);
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(QueueSizeOp);
+};
+
+class QueueIsClosedOp : public QueueOpKernel {
+ public:
+ explicit QueueIsClosedOp(OpKernelConstruction* context);
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(QueueIsClosedOp);
+};
+
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_QUEUE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_
diff --git a/tensorflow/core/kernels/queue_ops.cc b/tensorflow/core/kernels/queue_ops.cc
index 46a02854d7..c4d404259b 100644
--- a/tensorflow/core/kernels/queue_ops.cc
+++ b/tensorflow/core/kernels/queue_ops.cc
@@ -13,437 +13,44 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// See docs in ../ops/data_flow_ops.cc.
-
#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/queue_interface.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/queue_op.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
-class QueueOpKernel : public AsyncOpKernel {
- public:
- explicit QueueOpKernel(OpKernelConstruction* context)
- : AsyncOpKernel(context) {}
-
- void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final {
- QueueInterface* queue;
- if (ctx->input_dtype(0) == DT_RESOURCE) {
- OP_REQUIRES_OK_ASYNC(
- ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &queue), callback);
- } else {
- OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue),
- callback);
- }
- ComputeAsync(ctx, queue, [callback, queue]() {
- queue->Unref();
- callback();
- });
- }
-
- protected:
- virtual void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) = 0;
-};
-
-class QueueAccessOpKernel : public QueueOpKernel {
- public:
- explicit QueueAccessOpKernel(OpKernelConstruction* context)
- : QueueOpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_));
- // TODO(keveman): Enable timeout.
- OP_REQUIRES(context, timeout_ == -1,
- errors::InvalidArgument("Timeout not supported yet."));
- }
-
- protected:
- int64 timeout_;
-};
-
-// Defines an EnqueueOp, the execution of which enqueues a tuple of
-// tensors in the given Queue.
-//
-// The op has 1 + k inputs, where k is the number of components in the
-// tuples stored in the given Queue:
-// - Input 0: queue handle.
-// - Input 1: 0th element of the tuple.
-// - ...
-// - Input (1+k): kth element of the tuple.
-class EnqueueOp : public QueueAccessOpKernel {
- public:
- explicit EnqueueOp(OpKernelConstruction* context)
- : QueueAccessOpKernel(context) {}
-
- protected:
- void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) override {
- DataTypeVector expected_inputs;
- if (ctx->input_dtype(0) == DT_RESOURCE) {
- expected_inputs.push_back(DT_RESOURCE);
- } else {
- expected_inputs.push_back(DT_STRING_REF);
- }
- for (DataType dt : queue->component_dtypes()) {
- expected_inputs.push_back(dt);
- }
- OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}),
- callback);
-
- QueueInterface::Tuple tuple;
- OpInputList components;
- OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
- callback);
- for (const Tensor& Tcomponent : components) {
- tuple.push_back(Tcomponent);
- }
-
- OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateTuple(tuple), callback);
- queue->TryEnqueue(tuple, ctx, callback);
- }
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(EnqueueOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("QueueEnqueue").Device(DEVICE_CPU), EnqueueOp);
REGISTER_KERNEL_BUILDER(Name("QueueEnqueueV2").Device(DEVICE_CPU), EnqueueOp);
-// Defines an EnqueueManyOp, the execution of which slices each
-// component of a tuple of tensors along the 0th dimension, and
-// enqueues tuples of slices in the given Queue.
-//
-// The op has 1 + k inputs, where k is the number of components in the
-// tuples stored in the given Queue:
-// - Input 0: queue handle.
-// - Input 1: 0th element of the tuple.
-// - ...
-// - Input (1+k): kth element of the tuple.
-//
-// N.B. All tuple components must have the same size in the 0th
-// dimension.
-class EnqueueManyOp : public QueueAccessOpKernel {
- public:
- explicit EnqueueManyOp(OpKernelConstruction* context)
- : QueueAccessOpKernel(context) {}
-
- protected:
- void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) override {
- DataTypeVector expected_inputs;
- if (ctx->input_dtype(0) == DT_RESOURCE) {
- expected_inputs.push_back(DT_RESOURCE);
- } else {
- expected_inputs.push_back(DT_STRING_REF);
- }
- for (DataType dt : queue->component_dtypes()) {
- expected_inputs.push_back(dt);
- }
- OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}),
- callback);
-
- QueueInterface::Tuple tuple;
- OpInputList components;
- OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
- callback);
- for (const Tensor& Tcomponent : components) {
- tuple.push_back(Tcomponent);
- }
-
- OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateManyTuple(tuple), callback);
- queue->TryEnqueueMany(tuple, ctx, callback);
- }
-
- ~EnqueueManyOp() override {}
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(EnqueueManyOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("QueueEnqueueMany").Device(DEVICE_CPU),
EnqueueManyOp);
REGISTER_KERNEL_BUILDER(Name("QueueEnqueueManyV2").Device(DEVICE_CPU),
EnqueueManyOp);
-// Defines a DequeueOp, the execution of which dequeues a tuple of
-// tensors from the given Queue.
-//
-// The op has one input, which is the handle of the appropriate
-// Queue. The op has k outputs, where k is the number of components in
-// the tuples stored in the given Queue, and output i is the ith
-// component of the dequeued tuple.
-class DequeueOp : public QueueAccessOpKernel {
- public:
- explicit DequeueOp(OpKernelConstruction* context)
- : QueueAccessOpKernel(context) {}
-
- protected:
- void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) override {
- if (ctx->input_dtype(0) == DT_RESOURCE) {
- OP_REQUIRES_OK_ASYNC(
- ctx, ctx->MatchSignature({DT_RESOURCE}, queue->component_dtypes()),
- callback);
- } else {
- OP_REQUIRES_OK_ASYNC(
- ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()),
- callback);
- }
-
- queue->TryDequeue(ctx, [ctx, callback](const QueueInterface::Tuple& tuple) {
- if (!ctx->status().ok()) {
- callback();
- return;
- }
- OpOutputList output_components;
- OP_REQUIRES_OK_ASYNC(
- ctx, ctx->output_list("components", &output_components), callback);
- for (int i = 0; i < ctx->num_outputs(); ++i) {
- output_components.set(i, tuple[i]);
- }
- callback();
- });
- }
-
- ~DequeueOp() override {}
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(DequeueOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("QueueDequeue").Device(DEVICE_CPU), DequeueOp);
REGISTER_KERNEL_BUILDER(Name("QueueDequeueV2").Device(DEVICE_CPU), DequeueOp);
-// Defines a DequeueManyOp, the execution of which concatenates the
-// requested number of elements from the given Queue along the 0th
-// dimension, and emits the result as a single tuple of tensors.
-//
-// The op has two inputs:
-// - Input 0: the handle to a queue.
-// - Input 1: the number of elements to dequeue.
-//
-// The op has k outputs, where k is the number of components in the
-// tuples stored in the given Queue, and output i is the ith component
-// of the dequeued tuple.
-class DequeueManyOp : public QueueAccessOpKernel {
- public:
- explicit DequeueManyOp(OpKernelConstruction* context)
- : QueueAccessOpKernel(context) {}
-
- protected:
- void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) override {
- const Tensor& Tnum_elements = ctx->input(1);
- int32 num_elements = Tnum_elements.flat<int32>()(0);
-
- OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
- errors::InvalidArgument("DequeueManyOp requested ",
- num_elements, " < 0 elements"),
- callback);
-
- if (ctx->input_dtype(0) == DT_RESOURCE) {
- OP_REQUIRES_OK_ASYNC(ctx,
- ctx->MatchSignature({DT_RESOURCE, DT_INT32},
- queue->component_dtypes()),
- callback);
- } else {
- OP_REQUIRES_OK_ASYNC(ctx,
- ctx->MatchSignature({DT_STRING_REF, DT_INT32},
- queue->component_dtypes()),
- callback);
- }
-
- queue->TryDequeueMany(
- num_elements, ctx, false /* allow_small_batch */,
- [ctx, callback](const QueueInterface::Tuple& tuple) {
- if (!ctx->status().ok()) {
- callback();
- return;
- }
- OpOutputList output_components;
- OP_REQUIRES_OK_ASYNC(
- ctx, ctx->output_list("components", &output_components),
- callback);
- for (int i = 0; i < ctx->num_outputs(); ++i) {
- output_components.set(i, tuple[i]);
- }
- callback();
- });
- }
-
- ~DequeueManyOp() override {}
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(DequeueManyOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("QueueDequeueMany").Device(DEVICE_CPU),
DequeueManyOp);
REGISTER_KERNEL_BUILDER(Name("QueueDequeueManyV2").Device(DEVICE_CPU),
DequeueManyOp);
-// Defines a DequeueUpToOp, the execution of which concatenates the
-// requested number of elements from the given Queue along the 0th
-// dimension, and emits the result as a single tuple of tensors.
-//
-// The difference between this op and DequeueMany is the handling when
-// the Queue is closed. While the DequeueMany op will return if there
-// an error when there are less than num_elements elements left in the
-// closed queue, this op will return between 1 and
-// min(num_elements, elements_remaining_in_queue), and will not block.
-// If there are no elements left, then the standard DequeueMany error
-// is returned.
-//
-// This op only works if the underlying Queue implementation accepts
-// the allow_small_batch = true parameter to TryDequeueMany.
-// If it does not, an errors::Unimplemented exception is returned.
-//
-// The op has two inputs:
-// - Input 0: the handle to a queue.
-// - Input 1: the number of elements to dequeue.
-//
-// The op has k outputs, where k is the number of components in the
-// tuples stored in the given Queue, and output i is the ith component
-// of the dequeued tuple.
-//
-// The op has one attribute: allow_small_batch. If the Queue supports
-// it, setting this to true causes the queue to return smaller
-// (possibly zero length) batches when it is closed, up to however
-// many elements are available when the op executes. In this case,
-// the Queue does not block when closed.
-class DequeueUpToOp : public QueueAccessOpKernel {
- public:
- explicit DequeueUpToOp(OpKernelConstruction* context)
- : QueueAccessOpKernel(context) {}
-
- protected:
- void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) override {
- const Tensor& Tnum_elements = ctx->input(1);
- int32 num_elements = Tnum_elements.flat<int32>()(0);
-
- OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
- errors::InvalidArgument("DequeueUpToOp requested ",
- num_elements, " < 0 elements"),
- callback);
-
- if (ctx->input_dtype(0) == DT_RESOURCE) {
- OP_REQUIRES_OK_ASYNC(ctx,
- ctx->MatchSignature({DT_RESOURCE, DT_INT32},
- queue->component_dtypes()),
- callback);
- } else {
- OP_REQUIRES_OK_ASYNC(ctx,
- ctx->MatchSignature({DT_STRING_REF, DT_INT32},
- queue->component_dtypes()),
- callback);
- }
-
- queue->TryDequeueMany(
- num_elements, ctx, true /* allow_small_batch */,
- [ctx, callback](const QueueInterface::Tuple& tuple) {
- if (!ctx->status().ok()) {
- callback();
- return;
- }
- OpOutputList output_components;
- OP_REQUIRES_OK_ASYNC(
- ctx, ctx->output_list("components", &output_components),
- callback);
- for (int i = 0; i < ctx->num_outputs(); ++i) {
- output_components.set(i, tuple[i]);
- }
- callback();
- });
- }
-
- ~DequeueUpToOp() override {}
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(DequeueUpToOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpTo").Device(DEVICE_CPU),
DequeueUpToOp);
REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpToV2").Device(DEVICE_CPU),
DequeueUpToOp);
-// Defines a QueueCloseOp, which closes the given Queue. Closing a
-// Queue signals that no more elements will be enqueued in it.
-//
-// The op has one input, which is the handle of the appropriate Queue.
-class QueueCloseOp : public QueueOpKernel {
- public:
- explicit QueueCloseOp(OpKernelConstruction* context)
- : QueueOpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues",
- &cancel_pending_enqueues_));
- }
-
- protected:
- void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) override {
- queue->Close(ctx, cancel_pending_enqueues_, callback);
- }
-
- private:
- bool cancel_pending_enqueues_;
- TF_DISALLOW_COPY_AND_ASSIGN(QueueCloseOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("QueueClose").Device(DEVICE_CPU), QueueCloseOp);
REGISTER_KERNEL_BUILDER(Name("QueueCloseV2").Device(DEVICE_CPU), QueueCloseOp);
-// Defines a QueueSizeOp, which computes the number of elements in the
-// given Queue, and emits it as an output tensor.
-//
-// The op has one input, which is the handle of the appropriate Queue;
-// and one output, which is a single-element tensor containing the current
-// size of that Queue.
-class QueueSizeOp : public QueueOpKernel {
- public:
- explicit QueueSizeOp(OpKernelConstruction* context)
- : QueueOpKernel(context) {}
-
- protected:
- void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) override {
- Tensor* Tqueue_size = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &Tqueue_size));
- Tqueue_size->flat<int32>().setConstant(queue->size());
- callback();
- }
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(QueueSizeOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("QueueSize").Device(DEVICE_CPU), QueueSizeOp);
REGISTER_KERNEL_BUILDER(Name("QueueSizeV2").Device(DEVICE_CPU), QueueSizeOp);
-class QueueIsClosedOp : public QueueOpKernel {
- public:
- explicit QueueIsClosedOp(OpKernelConstruction* context)
- : QueueOpKernel(context) {}
-
- protected:
- void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) override {
- Tensor* Tqueue_is_closed = nullptr;
- OP_REQUIRES_OK(ctx,
- ctx->allocate_output(0, TensorShape({}), &Tqueue_is_closed));
- Tqueue_is_closed->flat<bool>().setConstant(queue->is_closed());
- callback();
- }
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(QueueIsClosedOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("QueueIsClosed").Device(DEVICE_CPU),
QueueIsClosedOp);
REGISTER_KERNEL_BUILDER(Name("QueueIsClosedV2").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h
index 15004ae4df..2da83a0288 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.h
+++ b/tensorflow/core/kernels/segment_reduction_ops.h
@@ -24,6 +24,12 @@ limitations under the License.
// non-GPU targets. This only breaks in clang, because it's more strict for
// template code and CudaAtomicMax is used in template context.
+// This file requires the following include because it uses CudaAtomicMax:
+// #include "tensorflow/core/util/cuda_kernel_helper.h"
+
+// Unfortunately we can't add the #include, since it breaks compilation for
+// non-GPU targets. This only breaks in clang, because it's more strict for
+// template code and CudaAtomicMax is used in template context.
// This file requires the following include because it uses CudaAtomicMax:
// #include "tensorflow/core/util/cuda_kernel_helper.h"
diff --git a/tensorflow/core/kernels/serialize_sparse_op.cc b/tensorflow/core/kernels/serialize_sparse_op.cc
index 4ad653601a..4fea57e6b7 100644
--- a/tensorflow/core/kernels/serialize_sparse_op.cc
+++ b/tensorflow/core/kernels/serialize_sparse_op.cc
@@ -559,16 +559,4 @@ REGISTER_KERNEL_BUILDER(Name("DeserializeSparse")
REGISTER_KERNEL_BUILDER(Name("DeserializeManySparse").Device(DEVICE_CPU),
DeserializeSparseOp<string>)
-template <>
-Status DeserializeSparseOp<Variant>::Deserialize(const Variant& serialized,
- Tensor* result) {
- *result = *serialized.get<Tensor>();
- return Status::OK();
-}
-
-REGISTER_KERNEL_BUILDER(Name("DeserializeSparse")
- .Device(DEVICE_CPU)
- .TypeConstraint<Variant>("Tserialized"),
- DeserializeSparseOp<Variant>)
-
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc
index 37803ec775..5aa5d20b1a 100644
--- a/tensorflow/core/kernels/tensor_array_ops.cc
+++ b/tensorflow/core/kernels/tensor_array_ops.cc
@@ -735,6 +735,7 @@ class TensorArrayPackOrGatherOp : public OpKernel {
TensorArrayPackOrGatherOp<CPUDevice, type, false /* LEGACY_PACK */>);
TF_CALL_POD_STRING_TYPES(REGISTER_GATHER_AND_PACK);
+TF_CALL_variant(REGISTER_GATHER_AND_PACK);
REGISTER_GATHER_AND_PACK(quint8);
REGISTER_GATHER_AND_PACK(qint8);
REGISTER_GATHER_AND_PACK(qint32);
diff --git a/tensorflow/core/kernels/variable_ops.cc b/tensorflow/core/kernels/variable_ops.cc
index 7fd5809ca4..eadea18f76 100644
--- a/tensorflow/core/kernels/variable_ops.cc
+++ b/tensorflow/core/kernels/variable_ops.cc
@@ -73,9 +73,6 @@ void VariableOp::Compute(OpKernelContext* ctx) {
// here is valid because it owns a ref on var.
ctx->set_output_ref(0, var->mu(), var->tensor());
if (ctx->track_allocations() && var->tensor()->IsInitialized()) {
- AllocatorAttributes attr;
- attr.set_gpu_compatible(true);
- attr.set_nic_compatible(true);
ctx->record_persistent_memory_allocation(var->tensor()->AllocatedBytes());
}
var->Unref();
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 96944f27cd..b5e42f5384 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -1851,7 +1851,7 @@ class MklPrimitiveFactory {
}
private:
- static inline std::unordered_map<std::string, MklPrimitive*> &GetHashMap() {
+ static inline std::unordered_map<std::string, MklPrimitive*>& GetHashMap() {
static thread_local std::unordered_map<std::string, MklPrimitive*> map_;
return map_;
}
diff --git a/tensorflow/docs_src/api_guides/python/spectral_ops.md b/tensorflow/docs_src/api_guides/python/spectral_ops.md
index 022c471ef1..dd13802f00 100644
--- a/tensorflow/docs_src/api_guides/python/spectral_ops.md
+++ b/tensorflow/docs_src/api_guides/python/spectral_ops.md
@@ -23,3 +23,4 @@ that you can use to transform Tensors of real and complex signals.
## Discrete Cosine Transforms
* @{tf.spectral.dct}
+* @{tf.spectral.idct}
diff --git a/tensorflow/docs_src/get_started/index.md b/tensorflow/docs_src/get_started/index.md
new file mode 100644
index 0000000000..bd2a80d9ef
--- /dev/null
+++ b/tensorflow/docs_src/get_started/index.md
@@ -0,0 +1,29 @@
+# Get Started
+
+If you are new to machine learning, we recommend taking the following online
+course prior to diving into TensorFlow documentation:
+
+ * [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/),
+ which introduces machine learning concepts and encourages experimentation
+ with existing TensorFlow code.
+
+TensorFlow is a tool for machine learning. While it contains a wide range of
+functionality, TensorFlow is mainly designed for deep neural network models.
+
+The easiest way to get started with TensorFlow is by using Eager Execution.
+
+ * @{$get_started/eager}, is for anyone new to machine learning or TensorFlow.
+
+TensorFlow provides many APIs. The remainder of this section focuses on the
+Estimator API which provide scalable, high-performance models. See the
+@{$estimators} guide.
+
+For more advanced users:
+
+ * The @{$low_level_intro$Low Level Introduction} demonstrates how to use
+ TensorFlow outside of the Estimator framework, for debugging and
+ experimentation.
+ * The @{$guide$Programmer's Guide} details major
+ TensorFlow components.
+ * The @{$tutorials$Tutorials} provide walkthroughs of a variety of
+ TensorFlow models.
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
index ce43d09b63..4c4f3f3934 100644
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -2010,13 +2010,35 @@ Slice(b, {2, 1}, {4, 3}) produces:
See also
[`XlaBuilder::Sort`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
-Sorts the elements in the operand.
+There are two versions of the Sort instruction: a single-operand and a
+two-operand version.
<b>`Sort(operand)`</b>
Arguments | Type | Semantics
+--------- | ------- | --------------------
+`operand` | `XlaOp` | The operand to sort.
+
+Sorts the elements in the operand in ascending order. The operand must be rank-1.
+If the operand's elements have floating point type, and the operand contains
+NaN elements, the order of elements in the output is implementation-defined.
+
+<b>`Sort(key, value)`</b>
+
+Sorts both the key and the value operands. The keys are sorted as in the
+single-operand version. The values are sorted according to the order of their
+corresponding keys. For example, if the inputs are `keys = [3, 1]` and
+`values = [42, 50]`, then the output of the sort is the tuple `{[1, 3], [50, 42]}`.
+The sort is not guaranteed to be stable, that is, if the keys array contains
+duplicates, the order of their corresponding values may not be preserved.
+
+Arguments | Type | Semantics
--------- | ------- | -------------------
-`operand` | `XlaOp` | The operand to sort
+`keys` | `XlaOp` | The sort keys.
+`values` | `XlaOp` | The values to sort.
+
+The `keys` and `values` operand must both be rank-1, and must have the same
+dimensions, but may have different element types.
## Transpose
diff --git a/tensorflow/go/attrs_test.go b/tensorflow/go/attrs_test.go
index 35b0cb352e..ea8af221ae 100644
--- a/tensorflow/go/attrs_test.go
+++ b/tensorflow/go/attrs_test.go
@@ -28,7 +28,7 @@ func TestOperationAttrs(t *testing.T) {
i := 0
makeConst := func(v interface{}) Output {
op, err := Const(g, fmt.Sprintf("const/%d/%+v", i, v), v)
- i += 1
+ i++
if err != nil {
t.Fatal(err)
}
@@ -71,6 +71,7 @@ func TestOperationAttrs(t *testing.T) {
"boundaries": []float32(nil),
},
},
+ /* TODO(ashankar): debug this issue and add it back later.
{
Name: "list(type),list(shape)",
Type: "InfeedEnqueueTuple",
@@ -111,6 +112,7 @@ func TestOperationAttrs(t *testing.T) {
"device_ordinal": int64(0),
},
},
+ */
{
Name: "list(int),int",
Type: "StringToHashBucketStrong",
diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc
index 2df69ee299..d5bd99bdd9 100644
--- a/tensorflow/java/src/gen/cc/op_generator.cc
+++ b/tensorflow/java/src/gen/cc/op_generator.cc
@@ -36,20 +36,21 @@ namespace java {
namespace {
constexpr const char kLicense[] =
- "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n"
- "\n"
- "Licensed under the Apache License, Version 2.0 (the \"License\");\n"
- "you may not use this file except in compliance with the License.\n"
- "You may obtain a copy of the License at\n"
- "\n"
- " http://www.apache.org/licenses/LICENSE-2.0\n"
- "\n"
- "Unless required by applicable law or agreed to in writing, software\n"
- "distributed under the License is distributed on an \"AS IS\" BASIS,\n"
- "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
- "See the License for the specific language governing permissions and\n"
- "limitations under the License.\n"
- "=======================================================================*/\n";
+ "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n"
+ "\n"
+ "Licensed under the Apache License, Version 2.0 (the \"License\");\n"
+ "you may not use this file except in compliance with the License.\n"
+ "You may obtain a copy of the License at\n"
+ "\n"
+ " http://www.apache.org/licenses/LICENSE-2.0\n"
+ "\n"
+ "Unless required by applicable law or agreed to in writing, software\n"
+ "distributed under the License is distributed on an \"AS IS\" BASIS,\n"
+ "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
+ "See the License for the specific language governing permissions and\n"
+ "limitations under the License.\n"
+ "=======================================================================*/"
+ "\n";
// There is three different modes to render an op class, depending on the
// number and type of outputs it has:
diff --git a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java
index 3524160d87..796d6a62dc 100644
--- a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java
+++ b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java
@@ -15,6 +15,18 @@ limitations under the License.
package org.tensorflow.processor;
+import com.google.common.base.CaseFormat;
+import com.google.common.base.Strings;
+import com.google.common.collect.HashMultimap;
+import com.google.common.collect.Multimap;
+import com.squareup.javapoet.ClassName;
+import com.squareup.javapoet.FieldSpec;
+import com.squareup.javapoet.JavaFile;
+import com.squareup.javapoet.MethodSpec;
+import com.squareup.javapoet.ParameterSpec;
+import com.squareup.javapoet.TypeName;
+import com.squareup.javapoet.TypeSpec;
+import com.squareup.javapoet.TypeVariableName;
import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
@@ -23,7 +35,6 @@ import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
-
import javax.annotation.processing.AbstractProcessor;
import javax.annotation.processing.Filer;
import javax.annotation.processing.Messager;
@@ -44,19 +55,6 @@ import javax.lang.model.util.ElementFilter;
import javax.lang.model.util.Elements;
import javax.tools.Diagnostic.Kind;
-import com.google.common.base.CaseFormat;
-import com.google.common.base.Strings;
-import com.google.common.collect.HashMultimap;
-import com.google.common.collect.Multimap;
-import com.squareup.javapoet.ClassName;
-import com.squareup.javapoet.FieldSpec;
-import com.squareup.javapoet.JavaFile;
-import com.squareup.javapoet.MethodSpec;
-import com.squareup.javapoet.ParameterSpec;
-import com.squareup.javapoet.TypeName;
-import com.squareup.javapoet.TypeSpec;
-import com.squareup.javapoet.TypeVariableName;
-
/**
* A compile-time Processor that aggregates classes annotated with {@link
* org.tensorflow.op.annotation.Operator} and generates the {@code Ops} convenience API. Please
@@ -115,10 +113,12 @@ public final class OperatorProcessor extends AbstractProcessor {
// generated our code, flag the location of each such class.
if (hasRun) {
for (Element e : annotated) {
- error(e, "The Operator processor has already processed @Operator annotated sources\n" +
- "and written out an Ops API. It cannot process additional @Operator sources.\n" +
- "One reason this can happen is if other annotation processors generate\n" +
- "new @Operator source files.");
+ error(
+ e,
+ "The Operator processor has already processed @Operator annotated sources\n"
+ + "and written out an Ops API. It cannot process additional @Operator sources.\n"
+ + "One reason this can happen is if other annotation processors generate\n"
+ + "new @Operator source files.");
}
return true;
}
@@ -146,9 +146,11 @@ public final class OperatorProcessor extends AbstractProcessor {
return Collections.singleton("org.tensorflow.op.annotation.Operator");
}
- private static final Pattern JAVADOC_TAG_PATTERN = Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*");
+ private static final Pattern JAVADOC_TAG_PATTERN =
+ Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*");
private static final TypeName T_OPS = ClassName.get("org.tensorflow.op", "Ops");
- private static final TypeName T_OPERATOR = ClassName.get("org.tensorflow.op.annotation", "Operator");
+ private static final TypeName T_OPERATOR =
+ ClassName.get("org.tensorflow.op.annotation", "Operator");
private static final TypeName T_SCOPE = ClassName.get("org.tensorflow.op", "Scope");
private static final TypeName T_GRAPH = ClassName.get("org.tensorflow", "Graph");
private static final TypeName T_STRING = ClassName.get(String.class);
@@ -167,20 +169,17 @@ public final class OperatorProcessor extends AbstractProcessor {
private void write(TypeSpec spec) {
try {
- JavaFile.builder("org.tensorflow.op", spec)
- .skipJavaLangImports(true)
- .build()
- .writeTo(filer);
+ JavaFile.builder("org.tensorflow.op", spec).skipJavaLangImports(true).build().writeTo(filer);
} catch (IOException e) {
throw new AssertionError(e);
}
}
private void writeApi(Multimap<String, MethodSpec> groupedMethods) {
- Map<String, ClassName> groups = new HashMap<String, ClassName>();
-
+ Map<String, ClassName> groups = new HashMap<>();
+
// Generate a API class for each group collected other than the default one (= empty string)
- for (Map.Entry<String, Collection<MethodSpec>> entry: groupedMethods.asMap().entrySet()) {
+ for (Map.Entry<String, Collection<MethodSpec>> entry : groupedMethods.asMap().entrySet()) {
if (!entry.getKey().isEmpty()) {
TypeSpec groupClass = buildGroupClass(entry.getKey(), entry.getValue());
write(groupClass);
@@ -193,12 +192,17 @@ public final class OperatorProcessor extends AbstractProcessor {
}
private boolean collectOpsMethods(
- RoundEnvironment roundEnv, Multimap<String, MethodSpec> groupedMethods, TypeElement annotation) {
+ RoundEnvironment roundEnv,
+ Multimap<String, MethodSpec> groupedMethods,
+ TypeElement annotation) {
boolean result = true;
for (Element e : roundEnv.getElementsAnnotatedWith(annotation)) {
// @Operator can only apply to types, so e must be a TypeElement.
if (!(e instanceof TypeElement)) {
- error(e, "@Operator can only be applied to classes, but this is a %s", e.getKind().toString());
+ error(
+ e,
+ "@Operator can only be applied to classes, but this is a %s",
+ e.getKind().toString());
result = false;
continue;
}
@@ -210,38 +214,42 @@ public final class OperatorProcessor extends AbstractProcessor {
}
return result;
}
-
- private void collectOpMethods(Multimap<String, MethodSpec> groupedMethods, TypeElement opClass, TypeElement annotation) {
+
+ private void collectOpMethods(
+ Multimap<String, MethodSpec> groupedMethods, TypeElement opClass, TypeElement annotation) {
AnnotationMirror am = getAnnotationMirror(opClass, annotation);
String groupName = getAnnotationElementValueAsString("group", am);
String methodName = getAnnotationElementValueAsString("name", am);
ClassName opClassName = ClassName.get(opClass);
if (Strings.isNullOrEmpty(methodName)) {
- methodName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, opClassName.simpleName());
+ methodName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, opClassName.simpleName());
}
- // Build a method for each @Operator found in the class path. There should be one method per operation factory called
+ // Build a method for each @Operator found in the class path. There should be one method per
+ // operation factory called
// "create", which takes in parameter a scope and, optionally, a list of arguments
for (ExecutableElement opMethod : ElementFilter.methodsIn(opClass.getEnclosedElements())) {
- if (opMethod.getModifiers().contains(Modifier.STATIC) && opMethod.getSimpleName().contentEquals("create")) {
+ if (opMethod.getModifiers().contains(Modifier.STATIC)
+ && opMethod.getSimpleName().contentEquals("create")) {
MethodSpec method = buildOpMethod(methodName, opClassName, opMethod);
groupedMethods.put(groupName, method);
}
}
}
- private MethodSpec buildOpMethod(String methodName, ClassName opClassName, ExecutableElement factoryMethod) {
+ private MethodSpec buildOpMethod(
+ String methodName, ClassName opClassName, ExecutableElement factoryMethod) {
MethodSpec.Builder builder =
MethodSpec.methodBuilder(methodName)
- .addModifiers(Modifier.PUBLIC)
- .returns(TypeName.get(factoryMethod.getReturnType()))
- .varargs(factoryMethod.isVarArgs())
- .addJavadoc("$L", buildOpMethodJavadoc(opClassName, factoryMethod));
+ .addModifiers(Modifier.PUBLIC)
+ .returns(TypeName.get(factoryMethod.getReturnType()))
+ .varargs(factoryMethod.isVarArgs())
+ .addJavadoc("$L", buildOpMethodJavadoc(opClassName, factoryMethod));
- for (TypeParameterElement tp: factoryMethod.getTypeParameters()) {
+ for (TypeParameterElement tp : factoryMethod.getTypeParameters()) {
TypeVariableName tvn = TypeVariableName.get((TypeVariable) tp.asType());
builder.addTypeVariable(tvn);
}
- for (TypeMirror thrownType: factoryMethod.getThrownTypes()) {
+ for (TypeMirror thrownType : factoryMethod.getThrownTypes()) {
builder.addException(TypeName.get(thrownType));
}
StringBuilder call = new StringBuilder("return $T.create(scope");
@@ -259,13 +267,17 @@ public final class OperatorProcessor extends AbstractProcessor {
call.append(")");
builder.addStatement(call.toString(), opClassName);
return builder.build();
- }
-
+ }
+
private String buildOpMethodJavadoc(ClassName opClassName, ExecutableElement factoryMethod) {
StringBuilder javadoc = new StringBuilder();
- javadoc.append("Adds an {@link ").append(opClassName.simpleName()).append("} operation to the graph\n\n");
+ javadoc
+ .append("Adds an {@link ")
+ .append(opClassName.simpleName())
+ .append("} operation to the graph\n\n");
- // Add all javadoc tags found in the operator factory method but the first one, which should be in all cases the
+ // Add all javadoc tags found in the operator factory method but the first one, which should be
+ // in all cases the
// 'scope' parameter that is implicitly passed by this API
Matcher tagMatcher = JAVADOC_TAG_PATTERN.matcher(elements.getDocComment(factoryMethod));
boolean firstParam = true;
@@ -277,136 +289,144 @@ public final class OperatorProcessor extends AbstractProcessor {
} else {
javadoc.append(tag).append('\n');
}
- }
+ }
javadoc.append("@see {@link ").append(opClassName).append("}\n");
return javadoc.toString();
}
-
+
private static TypeSpec buildGroupClass(String group, Collection<MethodSpec> methods) {
MethodSpec.Builder ctorBuilder =
MethodSpec.constructorBuilder()
- .addParameter(T_SCOPE, "scope")
- .addStatement("this.scope = scope");
-
+ .addParameter(T_SCOPE, "scope")
+ .addStatement("this.scope = scope");
+
TypeSpec.Builder builder =
TypeSpec.classBuilder(CaseFormat.LOWER_CAMEL.to(CaseFormat.UPPER_CAMEL, group) + "Ops")
- .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
- .addJavadoc("An API for adding {@code $L} operations to a {@link $T Graph}\n\n" +
- "@see {@link $T}\n", group, T_GRAPH, T_OPS)
- .addMethods(methods)
- .addMethod(ctorBuilder.build());
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .addJavadoc(
+ "An API for adding {@code $L} operations to a {@link $T Graph}\n\n"
+ + "@see {@link $T}\n",
+ group,
+ T_GRAPH,
+ T_OPS)
+ .addMethods(methods)
+ .addMethod(ctorBuilder.build());
builder.addField(
- FieldSpec.builder(T_SCOPE, "scope")
- .addModifiers(Modifier.PRIVATE, Modifier.FINAL)
- .build());
+ FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build());
return builder.build();
}
- private static TypeSpec buildTopClass(Map<String, ClassName> groupToClass, Collection<MethodSpec> methods) {
+ private static TypeSpec buildTopClass(
+ Map<String, ClassName> groupToClass, Collection<MethodSpec> methods) {
MethodSpec.Builder ctorBuilder =
MethodSpec.constructorBuilder()
- .addModifiers(Modifier.PRIVATE)
- .addParameter(T_SCOPE, "scope")
- .addStatement("this.scope = scope", T_SCOPE);
+ .addModifiers(Modifier.PRIVATE)
+ .addParameter(T_SCOPE, "scope")
+ .addStatement("this.scope = scope", T_SCOPE);
- for (Map.Entry<String, ClassName> entry: groupToClass.entrySet()) {
+ for (Map.Entry<String, ClassName> entry : groupToClass.entrySet()) {
ctorBuilder.addStatement("$L = new $T(scope)", entry.getKey(), entry.getValue());
}
TypeSpec.Builder opsBuilder =
TypeSpec.classBuilder("Ops")
- .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
- .addJavadoc("An API for building a {@link $T} with operation wrappers\n<p>\n" +
- "Any operation wrapper found in the classpath properly annotated as an {@link $T @Operator} is exposed\n" +
- "by this API or one of its subgroup.\n<p>Example usage:\n<pre>{@code\n" +
- "try (Graph g = new Graph()) {\n" +
- " Ops ops = new Ops(g);\n" +
- " // Operations are typed classes with convenience\n" +
- " // builders in Ops.\n" +
- " Constant three = ops.constant(3);\n" +
- " // Single-result operations implement the Operand\n" +
- " // interface, so this works too.\n" +
- " Operand four = ops.constant(4);\n" +
- " // Most builders are found within a group, and accept\n" +
- " // Operand types as operands\n" +
- " Operand nine = ops.math().add(four, ops.constant(5));\n" +
- " // Multi-result operations however offer methods to\n" +
- " // select a particular result for use.\n" +
- " Operand result = \n" +
- " ops.math().add(ops.array().unique(s, a).y(), b);\n" +
- " // Optional attributes\n" +
- " ops.math().matMul(a, b, MatMul.transposeA(true));\n" +
- " // Naming operators\n" +
- " ops.withName(“foo”).constant(5); // name “foo”\n" +
- " // Names can exist in a hierarchy\n" +
- " Ops sub = ops.withSubScope(“sub”);\n" +
- " sub.withName(“bar”).constant(4); // “sub/bar”\n" +
- "}\n" +
- "}</pre>\n", T_GRAPH, T_OPERATOR)
- .addMethods(methods)
- .addMethod(ctorBuilder.build());
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .addJavadoc(
+ "An API for building a {@link $T} with operation wrappers\n<p>\n"
+ + "Any operation wrapper found in the classpath properly annotated as an"
+ + "{@link $T @Operator} is exposed\n"
+ + "by this API or one of its subgroup.\n<p>Example usage:\n<pre>{@code\n"
+ + "try (Graph g = new Graph()) {\n"
+ + " Ops ops = new Ops(g);\n"
+ + " // Operations are typed classes with convenience\n"
+ + " // builders in Ops.\n"
+ + " Constant three = ops.constant(3);\n"
+ + " // Single-result operations implement the Operand\n"
+ + " // interface, so this works too.\n"
+ + " Operand four = ops.constant(4);\n"
+ + " // Most builders are found within a group, and accept\n"
+ + " // Operand types as operands\n"
+ + " Operand nine = ops.math().add(four, ops.constant(5));\n"
+ + " // Multi-result operations however offer methods to\n"
+ + " // select a particular result for use.\n"
+ + " Operand result = \n"
+ + " ops.math().add(ops.array().unique(s, a).y(), b);\n"
+ + " // Optional attributes\n"
+ + " ops.math().matMul(a, b, MatMul.transposeA(true));\n"
+ + " // Naming operators\n"
+ + " ops.withName(“foo”).constant(5); // name “foo”\n"
+ + " // Names can exist in a hierarchy\n"
+ + " Ops sub = ops.withSubScope(“sub”);\n"
+ + " sub.withName(“bar”).constant(4); // “sub/bar”\n"
+ + "}\n"
+ + "}</pre>\n",
+ T_GRAPH,
+ T_OPERATOR)
+ .addMethods(methods)
+ .addMethod(ctorBuilder.build());
opsBuilder.addMethod(
MethodSpec.methodBuilder("withSubScope")
- .addModifiers(Modifier.PUBLIC)
- .addParameter(T_STRING, "childScopeName")
- .returns(T_OPS)
- .addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS)
- .addJavadoc(
- "Returns an API that adds operations to the graph with the provided name prefix.\n\n" +
- "@see {@link $T#withSubScope(String)}\n", T_SCOPE)
- .build());
+ .addModifiers(Modifier.PUBLIC)
+ .addParameter(T_STRING, "childScopeName")
+ .returns(T_OPS)
+ .addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS)
+ .addJavadoc(
+ "Returns an API that adds operations to the graph with the provided name prefix.\n"
+ + "\n@see {@link $T#withSubScope(String)}\n",
+ T_SCOPE)
+ .build());
opsBuilder.addMethod(
MethodSpec.methodBuilder("withName")
- .addModifiers(Modifier.PUBLIC)
- .addParameter(T_STRING, "opName")
- .returns(T_OPS)
- .addStatement("return new Ops(scope.withName(opName))")
- .addJavadoc(
- "Returns an API that uses the provided name for an op.\n\n" +
- "@see {@link $T#withName(String)}\n", T_SCOPE)
- .build());
+ .addModifiers(Modifier.PUBLIC)
+ .addParameter(T_STRING, "opName")
+ .returns(T_OPS)
+ .addStatement("return new Ops(scope.withName(opName))")
+ .addJavadoc(
+ "Returns an API that uses the provided name for an op.\n\n"
+ + "@see {@link $T#withName(String)}\n",
+ T_SCOPE)
+ .build());
opsBuilder.addField(
- FieldSpec.builder(T_SCOPE, "scope")
- .addModifiers(Modifier.PRIVATE, Modifier.FINAL)
- .build());
+ FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build());
opsBuilder.addMethod(
MethodSpec.methodBuilder("scope")
- .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
- .returns(T_SCOPE)
- .addStatement("return scope")
- .addJavadoc("Returns the current {@link $T scope} of this API\n", T_SCOPE)
- .build());
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .returns(T_SCOPE)
+ .addStatement("return scope")
+ .addJavadoc("Returns the current {@link $T scope} of this API\n", T_SCOPE)
+ .build());
- for (Map.Entry<String, ClassName> entry: groupToClass.entrySet()) {
+ for (Map.Entry<String, ClassName> entry : groupToClass.entrySet()) {
opsBuilder.addField(
FieldSpec.builder(entry.getValue(), entry.getKey())
- .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
- .build());
-
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .build());
+
opsBuilder.addMethod(
MethodSpec.methodBuilder(entry.getKey())
- .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
- .returns(entry.getValue())
- .addStatement("return $L", entry.getKey())
- .addJavadoc("Returns an API for adding {@code $L} operations to the graph\n", entry.getKey())
- .build());
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .returns(entry.getValue())
+ .addStatement("return $L", entry.getKey())
+ .addJavadoc(
+ "Returns an API for adding {@code $L} operations to the graph\n", entry.getKey())
+ .build());
}
opsBuilder.addMethod(
MethodSpec.methodBuilder("create")
- .addModifiers(Modifier.PUBLIC, Modifier.STATIC)
- .addParameter(T_GRAPH, "graph")
- .returns(T_OPS)
- .addStatement("return new Ops(new $T(graph))", T_SCOPE)
- .addJavadoc("Creates an API for adding operations to the provided {@code graph}\n")
- .build());
+ .addModifiers(Modifier.PUBLIC, Modifier.STATIC)
+ .addParameter(T_GRAPH, "graph")
+ .returns(T_OPS)
+ .addStatement("return new Ops(new $T(graph))", T_SCOPE)
+ .addJavadoc("Creates an API for adding operations to the provided {@code graph}\n")
+ .build());
return opsBuilder.build();
}
@@ -417,12 +437,16 @@ public final class OperatorProcessor extends AbstractProcessor {
return am;
}
}
- throw new IllegalArgumentException("Annotation " + annotation.getSimpleName() + " not present on element "
- + element.getSimpleName());
+ throw new IllegalArgumentException(
+ "Annotation "
+ + annotation.getSimpleName()
+ + " not present on element "
+ + element.getSimpleName());
}
-
+
private static String getAnnotationElementValueAsString(String elementName, AnnotationMirror am) {
- for (Map.Entry<? extends ExecutableElement, ? extends AnnotationValue> entry : am.getElementValues().entrySet()) {
+ for (Map.Entry<? extends ExecutableElement, ? extends AnnotationValue> entry :
+ am.getElementValues().entrySet()) {
if (entry.getKey().getSimpleName().contentEquals(elementName)) {
return entry.getValue().getValue().toString();
}
diff --git a/tensorflow/python/compat/BUILD b/tensorflow/python/compat/BUILD
new file mode 100644
index 0000000000..5f55b22818
--- /dev/null
+++ b/tensorflow/python/compat/BUILD
@@ -0,0 +1,10 @@
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+ name = "compat",
+ srcs = ["compat.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:internal"],
+)
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
new file mode 100644
index 0000000000..e05ad55447
--- /dev/null
+++ b/tensorflow/python/compat/compat.py
@@ -0,0 +1,81 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for API compatibility between TensorFlow release versions.
+
+See
+@{$guide/version_compat#backward_and_partial_forward_compatibility}
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import datetime
+
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 8, 1)
+
+
+def forward_compatible(year, month, day):
+ """Return true if the forward compatibility window has expired.
+
+ Forward-compatibility refers to scenarios where the producer of a TensorFlow
+ model (a GraphDef or SavedModel) is compiled against a version of the
+ TensorFlow library newer than what the consumer was compiled against. The
+ "producer" is typically a Python program that constructs and trains a model
+ while the "consumer" is typically another program that loads and serves the
+ model.
+
+ TensorFlow has been supporting a 3 week forward-compatibility window for
+ programs compiled from source at HEAD.
+
+ For example, consider the case where a new operation `MyNewAwesomeAdd` is
+ created with the intent of replacing the implementation of an existing Python
+ wrapper - `tf.add`. The Python wrapper implementation should change from
+ something like:
+
+ ```python
+ def add(inputs, name=None):
+ return gen_math_ops.add(inputs, name)
+ ```
+
+ to:
+
+ ```python
+ from tensorflow.python.compat import compat
+
+ def add(inputs, name=None):
+ if compat.forward_compatible(year, month, day):
+ # Can use the awesome new implementation.
+ return gen_math_ops.my_new_awesome_add(inputs, name)
+ # To maintain forward compatibiltiy, use the old implementation.
+ return gen_math_ops.add(inputs, name)
+ ```
+
+ Where `year`, `month`, and `day` specify the date beyond which binaries
+ that consume a model are expected to have been updated to include the
+ new operations. This date is typically at least 3 weeks beyond the date
+ the code that adds the new operation is committed.
+
+ Args:
+ year: A year (e.g., 2018).
+ month: A month (1 <= month <= 12) in year.
+ day: A day (1 <= day <= 31, or 30, or 29, or 28) in month.
+
+ Returns:
+ True if the caller can expect that serialized TensorFlow graphs produced
+ can be consumed by programs that are compiled with the TensorFlow library
+ source code after (year, month, day).
+ """
+ return _FORWARD_COMPATIBILITY_HORIZON > datetime.date(year, month, day)
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 6941cacf23..c025dc8aa5 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -454,6 +454,17 @@ py_binary(
],
)
+py_binary(
+ name = "debug_keras",
+ srcs = ["examples/debug_keras.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":debug_py",
+ "//tensorflow:tensorflow_py",
+ "//third_party/py/numpy",
+ ],
+)
+
py_test(
name = "common_test",
size = "small",
@@ -1086,6 +1097,7 @@ py_test(
"//tensorflow/python:state_ops",
"//tensorflow/python:training",
"//tensorflow/python:variables",
+ "//third_party/py/numpy",
],
)
@@ -1096,6 +1108,7 @@ sh_test(
data = [
":debug_errors",
":debug_fibonacci",
+ ":debug_keras",
":debug_mnist",
":debug_tflearn_iris",
":offline_analyzer",
diff --git a/tensorflow/python/debug/examples/debug_keras.py b/tensorflow/python/debug/examples/debug_keras.py
new file mode 100644
index 0000000000..3272d85ade
--- /dev/null
+++ b/tensorflow/python/debug/examples/debug_keras.py
@@ -0,0 +1,89 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""tfdbg example: debugging tf.keras models training on tf.data.Dataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python import debug as tf_debug
+
+
+def main(_):
+ # Create a dummy dataset.
+ num_examples = 8
+ steps_per_epoch = 2
+ input_dims = 3
+ output_dims = 1
+ xs = np.zeros([num_examples, input_dims])
+ ys = np.zeros([num_examples, output_dims])
+ dataset = tf.data.Dataset.from_tensor_slices(
+ (xs, ys)).repeat(num_examples).batch(int(num_examples / steps_per_epoch))
+
+ sess = tf.Session()
+ if FLAGS.debug:
+ # Use the command-line interface (CLI) of tfdbg.
+ sess = tf_debug.LocalCLIDebugWrapperSession(sess, ui_type=FLAGS.ui_type)
+ elif FLAGS.tensorboard_debug_address:
+ # Use the TensorBoard Debugger Plugin (GUI of tfdbg).
+ sess = tf_debug.TensorBoardDebugWrapperSession(
+ sess, FLAGS.tensorboard_debug_address)
+ tf.keras.backend.set_session(sess)
+
+ # Create a dummy model.
+ model = tf.keras.Sequential([
+ tf.keras.layers.Dense(1, input_shape=[input_dims])])
+ model.compile(loss="mse", optimizer="sgd")
+
+ # Train the model using the dummy dataset created above.
+ model.fit(dataset, epochs=FLAGS.epochs, steps_per_epoch=steps_per_epoch)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.register("type", "bool", lambda v: v.lower() == "true")
+ parser.add_argument(
+ "--debug",
+ type="bool",
+ nargs="?",
+ const=True,
+ default=False,
+ help="Use debugger to track down bad values during training. "
+ "Mutually exclusive with the --tensorboard_debug_address flag.")
+ parser.add_argument(
+ "--ui_type",
+ type=str,
+ default="curses",
+ help="Command-line user interface type (curses | readline).")
+ parser.add_argument(
+ "--tensorboard_debug_address",
+ type=str,
+ default=None,
+ help="Connect to the TensorBoard Debugger Plugin backend specified by "
+ "the gRPC address (e.g., localhost:1234). Mutually exclusive with the "
+ "--debug flag.")
+ parser.add_argument(
+ "--epochs",
+ type=int,
+ default=2,
+ help="Number of epochs to train the model for.")
+ FLAGS, unparsed = parser.parse_known_args()
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/python/debug/examples/examples_test.sh b/tensorflow/python/debug/examples/examples_test.sh
index e9c45a7e6e..2d35b2d8bb 100755
--- a/tensorflow/python/debug/examples/examples_test.sh
+++ b/tensorflow/python/debug/examples/examples_test.sh
@@ -48,12 +48,14 @@ if [[ -z "${PYTHON_BIN_PATH}" ]]; then
DEBUG_ERRORS_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_errors"
DEBUG_MNIST_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_mnist"
DEBUG_TFLEARN_IRIS_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_tflearn_iris"
+ DEBUG_KERAS_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_keras"
OFFLINE_ANALYZER_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/offline_analyzer"
else
DEBUG_FIBONACCI_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_fibonacci"
DEBUG_ERRORS_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_errors"
DEBUG_MNIST_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_mnist"
DEBUG_TFLEARN_IRIS_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_tflearn_iris"
+ DEBUG_KERAS_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_keras"
OFFLINE_ANALYZER_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.cli.offline_analyzer"
fi
@@ -96,6 +98,11 @@ if [[ -d "${CUSTOM_DUMP_ROOT}" ]]; then
exit 1
fi
+# Test debugging of tf.keras.
+cat << EOF | "${DEBUG_KERAS_BIN}" --debug --ui_type=readline
+run -f has_inf_or_nan
+EOF
+
# Test offline_analyzer.
echo
echo "Testing offline_analyzer"
diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py
index c530204bbf..b9524ce649 100644
--- a/tensorflow/python/debug/wrappers/framework.py
+++ b/tensorflow/python/debug/wrappers/framework.py
@@ -392,6 +392,9 @@ class BaseDebugWrapperSession(session.SessionInterface):
self._default_session_context_manager = None
+ # A cache for callables created from CallableOptions.
+ self._cached_callables_from_options = dict()
+
@property
def graph(self):
return self._sess.graph
@@ -414,7 +417,8 @@ class BaseDebugWrapperSession(session.SessionInterface):
options=None,
run_metadata=None,
callable_runner=None,
- callable_runner_args=None):
+ callable_runner_args=None,
+ callable_options=None):
"""Wrapper around Session.run() that inserts tensor watch options.
Args:
@@ -424,7 +428,12 @@ class BaseDebugWrapperSession(session.SessionInterface):
run_metadata: Same as the `run_metadata` arg to regular `Session.run()`.
callable_runner: A `callable` returned by `Session.make_callable()`.
If not `None`, `fetches` and `feed_dict` must both be `None`.
- callable_runner_args: An optional list of arguments to `callable_runner`.
+ Mutually exclusive with `callable_options`.
+ callable_runner_args: An optional list of arguments to `callable_runner`
+ or for `callable_options`.
+ callable_options: An instance of `config_pb2.CallableOptions`, to be
+ used with `Session._make_callable_from_options()`. Mutually exclusive
+ with `callable_runner`.
Returns:
Simply forwards the output of the wrapped `Session.run()` call.
@@ -433,13 +442,17 @@ class BaseDebugWrapperSession(session.SessionInterface):
ValueError: On invalid `OnRunStartAction` value. Or if `callable_runner`
is not `None` and either or both of `fetches` and `feed_dict` is `None`.
"""
- if not callable_runner:
+ if callable_runner and callable_options:
+ raise ValueError(
+ "callable_runner and callable_options are mutually exclusive, but "
+ "are both specified in this call to BaseDebugWrapperSession.run().")
+
+ if not (callable_runner or callable_options):
self.increment_run_call_count()
- else:
- if fetches or feed_dict:
- raise ValueError(
- "callable_runner and fetches/feed_dict are mutually exclusive, but "
- "are used simultaneously.")
+ elif callable_runner and (fetches or feed_dict):
+ raise ValueError(
+ "callable_runner and fetches/feed_dict are mutually exclusive, "
+ "but are used simultaneously.")
empty_fetches = not nest.flatten(fetches)
if empty_fetches:
@@ -449,6 +462,11 @@ class BaseDebugWrapperSession(session.SessionInterface):
if self._is_disabled_thread() or empty_fetches:
if callable_runner:
return callable_runner(*callable_runner_args)
+ elif callable_options:
+ # pylint:disable=protected-access
+ return self._sess._make_callable_from_options(
+ callable_options)(*callable_runner_args)
+ # pylint:enable=protected-access
else:
return self._sess.run(fetches,
feed_dict=feed_dict,
@@ -464,19 +482,30 @@ class BaseDebugWrapperSession(session.SessionInterface):
if run_start_resp.action == OnRunStartAction.DEBUG_RUN:
# Decorate RunOption to fill in debugger tensor watch specifications.
- decorated_run_options = options or config_pb2.RunOptions()
+ decorated_run_options = None
+ if callable_options:
+ callable_options_id = id(callable_options)
+ if callable_options_id not in self._cached_callables_from_options:
+ # Make a copy of callable_options to avoid mutating it.
+ new_callable_options = config_pb2.CallableOptions()
+ new_callable_options.CopyFrom(callable_options)
+ decorated_run_options = new_callable_options.run_options
+ else:
+ decorated_run_options = options or config_pb2.RunOptions()
+
run_metadata = run_metadata or config_pb2.RunMetadata()
- self._decorate_run_options_for_debug(
- decorated_run_options,
- run_start_resp.debug_urls,
- debug_ops=run_start_resp.debug_ops,
- node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist,
- op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist,
- tensor_dtype_regex_whitelist=(
- run_start_resp.tensor_dtype_regex_whitelist),
- tolerate_debug_op_creation_failures=(
- run_start_resp.tolerate_debug_op_creation_failures))
+ if decorated_run_options:
+ self._decorate_run_options_for_debug(
+ decorated_run_options,
+ run_start_resp.debug_urls,
+ debug_ops=run_start_resp.debug_ops,
+ node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist,
+ op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist,
+ tensor_dtype_regex_whitelist=(
+ run_start_resp.tensor_dtype_regex_whitelist),
+ tolerate_debug_op_creation_failures=(
+ run_start_resp.tolerate_debug_op_creation_failures))
# Invoke the run() method of the wrapped Session. Catch any TensorFlow
# runtime errors.
@@ -486,6 +515,19 @@ class BaseDebugWrapperSession(session.SessionInterface):
retvals = callable_runner(*callable_runner_args,
options=decorated_run_options,
run_metadata=run_metadata)
+ elif callable_options:
+ # pylint:disable=protected-access
+ if callable_options_id in self._cached_callables_from_options:
+ callable_object = self._cached_callables_from_options[
+ callable_options_id]
+ else:
+ callable_object = self._sess._make_callable_from_options(
+ new_callable_options)
+ self._cached_callables_from_options[
+ callable_options_id] = callable_object
+ # pylint:enable=protected-access
+ retvals = callable_object(
+ *callable_runner_args, run_metadata=run_metadata)
else:
retvals = self._sess.run(fetches,
feed_dict=feed_dict,
@@ -590,7 +632,14 @@ class BaseDebugWrapperSession(session.SessionInterface):
run_metadata=kwargs.get("run_metadata", None),
callable_runner=runner,
callable_runner_args=runner_args)
+ return wrapped_runner
+ def _make_callable_from_options(self, callable_options):
+ def wrapped_runner(*feed_values, **kwargs):
+ return self.run(None,
+ run_metadata=kwargs.get("run_metadata", None),
+ callable_options=callable_options,
+ callable_runner_args=feed_values)
return wrapped_runner
@property
diff --git a/tensorflow/python/debug/wrappers/grpc_wrapper.py b/tensorflow/python/debug/wrappers/grpc_wrapper.py
index 1f9c8fa5a9..85944fa611 100644
--- a/tensorflow/python/debug/wrappers/grpc_wrapper.py
+++ b/tensorflow/python/debug/wrappers/grpc_wrapper.py
@@ -215,7 +215,8 @@ class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession):
options=None,
run_metadata=None,
callable_runner=None,
- callable_runner_args=None):
+ callable_runner_args=None,
+ callable_options=None):
if self._send_traceback_and_source_code:
self._sent_graph_version = publish_traceback(
self._grpc_debug_server_urls, self.graph, feed_dict, fetches,
@@ -226,4 +227,5 @@ class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession):
options=options,
run_metadata=run_metadata,
callable_runner=callable_runner,
- callable_runner_args=callable_runner_args)
+ callable_runner_args=callable_runner_args,
+ callable_options=callable_options)
diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper.py b/tensorflow/python/debug/wrappers/local_cli_wrapper.py
index 4e551ab995..668ffb57f1 100644
--- a/tensorflow/python/debug/wrappers/local_cli_wrapper.py
+++ b/tensorflow/python/debug/wrappers/local_cli_wrapper.py
@@ -596,7 +596,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
# Register tab completion for the filter names.
curses_cli.register_tab_comp_context(["run", "r"],
list(self._tensor_filters.keys()))
- if self._feed_dict:
+ if self._feed_dict and hasattr(self._feed_dict, "keys"):
# Register tab completion for feed_dict keys.
feed_keys = [common.get_graph_element_name(key)
for key in self._feed_dict.keys()]
diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
index b06fa26a93..05c9eaa4d2 100644
--- a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
+++ b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
@@ -21,7 +21,10 @@ import os
import shutil
import tempfile
+import numpy as np
+
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.debug.cli import cli_shared
from tensorflow.python.debug.cli import debugger_cli_common
@@ -149,7 +152,13 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
dtypes.float32, shape=([5, 5]), name="sparse_placeholder")
self.sparse_add = sparse_ops.sparse_add(self.sparse_ph, self.sparse_ph)
- self.sess = session.Session()
+ rewriter_config = rewriter_config_pb2.RewriterConfig(
+ disable_model_pruning=True,
+ arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
+ dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
+ graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
+ config_proto = config_pb2.ConfigProto(graph_options=graph_options)
+ self.sess = session.Session(config=config_proto)
# Initialize variable.
self.sess.run(variables.global_variables_initializer())
@@ -393,6 +402,113 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
self.assertAllClose(42.0, tensor_runner(41.0, 1.0))
self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))
+ def testDebuggingMakeCallableFromOptionsWithZeroFeedWorks(self):
+ variable_1 = variables.Variable(
+ 10.5, dtype=dtypes.float32, name="variable_1")
+ a = math_ops.add(variable_1, variable_1, "callable_a")
+ math_ops.add(a, a, "callable_b")
+ self.sess.run(variable_1.initializer)
+
+ wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
+ [["run"]] * 3, self.sess, dump_root=self._tmp_dir)
+ callable_options = config_pb2.CallableOptions()
+ callable_options.fetch.append("callable_b")
+ sess_callable = wrapped_sess._make_callable_from_options(callable_options)
+
+ for _ in range(2):
+ callable_output = sess_callable()
+ self.assertAllClose(np.array(42.0, dtype=np.float32), callable_output[0])
+
+ debug_dumps = wrapped_sess.observers["debug_dumps"]
+ self.assertEqual(2, len(debug_dumps))
+ for debug_dump in debug_dumps:
+ node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data]
+ self.assertItemsEqual(
+ ["callable_a", "callable_b", "variable_1", "variable_1/read"],
+ node_names)
+
+ def testDebuggingMakeCallableFromOptionsWithOneFeedWorks(self):
+ ph1 = array_ops.placeholder(dtypes.float32, name="callable_ph1")
+ a = math_ops.add(ph1, ph1, "callable_a")
+ math_ops.add(a, a, "callable_b")
+
+ wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
+ [["run"]] * 3, self.sess, dump_root=self._tmp_dir)
+ callable_options = config_pb2.CallableOptions()
+ callable_options.feed.append("callable_ph1")
+ callable_options.fetch.append("callable_b")
+ sess_callable = wrapped_sess._make_callable_from_options(callable_options)
+
+ ph1_value = np.array([10.5, -10.5], dtype=np.float32)
+
+ for _ in range(2):
+ callable_output = sess_callable(ph1_value)
+ self.assertAllClose(
+ np.array([42.0, -42.0], dtype=np.float32), callable_output[0])
+
+ debug_dumps = wrapped_sess.observers["debug_dumps"]
+ self.assertEqual(2, len(debug_dumps))
+ for debug_dump in debug_dumps:
+ node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data]
+ self.assertItemsEqual(["callable_a", "callable_b"], node_names)
+
+ def testDebuggingMakeCallableFromOptionsWithTwoFeedsWorks(self):
+ ph1 = array_ops.placeholder(dtypes.float32, name="callable_ph1")
+ ph2 = array_ops.placeholder(dtypes.float32, name="callable_ph2")
+ a = math_ops.add(ph1, ph2, "callable_a")
+ math_ops.add(a, a, "callable_b")
+
+ wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
+ [["run"]] * 3, self.sess, dump_root=self._tmp_dir)
+ callable_options = config_pb2.CallableOptions()
+ callable_options.feed.append("callable_ph1")
+ callable_options.feed.append("callable_ph2")
+ callable_options.fetch.append("callable_b")
+ sess_callable = wrapped_sess._make_callable_from_options(callable_options)
+
+ ph1_value = np.array(5.0, dtype=np.float32)
+ ph2_value = np.array(16.0, dtype=np.float32)
+
+ for _ in range(2):
+ callable_output = sess_callable(ph1_value, ph2_value)
+ self.assertAllClose(np.array(42.0, dtype=np.float32), callable_output[0])
+
+ debug_dumps = wrapped_sess.observers["debug_dumps"]
+ self.assertEqual(2, len(debug_dumps))
+ for debug_dump in debug_dumps:
+ node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data]
+ self.assertItemsEqual(["callable_a", "callable_b"], node_names)
+
+ def testDebugMakeCallableFromOptionsWithCustomOptionsAndMetadataWorks(self):
+ variable_1 = variables.Variable(
+ 10.5, dtype=dtypes.float32, name="variable_1")
+ a = math_ops.add(variable_1, variable_1, "callable_a")
+ math_ops.add(a, a, "callable_b")
+ self.sess.run(variable_1.initializer)
+
+ wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
+ [["run"], ["run"]], self.sess, dump_root=self._tmp_dir)
+ callable_options = config_pb2.CallableOptions()
+ callable_options.fetch.append("callable_b")
+ callable_options.run_options.trace_level = config_pb2.RunOptions.FULL_TRACE
+
+ sess_callable = wrapped_sess._make_callable_from_options(callable_options)
+
+ run_metadata = config_pb2.RunMetadata()
+ # Call the callable with a custom run_metadata.
+ callable_output = sess_callable(run_metadata=run_metadata)
+ # Verify that step_stats is populated in the custom run_metadata.
+ self.assertTrue(run_metadata.step_stats)
+ self.assertAllClose(np.array(42.0, dtype=np.float32), callable_output[0])
+
+ debug_dumps = wrapped_sess.observers["debug_dumps"]
+ self.assertEqual(1, len(debug_dumps))
+ debug_dump = debug_dumps[0]
+ node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data]
+ self.assertItemsEqual(
+ ["callable_a", "callable_b", "variable_1", "variable_1/read"],
+ node_names)
+
def testRuntimeErrorShouldBeCaught(self):
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
[["run"], ["run"]], self.sess, dump_root=self._tmp_dir)
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index a81ef90513..7edcb0931d 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -782,6 +782,9 @@ class _PolymorphicFunction(object):
kwd_values = _deterministic_dict_values(kwds)
inputs = args + kwd_values
signature = tuple(_cache_key(x) for x in inputs)
+ # The graph, or whether we're executing eagerly, should be a part of the
+ # signature so we don't improperly capture tensors such as variables.
+ signature += tuple([context.executing_eagerly() or ops.get_default_graph()])
if signature not in self._arguments_to_functions:
graph_function = _trace_and_define_function(
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index ad00adbabb..cf32f6e7fb 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -105,6 +105,18 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(grads.eval(), 2.0)
self.assertEqual(grads.shape, v.shape)
+ def testGraphEagerIsolation(self):
+
+ @function.defun
+ def f():
+ v = resource_variable_ops.ResourceVariable(1.0)
+ return v.read_value()
+
+ self.assertAllEqual(f(), 1.0)
+
+ with ops.Graph().as_default():
+ self.assertEqual(f().shape, ())
+
def testBasicDefunOpGraphMode(self):
matmul = function.defun(math_ops.matmul)
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 733c7fb95d..2a0e4e7617 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -38,6 +38,7 @@ from tensorflow.python.estimator.export import export_output
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
@@ -1296,6 +1297,31 @@ class EstimatorEvaluateTest(test.TestCase):
dummy_input_fn, steps=1, checkpoint_path=est1.latest_checkpoint())
self.assertEqual(5, scores['global_step'])
+ def test_wrong_shape_throws_reasonable_error(self):
+ """Make sure we are helpful when model_fns change. See b/110263146."""
+ def _get_model_fn(val=1):
+ def _model_fn(features, labels, mode):
+ del features, labels # unused
+ variables.Variable(val, name='weight')
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ predictions=constant_op.constant([[1.]]),
+ loss=constant_op.constant(0.),
+ train_op=state_ops.assign_add(training.get_global_step(), 1))
+ return _model_fn
+
+ model_fn_1 = _get_model_fn()
+ model_fn_2 = _get_model_fn(val=[1])
+
+ est1 = estimator.Estimator(model_fn=model_fn_1)
+ est1.train(dummy_input_fn, steps=5)
+ est2 = estimator.Estimator(
+ model_fn=model_fn_2, model_dir=est1.model_dir)
+
+ expected_msg = 'Restoring from checkpoint failed.*a mismatch between'
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, expected_msg):
+ est2.train(dummy_input_fn, steps=1,)
+
def test_scaffold_is_used(self):
def _model_fn_scaffold(features, labels, mode):
diff --git a/tensorflow/python/keras/datasets/mnist.py b/tensorflow/python/keras/datasets/mnist.py
index 2a1c8d5f51..a96b581960 100644
--- a/tensorflow/python/keras/datasets/mnist.py
+++ b/tensorflow/python/keras/datasets/mnist.py
@@ -50,5 +50,5 @@ def load_data(path='mnist.npz'):
with np.load(path) as f:
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
-
+
return (x_train, y_train), (x_test, y_test)
diff --git a/tensorflow/python/keras/engine/saving.py b/tensorflow/python/keras/engine/saving.py
index 5e95cd4340..d5ccd44604 100644
--- a/tensorflow/python/keras/engine/saving.py
+++ b/tensorflow/python/keras/engine/saving.py
@@ -854,7 +854,16 @@ def load_weights_from_hdf5_group_by_name(f, layers):
str(len(weight_values)) + ' element(s).')
# Set values.
for i in range(len(weight_values)):
- weight_value_tuples.append((symbolic_weights[i], weight_values[i]))
+ if K.int_shape(symbolic_weights[i]) != weight_values[i].shape:
+ raise ValueError('Layer #' + str(k) +' (named "' + layer.name +
+ '"), weight ' + str(symbolic_weights[i]) +
+ ' has shape {}'.format(K.int_shape(
+ symbolic_weights[i])) +
+ ', but the saved weight has shape ' +
+ str(weight_values[i].shape) + '.')
+
+ else:
+ weight_value_tuples.append((symbolic_weights[i], weight_values[i]))
K.batch_set_value(weight_value_tuples)
diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py
index 1a0aa60609..030328f2a6 100644
--- a/tensorflow/python/keras/engine/saving_test.py
+++ b/tensorflow/python/keras/engine/saving_test.py
@@ -21,7 +21,6 @@ from __future__ import print_function
import os
import shutil
import tempfile
-
from absl.testing import parameterized
import numpy as np
@@ -31,6 +30,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import saving
from tensorflow.python.keras.engine import training
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
@@ -248,6 +248,82 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
self.assertAllClose(y, ref_y)
+ def test_sequential_weight_loading_group_name_with_incorrect_length(self):
+ if h5py is None:
+ return
+
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir)
+ h5_path = os.path.join(temp_dir, 'test.h5')
+
+ num_hidden = 5
+ input_dim = 3
+ num_classes = 2
+ with self.test_session():
+ ref_model = keras.models.Sequential()
+ ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim,
+ name='d1'))
+ ref_model.add(keras.layers.Dense(num_classes, name='d2'))
+ ref_model.compile(loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(lr=0.0001),
+ metrics=[keras.metrics.categorical_accuracy])
+
+ f_ref_model = h5py.File(h5_path, 'w')
+ saving.save_weights_to_hdf5_group(f_ref_model, ref_model.layers)
+
+ f_model = h5py.File(h5_path, 'r')
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(num_hidden, use_bias=False,
+ input_dim=input_dim, name='d1'))
+ model.add(keras.layers.Dense(num_classes, name='d2'))
+ model.compile(loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(lr=0.0001),
+ metrics=[keras.metrics.categorical_accuracy])
+ with self.assertRaisesRegexp(ValueError,
+ r'Layer #0 \(named \"d1\"\) expects 1 '
+ r'weight\(s\), but the saved weights have 2 '
+ r'element\(s\)\.'):
+ saving.load_weights_from_hdf5_group_by_name(f_model, model.layers)
+
+ def test_sequential_weight_loading_group_name_with_incorrect_shape(self):
+ if h5py is None:
+ return
+
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir)
+ h5_path = os.path.join(temp_dir, 'test.h5')
+
+ num_hidden = 5
+ input_dim = 3
+ num_classes = 2
+ with self.test_session():
+ ref_model = keras.models.Sequential()
+ ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim,
+ name='d1'))
+ ref_model.add(keras.layers.Dense(num_classes, name='d2'))
+ ref_model.compile(loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(lr=0.0001),
+ metrics=[keras.metrics.categorical_accuracy])
+
+ f_ref_model = h5py.File(h5_path, 'w')
+ saving.save_weights_to_hdf5_group(f_ref_model, ref_model.layers)
+
+ f_model = h5py.File(h5_path, 'r')
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(num_hidden + 5, input_dim=input_dim,
+ name='d1'))
+ model.add(keras.layers.Dense(num_classes, name='d2'))
+ model.compile(loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(lr=0.0001),
+ metrics=[keras.metrics.categorical_accuracy])
+ with self.assertRaisesRegexp(ValueError,
+ r'Layer #0 \(named "d1"\), weight '
+ r'<tf\.Variable \'d1_1\/kernel:0\' '
+ r'shape=\(3, 10\) dtype=float32> has '
+ r'shape \(3, 10\), but the saved weight has '
+ r'shape \(3, 5\)\.'):
+ saving.load_weights_from_hdf5_group_by_name(f_model, model.layers)
+
class TestWholeModelSaving(test.TestCase):
diff --git a/tensorflow/python/keras/estimator/__init__.py b/tensorflow/python/keras/estimator/__init__.py
index cb86a69990..b244beb5b5 100644
--- a/tensorflow/python/keras/estimator/__init__.py
+++ b/tensorflow/python/keras/estimator/__init__.py
@@ -25,7 +25,7 @@ from tensorflow.python.util.tf_export import tf_export
# everything will work as normal.
try:
- import tensorflow.python.estimator.keras as keras_lib # pylint: disable=g-import-not-at-top
+ from tensorflow.python.estimator import keras as keras_lib # pylint: disable=g-import-not-at-top
model_to_estimator = tf_export('keras.estimator.model_to_estimator')(
keras_lib.model_to_estimator)
except Exception: # pylint: disable=broad-except
diff --git a/tensorflow/python/kernel_tests/dct_ops_test.py b/tensorflow/python/kernel_tests/dct_ops_test.py
index 93b2ff4561..97d7e2d8f9 100644
--- a/tensorflow/python/kernel_tests/dct_ops_test.py
+++ b/tensorflow/python/kernel_tests/dct_ops_test.py
@@ -40,50 +40,92 @@ def try_import(name): # pylint: disable=invalid-name
fftpack = try_import("scipy.fftpack")
+def _np_dct2(signals, norm=None):
+ """Computes the DCT-II manually with NumPy."""
+ # X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k) k=0,...,N-1
+ dct_size = signals.shape[-1]
+ dct = np.zeros_like(signals)
+ for k in range(dct_size):
+ phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size)
+ dct[..., k] = np.sum(signals * phi, axis=-1)
+ # SciPy's `dct` has a scaling factor of 2.0 which we follow.
+ # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src
+ if norm == "ortho":
+ # The orthonormal scaling includes a factor of 0.5 which we combine with
+ # the overall scaling of 2.0 to cancel.
+ dct[..., 0] *= np.sqrt(1.0 / dct_size)
+ dct[..., 1:] *= np.sqrt(2.0 / dct_size)
+ else:
+ dct *= 2.0
+ return dct
+
+
+def _np_dct3(signals, norm=None):
+ """Computes the DCT-III manually with NumPy."""
+ # SciPy's `dct` has a scaling factor of 2.0 which we follow.
+ # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src
+ dct_size = signals.shape[-1]
+ signals = np.array(signals) # make a copy so we can modify
+ if norm == "ortho":
+ signals[..., 0] *= np.sqrt(4.0 / dct_size)
+ signals[..., 1:] *= np.sqrt(2.0 / dct_size)
+ else:
+ signals *= 2.0
+ dct = np.zeros_like(signals)
+ # X_k = 0.5 * x_0 +
+ # sum_{n=1}^{N-1} x_n * cos(\frac{pi}{N} * n * (k + 0.5)) k=0,...,N-1
+ half_x0 = 0.5 * signals[..., 0]
+ for k in range(dct_size):
+ phi = np.cos(np.pi * np.arange(1, dct_size) * (k + 0.5) / dct_size)
+ dct[..., k] = half_x0 + np.sum(signals[..., 1:] * phi, axis=-1)
+ return dct
+
+
+NP_DCT = {2: _np_dct2, 3: _np_dct3}
+NP_IDCT = {2: _np_dct3, 3: _np_dct2}
+
+
class DCTOpsTest(test.TestCase):
- def _np_dct2(self, signals, norm=None):
- """Computes the DCT-II manually with NumPy."""
- # X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k) k=0,...,N-1
- dct_size = signals.shape[-1]
- dct = np.zeros_like(signals)
- for k in range(dct_size):
- phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size)
- dct[..., k] = np.sum(signals * phi, axis=-1)
- # SciPy's `dct` has a scaling factor of 2.0 which we follow.
- # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src
- if norm == "ortho":
- # The orthonormal scaling includes a factor of 0.5 which we combine with
- # the overall scaling of 2.0 to cancel.
- dct[..., 0] *= np.sqrt(1.0 / dct_size)
- dct[..., 1:] *= np.sqrt(2.0 / dct_size)
- else:
- dct *= 2.0
- return dct
-
- def _compare(self, signals, norm, atol=5e-4, rtol=5e-4):
- """Compares the DCT to SciPy (if available) and a NumPy implementation."""
- np_dct = self._np_dct2(signals, norm)
- tf_dct = spectral_ops.dct(signals, type=2, norm=norm).eval()
+ def _compare(self, signals, norm, dct_type, atol=5e-4, rtol=5e-4):
+ """Compares (I)DCT to SciPy (if available) and a NumPy implementation."""
+ np_dct = NP_DCT[dct_type](signals, norm)
+ tf_dct = spectral_ops.dct(signals, type=dct_type, norm=norm).eval()
self.assertAllClose(np_dct, tf_dct, atol=atol, rtol=rtol)
+ np_idct = NP_IDCT[dct_type](signals, norm)
+ tf_idct = spectral_ops.idct(signals, type=dct_type, norm=norm).eval()
+ self.assertAllClose(np_idct, tf_idct, atol=atol, rtol=rtol)
if fftpack:
- scipy_dct = fftpack.dct(signals, type=2, norm=norm)
+ scipy_dct = fftpack.dct(signals, type=dct_type, norm=norm)
self.assertAllClose(scipy_dct, tf_dct, atol=atol, rtol=rtol)
+ scipy_idct = fftpack.idct(signals, type=dct_type, norm=norm)
+ self.assertAllClose(scipy_idct, tf_idct, atol=atol, rtol=rtol)
+ # Verify inverse(forward(s)) == s, up to a normalization factor.
+ tf_idct_dct = spectral_ops.idct(
+ tf_dct, type=dct_type, norm=norm).eval()
+ tf_dct_idct = spectral_ops.dct(
+ tf_idct, type=dct_type, norm=norm).eval()
+ if norm is None:
+ tf_idct_dct *= 0.5 / signals.shape[-1]
+ tf_dct_idct *= 0.5 / signals.shape[-1]
+ self.assertAllClose(signals, tf_idct_dct, atol=atol, rtol=rtol)
+ self.assertAllClose(signals, tf_dct_idct, atol=atol, rtol=rtol)
def test_random(self):
"""Test randomly generated batches of data."""
with spectral_ops_test_util.fft_kernel_label_map():
with self.test_session(use_gpu=True):
- for shape in ([2, 20], [1], [2], [3], [10], [2, 20], [2, 3, 25]):
+ for shape in ([1], [2], [3], [10], [2, 20], [2, 3, 25]):
signals = np.random.rand(*shape).astype(np.float32)
for norm in (None, "ortho"):
- self._compare(signals, norm)
+ self._compare(signals, norm, 2)
+ self._compare(signals, norm, 3)
def test_error(self):
signals = np.random.rand(10)
# Unsupported type.
with self.assertRaises(ValueError):
- spectral_ops.dct(signals, type=3)
+ spectral_ops.dct(signals, type=1)
# Unknown normalization.
with self.assertRaises(ValueError):
spectral_ops.dct(signals, norm="bad")
diff --git a/tensorflow/python/lib/core/numpy.h b/tensorflow/python/lib/core/numpy.h
index 98354083c7..d4621d61ee 100644
--- a/tensorflow/python/lib/core/numpy.h
+++ b/tensorflow/python/lib/core/numpy.h
@@ -30,8 +30,8 @@ limitations under the License.
#endif
// Place `<locale>` before <Python.h> to avoid build failure in macOS.
-#include <locale>
#include <Python.h>
+#include <locale>
#include "numpy/arrayobject.h"
#include "numpy/ufuncobject.h"
diff --git a/tensorflow/python/lib/core/py_util.cc b/tensorflow/python/lib/core/py_util.cc
index 572693b1cf..6b6c82015f 100644
--- a/tensorflow/python/lib/core/py_util.cc
+++ b/tensorflow/python/lib/core/py_util.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/python/lib/core/py_util.h"
// Place `<locale>` before <Python.h> to avoid build failure in macOS.
-#include <locale>
#include <Python.h>
+#include <locale>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 2c7751f792..a2eae452ae 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -57,6 +57,7 @@ ops.NotDifferentiable('NonMaxSuppression')
ops.NotDifferentiable('NonMaxSuppressionV2')
+# pylint: disable=invalid-name
def _assert(cond, ex_type, msg):
"""A polymorphic assert, works with tensors and boolean expressions.
@@ -1070,15 +1071,16 @@ def resize_images(images,
@tf_export('image.resize_image_with_pad')
-def resize_image_with_pad(image, target_height, target_width,
+def resize_image_with_pad(image,
+ target_height,
+ target_width,
method=ResizeMethod.BILINEAR):
- """
- Resizes and pads an image to a target width and height.
+ """Resizes and pads an image to a target width and height.
Resizes an image to a target width and height by keeping
the aspect ratio the same without distortion. If the target
dimensions don't match the image dimensions, the image
- is resized and then padded with zeroes to match requested
+ is resized and then padded with zeroes to match requested
dimensions.
Args:
@@ -1139,10 +1141,10 @@ def resize_image_with_pad(image, target_height, target_width,
ratio = max_(f_width / f_target_width, f_height / f_target_height)
resized_height_float = f_height / ratio
resized_width_float = f_width / ratio
- resized_height = math_ops.cast(math_ops.floor(resized_height_float),
- dtype=dtypes.int32)
- resized_width = math_ops.cast(math_ops.floor(resized_width_float),
- dtype=dtypes.int32)
+ resized_height = math_ops.cast(
+ math_ops.floor(resized_height_float), dtype=dtypes.int32)
+ resized_width = math_ops.cast(
+ math_ops.floor(resized_width_float), dtype=dtypes.int32)
padding_height = (f_target_height - resized_height_float) / 2
padding_width = (f_target_width - resized_width_float) / 2
@@ -1154,13 +1156,13 @@ def resize_image_with_pad(image, target_height, target_width,
# Resize first, then pad to meet requested dimensions
resized = resize_images(image, [resized_height, resized_width], method)
- padded = pad_to_bounding_box(resized, p_height, p_width,
- target_height, target_width)
+ padded = pad_to_bounding_box(resized, p_height, p_width, target_height,
+ target_width)
if padded.get_shape().ndims is None:
raise ValueError('padded contains no shape.')
- _, padded_height, padded_width, _ = _ImageDimensions(padded, rank=4)
+ _ImageDimensions(padded, rank=4)
if not is_batch:
padded = array_ops.squeeze(padded, squeeze_dims=[0])
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 8e40de140d..cf9761803b 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -2731,7 +2731,7 @@ class ResizeImageWithPadTest(test_util.TensorFlowTestCase):
try:
self._ResizeImageWithPad(x, target_height, target_width,
use_tensor_inputs)
- except Exception as e:
+ except Exception as e: # pylint: disable=broad-except
if err_msg not in str(e):
raise
else:
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index 45e3bd65d2..6b709e5e7f 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -237,8 +237,8 @@ class ApproximateEqualTest(test_util.TensorFlowTestCase):
def testApproximateEqualShape(self):
for dtype in [np.float32, np.double]:
- x = np.array([1, 2], dtype=np.float32)
- y = np.array([[1, 2]], dtype=np.float32)
+ x = np.array([1, 2], dtype=dtype)
+ y = np.array([[1, 2]], dtype=dtype)
# The inputs 'x' and 'y' must have the same shape.
with self.assertRaisesRegexp(
ValueError, "Shapes must be equal rank, but are 1 and 2"):
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index 215140e987..deba133fb9 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import tensor_array_ops
@@ -131,6 +132,18 @@ def _maybe_tensor_shape_from_tensor(shape):
return shape
+def _should_cache():
+ """Returns True if a default caching device should be set, otherwise False."""
+ if context.executing_eagerly():
+ return False
+ # Don't set a caching device when running in a loop, since it is possible that
+ # train steps could be wrapped in a tf.while_loop. In that scenario caching
+ # prevents forward computations in loop iterations from re-reading the
+ # updated weights.
+ ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
+ return control_flow_util.GetContainingWhileContext(ctxt) is None
+
+
# pylint: disable=unused-argument
def _rnn_step(
time, sequence_length, min_sequence_length, max_sequence_length,
@@ -558,7 +571,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
# Create a new scope in which the caching device is either
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
- if not context.executing_eagerly():
+ if _should_cache():
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)
@@ -1015,7 +1028,7 @@ def raw_rnn(cell, loop_fn,
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
with vs.variable_scope(scope or "rnn") as varscope:
- if not context.executing_eagerly():
+ if _should_cache():
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)
@@ -1228,7 +1241,7 @@ def static_rnn(cell,
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
with vs.variable_scope(scope or "rnn") as varscope:
- if not context.executing_eagerly():
+ if _should_cache():
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)
diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py
index 6efcd39f13..9a10abfcf7 100644
--- a/tensorflow/python/ops/special_math_ops.py
+++ b/tensorflow/python/ops/special_math_ops.py
@@ -201,8 +201,8 @@ def einsum(equation, *inputs, **kwargs):
indices in its subscript, or
- the input shapes are inconsistent along a particular axis.
"""
- equation = equation.replace(" ", "")
-
+ equation = equation.replace(' ', '')
+
name = kwargs.pop('name', None)
if kwargs:
raise TypeError('invalid keyword arguments for this function: ' + ', '.join(
diff --git a/tensorflow/python/ops/spectral_ops.py b/tensorflow/python/ops/spectral_ops.py
index 28054f50ef..293aace728 100644
--- a/tensorflow/python/ops/spectral_ops.py
+++ b/tensorflow/python/ops/spectral_ops.py
@@ -167,8 +167,8 @@ def _validate_dct_arguments(dct_type, n, axis, norm):
raise NotImplementedError("The DCT length argument is not implemented.")
if axis != -1:
raise NotImplementedError("axis must be -1. Got: %s" % axis)
- if dct_type != 2:
- raise ValueError("Only the Type II DCT is supported.")
+ if dct_type not in (2, 3):
+ raise ValueError("Only Types II and III (I)DCT are supported.")
if norm not in (None, "ortho"):
raise ValueError(
"Unknown normalization. Expected None or 'ortho', got: %s" % norm)
@@ -179,18 +179,20 @@ def _validate_dct_arguments(dct_type, n, axis, norm):
def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
"""Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`.
- Currently only Type II is supported. Implemented using a length `2N` padded
- @{tf.spectral.rfft}, as described here: https://dsp.stackexchange.com/a/10606
+ Currently only Types II and III are supported. Type II is implemented using a
+ length `2N` padded @{tf.spectral.rfft}, as described here:
+ https://dsp.stackexchange.com/a/10606. Type III is a fairly straightforward
+ inverse of Type II (i.e. using a length `2N` padded @{tf.spectral.irfft}).
@compatibility(scipy)
- Equivalent to scipy.fftpack.dct for the Type-II DCT.
+ Equivalent to scipy.fftpack.dct for Type-II and Type-III DCT.
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
@end_compatibility
Args:
input: A `[..., samples]` `float32` `Tensor` containing the signals to
take the DCT of.
- type: The DCT type to perform. Must be 2.
+ type: The DCT type to perform. Must be 2 or 3.
n: For future expansion. The length of the transform. Must be `None`.
axis: For future expansion. The axis to compute the DCT along. Must be `-1`.
norm: The normalization to apply. `None` for no normalization or `'ortho'`
@@ -201,8 +203,8 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl
A `[..., samples]` `float32` `Tensor` containing the DCT of `input`.
Raises:
- ValueError: If `type` is not `2`, `n` is not `None, `axis` is not `-1`, or
- `norm` is not `None` or `'ortho'`.
+ ValueError: If `type` is not `2` or `3`, `n` is not `None, `axis` is not
+ `-1`, or `norm` is not `None` or `'ortho'`.
[dct]: https://en.wikipedia.org/wiki/Discrete_cosine_transform
"""
@@ -214,22 +216,91 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl
axis_dim = input.shape[-1].value or _array_ops.shape(input)[-1]
axis_dim_float = _math_ops.to_float(axis_dim)
- scale = 2.0 * _math_ops.exp(_math_ops.complex(
- 0.0, -_math.pi * _math_ops.range(axis_dim_float) /
- (2.0 * axis_dim_float)))
-
- # TODO(rjryan): Benchmark performance and memory usage of the various
- # approaches to computing a DCT via the RFFT.
- dct2 = _math_ops.real(
- rfft(input, fft_length=[2 * axis_dim])[..., :axis_dim] * scale)
-
- if norm == "ortho":
- n1 = 0.5 * _math_ops.rsqrt(axis_dim_float)
- n2 = n1 * _math_ops.sqrt(2.0)
- # Use tf.pad to make a vector of [n1, n2, n2, n2, ...].
- weights = _array_ops.pad(
- _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]],
- constant_values=n2)
- dct2 *= weights
-
- return dct2
+ if type == 2:
+ scale = 2.0 * _math_ops.exp(
+ _math_ops.complex(
+ 0.0, -_math_ops.range(axis_dim_float) * _math.pi * 0.5 /
+ axis_dim_float))
+
+ # TODO(rjryan): Benchmark performance and memory usage of the various
+ # approaches to computing a DCT via the RFFT.
+ dct2 = _math_ops.real(
+ rfft(input, fft_length=[2 * axis_dim])[..., :axis_dim] * scale)
+
+ if norm == "ortho":
+ n1 = 0.5 * _math_ops.rsqrt(axis_dim_float)
+ n2 = n1 * _math_ops.sqrt(2.0)
+ # Use tf.pad to make a vector of [n1, n2, n2, n2, ...].
+ weights = _array_ops.pad(
+ _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]],
+ constant_values=n2)
+ dct2 *= weights
+
+ return dct2
+
+ elif type == 3:
+ if norm == "ortho":
+ n1 = _math_ops.sqrt(axis_dim_float)
+ n2 = n1 * _math_ops.sqrt(0.5)
+ # Use tf.pad to make a vector of [n1, n2, n2, n2, ...].
+ weights = _array_ops.pad(
+ _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]],
+ constant_values=n2)
+ input *= weights
+ else:
+ input *= axis_dim_float
+ scale = 2.0 * _math_ops.exp(
+ _math_ops.complex(
+ 0.0,
+ _math_ops.range(axis_dim_float) * _math.pi * 0.5 /
+ axis_dim_float))
+ dct3 = _math_ops.real(
+ irfft(
+ scale * _math_ops.complex(input, 0.0),
+ fft_length=[2 * axis_dim]))[..., :axis_dim]
+
+ return dct3
+
+
+# TODO(rjryan): Implement `type`, `n` and `axis` parameters.
+@tf_export("spectral.idct")
+def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
+ """Computes the 1D [Inverse Discrete Cosine Transform (DCT)][idct] of `input`.
+
+ Currently only Types II and III are supported. Type III is the inverse of
+ Type II, and vice versa.
+
+ Note that you must re-normalize by 1/(2n) to obtain an inverse if `norm` is
+ not `'ortho'`. That is:
+ `signal == idct(dct(signal)) * 0.5 / signal.shape[-1]`.
+ When `norm='ortho'`, we have:
+ `signal == idct(dct(signal, norm='ortho'), norm='ortho')`.
+
+ @compatibility(scipy)
+ Equivalent to scipy.fftpack.idct for Type-II and Type-III DCT.
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.idct.html
+ @end_compatibility
+
+ Args:
+ input: A `[..., samples]` `float32` `Tensor` containing the signals to take
+ the DCT of.
+ type: The IDCT type to perform. Must be 2 or 3.
+ n: For future expansion. The length of the transform. Must be `None`.
+ axis: For future expansion. The axis to compute the DCT along. Must be `-1`.
+ norm: The normalization to apply. `None` for no normalization or `'ortho'`
+ for orthonormal normalization.
+ name: An optional name for the operation.
+
+ Returns:
+ A `[..., samples]` `float32` `Tensor` containing the IDCT of `input`.
+
+ Raises:
+ ValueError: If `type` is not `2` or `3`, `n` is not `None, `axis` is not
+ `-1`, or `norm` is not `None` or `'ortho'`.
+
+ [idct]:
+ https://en.wikipedia.org/wiki/Discrete_cosine_transform#Inverse_transforms
+ """
+ _validate_dct_arguments(type, n, axis, norm)
+ inverse_type = {2: 3, 3: 2}[type]
+ return dct(input, type=inverse_type, n=n, axis=axis, norm=norm, name=name)
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 53ed89e4ab..1ee975fbe4 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -22,7 +22,6 @@ from __future__ import print_function
import collections
import os.path
import re
-import sys
import time
import uuid
@@ -1043,8 +1042,8 @@ def get_checkpoint_state(checkpoint_dir, latest_filename=None):
ckpt = CheckpointState()
text_format.Merge(file_content, ckpt)
if not ckpt.model_checkpoint_path:
- raise ValueError("Invalid checkpoint state loaded from %s",
- checkpoint_dir)
+ raise ValueError("Invalid checkpoint state loaded from "
+ + checkpoint_dir)
# For relative model_checkpoint_path and all_model_checkpoint_paths,
# prepend checkpoint_dir.
if not os.path.isabs(ckpt.model_checkpoint_path):
@@ -1706,12 +1705,17 @@ class Saver(object):
save_path: Path where parameters were previously saved.
Raises:
- ValueError: If save_path is None.
+ ValueError: If save_path is None or not a valid checkpoint.
"""
if self._is_empty:
return
if save_path is None:
raise ValueError("Can't load save_path when it is None.")
+
+ if not checkpoint_exists(compat.as_text(save_path)):
+ raise ValueError("The passed save_path is not a valid checkpoint: "
+ + compat.as_text(save_path))
+
logging.info("Restoring parameters from %s", compat.as_text(save_path))
try:
if context.executing_eagerly():
@@ -1719,23 +1723,24 @@ class Saver(object):
else:
sess.run(self.saver_def.restore_op_name,
{self.saver_def.filename_tensor_name: save_path})
- except errors.NotFoundError:
- exception_type, exception_value, exception_traceback = sys.exc_info()
- # The checkpoint would not be loaded successfully as is. Try to parse it
- # as an object-based checkpoint.
- should_reraise = False
+ except errors.NotFoundError as err:
+ # There are three common conditions that might cause this error:
+ # 0. The file is missing. We ignore here, as this is checked above.
+ # 1. This is an object-based checkpoint trying name-based loading.
+ # 2. The graph has been altered and a variable or other name is missing.
+
+ # 1. The checkpoint would not be loaded successfully as is. Try to parse
+ # it as an object-based checkpoint.
try:
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
object_graph_string = reader.get_tensor(
checkpointable.OBJECT_GRAPH_PROTO_KEY)
except errors.NotFoundError:
- # This is not an object-based checkpoint, or the checkpoint doesn't
- # exist. Re-raise the original exception, but do it outside the except
- # block so the object graph lookup isn't included in the stack trace.
- should_reraise = True
- if should_reraise:
- six.reraise(exception_type, exception_value, exception_traceback)
- del exception_traceback # avoid reference cycles
+ # 2. This is not an object-based checkpoint, which likely means there
+ # is a graph mismatch. Re-raise the original error with
+ # a helpful message (b/110263146)
+ raise _wrap_restore_error_with_msg(
+ err, "a Variable name or other graph key that is missing")
# This is an object-based checkpoint. We'll print a warning and then do
# the restore.
@@ -1747,6 +1752,11 @@ class Saver(object):
self._restore_from_object_based_checkpoint(
sess=sess, save_path=save_path,
object_graph_string=object_graph_string)
+ except errors.InvalidArgumentError as err:
+ # There is a mismatch between the graph and the checkpoint being loaded.
+ # We add a more reasonable error message here to help users (b/110263146)
+ raise _wrap_restore_error_with_msg(
+ err, "a mismatch between the current graph and the graph")
def _restore_from_object_based_checkpoint(self, sess, save_path,
object_graph_string):
@@ -2139,6 +2149,14 @@ def _meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
return meta_graph_filename
+def _wrap_restore_error_with_msg(err, extra_verbiage):
+ err_msg = ("Restoring from checkpoint failed. This is most likely "
+ "due to {} from the checkpoint. Please ensure that you "
+ "have not altered the graph expected based on the checkpoint. "
+ "Original error:\n\n{}").format(extra_verbiage, err.message)
+ return err.__class__(err.node_def, err.op, err_msg)
+
+
ops.register_proto_function(
ops.GraphKeys.SAVERS,
proto_type=saver_pb2.SaverDef,
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index f235300eb5..ae9c244aaf 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -24,10 +24,8 @@ import math
import os
import random
import shutil
-import sys
import tempfile
import time
-import traceback
import numpy as np
import six
@@ -369,8 +367,8 @@ class SaverTest(test.TestCase):
for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2):
with self.test_session() as sess:
save = saver_module.Saver({"v0": v0}, write_version=ver)
- with self.assertRaisesRegexp(errors.NotFoundError,
- "Failed to find any matching files for"):
+ with self.assertRaisesRegexp(
+ ValueError, "The passed save_path is not a valid checkpoint:"):
save.restore(sess, "invalid path")
def testInt64(self):
@@ -3139,27 +3137,33 @@ class CheckpointableCompatibilityTests(test.TestCase):
errors.NotFoundError, "Key b not found in checkpoint"):
b_saver.restore(sess=sess, save_path=save_path)
- def testCheckpointNotFoundErrorRaised(self):
- # Restore does some tricky exception handling to figure out if it should
- # load an object-based checkpoint. Tests that the exception handling isn't
- # too broad.
- a = resource_variable_ops.ResourceVariable(1., name="a")
- saver = saver_module.Saver([a])
- with self.test_session() as sess:
- with self.assertRaisesRegexp(
- errors.NotFoundError,
- "Failed to find any matching files for path_which_does_not_exist"):
- saver.restore(sess=sess, save_path="path_which_does_not_exist")
- try:
- saver.restore(sess=sess, save_path="path_which_does_not_exist")
- except errors.NotFoundError:
- # Make sure we don't have a confusing "During handling of the above
- # exception" block in Python 3.
- # pylint: disable=no-value-for-parameter
- exception_string = "\n".join(
- traceback.format_exception(*sys.exc_info()))
- # pylint: enable=no-value-for-parameter
- self.assertNotIn("NewCheckpointReader", exception_string)
+ with self.assertRaises(errors.NotFoundError) as cs:
+ b_saver.restore(sess=sess, save_path=save_path)
+
+ # Make sure we don't have a confusing "During handling of the above
+ # exception" block in Python 3.
+ self.assertNotIn("NewCheckpointReader", cs.exception.message)
+
+ def testGraphChangedForRestoreErrorRaised(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+
+ with ops_lib.Graph().as_default() as g:
+ a = variables.Variable(1., name="a")
+ a_saver = saver_module.Saver([a])
+
+ with self.test_session(graph=g) as sess:
+ sess.run(a.initializer)
+ save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix)
+
+ with ops_lib.Graph().as_default() as g:
+ a = variables.Variable([1.], name="a")
+ a_saver = saver_module.Saver([a])
+ with self.test_session(graph=g) as sess:
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "a mismatch between the current graph and the graph"):
+ a_saver.restore(sess=sess, save_path=save_path)
def testLoadFromObjectBasedGraph(self):
checkpoint_directory = self.get_temp_dir()
diff --git a/tensorflow/tools/api/golden/tensorflow.spectral.pbtxt b/tensorflow/tools/api/golden/tensorflow.spectral.pbtxt
index 4f306540cc..6a421ef12d 100644
--- a/tensorflow/tools/api/golden/tensorflow.spectral.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.spectral.pbtxt
@@ -17,6 +17,10 @@ tf_module {
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "idct"
+ argspec: "args=[\'input\', \'type\', \'n\', \'axis\', \'norm\', \'name\'], varargs=None, keywords=None, defaults=[\'2\', \'None\', \'-1\', \'None\', \'None\'], "
+ }
+ member_method {
name: "ifft"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/ci_build/Dockerfile.cpu.ppc64le b/tensorflow/tools/ci_build/Dockerfile.cpu.ppc64le
index e879c34bbd..f496ac59b6 100644
--- a/tensorflow/tools/ci_build/Dockerfile.cpu.ppc64le
+++ b/tensorflow/tools/ci_build/Dockerfile.cpu.ppc64le
@@ -8,7 +8,6 @@ RUN /install/install_bootstrap_deb_packages.sh
RUN add-apt-repository -y ppa:openjdk-r/ppa
RUN /install/install_deb_packages.sh
RUN apt-get update && apt-get install -y libopenblas-dev
-RUN /install/install_hdf5_ppc64le.sh
RUN /install/install_pip_packages.sh
RUN /install/install_bazel_from_source.sh
RUN /install/install_proto3.sh
diff --git a/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le b/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le
index 8967138747..3eddc56550 100644
--- a/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le
+++ b/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le
@@ -14,7 +14,6 @@ RUN /install/install_bootstrap_deb_packages.sh
RUN add-apt-repository -y ppa:openjdk-r/ppa
RUN /install/install_deb_packages.sh
RUN apt-get update && apt-get install -y libopenblas-dev
-RUN /install/install_hdf5_ppc64le.sh
RUN /install/install_pip_packages.sh
RUN /install/install_bazel_from_source.sh
RUN /install/install_golang_ppc64le.sh
diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh
index 05676f9551..f0a437c183 100755
--- a/tensorflow/tools/ci_build/ci_sanity.sh
+++ b/tensorflow/tools/ci_build/ci_sanity.sh
@@ -349,12 +349,12 @@ do_external_licenses_check(){
# Blacklist
echo ${MISSING_LICENSES_FILE}
- grep -e "@bazel_tools//third_party/" -e "@com_google_absl//absl" -e "@org_tensorflow//" -v ${MISSING_LICENSES_FILE} > temp.txt
+ grep -e "@bazel_tools//third_party/" -e "@com_google_absl//absl" -e "@org_tensorflow//" -e "@com_github_googlecloudplatform_google_cloud_cpp//google" -v ${MISSING_LICENSES_FILE} > temp.txt
mv temp.txt ${MISSING_LICENSES_FILE}
# Whitelist
echo ${EXTRA_LICENSE_FILE}
- grep -e "@bazel_tools//src" -e "@bazel_tools//tools/" -e "@com_google_absl//" -e "//external" -e "@local" -v ${EXTRA_LICENSES_FILE} > temp.txt
+ grep -e "@bazel_tools//src" -e "@bazel_tools//tools/" -e "@com_google_absl//" -e "//external" -e "@local" -e "@com_github_googlecloudplatform_google_cloud_cpp//" -v ${EXTRA_LICENSES_FILE} > temp.txt
mv temp.txt ${EXTRA_LICENSES_FILE}
diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD
index 05c23cd3ee..173f418dc8 100644
--- a/tensorflow/tools/lib_package/BUILD
+++ b/tensorflow/tools/lib_package/BUILD
@@ -115,6 +115,7 @@ genrule(
"//third_party/fft2d:LICENSE",
"@aws//:LICENSE",
"@boringssl//:LICENSE",
+ "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
"@com_googlesource_code_re2//:LICENSE",
"@cub_archive//:LICENSE.TXT",
"@curl//:COPYING",
@@ -156,6 +157,7 @@ genrule(
"//third_party/fft2d:LICENSE",
"@aws//:LICENSE",
"@boringssl//:LICENSE",
+ "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
"@com_googlesource_code_re2//:LICENSE",
"@cub_archive//:LICENSE.TXT",
"@curl//:COPYING",
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index a0caf42331..c9d53f46c3 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -130,6 +130,8 @@ filegroup(
"@astor_archive//:LICENSE",
"@aws//:LICENSE",
"@boringssl//:LICENSE",
+ "@com_github_googleapis_googleapis//:LICENSE",
+ "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
"@com_google_absl//:LICENSE",
"@com_googlesource_code_re2//:LICENSE",
"@cub_archive//:LICENSE.TXT",
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index c630ca04b8..1236de2657 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -53,7 +53,7 @@ REQUIRED_PACKAGES = [
'gast >= 0.2.0',
'numpy >= 1.13.3',
'six >= 1.10.0',
- 'protobuf >= 3.6.0',
+ 'protobuf >= 3.4.0',
'setuptools <= 39.1.0',
'tensorboard >= 1.8.0, < 1.9.0',
'termcolor >= 1.1.0',
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index cae6f51eb5..172eed0b57 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -107,11 +107,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "eigen_archive",
urls = [
- "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/fd6845384b86.tar.gz",
- "https://bitbucket.org/eigen/eigen/get/fd6845384b86.tar.gz",
+ "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/e5e305a158a0.tar.gz",
+ "https://bitbucket.org/eigen/eigen/get/e5e305a158a0.tar.gz",
],
- sha256 = "d956415d784fa4e42b6a2a45c32556d6aec9d0a3d8ef48baee2522ab762556a9",
- strip_prefix = "eigen-eigen-fd6845384b86",
+ sha256 = "8bbe676d69e7f59070c83a949454b8b6344034e0ebbf686b337528e5dc04c7de",
+ strip_prefix = "eigen-eigen-e5e305a158a0",
build_file = clean_dep("//third_party:eigen.BUILD"),
)
@@ -142,11 +142,13 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "ortools_archive",
urls = [
- "https://mirror.bazel.build/github.com/google/or-tools/archive/v6.7.2.tar.gz",
- "https://github.com/google/or-tools/archive/v6.7.2.tar.gz",
+ "https://mirror.bazel.build/github.com/google/or-tools/archive/253f7955c6a1fd805408fba2e42ac6d45b312d15.tar.gz",
+ # Please uncomment me, when the next upgrade happens. Then
+ # remove the whitelist entry in third_party/repo.bzl.
+ # "https://github.com/google/or-tools/archive/253f7955c6a1fd805408fba2e42ac6d45b312d15.tar.gz",
],
- sha256 = "d025a95f78b5fc5eaa4da5f395f23d11c23cf7dbd5069f1f627f002de87b86b9",
- strip_prefix = "or-tools-6.7.2/src",
+ sha256 = "932075525642b04ac6f1b50589f1df5cd72ec2f448b721fd32234cf183f0e755",
+ strip_prefix = "or-tools-253f7955c6a1fd805408fba2e42ac6d45b312d15/src",
build_file = clean_dep("//third_party:ortools.BUILD"),
)
@@ -162,6 +164,27 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
)
tf_http_archive(
+ name = "com_github_googlecloudplatform_google_cloud_cpp",
+ urls = [
+ "https://mirror.bazel.build/github.com/GoogleCloudPlatform/google-cloud-cpp/archive/f9ff105957965bcf87f7cb9a93e951c3d08d1734.tar.gz",
+ "https://github.com/GoogleCloudPlatform/google-cloud-cpp/archive/f9ff105957965bcf87f7cb9a93e951c3d08d1734.tar.gz",
+ ],
+ sha256 = "edb347aae9869ffdcf8df6288335bcc535fec46da946b385c16968e96a74b208",
+ strip_prefix = "google-cloud-cpp-f9ff105957965bcf87f7cb9a93e951c3d08d1734",
+ )
+
+ tf_http_archive(
+ name = "com_github_googleapis_googleapis",
+ urls = [
+ "https://mirror.bazel.build/github.com/googleapis/googleapis/archive/f81082ea1e2f85c43649bee26e0d9871d4b41cdb.zip",
+ "https://github.com/googleapis/googleapis/archive/f81082ea1e2f85c43649bee26e0d9871d4b41cdb.zip",
+ ],
+ sha256 = "824870d87a176f26bcef663e92051f532fac756d1a06b404055dc078425f4378",
+ strip_prefix="googleapis-f81082ea1e2f85c43649bee26e0d9871d4b41cdb",
+ build_file = clean_dep("//third_party:googleapis.BUILD"),
+ )
+
+ tf_http_archive(
name = "gemmlowp",
urls = [
"https://mirror.bazel.build/github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip",
@@ -231,11 +254,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "org_sqlite",
urls = [
- "https://mirror.bazel.build/www.sqlite.org/2018/sqlite-amalgamation-3240000.zip",
- "https://www.sqlite.org/2018/sqlite-amalgamation-3240000.zip",
+ "https://mirror.bazel.build/www.sqlite.org/2018/sqlite-amalgamation-3230100.zip",
+ "https://www.sqlite.org/2018/sqlite-amalgamation-3230100.zip",
],
- sha256 = "ad68c1216c3a474cf360c7581a4001e952515b3649342100f2d7ca7c8e313da6",
- strip_prefix = "sqlite-amalgamation-3240000",
+ sha256 = "4239a1f69e5721d07d9a374eb84d594225229e54be4ee628da2995f4315d8dfc",
+ strip_prefix = "sqlite-amalgamation-3230100",
build_file = clean_dep("//third_party:sqlite.BUILD"),
)
@@ -426,11 +449,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "grpc",
urls = [
- "https://mirror.bazel.build/github.com/grpc/grpc/archive/v1.12.1.tar.gz",
- "https://github.com/grpc/grpc/archive/v1.12.1.tar.gz",
+ "https://mirror.bazel.build/github.com/grpc/grpc/archive/d184fa229d75d336aedea0041bd59cb93e7e267f.tar.gz",
+ "https://github.com/grpc/grpc/archive/d184fa229d75d336aedea0041bd59cb93e7e267f.tar.gz",
],
- sha256 = "f6afbfafa8e7b524727d1ff37ff22fe9c3dcca07bd864e7a9d1efabf1d15d13c",
- strip_prefix = "grpc-1.12.1",
+ sha256 = "895b31310e718a61f7335759a778c068a6edde1c089883598a0830cbb7075673",
+ strip_prefix = "grpc-d184fa229d75d336aedea0041bd59cb93e7e267f",
)
@@ -660,12 +683,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "cython",
- sha256 = "05e3eb7f06043f5ff2028338370329e71c29f57315e95f4dc6ad7c4971dd4c6f",
+ sha256 = "6dcd30b5ceb887b2b965ee7ceb82ea3acb5f0642fe2206c7636b45acea4798e5",
urls = [
- "https://mirror.bazel.build/github.com/cython/cython/archive/0.28.3.tar.gz",
- "https://github.com/cython/cython/archive/0.28.3.tar.gz",
+ "https://mirror.bazel.build/github.com/cython/cython/archive/3732784c45cfb040a5b0936951d196f83a12ea17.tar.gz",
+ "https://github.com/cython/cython/archive/3732784c45cfb040a5b0936951d196f83a12ea17.tar.gz",
],
- strip_prefix = "cython-0.28.3",
+ strip_prefix = "cython-3732784c45cfb040a5b0936951d196f83a12ea17",
build_file = clean_dep("//third_party:cython.BUILD"),
delete = ["BUILD.bazel"],
)
@@ -673,11 +696,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "bazel_toolchains",
urls = [
- "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/2cec6c9f6d12224e93d9b3f337b24e41602de3ba.tar.gz",
- "https://github.com/bazelbuild/bazel-toolchains/archive/2cec6c9f6d12224e93d9b3f337b24e41602de3ba.tar.gz",
+ "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/44200e0c026d86c53470d107b3697a3e46469c43.tar.gz",
+ "https://github.com/bazelbuild/bazel-toolchains/archive/44200e0c026d86c53470d107b3697a3e46469c43.tar.gz",
],
- strip_prefix = "bazel-toolchains-2cec6c9f6d12224e93d9b3f337b24e41602de3ba",
- sha256 = "9b8d85b61d8945422e86ac31e4d4d2d967542c080d1da1b45364da7fd6bdd638",
+ strip_prefix = "bazel-toolchains-44200e0c026d86c53470d107b3697a3e46469c43",
+ sha256 = "699b55a6916c687f4b7dc092dbbf5f64672cde0dc965f79717735ec4e5416556",
)
tf_http_archive(
diff --git a/third_party/googleapis.BUILD b/third_party/googleapis.BUILD
new file mode 100644
index 0000000000..95e999af18
--- /dev/null
+++ b/third_party/googleapis.BUILD
@@ -0,0 +1,45 @@
+# Copyright 2018 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(default_visibility = ["//visibility:public"])
+licenses(["notice"]) # Apache 2.0
+exports_files(["LICENSE"])
+
+load("@protobuf_archive//:protobuf.bzl", "cc_proto_library")
+
+cc_proto_library(
+ name = "bigtable_protos",
+ srcs = [
+ "google/bigtable/admin/v2/bigtable_instance_admin.proto",
+ "google/bigtable/admin/v2/bigtable_table_admin.proto",
+ "google/bigtable/admin/v2/common.proto",
+ "google/bigtable/admin/v2/instance.proto",
+ "google/bigtable/admin/v2/table.proto",
+ "google/bigtable/v2/bigtable.proto",
+ "google/bigtable/v2/data.proto",
+ "google/iam/v1/iam_policy.proto",
+ "google/iam/v1/policy.proto",
+ "google/longrunning/operations.proto",
+ "google/rpc/status.proto",
+ "google/rpc/error_details.proto",
+ "google/api/annotations.proto",
+ "google/api/auth.proto",
+ "google/api/http.proto",
+ ],
+ include = ".",
+ protoc = "@protobuf_archive//:protoc",
+ default_runtime = "@protobuf_archive//:protobuf",
+ deps = ["@protobuf_archive//:cc_wkt_protos"],
+ use_grpc_plugin = True,
+)